Compare commits

...

37 Commits

Author SHA1 Message Date
Georgi Gerganov 72f80499ee server : headers cleanup 2025-11-24 12:50:50 +02:00
Xuan Son Nguyen 625010d42d move enum stop_type to server-task 2025-11-18 17:28:21 +01:00
Xuan Son Nguyen ca993bad51 rm redundant includes 2025-11-18 15:01:17 +01:00
Xuan Son Nguyen 3b7946034c add server-queue 2025-11-18 14:24:46 +01:00
Xuan Son Nguyen e1a756e934 add server-task, server-common 2025-11-18 14:15:14 +01:00
Chenguang Li bc4064cfea CANN: fix acl_tensor_ptr usage in ASCEND_310P ROPE (#17347)
* cann: fix acl_tensor_ptr usage in ASCEND_310P ROPE implementation

Fix compilation errors in the ASCEND_310P-specific ROPE operation code
by adding .get() calls when passing acl_tensor_ptr smart pointers to
functions expecting raw aclTensor* pointers.

This fixes the code that was missed in the previous refactoring commit
(8981848) which changed ggml_cann_create_tensor() return type from
aclTensor* to acl_tensor_ptr.

* cann: format code
2025-11-18 16:41:52 +08:00
o7si 97cb3fd5ae fix: resolve undefined variable 'svr' compilation error (#17348) 2025-11-18 10:10:47 +02:00
jiahao su ffa277a54c CANN: Add openEuler-cann in build and release (#17192)
Update openEuler version

Remove variable ASCEND_SOC_TYPE

Modify the chip type

Fix case in zip filename

Change "device" to "chip_type"

Modify the value of chip_type
2025-11-18 16:08:55 +08:00
Jeff Bolz da95bf2a85 vulkan: support noncontig i32 copy (#17328) 2025-11-18 07:41:24 +01:00
Xuan-Son Nguyen 0de8878c96 server: split HTTP into its own interface (#17216)
* server: split HTTP into its own interface

* move server-http and httplib to its own file

* add the remaining endpoints

* fix exception/error handling

* renaming

* missing header

* fix missing windows header

* fix error responses from http layer

* fix slot save/restore handler

* fix case where only one stream chunk is returned

* add NOMINMAX

* do not call sink.write on empty data

* use safe_json_to_str for SSE

* clean up

* add some comments

* improve usage of next()

* bring back the "server is listening on" message

* more generic handler

* add req.headers

* move the chat template print to init()

* add req.path

* cont : minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-11-17 22:05:44 +01:00
Ruben Ortlam 38e2c1b412 vulkan: add log RTE support to fix Nvidia CI (#17320)
* vulkan: add log RTE support to fix Nvidia CI

* actually use the rte shader
2025-11-17 14:37:49 -06:00
Adrien Gallouët cb44fc84e8 cmake : fix ARM feature verification (#17170)
* cmake : fix ARM feature verification

Use check_cxx_source_compiles to prevent conflicts with
the existing GGML_NATIVE detection code.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* cmake : unset __ARM_FEATURE when feature is disabled

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* cmake : fix scope, this is really a macro

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* arm_neon.h is useless

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

---------

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2025-11-17 21:37:29 +01:00
Adrien Gallouët cb623de3fc ggml : add missing AVX512 feature checks (#17270)
_mm512_cvtepu8_epi16        requires  __AVX512BW__
_mm512_srli_epi16           requires  __AVX512BW__
__builtin_ia32_inserti32x8  requires  __AVX512DQ__

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2025-11-17 12:12:00 +01:00
Georgi Gerganov 7aaeedc098 metal : support I32 -> I32 copy (#17317) 2025-11-17 11:52:00 +02:00
Georgi Gerganov 3347e6d904 metal : faster argsort (#17315)
* metal : faster argsort

* cont : keep data in registers
2025-11-17 11:51:48 +02:00
Georgi Gerganov 1a139644a8 metal : add cumsum (#17305) 2025-11-17 11:51:13 +02:00
hipudding 2376b7758c CANN: Use smart pointers to manage ACL objects (#17238)
* CANN: Use smart pointers to manage ACL objects

Previously, ACL objects were managed via manual destruction, which
led to multiple memory-leak issues during runtime. This patch replaces
manual memory management with smart pointers so that ACL objects
are properly released and ownership is clearly defined.

Note that the ownership of an ACL object belongs to the function
that creates it. Other internal functions should operate on these ACL
objects using raw pointers to avoid unintended ownership transfers.

Additionally, since aclTensorList automatically frees its contained
aclTensor objects, any aclTensor added to a tensor list must release
ownership to avoid double free operations.

This PR also removes the asynchronous task submission mechanism.
Due to changes in recent CANN versions, tiling time has significantly
decreased. Even with a dual-thread submission model, the dispatch
overhead still falls on the critical path, making async submission
less beneficial. Moreover, aclGraph support provides a much better
path to reducing operator dispatch latency.

* CANN: resolve review comments
2025-11-17 08:43:59 +08:00
Pavels Zaicenkovs dbed61294a vulkan: add LOG operation support for F32 and F16 (#17183)
* vulkan: add LOG operation support for F32 and F16

Part of #14909.

* vulkan: Fix LOG operation types

* docs: Update operation support documentation for Vulkan LOG operation

* vulkan: fix log_f16 shader

* docs: restore missing LOG test cases and regenerate ops.md
2025-11-16 22:50:09 +01:00
Ruben Ortlam 80deff3648 vulkan: fix MMQ quantize_y condition (#17301) 2025-11-16 19:38:17 +01:00
Eve 8b1c339bd2 ci : revert #16249 (#17303)
* Delete .github/workflows/build-amd.yml

* Update build.yml
2025-11-16 19:09:17 +01:00
Georgi Gerganov 416e7c7f47 metal : remove obosolete asserts (#17295) 2025-11-16 09:50:26 +02:00
Georgi Gerganov 5b2093becc server : handle context overflow during decode (#17267)
* server : handle context overflow during decode

* server : minor refactor
2025-11-16 09:23:37 +02:00
lhez 52e5d421f1 opencl: fix rms_norm_mul (#17250)
* opencl: use subgrroup reduce for reduction in rms_norm_mul

* opencl: add comment about workgroup size
2025-11-15 17:40:14 -08:00
shaofeiqi 4db5641210 opencl: add kernel to handle mat mul in attention to improve encoding speed (#17181)
* Add mul_mm_f16_f32_kq_kqv kernel

* Add ggml_cl_mul_mat_kq_kqv_adreno func

* fix whitespace

* remove unused variable

* remove redundant

* refactor and clean up

* remove trailing whitespace
2025-11-15 17:33:10 -08:00
shani-f 72bd7321a7 sycl : unify unary kernels with a generic implementation and enable wide operator support (#17213)
* SYCL: add generic unary op implementation for multiple ops (ABS/SGN/…); unify non-contiguous access

* SYCL: update documentation and sycl.csv to reflect new unary op support

* update ops.md after syncing SYCL.csv changes

* Fix SYCL.csv merge conflict

* Update ops.md after fixing SYCL.csv conflicts

* Fix SYCL.csv tail after merge conflict and regenerate ops.md

* Fix line endings and final newline in SYCL.csv

* Remove TOPK_MOE entries from SYCL.csv as requested

* Update ops.md after removing TOPK_MOE from SYCL.csv

* Regenerated SYCL.csv and synced ops.md with upstream

* Update ops.md using create_ops_docs.py
2025-11-16 00:52:42 +01:00
Aleksander Grygier 22e1ce2f81 webui: Fix clickability around chat processing statistics UI (#17278)
* fix: Better pointer events handling in chat processing info elements

* chore: update webui build output
2025-11-15 22:41:41 +01:00
Pascal 1411d9275a webui: add OAI-Compat Harmony tool-call streaming visualization and persistence in chat UI (#16618)
* webui: add OAI-Compat Harmony tool-call live streaming visualization and persistence in chat UI

- Purely visual and diagnostic change, no effect on model context, prompt
  construction, or inference behavior

- Captured assistant tool call payloads during streaming and non-streaming
  completions, and persisted them in chat state and storage for downstream use

- Exposed parsed tool call labels beneath the assistant's model info line
  with graceful fallback when parsing fails

- Added tool call badges beneath assistant responses that expose JSON tooltips
  and copy their payloads when clicked, matching the existing model badge styling

- Added a user-facing setting to toggle tool call visibility to the Developer
  settings section directly under the model selector option

* webui: remove scroll listener causing unnecessary layout updates (model selector)

* Update tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* Update tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* chore: npm run format & update webui build output

* chore: update webui build output

---------

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>
2025-11-15 21:09:32 +01:00
Sigbjørn Skjæret 662192e1dc convert : remove unnecessary chat template patching (#17289) 2025-11-15 20:58:59 +01:00
Jeff Bolz 24dc769f1b vulkan: Fuse mul_mat_id+add_id+mul and mul_mat+add+add. (#17287)
These both show up in gpt-oss. Also, cleanup the mul_mat_vec fusion code a bit.
2025-11-15 19:54:23 +01:00
Ruben Ortlam 4dca015b7e vulkan: Replace 16-bit unpack8 calls to work around legacy Windows AMD driver bug (#17285) 2025-11-15 15:18:58 +01:00
Sigbjørn Skjæret 9a8860cf5d convert : use all parts in safetensors index (#17286) 2025-11-15 14:12:39 +01:00
Sigbjørn Skjæret 9d3ef4809f convert : set expert gating func in base class (#17279) 2025-11-15 14:06:24 +01:00
Ankur Verma c7b7db0445 mtmd-cli: Avoid logging to stdout for model loading messages in mtmd-cli (#17277) 2025-11-15 12:41:16 +01:00
Giuseppe Scrivano 1568d13c2c vulkan: implement ABS and NEG (#17245)
* docs: update Vulkan ops

* vulkan: add NEG op

* vulkan: add ABS op

---------

Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
2025-11-15 12:00:29 +01:00
Jeff Bolz 439342ea0b vulkan: Use ggml_vk_tensor_subbuffer in mul_mat_vec(id) paths (#17244)
* vulkan: Use ggml_vk_tensor_subbuffer in mul_mat_vec(id) paths

* set allow_misalign
2025-11-15 11:56:15 +01:00
Jeff Bolz 234ae7d7bd vulkan: skip all-negative-inf blocks in FA (#17186) 2025-11-15 10:37:25 +01:00
Jeff Bolz 38eaf32af1 vulkan: change graph_compute to be async and enable get_tensor_async (#17158)
* vulkan: change graph_compute to be async and enable get_tensor_async

This allows some additional CPU/GPU overlap for large pp workloads. Also seems
to help a bit for token gen, maybe getting rid of a small bubble between
graph_compute and get_tensor.

Async set and copy functions seem to be very rarely used, so I didn't enable
them because I didn't have a good way to test them.

The async commands need to be ordered against each other, so put them all on
the compute queue. The non-async commands still use the transfer queue.

The fence for graph_compute/get_tensor_async is submitted and waited on in
ggml_vk_synchronize.

* fix thread safety errors

* teardown context cleanly

* Handle async read to non-pinned dst
2025-11-15 09:06:41 +01:00
69 changed files with 24595 additions and 13179 deletions
+5 -6
View File
@@ -3,7 +3,8 @@
# ==============================================================================
# Define the CANN base image for easier version updates later
ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.1.rc1-910b-openeuler22.03-py3.10
ARG CHIP_TYPE=910b
ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.3.rc1.alpha001-${CHIP_TYPE}-openeuler22.03-py3.11
# ==============================================================================
# BUILD STAGE
@@ -11,9 +12,6 @@ ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.1.rc1-910b-openeuler22.03-py3.10
# ==============================================================================
FROM ${CANN_BASE_IMAGE} AS build
# Define the Ascend chip model for compilation. Default is Ascend910B3
ARG ASCEND_SOC_TYPE=Ascend910B3
# -- Install build dependencies --
RUN yum install -y gcc g++ cmake make git libcurl-devel python3 python3-pip && \
yum clean all && \
@@ -36,13 +34,14 @@ ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/runtime/lib64/stub:$LD_LIBRARY_PATH
# For brevity, only core variables are listed here. You can paste the original ENV list here.
# -- Build llama.cpp --
# Use the passed ASCEND_SOC_TYPE argument and add general build options
# Use the passed CHIP_TYPE argument and add general build options
ARG CHIP_TYPE
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh --force \
&& \
cmake -B build \
-DGGML_CANN=ON \
-DCMAKE_BUILD_TYPE=Release \
-DSOC_TYPE=${ASCEND_SOC_TYPE} \
-DSOC_TYPE=ascend${CHIP_TYPE} \
. && \
cmake --build build --config Release -j$(nproc)
-52
View File
@@ -1,52 +0,0 @@
name: CI (AMD)
on:
workflow_dispatch: # allows manual triggering
push:
branches:
- master
paths: [
'.github/workflows/build-amd.yml',
'**/CMakeLists.txt',
'**/.cmake',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp',
'**/*.cu',
'**/*.cuh',
'**/*.comp'
]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
jobs:
ggml-ci-x64-amd-vulkan:
runs-on: [self-hosted, Linux, X64, AMD]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Test
id: ggml-ci
run: |
vulkaninfo --summary
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
ggml-ci-x64-amd-rocm:
runs-on: [self-hosted, Linux, X64, AMD]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Test
id: ggml-ci
run: |
amd-smi static
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
+32 -4
View File
@@ -1391,9 +1391,9 @@ jobs:
matrix:
arch: [x86, aarch64]
cann:
- '8.1.RC1.alpha001-910b-openeuler22.03-py3.10'
device:
- 'ascend910b3'
- '8.3.rc1.alpha001-910b-openeuler22.03-py3.11'
chip_type:
- '910b'
build:
- 'Release'
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
@@ -1414,7 +1414,7 @@ jobs:
cmake -S . -B build \
-DCMAKE_BUILD_TYPE=${{ matrix.build }} \
-DGGML_CANN=on \
-DSOC_TYPE=${{ matrix.device }}
-DSOC_TYPE=ascend${{ matrix.chip_type }}
cmake --build build -j $(nproc)
# TODO: simplify the following workflows using a matrix
@@ -1599,6 +1599,34 @@ jobs:
run: |
bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
ggml-ci-x64-amd-vulkan:
runs-on: [self-hosted, Linux, X64, AMD]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Test
id: ggml-ci
run: |
vulkaninfo --summary
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
ggml-ci-x64-amd-rocm:
runs-on: [self-hosted, Linux, X64, AMD]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Test
id: ggml-ci
run: |
amd-smi static
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
ggml-ci-mac-metal:
runs-on: [self-hosted, macOS, ARM64]
+47
View File
@@ -693,6 +693,52 @@ jobs:
path: llama-${{ steps.tag.outputs.name }}-xcframework.zip
name: llama-${{ steps.tag.outputs.name }}-xcframework
openEuler-cann:
strategy:
matrix:
arch: [x86, aarch64]
chip_type: ['910b', '310p']
build:
- 'Release'
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.3.rc1.alpha001-310p-openeuler22.03-py3.11' }}
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Dependencies
run: |
yum update -y
yum install -y git gcc gcc-c++ make cmake libcurl-devel
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Build
run: |
export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH}
cmake -S . -B build \
-DCMAKE_BUILD_TYPE=${{ matrix.build }} \
-DGGML_CANN=on \
-DSOC_TYPE=ascend${{ matrix.chip_type }}
cmake --build build -j $(nproc)
- name: Determine tag name
id: tag
uses: ./.github/actions/get-tag-name
- name: Pack artifacts
run: |
cp LICENSE ./build/bin/
zip -r llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip ./build/bin/*
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip
name: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}
release:
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
@@ -714,6 +760,7 @@ jobs:
- macOS-arm64
- macOS-x64
- ios-xcode-build
- openEuler-cann
steps:
- name: Clone
+12 -52
View File
@@ -189,10 +189,10 @@ class ModelBase:
return tensors
prefix = "model" if not self.is_mistral_format else "consolidated"
part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
part_names: set[str] = set(ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors"))
is_safetensors: bool = len(part_names) > 0
if not is_safetensors:
part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
part_names = set(ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin"))
tensor_names_from_index: set[str] = set()
@@ -209,6 +209,7 @@ class ModelBase:
if weight_map is None or not isinstance(weight_map, dict):
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
tensor_names_from_index.update(weight_map.keys())
part_names |= set(weight_map.values())
else:
weight_map = {}
else:
@@ -825,6 +826,15 @@ class TextModel(ModelBase):
self.gguf_writer.add_expert_group_used_count(n_group_used)
logger.info(f"gguf: expert groups used count = {n_group_used}")
if (score_func := self.find_hparam(["score_function", "scoring_func", "score_func"], optional=True)) is not None:
if score_func == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif score_func == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
raise ValueError(f"Unsupported expert score gating function value: {score_func}")
logger.info(f"gguf: expert score gating function = {score_func}")
if (head_dim := self.hparams.get("head_dim")) is not None:
self.gguf_writer.add_key_length(head_dim)
self.gguf_writer.add_value_length(head_dim)
@@ -2553,15 +2563,6 @@ class AfmoeModel(LlamaModel):
if (n_dense_layers := self.hparams.get("num_dense_layers")) is not None:
self.gguf_writer.add_leading_dense_block_count(n_dense_layers)
# Expert Gating Function
score_func = self.hparams.get("score_func")
if score_func == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif score_func == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
elif score_func is not None:
raise ValueError(f"Unsupported score_function value: {score_func}")
# Route normalization and scaling
if (route_norm := self.hparams.get("route_norm")) is not None:
self.gguf_writer.add_expert_weights_norm(route_norm)
@@ -7182,13 +7183,6 @@ class DeepseekV2Model(TextModel):
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
if hparams["scoring_func"] == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif hparams["scoring_func"] == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}")
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
rope_scaling = self.hparams.get("rope_scaling") or {}
@@ -7294,12 +7288,6 @@ class MiniMaxM2Model(TextModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
if self.hparams["scoring_func"] == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif self.hparams["scoring_func"] == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}")
self.gguf_writer.add_expert_feed_forward_length(self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_rope_dimension_count(self.find_hparam(["rotary_dim"]))
@@ -7392,11 +7380,6 @@ class Dots1Model(Qwen2MoeModel):
self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"])
self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"])
if self.hparams["scoring_func"] == "noaux_tc":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
else:
raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}")
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
if name.endswith("e_score_correction_bias"):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
@@ -7857,12 +7840,6 @@ class Glm4MoeModel(TextModel):
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
# Patch broken chat template
if isinstance(special_vocab.chat_template, str) and "visible_text(m.content).endswith" in special_vocab.chat_template:
special_vocab.chat_template = special_vocab.chat_template.replace(
"""{{ visible_text(m.content) }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}}""",
"""{% set content = visible_text(m.content) %}{{ content }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not content.endswith("/nothink")) else '' -}}""")
special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self):
@@ -8717,13 +8694,6 @@ class BailingMoeV2Model(TextModel):
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
if hparams["score_function"] == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif hparams["score_function"] == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
raise ValueError(f"Unsupported score_function value: {hparams['score_function']}")
if (nextn_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
self.gguf_writer.add_nextn_predict_layers(nextn_layers)
@@ -9419,16 +9389,6 @@ class HunYuanModel(TextModel):
class SmolLM3Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.SMOLLM3
def set_vocab(self):
super().set_vocab()
# remove unsupported array slicing in chat template
# ref: https://huggingface.co/ggml-org/SmolLM3-3B-GGUF/discussions/1
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
if tokenizer.chat_template is not None:
chat_template = tokenizer.chat_template.replace("[:]", "")
self.gguf_writer.add_chat_template(chat_template)
@ModelBase.register("GptOssForCausalLM")
class GptOssModel(TextModel):
+35 -36
View File
@@ -14,24 +14,24 @@ Legend:
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | zDNN |
|-----------|------|------|------|------|------|------|------|------|------|
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | | ❌ |
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | | 🟡 | ❌ |
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ❌ |
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ❌ |
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | | 🟡 | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ |
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | | 🟡 | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
@@ -40,8 +40,8 @@ Legend:
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | | ❌ |
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | | 🟡 | ❌ |
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
| FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
@@ -50,40 +50,40 @@ Legend:
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | | 🟡 | ❌ |
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | | 🟡 | ❌ |
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | | 🟡 | ❌ |
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | | ❌ |
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | | ❌ |
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | | 🟡 | ❌ |
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | | 🟡 | ❌ |
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ❌ |
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ❌ |
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | | ❌ |
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | | ❌ |
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | 🟡 | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | | ❌ |
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | | ❌ |
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | | ❌ |
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | | 🟡 | ❌ |
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ❌ |
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ❌ |
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
| PAD | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | | 🟡 | ❌ |
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ |
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
| RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | | | ❌ |
| RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | | | ❌ |
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
@@ -93,29 +93,28 @@ Legend:
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| SET | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ❌ | ❌ |
| SET_ROWS | ❌ | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | | ❌ | ❌ |
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | | 🟡 | ❌ |
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | | 🟡 | ❌ |
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | | 🟡 | ❌ |
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | | 🟡 | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | | | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | | ❌ | ❌ |
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 | | ❌ |
| SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ❌ |
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | 🟡 | ❌ |
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | | 🟡 | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
+2348 -149
View File
File diff suppressed because it is too large Load Diff
+14536 -4360
View File
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+19 -10
View File
@@ -48,15 +48,14 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
default:
return ACL_DT_UNDEFINED;
}
return ACL_DT_UNDEFINED;
}
aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
int64_t * ne,
size_t * nb,
int64_t dims,
aclFormat format,
size_t offset) {
acl_tensor_ptr ggml_cann_create_tensor(const ggml_tensor * tensor,
int64_t * ne,
size_t * nb,
int64_t dims,
aclFormat format,
size_t offset) {
// If tensor is bcasted, Up to GGML_MAX_DIMS additional dimensions will be
// added.
int64_t acl_ne[GGML_MAX_DIMS * 2], acl_stride[GGML_MAX_DIMS * 2];
@@ -87,10 +86,20 @@ aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
std::reverse(acl_ne, acl_ne + final_dims);
std::reverse(acl_stride, acl_stride + final_dims);
aclTensor * acl_tensor = aclCreateTensor(acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride,
elem_offset, format, &acl_storage_len, 1, tensor->data);
aclTensor * raw = aclCreateTensor(acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride, elem_offset,
format, &acl_storage_len, 1, tensor->data);
return acl_tensor;
return acl_tensor_ptr(raw);
}
acl_int_array_ptr ggml_cann_create_int_array(const int64_t * value, uint64_t size) {
aclIntArray * raw = aclCreateIntArray(value, size);
return acl_int_array_ptr(raw);
}
acl_scalar_ptr ggml_cann_create_scalar(void * value, aclDataType dataType) {
aclScalar * raw = aclCreateScalar(value, dataType);
return acl_scalar_ptr(raw);
}
bool ggml_cann_need_bcast(const ggml_tensor * t0, const ggml_tensor * t1) {
+99 -19
View File
@@ -23,11 +23,12 @@
#ifndef CANN_ACL_TENSOR_H
#define CANN_ACL_TENSOR_H
#include <algorithm>
#include <cstring>
#include "common.h"
#include <aclnn/aclnn_base.h>
#include "common.h"
#include <algorithm>
#include <cstring>
/**
* @brief Maps a ggml_type to its corresponding aclDataType.
@@ -43,6 +44,20 @@
*/
aclDataType ggml_cann_type_mapping(ggml_type type);
// Deleter for acl objects.
template <typename T, aclError (*DestroyFunc)(const T *)> struct acl_deleter {
void operator()(T * ptr) const noexcept {
if (ptr) {
ACL_CHECK(DestroyFunc(ptr));
}
}
};
using acl_tensor_ptr = std::unique_ptr<aclTensor, acl_deleter<aclTensor, aclDestroyTensor>>;
using acl_int_array_ptr = std::unique_ptr<aclIntArray, acl_deleter<aclIntArray, aclDestroyIntArray>>;
using acl_scalar_ptr = std::unique_ptr<aclScalar, acl_deleter<aclScalar, aclDestroyScalar>>;
using acl_tensor_list_ptr = std::unique_ptr<aclTensorList, acl_deleter<aclTensorList, aclDestroyTensorList>>;
/**
* @brief Creates an ACL tensor from a ggml_tensor with optional shape.
*
@@ -62,12 +77,12 @@ aclDataType ggml_cann_type_mapping(ggml_type type);
* @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
* @return Pointer to the created ACL tensor.
*/
aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
int64_t * ne = nullptr,
size_t * nb = nullptr,
int64_t dims = 0,
aclFormat format = ACL_FORMAT_ND,
size_t offset = 0);
acl_tensor_ptr ggml_cann_create_tensor(const ggml_tensor * tensor,
int64_t * ne = nullptr,
size_t * nb = nullptr,
int64_t dims = 0,
aclFormat format = ACL_FORMAT_ND,
size_t offset = 0);
/**
* @brief Template for creating an ACL tensor from provided parameters. typename TYPE
@@ -90,14 +105,14 @@ aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
* @return Pointer to the created ACL tensor.
*/
template <typename TYPE>
aclTensor * ggml_cann_create_tensor(void * data_ptr,
aclDataType dtype,
TYPE type_size,
int64_t * ne,
TYPE * nb,
int64_t dims,
aclFormat format = ACL_FORMAT_ND,
size_t offset = 0) {
acl_tensor_ptr ggml_cann_create_tensor(void * data_ptr,
aclDataType dtype,
TYPE type_size,
int64_t * ne,
TYPE * nb,
int64_t dims,
aclFormat format = ACL_FORMAT_ND,
size_t offset = 0) {
int64_t tmp_ne[GGML_MAX_DIMS * 2];
int64_t tmp_stride[GGML_MAX_DIMS * 2];
@@ -114,10 +129,75 @@ aclTensor * ggml_cann_create_tensor(void * data_ptr,
std::reverse(tmp_ne, tmp_ne + dims);
std::reverse(tmp_stride, tmp_stride + dims);
aclTensor * acl_tensor =
aclTensor * raw =
aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size, format, &acl_storage_len, 1, data_ptr);
return acl_tensor;
return acl_tensor_ptr(raw);
}
/**
* @brief Create an ACL int array resource wrapped in a smart pointer.
*
* This function constructs an aclIntArray from the provided int64_t values
* and returns it as an acl_int_array_ptr (a std::unique_ptr with a custom
* deleter). The returned pointer owns the ACL resource and will automatically
* destroy it via aclDestroyIntArray().
*
* @param value Pointer to the int64_t elements.
* @param size Number of elements in value.
*
* @return A smart pointer managing the created ACL int array.
*/
acl_int_array_ptr ggml_cann_create_int_array(const int64_t * value, uint64_t size);
/**
* @brief Create an ACL scalar resource wrapped in a smart pointer.
*
* This function constructs an aclScalar from the raw value pointer and ACL
* data type, then returns it as an acl_scalar_ptr (a std::unique_ptr with
* a custom deleter). The returned pointer owns the ACL scalar and will
* automatically destroy it via aclDestroyScalar().
*
* @param value Pointer to the raw scalar memory.
* @param dataType ACL data type of the scalar.
*
* @return A smart pointer managing the created ACL scalar.
*/
acl_scalar_ptr ggml_cann_create_scalar(void * value, aclDataType dataType);
/**
* @brief Create an ACL tensor list from multiple tensor smart pointers.
*
* This function accepts a variadic list of acl_tensor_ptr (a unique_ptr with
* custom deleter) and produces an aclTensorList using aclCreateTensorList().
*
* The lifecycle management of the tensor objects changes as follows:
* - aclCreateTensorList() takes ownership of the tensors
* - Each input smart pointer releases ownership using release()
* - As a result, the tensors will NOT be destroyed by unique_ptr
* - Instead, they will be destroyed when aclDestroyTensorList() is called
*
* This ensures correct ownership transfer and prevents double-free situations.
*
* @param acl_tensor_ptr Variadic template parameter; each argument must be
* a unique_ptr-like type supporting get() and release().
*
* @param tensors Variadic list of acl_tensor_ptr objects. Ownership of
* each tensor is transferred away from these smart pointers.
*
* @return A smart pointer (acl_tensor_list_ptr) owning the created ACL tensor list.
*
* @note This implementation is C++11 compatible. The ownership-release process is
* executed using a pack expansion inside an initializer list.
*/
template <typename... acl_tensor_ptr> acl_tensor_list_ptr ggml_cann_create_tensor_list(acl_tensor_ptr &&... tensors) {
aclTensor * raw_tensors[] = { tensors.get()... };
aclTensorList * raw = aclCreateTensorList(raw_tensors, sizeof...(tensors));
// aclTensor will release by aclTensorList, so release ownership without
// destroying the tensor
int dummy[] = { (tensors.release(), 0)... };
GGML_UNUSED(dummy);
return acl_tensor_list_ptr(raw);
}
/**
File diff suppressed because it is too large Load Diff
+42 -197
View File
@@ -23,33 +23,35 @@
#ifndef CANN_ACLNN_OPS
#define CANN_ACLNN_OPS
#include <unordered_set>
#include <functional>
#include "acl_tensor.h"
#include "common.h"
#include <aclnnop/aclnn_abs.h>
#include <aclnnop/aclnn_neg.h>
#include <aclnnop/aclnn_exp.h>
#include <aclnnop/aclnn_arange.h>
#include <aclnnop/aclnn_argsort.h>
#include <aclnnop/aclnn_cat.h>
#include <aclnnop/aclnn_clamp.h>
#include <aclnnop/aclnn_cos.h>
#include <aclnnop/aclnn_exp.h>
#include <aclnnop/aclnn_gelu.h>
#include <aclnnop/aclnn_gelu_v2.h>
#include <aclnnop/aclnn_sigmoid.h>
#include <aclnnop/aclnn_hardsigmoid.h>
#include <aclnnop/aclnn_hardswish.h>
#include <aclnnop/aclnn_leaky_relu.h>
#include <aclnnop/aclnn_relu.h>
#include <aclnnop/aclnn_silu.h>
#include <aclnnop/aclnn_tanh.h>
#include <aclnnop/aclnn_sqrt.h>
#include <aclnnop/aclnn_sin.h>
#include <aclnnop/aclnn_cos.h>
#include <aclnnop/aclnn_log.h>
#include <aclnnop/aclnn_sign.h>
#include <aclnnop/aclnn_norm.h>
#include <aclnnop/aclnn_logsoftmax.h>
#include "acl_tensor.h"
#include "common.h"
#include <aclnnop/aclnn_neg.h>
#include <aclnnop/aclnn_norm.h>
#include <aclnnop/aclnn_relu.h>
#include <aclnnop/aclnn_sigmoid.h>
#include <aclnnop/aclnn_sign.h>
#include <aclnnop/aclnn_silu.h>
#include <aclnnop/aclnn_sin.h>
#include <aclnnop/aclnn_sqrt.h>
#include <aclnnop/aclnn_tanh.h>
#include <functional>
#include <unordered_set>
/**
* @brief Repeats a ggml tensor along each dimension to match the dimensions
@@ -688,12 +690,12 @@ void aclnn_sin(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor *
* @param acl_src1 Output pointer to the created ACL tensor corresponding to src1.
* @param acl_dst Output pointer to the created ACL tensor corresponding to dst.
*/
void bcast_shape(ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst,
aclTensor ** acl_src0,
aclTensor ** acl_src1,
aclTensor ** acl_dst);
void bcast_shape(ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst,
acl_tensor_ptr & acl_src0,
acl_tensor_ptr & acl_src1,
acl_tensor_ptr & acl_dst);
/**
* @brief Computes the 1D transposed convolution (deconvolution) of a ggml
@@ -873,83 +875,6 @@ template <typename... Args> void register_acl_resources(std::vector<any_acl_reso
(vec.emplace_back(make_acl_resource(args)), ...);
}
/**
* @brief Task class that wraps the execution of an aclnn function call.
*/
class aclnn_task : public cann_task {
public:
aclnn_task(aclnn_func_t aclnn_func,
void * workspace_addr,
uint64_t workspace_size,
aclOpExecutor * executor,
aclrtStream stream) :
aclnn_func_(aclnn_func),
workspace_addr_(workspace_addr),
workspace_size_(workspace_size),
executor_(executor),
stream_(stream) {}
virtual void run_task() override { ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_)); }
private:
aclnn_func_t aclnn_func_;
void * workspace_addr_;
uint64_t workspace_size_;
aclOpExecutor * executor_;
aclrtStream stream_;
};
/**
* @brief Task class that releases ACL resources after usage.
*/
class release_resource_task : public cann_task {
public:
release_resource_task(std::vector<any_acl_resource> && resources) { resource_ = std::move(resources); }
virtual void run_task() override { resource_.clear(); }
private:
std::vector<any_acl_resource> resource_;
};
/**
* @brief Task class for performing asynchronous memory copy operations.
*/
class async_memcpy_task : public cann_task {
public:
async_memcpy_task(void * dst, const void * src, size_t size, aclrtMemcpyKind kind, aclrtStream stream) :
dst_(dst),
src_(src),
size_(size),
kind_(kind),
stream_(stream) {}
virtual void run_task() override { ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_)); }
private:
void * dst_;
const void * src_;
size_t size_;
aclrtMemcpyKind kind_;
aclrtStream stream_;
};
/**
* @brief Task class for performing asynchronous memory set operations.
*/
class async_memset_task : public cann_task {
public:
async_memset_task(void * buffer, size_t size, int32_t value, aclrtStream stream) :
buffer_(buffer),
size_(size),
value_(value),
stream_(stream) {}
virtual void run_task() override { ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_)); }
private:
void * buffer_;
size_t size_;
int32_t value_;
aclrtStream stream_;
};
/**
* @brief Launches an asynchronous task using the memory allocator.
*
@@ -968,95 +893,20 @@ class async_memset_task : public cann_task {
* same stream are executed in queue order.
*/
#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \
do { \
uint64_t workspaceSize = 0; \
aclOpExecutor * executor; \
void * workspaceAddr = nullptr; \
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
/* workspace should alloced in main thread to keep malloc order when using vmm. */ \
if (workspaceSize > 0) { \
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
workspaceAddr = workspace_allocator.get(); \
} \
if (CTX.async_mode) { \
auto task = \
std::make_unique<aclnn_task>(aclnn##OP_NAME, workspaceAddr, workspaceSize, executor, CTX.stream()); \
CTX.task_queue.submit_task(std::move(task)); \
} else { \
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \
} \
#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \
do { \
uint64_t workspaceSize = 0; \
aclOpExecutor * executor; \
void * workspaceAddr = nullptr; \
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
/* workspace should alloced in main thread to keep malloc order when using vmm. */ \
if (workspaceSize > 0) { \
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
workspaceAddr = workspace_allocator.get(); \
} \
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \
} while (0)
/**
* @brief Registers and releases multiple ACL resources, optionally deferring the release
* using a task.
*
* @tparam Args Types of the ACL resources.
* @param ctx Backend context which manages task submission and async mode.
* @param args Pointers to ACL resources to be released.
*/
template <typename... Args> void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) {
std::vector<any_acl_resource> resources;
register_acl_resources(resources, std::forward<Args>(args)...);
if (ctx.async_mode) {
auto task = std::make_unique<release_resource_task>(std::move(resources));
ctx.task_queue.submit_task(std::move(task));
}
}
/**
* @brief Performs an asynchronous memory copy operation, optionally deferred via task submission.
*
* @param ctx Backend context containing stream and async configuration.
* @param dst Destination memory address.
* @param src Source memory address.
* @param len Size of memory to copy (in bytes).
* @param kind Type of memory copy (host-to-device, device-to-host, etc).
*/
inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx,
void * dst,
const void * src,
size_t len,
aclrtMemcpyKind kind) {
if (ctx.async_mode) {
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx.stream());
ctx.task_queue.submit_task(std::move(task));
} else {
ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx.stream()));
}
}
inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx,
void * dst,
const void * src,
size_t len,
aclrtMemcpyKind kind) {
if (ctx->async_mode) {
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx->stream());
ctx->task_queue.submit_task(std::move(task));
} else {
ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx->stream()));
}
}
/**
* @brief Performs an asynchronous memory set operation, optionally deferred via task submission.
*
* @param ctx Backend context containing stream and async configuration.
* @param buffer Memory buffer to be set.
* @param size Size of the memory buffer (in bytes).
* @param value Value to set in the buffer.
*/
inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer, size_t size, int value) {
if (ctx.async_mode) {
auto task = std::make_unique<async_memset_task>(buffer, size, value, ctx.stream());
ctx.task_queue.submit_task(std::move(task));
} else {
ACL_CHECK(aclrtMemsetAsync(buffer, size, value, size, ctx.stream()));
}
}
/**
* @brief Performs sparse expert-based matrix multiplication using the CANN backend.
*
@@ -1129,15 +979,11 @@ template <auto binary_op> void ggml_cann_binary_op(ggml_backend_cann_context & c
ggml_tensor * src0 = dst->src[0];
ggml_tensor * src1 = dst->src[1];
aclTensor * acl_src0;
aclTensor * acl_src1;
aclTensor * acl_dst;
acl_tensor_ptr acl_src0, acl_src1, acl_dst;
// Need bcast
bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst);
binary_op(ctx, acl_src0, acl_src1, acl_dst);
ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst);
bcast_shape(src0, src1, dst, acl_src0, acl_src1, acl_dst);
binary_op(ctx, acl_src0.get(), acl_src1.get(), acl_dst.get());
}
/**
@@ -1147,7 +993,7 @@ template <auto binary_op> void ggml_cann_binary_op(ggml_backend_cann_context & c
* and stores the result in the destination tensor.
*
* @tparam unary_op A callable with the signature:
* void(ggml_backend_cann_context&, aclTensor*, aclTensor*)
* void(ggml_backend_cann_context&, aclTensor *, aclTensor *)
* where the first aclTensor is the source and the second is the destination.
* @param ctx The CANN backend context for managing resources and execution.
* @param dst The destination tensor. Its src[0] is treated as the input tensor.
@@ -1156,11 +1002,10 @@ template <void unary_op(ggml_backend_cann_context &, aclTensor *, aclTensor *)>
void ggml_cann_op_unary(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src = dst->src[0];
aclTensor * acl_src = ggml_cann_create_tensor(src);
aclTensor * acl_dst = ggml_cann_create_tensor(dst);
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
unary_op(ctx, acl_src, acl_dst);
ggml_cann_release_resources(ctx, acl_src, acl_dst);
unary_op(ctx, acl_src.get(), acl_dst.get());
}
/**
+19 -150
View File
@@ -23,26 +23,26 @@
#ifndef CANN_COMMON_H
#define CANN_COMMON_H
#include <acl/acl.h>
#include <cstdio>
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <atomic>
#include <condition_variable>
#include <mutex>
#include <thread>
#include <unistd.h>
#include <functional>
#include <optional>
#include <list>
#include "../ggml-impl.h"
#include "../include/ggml-cann.h"
#include "../include/ggml.h"
#include "../ggml-impl.h"
#include <acl/acl.h>
#include <unistd.h>
#include <atomic>
#include <condition_variable>
#include <cstdio>
#include <functional>
#include <iostream>
#include <list>
#include <map>
#include <memory>
#include <mutex>
#include <optional>
#include <string>
#include <thread>
#include <vector>
#define MATRIX_ROW_PADDING 512
#define GGML_CANN_MAX_STREAMS 8
@@ -214,130 +214,6 @@ struct ggml_cann_pool_alloc {
ggml_cann_pool_alloc & operator=(ggml_cann_pool_alloc &&) = delete;
};
/**
* @brief Function pointer type for ACLNN operator calls.
*/
using aclnn_func_t = aclnnStatus (*)(void *, uint64_t, aclOpExecutor *, aclrtStream);
/**
* @brief Base class for all CANN tasks to be submitted to the task queue.
*
* Users should override the run_task() method with actual task logic.
*/
class cann_task {
public:
virtual void run_task() {}
};
/**
* @brief A lock-free ring-buffer based task queue for asynchronously executing cann_task instances.
*/
class cann_task_queue {
public:
/**
* @brief Constructs a task queue with a fixed power-of-two capacity for a specific device.
*
* @param capacity Queue capacity. Must be a power of 2.
* @param device Target device ID (used for context setting).
*/
explicit cann_task_queue(size_t capacity, int32_t device) :
buffer_(capacity),
capacity_(capacity),
head_(0),
tail_(0),
running_(false),
device_(device) {
GGML_ASSERT((capacity & (capacity - 1)) == 0 && "capacity must be power of 2");
mask_ = capacity_ - 1;
}
/**
* @brief Attempts to enqueue a task into the queue.
*
* @param item Unique pointer to the task.
* @return true if the task was successfully enqueued, false if the queue was full.
*/
bool enqueue(std::unique_ptr<cann_task> && item) {
size_t next_tail = (tail_ + 1) & mask_;
if (next_tail == head_) {
return false;
}
buffer_[tail_] = std::move(item);
std::atomic_thread_fence(std::memory_order_release);
tail_ = next_tail;
return true;
}
/**
* @brief Submits a task to the queue, and starts the worker thread if not already running.
*
* @param task Task to be submitted.
*/
void submit_task(std::unique_ptr<cann_task> && task) {
while (!enqueue(std::move(task))) {
std::this_thread::yield();
continue;
}
if (!running_) {
running_ = true;
thread_ = std::thread(&cann_task_queue::execute, this);
}
}
/**
* @brief Waits until the queue is completely empty and no tasks are being processed.
*/
void wait() {
while (running_ && head_ != tail_) {
std::this_thread::yield();
continue;
}
}
/**
* @brief Stops the task queue and joins the worker thread.
*/
void stop() {
running_ = false;
if (thread_.joinable()) {
thread_.join();
}
}
private:
/**
* @brief Worker thread function that continuously dequeues and executes tasks.
*/
void execute() {
ggml_cann_set_device(device_);
while (running_) {
if (head_ == tail_) {
std::this_thread::yield();
continue;
}
std::atomic_thread_fence(std::memory_order_acquire);
buffer_[head_]->run_task();
buffer_[head_].reset();
head_ = (head_ + 1) & mask_;
}
}
std::vector<std::unique_ptr<cann_task>> buffer_;
const size_t capacity_;
size_t mask_;
size_t head_;
size_t tail_;
bool running_;
std::thread thread_;
int32_t device_;
};
#ifdef USE_ACL_GRAPH
struct ggml_graph_node_properties {
// dst tensor
@@ -474,7 +350,6 @@ struct ggml_backend_cann_context {
ggml_cann_graph_lru_cache graph_lru_cache;
bool acl_graph_mode = true;
#endif
cann_task_queue task_queue;
bool async_mode;
// Rope Cache
ggml_cann_rope_cache rope_cache;
@@ -488,15 +363,10 @@ struct ggml_backend_cann_context {
* @brief Constructor for initializing the context with a given device.
* @param device Device ID.
*/
explicit ggml_backend_cann_context(int device) :
device(device),
name("CANN" + std::to_string(device)),
task_queue(1024, device) {
explicit ggml_backend_cann_context(int device) : device(device), name("CANN" + std::to_string(device)) {
ggml_cann_set_device(device);
description = aclrtGetSocName();
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__, device, async_mode ? "ON" : "OFF");
#ifdef USE_ACL_GRAPH
acl_graph_mode = parse_bool(get_env("GGML_CANN_ACL_GRAPH").value_or("on"));
GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", __func__, device, acl_graph_mode ? "GRAPH" : "EAGER",
@@ -509,7 +379,6 @@ struct ggml_backend_cann_context {
*/
~ggml_backend_cann_context() {
ggml_cann_set_device(device);
task_queue.stop();
if (copy_event != nullptr) {
ACL_CHECK(aclrtDestroyEvent(copy_event));
}
+25 -22
View File
@@ -22,24 +22,24 @@
#include "ggml-cann.h"
#include <acl/acl.h>
#include <stdarg.h>
#include <aclnnop/aclnn_trans_matmul_weight.h>
#include "ggml-backend-impl.h"
#include "ggml-cann/aclnn_ops.h"
#include "ggml-cann/common.h"
#include "ggml-impl.h"
#include "ggml.h"
#include <acl/acl.h>
#include <aclnnop/aclnn_trans_matmul_weight.h>
#include <stdarg.h>
#include <chrono>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <mutex>
#include <queue>
#include <chrono>
#include <unordered_set>
#include <optional>
#include "ggml-impl.h"
#include "ggml-backend-impl.h"
#include "ggml-cann/aclnn_ops.h"
#include "ggml-cann/common.h"
#include "ggml.h"
#include <queue>
#include <unordered_set>
#define GGML_COMMON_DECL_C
@@ -1177,19 +1177,18 @@ static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
* across calls. This reduces overhead from repeated memory allocation and deallocation.
*/
static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) {
aclTensor * weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset);
uint64_t workspaceSize = 0;
acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset);
uint64_t workspaceSize = 0;
aclOpExecutor * executor;
// TransMatmulWeight
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed, &workspaceSize, &executor));
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed.get(), &workspaceSize, &executor));
// Avoid frequent malloc/free of the workspace.
g_nz_workspaces[device].realloc(workspaceSize);
void * g_nz_workspace = g_nz_workspaces[device].get();
ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));
ACL_CHECK(aclDestroyTensor(weightTransposed));
}
// TODO: need handle tensor which has paddings.
@@ -1641,7 +1640,7 @@ ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
},
/* .device = */
ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
/* .context = */ nullptr,
};
@@ -1949,7 +1948,8 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
GGML_ASSERT(!ggml_is_quantized(tensor->type));
ggml_cann_async_memcpy(cann_ctx, (char *) tensor->data + offset, data, size, ACL_MEMCPY_HOST_TO_DEVICE);
ACL_CHECK(aclrtMemcpyAsync((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE,
cann_ctx->stream()));
}
/**
@@ -1974,7 +1974,8 @@ static void ggml_backend_cann_get_tensor_async(ggml_backend_t backend,
GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
GGML_ASSERT(!ggml_is_quantized(tensor->type));
ggml_cann_async_memcpy(cann_ctx, data, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST);
ACL_CHECK(aclrtMemcpyAsync(data, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST,
cann_ctx->stream()));
}
/**
@@ -2035,7 +2036,6 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src,
ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
// wait for task_queue empty to keep task order.
cann_ctx_src->task_queue.wait();
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE,
cann_ctx_src->stream()));
// record event on src stream after the copy
@@ -2068,7 +2068,6 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src,
*/
static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
cann_ctx->task_queue.wait();
ggml_cann_set_device(cann_ctx->device);
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
}
@@ -2485,6 +2484,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
if (mode & GGML_ROPE_TYPE_VISION) {
return false;
}
if (op->src[0]->ne[0] > 896) {
return false;
}
#ifdef ASCEND_310P
if (!ggml_is_contiguous(op->src[0])) {
return false;
@@ -2521,10 +2523,11 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
// value of paddingW should be at most half of kernelW
return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
}
case GGML_OP_SUM:
return ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM:
case GGML_OP_CROSS_ENTROPY_LOSS:
case GGML_OP_DUP:
case GGML_OP_SUM:
case GGML_OP_IM2COL:
case GGML_OP_CONCAT:
case GGML_OP_REPEAT:
+29 -36
View File
@@ -145,26 +145,27 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
include(CheckCXXSourceRuns)
function(check_arm_feature tag code)
macro(check_arm_feature tag feature code)
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+${tag}")
check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag})
if (GGML_MACHINE_SUPPORTS_${tag})
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+${tag}" PARENT_SCOPE)
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+${tag}")
else()
set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+no${tag}")
check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag})
if (GGML_MACHINE_SUPPORTS_no${tag})
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+no${tag}" PARENT_SCOPE)
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+no${tag}")
list(APPEND ARCH_FLAGS -U__ARM_FEATURE_${feature})
endif()
endif()
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
endfunction()
endmacro()
check_arm_feature(dotprod "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
check_arm_feature(i8mm "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
check_arm_feature(sve "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
check_arm_feature(sme "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
check_arm_feature(dotprod DOTPROD "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
check_arm_feature(i8mm MATMUL_INT8 "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
check_arm_feature(sve SVE "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
check_arm_feature(sme SME "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
list(APPEND ARCH_FLAGS "${ARM_NATIVE_FLAG}${ARM_NATIVE_FLAG_FIX}")
else()
@@ -216,35 +217,27 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
endif()
endif()
# show enabled features
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
set(FEAT_INPUT_FILE "NUL")
else()
set(FEAT_INPUT_FILE "/dev/null")
endif()
message(STATUS "Checking for ARM features using flags:")
foreach(flag IN LISTS ARCH_FLAGS)
message(STATUS " ${flag}")
endforeach()
execute_process(
COMMAND ${CMAKE_C_COMPILER} ${ARCH_FLAGS} -dM -E -
INPUT_FILE ${FEAT_INPUT_FILE}
OUTPUT_VARIABLE ARM_FEATURE
RESULT_VARIABLE ARM_FEATURE_RESULT
)
if (ARM_FEATURE_RESULT)
message(WARNING "Failed to get ARM features")
else()
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
if (NOT ${feature_pos} EQUAL -1)
# Special handling for MATMUL_INT8 when machine doesn't support i8mm
if ("${feature}" STREQUAL "MATMUL_INT8" AND GGML_MACHINE_SUPPORTS_noi8mm)
message(STATUS "ARM feature ${feature} detected but unsetting due to machine not supporting i8mm")
list(APPEND ARCH_FLAGS -U__ARM_FEATURE_MATMUL_INT8)
else()
message(STATUS "ARM feature ${feature} enabled")
endif()
endif()
endforeach()
endif()
include(CheckCXXSourceCompiles)
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS}")
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
set(ARM_FEATURE "HAVE_${feature}")
check_cxx_source_compiles(
"
#if !defined(__ARM_FEATURE_${feature})
# error \"Feature ${feature} is not defined\"
#endif
int main() { return 0; }
"
${ARM_FEATURE}
)
endforeach()
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
endif()
elseif (GGML_SYSTEM_ARCH STREQUAL "x86")
message(STATUS "x86 detected")
+6 -6
View File
@@ -646,7 +646,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
__m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
int64_t xstart = 0;
int anr = nr - nr%16; // Used to align nr with boundary of 16
#ifdef __AVX512F__
#if defined(__AVX512BW__) && defined(__AVX512DQ__)
int anc = nc - nc%16; // Used to align nc with boundary of 16
// Mask to mask out nibbles from packed bytes expanded to 512 bit length
const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
@@ -1041,7 +1041,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
xstart = anc/8;
y = 0;
}
#endif // __AVX512F__
#endif // __AVX512BW__ && __AVX512DQ__
// Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
@@ -1989,7 +1989,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
__m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
int64_t xstart = 0;
int anr = nr - nr % 16;; // Used to align nr with boundary of 16
#ifdef __AVX512F__
#if defined(__AVX512BW__) && defined(__AVX512DQ__)
int anc = nc - nc % 16; // Used to align nc with boundary of 16
// Mask to mask out nibbles from packed bytes expanded to 512 bit length
const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
@@ -2727,7 +2727,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
xstart = anc/8;
y = 0;
}
#endif //AVX512F
#endif // __AVX512BW__ && __AVX512DQ__
// Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
for (; y < anr / 4; y += 4) {
@@ -3467,7 +3467,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
__m256i scalesmask2 = _mm256_castsi128_si256(scalesmask2_sse);
scalesmask2 = _mm256_permute2f128_si256(scalesmask2, scalesmask2, 0);
#ifdef __AVX512F__
#if defined(__AVX512BW__) && defined(__AVX512DQ__)
int anc = nc - nc % 16; // Used to align nc with boundary of 16
@@ -4947,7 +4947,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
y = 0;
}
#endif //AVX512F
#endif // __AVX512BW__ && __AVX512DQ__
// Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
for (; y < anr / 4; y += 4) {
+38
View File
@@ -318,6 +318,44 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_librar
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->op == GGML_OP_CUMSUM);
char base[256];
char name[256];
snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->op == GGML_OP_CUMSUM);
char base[256];
char name[256];
snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32);
+2
View File
@@ -113,6 +113,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
+2 -1
View File
@@ -870,6 +870,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_SUM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
case GGML_OP_SUM_ROWS:
case GGML_OP_CUMSUM:
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
@@ -988,7 +989,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return false;
}
case GGML_TYPE_I32:
return op->type == GGML_TYPE_F32;
return op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32;
default:
return false;
};
+39
View File
@@ -612,6 +612,45 @@ typedef struct {
uint64_t nb3;
} ggml_metal_kargs_sum_rows;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int64_t net0;
int64_t net1;
int64_t net2;
int64_t net3;
uint64_t nbt0;
uint64_t nbt1;
uint64_t nbt2;
uint64_t nbt3;
bool outb;
} ggml_metal_kargs_cumsum_blk;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int64_t net0;
int64_t net1;
int64_t net2;
int64_t net3;
uint64_t nbt0;
uint64_t nbt1;
uint64_t nbt2;
uint64_t nbt3;
} ggml_metal_kargs_cumsum_add;
typedef struct {
int32_t ne00;
int32_t ne01;
+184 -71
View File
@@ -311,6 +311,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_sum_rows(ctx, idx);
} break;
case GGML_OP_CUMSUM:
{
n_fuse = ggml_metal_op_cumsum(ctx, idx);
} break;
case GGML_OP_SOFT_MAX:
{
n_fuse = ggml_metal_op_soft_max(ctx, idx);
@@ -539,7 +543,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
@@ -585,7 +589,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
@@ -694,7 +698,7 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float scale;
float bias;
@@ -733,7 +737,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float min;
float max;
@@ -772,7 +776,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
int64_t n = ggml_nelements(op);
@@ -802,7 +806,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
if (op->src[1]) {
GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));
@@ -834,18 +838,6 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);
//[encoder setComputePipelineState:pipeline];
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
//if (src1) {
// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
//} else {
// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
//}
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
//[encoder setBytes:&args length:sizeof(args) atIndex:3];
//[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
@@ -907,7 +899,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_kargs_sum_rows args = {
/*.ne00 =*/ ne00,
@@ -941,14 +933,6 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
//[encoder setComputePipelineState:pipeline];
//[encoder setBytes:&args length:sizeof(args) atIndex:0];
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
//[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
//[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
@@ -961,6 +945,149 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
int nth = 1;
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
nth *= 2;
}
GGML_ASSERT(ne00 <= nth*nth);
const int64_t net0 = (ne00 + nth - 1) / nth;
const int64_t net1 = ne01;
const int64_t net2 = ne02;
const int64_t net3 = ne03;
const uint64_t nbt0 = sizeof(float);
const uint64_t nbt1 = net0*nbt0;
const uint64_t nbt2 = net1*nbt1;
const uint64_t nbt3 = net2*nbt2;
const size_t smem = GGML_PAD(32*sizeof(float), 16);
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_buffer_id bid_tmp = bid_dst;
bid_tmp.offs += ggml_nbytes(op);
{
ggml_metal_kargs_cumsum_blk args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.net0 =*/ net0,
/*.net1 =*/ net1,
/*.net2 =*/ net2,
/*.net3 =*/ net3,
/*.nbt0 =*/ nbt0,
/*.nbt1 =*/ nbt1,
/*.nbt2 =*/ nbt2,
/*.nbt3 =*/ nbt3,
/*.outb =*/ ne00 > nth,
};
ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
}
if (ne00 > nth) {
ggml_metal_op_concurrency_reset(ctx);
{
ggml_metal_kargs_cumsum_blk args = {
/*.ne00 =*/ net0,
/*.ne01 =*/ net1,
/*.ne02 =*/ net2,
/*.ne03 =*/ net3,
/*.nb00 =*/ nbt0,
/*.nb01 =*/ nbt1,
/*.nb02 =*/ nbt2,
/*.nb03 =*/ nbt3,
/*.net0 =*/ net0,
/*.net1 =*/ net1,
/*.net2 =*/ net2,
/*.net3 =*/ net3,
/*.nbt0 =*/ nbt0,
/*.nbt1 =*/ nbt1,
/*.nbt2 =*/ nbt2,
/*.nbt3 =*/ nbt3,
/*.outb =*/ false,
};
ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1);
}
ggml_metal_op_concurrency_reset(ctx);
{
ggml_metal_pipeline_t pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
ggml_metal_kargs_cumsum_add args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.net0 =*/ net0,
/*.net1 =*/ net1,
/*.net2 =*/ net2,
/*.net3 =*/ net3,
/*.nbt0 =*/ nbt0,
/*.nbt1 =*/ nbt1,
/*.nbt2 =*/ nbt2,
/*.nbt3 =*/ nbt3,
};
ggml_metal_encoder_set_pipeline(enc, pipeline_add);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
}
}
return 1;
}
int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
@@ -972,7 +1099,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
@@ -1017,7 +1144,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
@@ -1081,7 +1208,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float scale;
float max_bias;
@@ -1169,7 +1296,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_kargs_ssm_conv args = {
/*.ne00 =*/ ne00,
@@ -1224,7 +1351,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);
GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const ggml_tensor * src3 = op->src[3];
const ggml_tensor * src4 = op->src[4];
@@ -1310,7 +1437,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
const int64_t T = op->src[0]->ne[2];
@@ -1351,7 +1478,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
@@ -1424,7 +1551,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const int32_t * opts = op->op_params;
ggml_op_pool op_pool = (ggml_op_pool) opts[0];
@@ -1488,7 +1615,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
GGML_ASSERT(ne00 == ne10);
@@ -1729,7 +1856,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
// src2 = ids
GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);
@@ -2191,8 +2318,6 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
need_sync = true;
} else {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
}
if (has_mask) {
@@ -2222,8 +2347,6 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
need_sync = true;
} else {
assert(ggml_metal_op_flash_attn_ext_extra_blk(op) == 0);
}
if (need_sync) {
@@ -2363,8 +2486,6 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
need_sync = true;
} else {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
}
if (need_sync) {
@@ -2695,7 +2816,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float eps;
memcpy(&eps, op->op_params, sizeof(float));
@@ -2743,7 +2864,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const int32_t ngrp = ((const int32_t *) op->op_params)[0];
@@ -2798,7 +2919,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float eps;
memcpy(&eps, op->op_params, sizeof(float));
@@ -2934,7 +3055,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
// make sure we have one or more position id(ne10) per token(ne02)
GGML_ASSERT(ne10 % ne02 == 0);
@@ -3028,7 +3149,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
const int32_t s1 = ((const int32_t *)(op->op_params))[1];
@@ -3178,7 +3299,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
@@ -3223,7 +3344,7 @@ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
@@ -3277,7 +3398,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const float sf0 = (float)ne0/op->src[0]->ne[0];
const float sf1 = (float)ne1/op->src[0]->ne[1];
@@ -3330,7 +3451,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_kargs_pad args = {
/*.ne00 =*/ ne00,
@@ -3374,7 +3495,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_kargs_pad_reflect_1d args = {
/*.ne00 =*/ ne00,
@@ -3418,7 +3539,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float start;
float step;
@@ -3436,12 +3557,6 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
//[encoder setComputePipelineState:pipeline];
//[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
//[encoder setBytes:&args length:sizeof(args) atIndex:1];
//[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
@@ -3460,7 +3575,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const int dim = op->op_params[0];
const int max_period = op->op_params[1];
@@ -3494,7 +3609,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_kargs_argmax args = {
/*.ne00 = */ ne00,
@@ -3535,7 +3650,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
@@ -3545,7 +3660,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
nth *= 2;
}
const int nptg = (ne00 + nth - 1)/nth;
const int npr = (ne00 + nth - 1)/nth;
// Metal kernels require the buffer size to be multiple of 16 bytes
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
@@ -3557,7 +3672,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
ggml_metal_buffer_id bid_tmp = bid_dst;
bid_tmp.offs += ggml_nbytes(op);
if ((int) ceil(std::log(nptg) / std::log(2)) % 2 == 1) {
if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
std::swap(bid_dst, bid_tmp);
}
@@ -3579,7 +3694,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, nptg*ne01, ne02, ne03, nth, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
@@ -3611,8 +3726,6 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
ggml_metal_encoder_set_threadgroup_memory_size(enc, 0, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
std::swap(bid_dst, bid_tmp);
@@ -3632,7 +3745,7 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float slope;
memcpy(&slope, op->op_params, sizeof(float));
@@ -3668,7 +3781,7 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
@@ -3704,7 +3817,7 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
+1
View File
@@ -52,6 +52,7 @@ int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx);
+1
View File
@@ -197,6 +197,7 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
} break;
case GGML_OP_CUMSUM:
case GGML_OP_ARGSORT:
{
res *= 2;
+214 -54
View File
@@ -1832,6 +1832,117 @@ typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
template<typename T>
kernel void kernel_cumsum_blk(
constant ggml_metal_kargs_cumsum_blk & args,
device const char * src0,
device char * tmp,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int ib = tgpig[0]/args.ne01;
const int i00 = ib*ntg.x;
const int i01 = tgpig[0]%args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
device const float * src0_row = (device const float *) (src0 +
args.nb01*i01 +
args.nb02*i02 +
args.nb03*i03);
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
float v = 0.0f;
if (i00 + tpitg.x < args.ne00) {
v = src0_row[i00 + tpitg.x];
}
float s = simd_prefix_inclusive_sum(v);
if (tiisg == N_SIMDWIDTH - 1) {
shmem_f32[sgitg] = s;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
shmem_f32[tiisg] = simd_prefix_exclusive_sum(shmem_f32[tiisg]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
s += shmem_f32[sgitg];
device float * dst_row = (device float *) dst +
args.ne00*i01 +
args.ne00*args.ne01*i02 +
args.ne00*args.ne01*args.ne02*i03;
if (i00 + tpitg.x < args.ne00) {
dst_row[i00 + tpitg.x] = s;
}
if (args.outb && tpitg.x == ntg.x - 1) {
device float * tmp_row = (device float *) tmp +
args.net0*i01 +
args.net0*args.net1*i02 +
args.net0*args.net1*args.net2*i03;
tmp_row[ib] = s;
}
}
typedef decltype(kernel_cumsum_blk<float>) kernel_cumsum_blk_t;
template [[host_name("kernel_cumsum_blk_f32")]] kernel kernel_cumsum_blk_t kernel_cumsum_blk<float>;
template<typename T>
kernel void kernel_cumsum_add(
constant ggml_metal_kargs_cumsum_add & args,
device const char * tmp,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int ib = tgpig[0]/args.ne01;
if (ib == 0) {
return;
}
const int i00 = ib*ntg.x;
const int i01 = tgpig[0]%args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
device const float * tmp_row = (device const float *) (tmp +
args.nbt1*i01 +
args.nbt2*i02 +
args.nbt3*i03);
device float * dst_row = (device float *) dst +
args.ne00*i01 +
args.ne00*args.ne01*i02 +
args.ne00*args.ne01*args.ne02*i03;
if (i00 + tpitg.x < args.ne00) {
dst_row[i00 + tpitg.x] += tmp_row[ib - 1];
}
}
typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
template<typename T>
kernel void kernel_soft_max(
constant ggml_metal_kargs_soft_max & args,
@@ -4543,7 +4654,7 @@ typedef void (argsort_t)(
constant ggml_metal_kargs_argsort & args,
device const char * src0,
device int32_t * dst,
threadgroup int32_t * smem_i32 [[threadgroup(0)]],
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]);
@@ -4553,7 +4664,7 @@ kernel void kernel_argsort_f32_i32(
constant ggml_metal_kargs_argsort & args,
device const char * src0,
device int32_t * dst,
threadgroup int32_t * smem_i32 [[threadgroup(0)]],
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
@@ -4565,10 +4676,10 @@ kernel void kernel_argsort_f32_i32(
const int i02 = tgpig[1];
const int i03 = tgpig[2];
device const float * x_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
// initialize indices
smem_i32[col] = i00 + col;
shmem_i32[col] = i00 + col;
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -4577,20 +4688,20 @@ kernel void kernel_argsort_f32_i32(
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (smem_i32[col] >= args.ne00 ||
(smem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
x_row[smem_i32[col]] > x_row[smem_i32[ixj]] :
x_row[smem_i32[col]] < x_row[smem_i32[ixj]]))
if (shmem_i32[col] >= args.ne00 ||
(shmem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))
) {
SWAP(smem_i32[col], smem_i32[ixj]);
SWAP(shmem_i32[col], shmem_i32[ixj]);
}
} else {
if (smem_i32[ixj] >= args.ne00 ||
(smem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
x_row[smem_i32[col]] < x_row[smem_i32[ixj]] :
x_row[smem_i32[col]] > x_row[smem_i32[ixj]]))
if (shmem_i32[ixj] >= args.ne00 ||
(shmem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))
) {
SWAP(smem_i32[col], smem_i32[ixj]);
SWAP(shmem_i32[col], shmem_i32[ixj]);
}
}
}
@@ -4603,7 +4714,7 @@ kernel void kernel_argsort_f32_i32(
if (i00 + col < args.ne00) {
dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
dst[col] = smem_i32[col];
dst[col] = shmem_i32[col];
}
}
@@ -4628,12 +4739,13 @@ kernel void kernel_argsort_merge_f32_i32(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
int im = tgpig[0] / args.ne01;
int i01 = tgpig[0] % args.ne01;
int i02 = tgpig[1];
int i03 = tgpig[2];
const int start = im * (2*args.len);
const int im = tgpig[0] / args.ne01;
const int i01 = tgpig[0] % args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
const int start = im * (2 * args.len);
const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
@@ -4657,54 +4769,101 @@ kernel void kernel_argsort_merge_f32_i32(
+ args.nb02*i02
+ args.nb03*i03);
for (int k = tpitg.x; k < (int) total; k += ntg.x) {
// find partition (i,j) such that i+j = k
int low = k > len1 ? k - len1 : 0;
int high = MIN(k, len0);
if (total == 0) {
return;
}
while (low < high) {
const int mid = (low + high) >> 1;
const int chunk = (total + ntg.x - 1) / ntg.x;
const int32_t idx0 = tmp0[mid];
const int32_t idx1 = tmp1[k - mid - 1];
const int k0 = tpitg.x * chunk;
const int k1 = min(k0 + chunk, total);
const float val0 = src0_row[idx0];
const float val1 = src0_row[idx1];
if (k0 >= total) {
return;
}
if (order == GGML_SORT_ORDER_ASC) {
if (val0 <= val1) {
low = mid + 1;
} else {
high = mid;
}
} else {
if (val0 >= val1) {
low = mid + 1;
} else {
high = mid;
}
}
int low = k0 > len1 ? k0 - len1 : 0;
int high = MIN(k0, len0);
// binary-search partition (i, j) such that i + j = k
while (low < high) {
const int mid = (low + high) >> 1;
const int32_t idx0 = tmp0[mid];
const int32_t idx1 = tmp1[k0 - mid - 1];
const float val0 = src0_row[idx0];
const float val1 = src0_row[idx1];
bool take_left;
if (order == GGML_SORT_ORDER_ASC) {
take_left = (val0 <= val1);
} else {
take_left = (val0 >= val1);
}
const int i = low;
const int j = k - i;
if (take_left) {
low = mid + 1;
} else {
high = mid;
}
}
int i = low;
int j = k0 - i;
// keep the merge fronts into registers
int32_t idx0 = 0;
float val0 = 0.0f;
if (i < len0) {
idx0 = tmp0[i];
val0 = src0_row[idx0];
}
int32_t idx1 = 0;
float val1 = 0.0f;
if (j < len1) {
idx1 = tmp1[j];
val1 = src0_row[idx1];
}
for (int k = k0; k < k1; ++k) {
int32_t out_idx;
if (i >= len0) {
out_idx = tmp1[j];
while (k < k1) {
dst[k++] = tmp1[j++];
}
break;
} else if (j >= len1) {
out_idx = tmp0[i];
while (k < k1) {
dst[k++] = tmp0[i++];
}
break;
} else {
const int32_t idx0 = tmp0[i];
const int32_t idx1 = tmp1[j];
bool take_left;
const float val0 = src0_row[idx0];
const float val1 = src0_row[idx1];
if (order == GGML_SORT_ORDER_ASC) {
take_left = (val0 <= val1);
} else {
take_left = (val0 >= val1);
}
out_idx = (order == GGML_SORT_ORDER_ASC)
? (val0 <= val1 ? idx0 : idx1)
: (val0 >= val1 ? idx0 : idx1);
if (take_left) {
out_idx = idx0;
++i;
if (i < len0) {
idx0 = tmp0[i];
val0 = src0_row[idx0];
}
} else {
out_idx = idx1;
++j;
if (j < len1) {
idx1 = tmp1[j];
val1 = src0_row[idx1];
}
}
}
dst[k] = out_idx;
@@ -6401,6 +6560,7 @@ template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>;
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, int32_t>;
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
template [[host_name("kernel_cpy_i32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, int32_t>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, bfloat>;
#endif
+1
View File
@@ -119,6 +119,7 @@ set(GGML_OPENCL_KERNELS
pad
repeat
mul_mat_f16_f32
mul_mm_f16_f32_kq_kqv
conv2d
conv2d_f16_f32
flash_attn_f32_f16
+171 -1
View File
@@ -407,6 +407,8 @@ struct ggml_backend_opencl_context {
cl_program program_mul_mv_f32_f32;
cl_program program_mul;
cl_program program_mul_mat_f16_f32_tiled;
cl_program program_mul_mm_f16_f32_kqv;
cl_program program_mul_mm_f16_f32_kq;
cl_program program_div;
cl_program program_sub;
cl_program program_norm;
@@ -481,6 +483,8 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_mul_mat_f16_f32;
cl_kernel kernel_mul_mat_f16_f32_l4;
cl_kernel kernel_mul_mat_f16_f32_tiled;
cl_kernel kernel_mul_mm_f16_f32_kqv;
cl_kernel kernel_mul_mm_f16_f32_kq;
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
@@ -1235,6 +1239,25 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// mul_mm_f16_f32_kq_kqv
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mm_f16_f32_kq_kqv.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mm_f16_f32_kq_kqv.cl");
#endif
backend_ctx->program_mul_mm_f16_f32_kqv =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts+" -DKQV ");
backend_ctx->program_mul_mm_f16_f32_kq =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_kqv = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_kqv, "mul_mm_f16_f32_kqv", &err), err));
CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_kq = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_kq, "mul_mm_f16_f32_kq", &err), err));
GGML_LOG_CONT(".");
}
// mul
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -5682,7 +5705,7 @@ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor *
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &eps));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*nth/sgs, NULL));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*sgs, NULL));
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
@@ -6665,6 +6688,146 @@ static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, co
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
}
static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const int ne10 = src1->ne[0];
const int ne11 = src1->ne[1];
const int ne12 = src1->ne[2];
const cl_ulong nb10 = src1->nb[0];
const int ne0 = dst->ne[0];
const int ne1 = dst->ne[1];
GGML_ASSERT(ne00 == ne10);
cl_kernel kernel;
cl_context context = backend_ctx->context;
cl_int status;
cl_image_format img_fmt_1d;
cl_image_desc img_desc_1d;
cl_buffer_region region;
cl_mem A_image1d;
cl_mem A_sub_buffer;
cl_mem B_sub_buffer;
cl_mem D_image1d;
cl_mem D_sub_buffer;
int M = ne01;
int N = ne1;
int K = ne00;
if (nb01 > nb02) {
// KQ
kernel = backend_ctx->kernel_mul_mm_f16_f32_kq;
} else {
// KQV
kernel = backend_ctx->kernel_mul_mm_f16_f32_kqv;
}
// create sub-buffer for A
// <--------------------------------------------> //
extra0 = src0->view_src ? (ggml_tensor_extra_cl *)src0->view_src->extra : (ggml_tensor_extra_cl *)src0->extra;
region.origin = (extra0->offset);
if (nb01 > nb02) {
// KQ
region.size = nb01 * ne01;
} else {
// KQV
region.size = nb02 * ne02;
}
A_sub_buffer = clCreateSubBuffer((extra0->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
// <--------------------------------------------> //
// create sub-buffer for B
// <--------------------------------------------> //
region.origin = (extra1->offset);
region.size = nb10 * ne10 * ne11 * ne12;
B_sub_buffer = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
// <--------------------------------------------> //
img_fmt_1d = {CL_RGBA, CL_FLOAT};
memset(&img_desc_1d, 0, sizeof(img_desc_1d));
img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
if (nb01 > nb02) {
img_desc_1d.image_width = (nb01 * ne01 / 4)/4;
}
else {
img_desc_1d.image_width = (nb02 * ne02 / 4)/4;
}
img_desc_1d.buffer = A_sub_buffer;
A_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);
CL_CHECK(status);
// create sub-buffer for output C
// <--------------------------------------------> //
region.origin = (extrad->offset);
region.size = ne0 * ne1 * dst->ne[2] * dst->nb[0]; // size of C in bytes
D_sub_buffer = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
// <--------------------------------------------> //
// create image for C output
// <--------------------------------------------> //
img_fmt_1d = {CL_R, CL_FLOAT};
memset(&img_desc_1d, 0, sizeof(img_desc_1d));
img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc_1d.image_width = ne0 * ne1 * dst->ne[2] * dst->nb[0] / 4;
img_desc_1d.buffer = D_sub_buffer;
D_image1d = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);
CL_CHECK(status);
// <--------------------------------------------> //
int offset_src0 = 0;
int offset_src1 = 0;
// set kernel args
// <--------------------------------------------> //
cl_uint k_arg = 0;
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d));
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &offset_src0));
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_sub_buffer));
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &offset_src1));
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &D_image1d));
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &extrad->offset));
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &M));
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &K));
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &N));
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &nb01));
size_t global_work_size[3] = {64, static_cast<size_t>(((M+63)/64)), static_cast<size_t>(((N+31)/32)*ne12)};
size_t local_work_size[3] = {64, 1, 2};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
// deallocate sub buffers and images
// <--------------------------------------------> //
CL_CHECK(clReleaseMemObject(A_image1d));
CL_CHECK(clReleaseMemObject(D_image1d));
CL_CHECK(clReleaseMemObject(A_sub_buffer));
CL_CHECK(clReleaseMemObject(B_sub_buffer));
CL_CHECK(clReleaseMemObject(D_sub_buffer));
}
static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
@@ -6731,6 +6894,13 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
cl_context context = backend_ctx->context;
if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){
if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0){
ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
return;
}
}
if (ne01 && ne1 && use_adreno_kernels(backend_ctx, src0)) {
// init CL objects
@@ -0,0 +1,273 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#define LM_FIRST_256B 0
#define LM_SECOND_256B 64
#define LM_THIRD_256B 128
#define LM_FOURTH_256B 192
inline float16 mm_load_a(
image1d_buffer_t matrix_A,
uint subMatrixAStartInElements,
int nb01,
int line_stride_matrix_A_in_bytes
) {
__private float8 regA;
size_t sub_block_id_m = get_local_id(0);
#ifdef KQV
uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * nb01/4);
#else // KQ
uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * line_stride_matrix_A_in_bytes/4);
#endif
regA.s0123 = read_imagef(matrix_A, a_texCoord/4);
regA.s4567 = read_imagef(matrix_A, (a_texCoord+4)/4);
return convert_float16(as_half16(regA));
}
inline float4 alu_32(
float16 regA,
__local float4* matrix_B_vec
) {
__private float4 rC = 0;
int i = get_sub_group_id() * 64;
rC += regA.s0 * matrix_B_vec[i];
rC += regA.s1 * matrix_B_vec[i + 16];
rC += regA.s4 * matrix_B_vec[i + 1];
rC += regA.s5 * matrix_B_vec[i + 17];
rC += regA.s8 * matrix_B_vec[i + 2];
rC += regA.s9 * matrix_B_vec[i + 18];
rC += regA.sc * matrix_B_vec[i + 3];
rC += regA.sd * matrix_B_vec[i + 19];
i += 32;
rC += regA.s2 * matrix_B_vec[i];
rC += regA.s3 * matrix_B_vec[i + 16];
rC += regA.s6 * matrix_B_vec[i + 1];
rC += regA.s7 * matrix_B_vec[i + 17];
rC += regA.sa * matrix_B_vec[i + 2];
rC += regA.sb * matrix_B_vec[i + 18];
rC += regA.se * matrix_B_vec[i + 3];
rC += regA.sf * matrix_B_vec[i + 19];
return rC;
}
inline float16 alu_16(
float16 regA,
__local float* matrix_B_local
) {
float16 out;
__local float4* matrix_B_vec = (__local float4*)matrix_B_local;
out.s0123 = alu_32(regA, matrix_B_vec);
out.s4567 = alu_32(regA, matrix_B_vec + 4);
out.s89ab = alu_32(regA, matrix_B_vec + 8);
out.scdef = alu_32(regA, matrix_B_vec + 12);
return out;
}
inline void mm_mad(
__local float* matrix_B_local,
float16 regA,
float8 regB,
uint b_localOffsetInWords,
float16* regC0_ptr,
float16* regC1_ptr
) {
int offset = b_localOffsetInWords + get_sub_group_id() * 256;
matrix_B_local[offset + LM_FIRST_256B] = regB.s0;
matrix_B_local[offset + LM_SECOND_256B] = regB.s1;
matrix_B_local[offset + LM_THIRD_256B] = regB.s2;
matrix_B_local[offset + LM_FOURTH_256B] = regB.s3;
float16 add0 = alu_16(regA, matrix_B_local);
*regC0_ptr += add0;
matrix_B_local[offset + LM_FIRST_256B] = regB.s4;
matrix_B_local[offset + LM_SECOND_256B] = regB.s5;
matrix_B_local[offset + LM_THIRD_256B] = regB.s6;
matrix_B_local[offset + LM_FOURTH_256B] = regB.s7;
float16 add1 = alu_16(regA, matrix_B_local);
*regC1_ptr += add1;
}
inline void mm_store_c_N(
__write_only image1d_buffer_t matrix_C,
float16 regC0,
float16 regC1,
uint subMatrixCStartInElements,
int line_stride_matrix_C_in_bytes,
int mask
) {
size_t sub_block_id_m = get_local_id(0);
uint strideInWords = line_stride_matrix_C_in_bytes/4;
uint c_coordInWords_0 = (subMatrixCStartInElements + sub_block_id_m);
uint c_coordInWords_1 = c_coordInWords_0 + 1 * strideInWords;
uint c_coordInWords_2 = c_coordInWords_0 + 2 * strideInWords;
uint c_coordInWords_3 = c_coordInWords_0 + 3 * strideInWords;
uint c_coordInWords_4 = c_coordInWords_0 + 4 * strideInWords;
uint c_coordInWords_5 = c_coordInWords_0 + 5 * strideInWords;
uint c_coordInWords_6 = c_coordInWords_0 + 6 * strideInWords;
uint c_coordInWords_7 = c_coordInWords_0 + 7 * strideInWords;
uint c_coordInWords_8 = c_coordInWords_0 + 8 * strideInWords;
uint c_coordInWords_9 = c_coordInWords_0 + 9 * strideInWords;
uint c_coordInWords_10 = c_coordInWords_0 + 10 * strideInWords;
uint c_coordInWords_11 = c_coordInWords_0 + 11 * strideInWords;
uint c_coordInWords_12 = c_coordInWords_0 + 12 * strideInWords;
uint c_coordInWords_13 = c_coordInWords_0 + 13 * strideInWords;
uint c_coordInWords_14 = c_coordInWords_0 + 14 * strideInWords;
uint c_coordInWords_15 = c_coordInWords_0 + 15 * strideInWords;
uint c_coordInWords_16 = c_coordInWords_0 + 16 * strideInWords;
uint c_coordInWords_17 = c_coordInWords_0 + 17 * strideInWords;
uint c_coordInWords_18 = c_coordInWords_0 + 18 * strideInWords;
uint c_coordInWords_19 = c_coordInWords_0 + 19 * strideInWords;
uint c_coordInWords_20 = c_coordInWords_0 + 20 * strideInWords;
uint c_coordInWords_21 = c_coordInWords_0 + 21 * strideInWords;
uint c_coordInWords_22 = c_coordInWords_0 + 22 * strideInWords;
uint c_coordInWords_23 = c_coordInWords_0 + 23 * strideInWords;
uint c_coordInWords_24 = c_coordInWords_0 + 24 * strideInWords;
uint c_coordInWords_25 = c_coordInWords_0 + 25 * strideInWords;
uint c_coordInWords_26 = c_coordInWords_0 + 26 * strideInWords;
uint c_coordInWords_27 = c_coordInWords_0 + 27 * strideInWords;
uint c_coordInWords_28 = c_coordInWords_0 + 28 * strideInWords;
uint c_coordInWords_29 = c_coordInWords_0 + 29 * strideInWords;
uint c_coordInWords_30 = c_coordInWords_0 + 30 * strideInWords;
uint c_coordInWords_31 = c_coordInWords_0 + 31 * strideInWords;
if (mask > 0) { write_imagef(matrix_C, c_coordInWords_0, regC0.s0); }
if (mask > 1) { write_imagef(matrix_C, c_coordInWords_1, regC0.s1); }
if (mask > 2) { write_imagef(matrix_C, c_coordInWords_2, regC0.s2); }
if (mask > 3) { write_imagef(matrix_C, c_coordInWords_3, regC0.s3); }
if (mask > 4) { write_imagef(matrix_C, c_coordInWords_4, regC0.s4); }
if (mask > 5) { write_imagef(matrix_C, c_coordInWords_5, regC0.s5); }
if (mask > 6) { write_imagef(matrix_C, c_coordInWords_6, regC0.s6); }
if (mask > 7) { write_imagef(matrix_C, c_coordInWords_7, regC0.s7); }
if (mask > 8) { write_imagef(matrix_C, c_coordInWords_8, regC0.s8); }
if (mask > 9) { write_imagef(matrix_C, c_coordInWords_9, regC0.s9); }
if (mask > 10) { write_imagef(matrix_C, c_coordInWords_10, regC0.sa); }
if (mask > 11) { write_imagef(matrix_C, c_coordInWords_11, regC0.sb); }
if (mask > 12) { write_imagef(matrix_C, c_coordInWords_12, regC0.sc); }
if (mask > 13) { write_imagef(matrix_C, c_coordInWords_13, regC0.sd); }
if (mask > 14) { write_imagef(matrix_C, c_coordInWords_14, regC0.se); }
if (mask > 15) { write_imagef(matrix_C, c_coordInWords_15, regC0.sf); }
if (mask > 16) { write_imagef(matrix_C, c_coordInWords_16, regC1.s0); }
if (mask > 17) { write_imagef(matrix_C, c_coordInWords_17, regC1.s1); }
if (mask > 18) { write_imagef(matrix_C, c_coordInWords_18, regC1.s2); }
if (mask > 19) { write_imagef(matrix_C, c_coordInWords_19, regC1.s3); }
if (mask > 20) { write_imagef(matrix_C, c_coordInWords_20, regC1.s4); }
if (mask > 21) { write_imagef(matrix_C, c_coordInWords_21, regC1.s5); }
if (mask > 22) { write_imagef(matrix_C, c_coordInWords_22, regC1.s6); }
if (mask > 23) { write_imagef(matrix_C, c_coordInWords_23, regC1.s7); }
if (mask > 24) { write_imagef(matrix_C, c_coordInWords_24, regC1.s8); }
if (mask > 25) { write_imagef(matrix_C, c_coordInWords_25, regC1.s9); }
if (mask > 26) { write_imagef(matrix_C, c_coordInWords_26, regC1.sa); }
if (mask > 27) { write_imagef(matrix_C, c_coordInWords_27, regC1.sb); }
if (mask > 28) { write_imagef(matrix_C, c_coordInWords_28, regC1.sc); }
if (mask > 29) { write_imagef(matrix_C, c_coordInWords_29, regC1.sd); }
if (mask > 30) { write_imagef(matrix_C, c_coordInWords_30, regC1.se); }
if (mask > 31) { write_imagef(matrix_C, c_coordInWords_31, regC1.sf); }
}
#define TILESIZE_K 16
#define TILESIZE_M 64
#define TILESIZE_N 32
#ifdef KQV
__kernel void mul_mm_f16_f32_kqv(
#else
__kernel void mul_mm_f16_f32_kq(
#endif
__read_only image1d_buffer_t matrix_A,
int offset0,
__global float* matrix_B,
int offset1,
__write_only image1d_buffer_t matrix_C,
int offsetd,
int M, int K, int N,
int D_A,
int D_B,
int nb01
) {
uint block_id_m = get_global_id(1);
uint block_id_n = get_global_id(2) % ((N+TILESIZE_N-1)/TILESIZE_N);
uint block_id_d = get_global_id(2) / ((N+TILESIZE_N-1)/TILESIZE_N);
__private float16 regA;
__private float8 regB;
__private float16 regC0;
__private float16 regC1;
const uint col = block_id_m * TILESIZE_M;
const uint row = block_id_n * TILESIZE_N;
const uint depth_A = block_id_d / (D_B/D_A);
const uint depth_B = block_id_d;
#ifdef KQV
int line_stride_matrix_A_in_bytes = nb01 * M;
int line_stride_matrix_B_in_bytes = K * N * 4;
#else
int line_stride_matrix_A_in_bytes = K * D_A * 2;
int line_stride_matrix_B_in_bytes = K * D_B * 4;
#endif
int line_stride_matrix_C_in_bytes = M * 4;
const uint strideAinElements = line_stride_matrix_A_in_bytes / 2;
const uint strideBinElements = line_stride_matrix_B_in_bytes / 4;
size_t sub_block_id_m = get_local_id(0);
uint b_localOffsetInWords = (sub_block_id_m/16)*16
+ ((((sub_block_id_m)>>0)&1)<<2)
+ ((((sub_block_id_m)>>1)&1)<<3)
+ ((((sub_block_id_m)>>2)&1)<<0)
+ ((((sub_block_id_m)>>3)&1)<<1);
uint2 b_globalOffsetInWords_xy = {((sub_block_id_m%4)*4), (sub_block_id_m>>2)};
uint b_globalOffsetInWords00, b_globalOffsetInWords16;
#ifdef KQV
b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*K;
b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * K);
uint subMatrixAStartInElements = depth_A * strideAinElements + col * nb01 / 2;
uint subMatrixBStartInElements = depth_B * strideBinElements + row * K;
#else
b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*line_stride_matrix_B_in_bytes/4;
b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * line_stride_matrix_B_in_bytes/4);
uint subMatrixAStartInElements = col * strideAinElements + depth_A * K;
uint subMatrixBStartInElements = row * strideBinElements + depth_B * K;
#endif
__local float matrix_B_local[1024];
for (uint step=0; step < K; step+=TILESIZE_K) {
size_t sub_block_id_m = get_local_id(0);
regA = mm_load_a(matrix_A, subMatrixAStartInElements, nb01, line_stride_matrix_A_in_bytes);
uint b_coordInWords00 = subMatrixBStartInElements + b_globalOffsetInWords00;
uint b_coordInWords16 = subMatrixBStartInElements + b_globalOffsetInWords16;
regB.s0123 = vload4(b_coordInWords00/4, matrix_B);
regB.s4567 = vload4(b_coordInWords16/4, matrix_B);
mm_mad(matrix_B_local, regA, regB, b_localOffsetInWords, &regC0, &regC1);
subMatrixAStartInElements += TILESIZE_K;
subMatrixBStartInElements += TILESIZE_K;
}
uint subMatrixCStartInElements = depth_B * N * M + row * M + col;
mm_store_c_N(matrix_C, regC0, regC1, subMatrixCStartInElements, line_stride_matrix_C_in_bytes, (N-block_id_n*32));
}
+25 -10
View File
@@ -134,6 +134,15 @@ kernel void kernel_rms_norm_mul(
src1 = src1 + offset1;
dst = dst + offsetd;
// The size of sum is sizeof(float)*subgroup_size.
// Each subgroup writes its partial sum to this array.
// So the number of subgroups per workgroup for this kernel cannot exceed the subgroup size.
// This is generally true -
// for subgroup size 64, workgroup size should be less than 4096 (the max is usually 1024).
if (get_sub_group_id() == 0) {
sum[get_sub_group_local_id()] = 0.0f;
}
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
@@ -148,24 +157,30 @@ kernel void kernel_rms_norm_mul(
sumf += dot(x[i00], x[i00]);
}
sumf = sub_group_reduce_add(sumf);
barrier(CLK_LOCAL_MEM_FENCE);
if (get_sub_group_local_id() == 0) {
sum[get_sub_group_id()] = sumf;
}
barrier(CLK_LOCAL_MEM_FENCE);
for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
if (get_local_id(0) < i) {
sum[get_local_id(0)] += sum[get_local_id(0) + i];
}
}
if (get_local_id(0) == 0) {
sum[0] /= ne00;
}
//for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
// if (get_local_id(0) < i) {
// sum[get_local_id(0)] += sum[get_local_id(0) + i];
// }
//}
//if (get_local_id(0) == 0) {
// sum[0] /= ne00;
//}
barrier(CLK_LOCAL_MEM_FENCE);
//barrier(CLK_LOCAL_MEM_FENCE);
float mean = sum[0];
sumf = sum[get_sub_group_local_id()];
sumf = sub_group_reduce_add(sumf);
float mean = sumf / ne00;
float scale = 1.0f/sqrt(mean + eps);
global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1);
+111 -249
View File
@@ -170,73 +170,31 @@ static __dpct_inline__ T op_trunc(T x) {
return sycl::trunc(x);
}
template<typename T>
static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_sgn(x[i]);
}
}
template<typename T, typename F>
static void unary_op_generic_kernel(
const T * x,
T * dst,
const int k,
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3,
const size_t nb0, const size_t nb1, const size_t nb2, const size_t nb3,
const size_t nbd0, const size_t nbd1, const size_t nbd2, const size_t nbd3,
const sycl::nd_item<1> & item_ct1,
F func) {
template<typename T>
static void unary_op_abs_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
(void) ne3;
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_abs(x[i]);
}
}
const int64_t i0 = i % ne0;
const int64_t i1 = (i / ne0) % ne1;
const int64_t i2 = (i / (ne0*ne1)) % ne2;
const int64_t i3 = i / (ne0*ne1*ne2);
template<typename T>
static void unary_op_elu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_elu(x[i]);
}
}
const char * src_base = (const char *) x;
char * dst_base = (char *) dst;
template<typename T>
static void unary_op_gelu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_gelu(x[i]);
}
}
const T * srcp = (const T *)(src_base + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3 );
T * dstp = (T *)(dst_base + i0*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3);
template<typename T>
static void unary_op_silu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_silu(x[i]);
}
}
template<typename T>
static void unary_op_gelu_quick_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_gelu_quick(x[i]);
}
}
template<typename T>
static void unary_op_gelu_erf_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_gelu_erf(x[i]);
}
}
template<typename T>
static void unary_op_tanh_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_tanh(x[i]);
}
}
template<typename T>
static void unary_op_relu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_relu(x[i]);
}
}
template<typename T>
static void unary_op_sigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_sigmoid(x[i]);
*dstp = func(*srcp);
}
}
@@ -261,27 +219,6 @@ static void unary_op_cos_kernel(const T * x, T * dst, const int k, const sycl::n
}
}
template<typename T>
static void unary_op_hardsigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_hardsigmoid(x[i]);
}
}
template<typename T>
static void unary_op_hardswish_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_hardswish(x[i]);
}
}
template<typename T>
static void unary_op_exp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_exp(x[i]);
}
}
template<typename T>
static void unary_op_log_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
@@ -289,19 +226,6 @@ static void unary_op_log_kernel(const T * x, T * dst, const int k, const sycl::n
}
}
template<typename T>
static void unary_op_neg_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_neg(x[i]);
}
}
template<typename T>
static void unary_op_step_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_step(x[i]);
}
}
template<typename T>
static void unary_op_leaky_relu_kernel(const T * x, T * dst, const int k, float negative_slope, const sycl::nd_item<1> &item_ct1) {
@@ -620,6 +544,48 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
}
}
template<typename F>
static inline void ggml_sycl_op_unary(
ggml_backend_sycl_context & ctx, ggml_tensor * dst, F func) {
ggml_tensor * src0 = dst->src[0];
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
const int64_t ne2 = dst->ne[2];
const int64_t ne3 = dst->ne[3];
const size_t nb0 = src0->nb[0];
const size_t nb1 = src0->nb[1];
const size_t nb2 = src0->nb[2];
const size_t nb3 = src0->nb[3];
const size_t nbd0 = dst->nb[0];
const size_t nbd1 = dst->nb[1];
const size_t nbd2 = dst->nb[2];
const size_t nbd3 = dst->nb[3];
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[=](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, 256);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
sycl::range<1>(256)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_generic_kernel(
src, dst_ptr, k_elements,
ne0, ne1, ne2, ne3,
nb0, nb1, nb2, nb3,
nbd0, nbd1, nbd2, nbd3,
item_ct1,
func
);
});
});
}
static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->type == GGML_TYPE_F32);
@@ -645,159 +611,75 @@ static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_ten
static inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, 256);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
sycl::range<1>(256)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_sgn_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
return op_sgn(x);
});
}
static inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, 256);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
sycl::range<1>(256)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_abs_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
return op_abs(x);
});
}
static inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, 256);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
sycl::range<1>(256)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_elu_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
return op_elu(x);
});
}
static inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, SYCL_SILU_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SILU_BLOCK_SIZE),
sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_silu_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
return op_silu(x);
});
}
static inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_gelu_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
return op_gelu(x);
});
}
static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_gelu_quick_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
return op_gelu_quick(x);
});
}
static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_gelu_erf_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
return op_gelu_erf(x);
});
}
static inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, SYCL_TANH_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_TANH_BLOCK_SIZE),
sycl::range<1>(SYCL_TANH_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_tanh_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
return op_tanh(x);
});
}
static inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_relu_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
return op_relu(x);
});
}
static inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, SYCL_HARDSIGMOID_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE),
sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_hardsigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
return op_hardsigmoid(x);
});
}
static inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, SYCL_HARDSWISH_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE),
sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_hardswish_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
return op_hardswish(x);
});
}
static inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_exp_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
return op_exp(x);
});
}
static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
@@ -814,42 +696,22 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor
}
static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_neg_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
return op_neg(x);
});
}
static inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); // Using NEG block size
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_step_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
return op_step(x);
});
}
static inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, SYCL_SIGMOID_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE),
sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_sigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
return op_sigmoid(x);
});
}
static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+6 -5
View File
@@ -4360,21 +4360,22 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
}
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_SGN:
case GGML_UNARY_OP_ABS:
case GGML_UNARY_OP_NEG:
case GGML_UNARY_OP_STEP:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SGN:
case GGML_UNARY_OP_ABS:
case GGML_UNARY_OP_ELU:
return true;
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_ROUND:
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,21 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
data_d[i] = D_TYPE(abs(float(data_a[i])));
}
@@ -7,6 +7,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_KHR_shader_subgroup_shuffle : enable
#extension GL_KHR_shader_subgroup_vote : enable
#include "types.glsl"
#include "flash_attn_base.glsl"
@@ -108,6 +109,38 @@ void main() {
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
float max_mask = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
masksh[c][r] = m;
max_mask = max(max_mask, m);
} else {
masksh[c][r] = float(0);
}
}
}
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
float Sf[Br][cols_per_thread];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
@@ -153,21 +186,6 @@ void main() {
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
} else {
masksh[c][r] = float(0);
}
}
}
barrier();
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float mvf = masksh[c * cols_per_iter + col_tid][r];
@@ -7,6 +7,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_vote : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable
@@ -148,6 +149,37 @@ void main() {
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
float mask_cache[Bc * Br / WorkGroupSize];
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
float max_mask = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
mask_cache[idx / WorkGroupSize] = m;
max_mask = max(max_mask, m);
}
}
}
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
@@ -208,7 +240,8 @@ void main() {
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
float f = mask_cache[idx / WorkGroupSize];
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f);
}
}
}
@@ -29,6 +29,10 @@ ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
return max(x, y);
}
float16_t maxReduceFp16(const in float16_t x, const in float16_t y) {
return max(x, y);
}
ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
return x;
}
@@ -142,6 +146,44 @@ void main() {
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
if (nem1_bounds_check) {
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv, mvmax;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
// skip the block if the mask is entirely -inf
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
continue;
}
} else {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
// Don't clamp against nem1 when GQA is enabled
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
// skip the block if the mask is entirely -inf
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
}
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
@@ -158,31 +200,7 @@ void main() {
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
if (nem1_bounds_check) {
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
} else {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
// Don't clamp against nem1 when GQA is enabled
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
}
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
}
// Clear padding elements to -inf, so they don't contribute to rowmax
@@ -0,0 +1,18 @@
#version 450
#include "rte.glsl"
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const float val = float(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(log(val));
}
@@ -11,29 +11,7 @@
#define EXPERT_COUNT 8
#endif
#include "types.glsl"
#ifndef MMQ
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#else
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
#ifdef B_TYPE_VEC2
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
#endif
#ifdef B_TYPE_VEC4
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
#endif
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
#ifdef MUL_MAT_ID
layout (binding = 4) readonly buffer IDS {int data_ids[];};
#endif
#include "mul_mat_vec_iface.glsl"
#include "dequant_funcs.glsl"
@@ -48,8 +26,7 @@ layout (push_constant) uniform parameter
uint batch_stride_b;
uint batch_stride_d;
uint enable_bias;
uint enable_scale;
uint fusion_flags;
#ifdef MUL_MAT_ID
uint nei0;
@@ -123,17 +100,24 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t
if (tid == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
if (p.enable_bias != 0) {
#ifdef MUL_MAT_ID
temp[j][n] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
#else
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
#endif
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
}
#ifdef MUL_MAT_ID
if (p.enable_scale != 0) {
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
}
#else
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
temp[j][n] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
temp[j][n] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
}
#endif
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
@@ -171,17 +155,24 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
temp[j][n] += tmpsh[j][n][s];
}
if (p.enable_bias != 0) {
#ifdef MUL_MAT_ID
temp[j][n] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
#else
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
#endif
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
}
#ifdef MUL_MAT_ID
if (p.enable_scale != 0) {
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
}
#else
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
temp[j][n] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
temp[j][n] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
}
#endif
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
@@ -209,17 +200,24 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
if (tid == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
if (p.enable_bias != 0) {
#ifdef MUL_MAT_ID
tmpsh[j][n][0] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
#else
tmpsh[j][n][0] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
#endif
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
}
#ifdef MUL_MAT_ID
if (p.enable_scale != 0) {
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
tmpsh[j][n][0] *= FLOAT_TYPE(data_bias[expert_idx]);
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_idx]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_idx]);
}
#else
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
tmpsh[j][n][0] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
}
#endif
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
@@ -0,0 +1,33 @@
#include "types.glsl"
#define MAT_VEC_FUSION_FLAGS_BIAS0 0x1
#define MAT_VEC_FUSION_FLAGS_BIAS1 0x2
#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4
#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
#ifndef MMQ
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_VEC4)
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
#endif
#else
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
#ifdef B_TYPE_VEC2
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
#endif
#ifdef B_TYPE_VEC4
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
#endif
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
layout (binding = 3) readonly buffer Fuse0 {D_TYPE data_fuse0[];};
layout (binding = 4) readonly buffer Fuse1 {D_TYPE data_fuse1[];};
#ifdef MUL_MAT_ID
layout (binding = 5) readonly buffer IDS {int data_ids[];};
#endif
@@ -8,14 +8,7 @@
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
#include "mul_mat_vec_iface.glsl"
layout (push_constant) uniform parameter
{
@@ -31,7 +24,7 @@ layout (push_constant) uniform parameter
uint nb03;
uint nb13;
uint nb23;
uint enable_bias;
uint fusion_flags;
} p;
shared FLOAT_TYPE tmp[BLOCK_SIZE];
@@ -120,9 +113,12 @@ void main() {
}
if (tid == 0) {
if (p.enable_bias != 0) {
tmp[0] += FLOAT_TYPE(data_bias[idst]);
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
tmp[0] += FLOAT_TYPE(data_fuse0[idst]);
}
dst[idst] = tmp[0];
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
tmp[0] += FLOAT_TYPE(data_fuse1[idst]);
}
data_d[idst] = tmp[0];
}
}
@@ -10,14 +10,7 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
#include "mul_mat_vec_iface.glsl"
layout(constant_id = 0) const int BLOCK_SIZE = 32;
// gqa_ratio is in the range [1,8]
@@ -31,7 +24,7 @@ layout (push_constant) uniform parameter
uint nchannels_y;
uint b_offset;
uint d_offset;
uint enable_bias;
uint fusion_flags;
} p;
#if !USE_SUBGROUP_ADD
@@ -151,10 +144,13 @@ void main() {
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
// dst is not transposed and not permuted
const uint idst = (channel + c)*nrows_dst + row_dst;
if (p.enable_bias != 0) {
temp[c] += FLOAT_TYPE(data_bias[idst]);
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
temp[c] += FLOAT_TYPE(data_fuse0[idst]);
}
dst[idst] = temp[c];
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
temp[c] += FLOAT_TYPE(data_fuse1[idst]);
}
data_d[idst] = temp[c];
}
}
}
@@ -300,7 +300,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]);
buf_a[buf_ib].scales = unpack8(uint32_t(data_a_packed16[ib_k].scales[iqs_k / 8])).xy; // vec4 used due to #12147
}
}
@@ -345,21 +345,22 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
// Repack 2x4 quants into one int
// Add the 3rd bit instead of subtracting it to allow packing the quants
const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
// vec4 for unpack8 used due to #12147
const i8vec2 vals00 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303)))).xy |
unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
const i8vec2 vals01 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303)))).xy |
unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
const i8vec2 vals10 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303)))).xy |
unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
const i8vec2 vals11 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303)))).xy |
unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) |
(pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4);
if (iqs == 0) {
const uint is = iqs_k / 4;
const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
(((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4)));
const i8vec2 scales = i8vec2(unpack8(uint32_t(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
(((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))).xy); // vec4 used due to #12147
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32);
}
@@ -516,15 +517,15 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint qh_idx = (iqs_k / 32) * 8 + iqs;
const uint qh_shift = ((iqs_k % 32) / 8) * 2;
const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals00 = (unpack8(int32_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))).xy |
unpack8(int32_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4)).xy) - int8_t(32);
const i8vec2 vals01 = (unpack8(int32_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))).xy |
unpack8(int32_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4)).xy) - int8_t(32);
buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y));
if (iqs == 0) {
const uint is = iqs_k / 4;
const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2]);
const i8vec2 scales = unpack8(int32_t(data_a_packed16[ib_k].scales[is / 2])).xy;
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
}
@@ -0,0 +1,20 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
data_d[i] = D_TYPE(-float(data_a[i]));
}
@@ -816,6 +816,9 @@ void process_shaders() {
std::string suffix = rte ? "_rte" : "";
string_to_spv("exp_f16" + suffix, "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("exp_f32" + suffix, "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"} , {"RTE16", rte ? "1" : "0"}});
string_to_spv("log_f16" + suffix, "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("log_f32" + suffix, "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
}
string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
@@ -827,6 +830,8 @@ void process_shaders() {
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("relu_f16", "relu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("neg_f16", "neg.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("neg_f32", "neg.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("tanh_f16", "tanh.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
@@ -835,6 +840,8 @@ void process_shaders() {
string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("hardswish_f16", "hardswish.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
for (auto rte : {false, true}) {
std::string suffix = rte ? "_rte" : "";
+33 -6
View File
@@ -5002,17 +5002,19 @@ struct test_mul_mat_vec_fusion : public test_case {
const bool b; // broadcast b matrix (only for use_id)
const bool with_bias;
const bool with_gate;
std::array<int64_t, 2> batch_dims;
test_mul_mat_vec_fusion(ggml_type type, ggml_glu_op op, int64_t m, int64_t n, int64_t k,
bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true)
: type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate) {
bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true,
std::array<int64_t, 2> batch_dims = {4, 2})
: type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate), batch_dims(batch_dims) {
if (use_id) {
GGML_ASSERT(n_used <= n_mats);
}
}
std::string vars() override {
return VARS_TO_STR11(type, glu_op, m, n, k, use_id, n_mats, n_used, b, with_bias, with_gate);
return VARS_TO_STR12(type, glu_op, m, n, k, use_id, n_mats, n_used, b, with_bias, with_gate, batch_dims);
}
std::string op_desc(ggml_tensor * t) override {
@@ -5038,8 +5040,8 @@ struct test_mul_mat_vec_fusion : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
if (!use_id) {
const int channels = 4;
const int samples = 2;
const int channels = batch_dims[0];
const int samples = batch_dims[1];
std::array<int64_t, 4> ne = { k, m, channels, samples };
std::array<int64_t, 4> ne0 = { k, n, channels, samples };
@@ -5062,6 +5064,11 @@ struct test_mul_mat_vec_fusion : public test_case {
}
ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;
std::array<int64_t, 4> bias2_ne = { out->ne[0], 1, channels, samples };
ggml_tensor * bias2 = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias2_ne.data());
out = ggml_add(ctx, out, bias2);
ggml_set_name(out, "out");
return out;
} else {
@@ -5089,6 +5096,11 @@ struct test_mul_mat_vec_fusion : public test_case {
}
ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;
std::array<int64_t, 4> scale_ne { 1, out->ne[1], out->ne[2], out->ne[3] };
ggml_tensor * scale = ggml_new_tensor(ctx, out->type, 4, scale_ne.data());
out = ggml_mul(ctx, out, scale);
ggml_set_name(out, "out");
return out;
}
@@ -7546,7 +7558,20 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_arange());
test_cases.emplace_back(new test_timestep_embedding());
test_cases.emplace_back(new test_leaky_relu());
test_cases.emplace_back(new test_cumsum());
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 10, 5, 4, 3 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 127, 5, 4, 3 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 5, 4, 3 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 255, 5, 4, 3 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 256, 5, 4, 3 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 511, 5, 4, 3 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 512, 5, 4, 3 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1023, 5, 4, 3 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1024, 5, 4, 3 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2047, 5, 4, 3 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 5, 4, 3 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 201*1204, 1, 1, 1 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 312*1205, 1, 1, 1 }));
test_cases.emplace_back(new test_xielu());
@@ -7645,6 +7670,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256,
use_id, 16, 8, b, with_bias, with_gate));
test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256,
use_id, 16, 8, b, with_bias, with_gate, {1, 1}));
}
}
}
+1 -1
View File
@@ -285,7 +285,7 @@ int main(int argc, char ** argv) {
}
mtmd_cli_context ctx(params);
LOG("%s: loading model: %s\n", __func__, params.model.path.c_str());
LOG_INF("%s: loading model: %s\n", __func__, params.model.path.c_str());
bool is_single_turn = !params.prompt.empty() && !params.image.empty();
+8 -1
View File
@@ -13,7 +13,14 @@ endif()
set(TARGET_SRCS
server.cpp
utils.hpp
server-http.cpp
server-http.h
server-task.cpp
server-task.h
server-queue.cpp
server-queue.h
server-common.cpp
server-common.h
)
set(PUBLIC_ASSETS
index.html.gz
Binary file not shown.
File diff suppressed because it is too large Load Diff
+349
View File
@@ -0,0 +1,349 @@
#pragma once
#include "common.h"
#include "log.h"
#include "llama.h"
#include "chat.h"
#include "mtmd.h"
#define JSON_ASSERT GGML_ASSERT
#include <nlohmann/json.hpp>
#include <string>
#include <vector>
#include <cinttypes>
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo"
const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
using json = nlohmann::ordered_json;
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
using raw_buffer = std::vector<uint8_t>;
template <typename T>
static T json_value(const json & body, const std::string & key, const T & default_value) {
// Fallback null to default value
if (body.contains(key) && !body.at(key).is_null()) {
try {
return body.at(key);
} catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const & err) {
LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value: %s\n", key.c_str(), json(default_value).type_name(), err.what());
return default_value;
}
} else {
return default_value;
}
}
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
enum error_type {
ERROR_TYPE_INVALID_REQUEST,
ERROR_TYPE_AUTHENTICATION,
ERROR_TYPE_SERVER,
ERROR_TYPE_NOT_FOUND,
ERROR_TYPE_PERMISSION,
ERROR_TYPE_UNAVAILABLE, // custom error
ERROR_TYPE_NOT_SUPPORTED, // custom error
ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error
};
// thin wrapper around common_grammar_trigger with (de)serialization functions
struct server_grammar_trigger {
common_grammar_trigger value;
server_grammar_trigger() = default;
server_grammar_trigger(const common_grammar_trigger & value) : value(value) {}
server_grammar_trigger(const json & in) {
value.type = (common_grammar_trigger_type) in.at("type").get<int>();
value.value = in.at("value").get<std::string>();
if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
value.token = (llama_token) in.at("token").get<int>();
}
}
json to_json() const {
json out {
{"type", (int) value.type},
{"value", value.value},
};
if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
out["token"] = (int) value.token;
}
return out;
}
};
json format_error_response(const std::string & message, const enum error_type type);
//
// random string / id
//
std::string random_string();
std::string gen_chatcmplid();
std::string gen_tool_call_id();
//
// lora utils
//
// check whether the given lora set has only aloras activated (empty => false)
bool lora_all_alora(const std::vector<common_adapter_lora_info> & loras);
// if the two sets of loras are different, they require a cache clear unless the
// change is only from aloras to aloras.
bool lora_should_clear_cache(
const std::vector<common_adapter_lora_info> & current,
const std::vector<common_adapter_lora_info> & next);
std::vector<common_adapter_lora_info> parse_lora_request(
const std::vector<common_adapter_lora_info> & lora_base,
const json & data);
bool are_lora_equal(
const std::vector<common_adapter_lora_info> & l1,
const std::vector<common_adapter_lora_info> & l2);
// get the ids of all enabled loras
std::vector<size_t> lora_get_enabled_ids(const std::vector<common_adapter_lora_info> & loras);
//
// server_tokens
//
/**
* server_tokens is a helper to manage the input tokens and image for the server.
* it is made this way to simplify the logic of KV cache management.
*/
struct server_tokens {
bool has_mtmd = false;
private: // disallow accessing these members directly, risking out-of-sync
// map a **start** index in tokens to the image chunk
// note: the order need to be in-sync with tokens
std::map<size_t, mtmd::input_chunk_ptr> map_idx_to_media;
// list of tokens
// if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk
// otherwise, it is a normal text token
// note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list
// note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos
llama_tokens tokens;
// for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
// idx 0 1 2 3 4 5 6 7 8 9 10
// pos 0 1 2 3 4 5 5 5 7 7 7
// map_idx_to_media will contain: {5, img0}, {8, img1}
public:
server_tokens() = default;
~server_tokens() = default;
// Prevent copying
// TODO: server_tokens should be copyable - remove this:
server_tokens(const server_tokens&) = delete;
server_tokens& operator=(const server_tokens&) = delete;
// Allow moving (usually implicitly generated if members are movable)
server_tokens(server_tokens&&) = default;
server_tokens& operator=(server_tokens&&) = default;
// Allow accessing elements using [] operator
llama_token operator[](size_t index) { return tokens[index]; }
const llama_token& operator[](size_t index) const { return tokens[index]; }
server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd);
server_tokens(const llama_tokens & tokens, bool has_mtmd);
// for debugging
std::string str() const;
llama_pos pos_next() const;
const mtmd::input_chunk_ptr & find_chunk(size_t idx) const;
void push_back(llama_token tok);
// will create a copy of the chunk if it contains non-text data
void push_back(const mtmd_input_chunk * chunk);
// appends server tokens, updates the media map. copies media chunks.
void push_back(server_tokens & tokens);
// for compatibility with context shift and prompt truncation
void insert(const llama_tokens & inp_tokens);
// for compatibility with speculative decoding, ctx shift, slot save/load
const llama_tokens & get_text_tokens() const;
// for compatibility with speculative decoding
void set_token(llama_pos pos, llama_token id);
size_t size() const { return tokens.size(); }
bool empty() const { return tokens.empty(); }
void clear() {
map_idx_to_media.clear();
tokens.clear();
}
void keep_first(size_t n);
std::string detokenize(const llama_context * ctx, bool special) const;
size_t get_common_prefix(const server_tokens & b) const;
// make sure all text tokens are within the vocab range
bool validate(const struct llama_context * ctx) const;
// encode and decode the image chunk
int32_t process_chunk(
llama_context * ctx,
mtmd_context * mctx,
size_t idx,
llama_pos pos,
int32_t seq_id,
size_t & n_tokens_out) const;
};
//
// tokenizer and input processing utils
//
bool json_is_array_of_numbers(const json & data);
// is array having BOTH numbers & strings?
bool json_is_array_of_mixed_numbers_strings(const json & data);
// does array have any individual integers/tokens?
bool json_is_array_and_contains_numbers(const json & data);
// get value by path(key1 / key2)
json json_get_nested_values(const std::vector<std::string> & paths, const json & js);
/**
* this handles 2 cases:
* - only string, example: "string"
* - mixed string and tokens, example: [12, 34, "string", 56, 78]
*/
llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special);
// return the last index of character that can form a valid string
// if the last character is potentially cut in half, return the index before the cut
// if validate_utf8(text) == text.size(), then the whole text is valid utf8
size_t validate_utf8(const std::string& text);
// process mtmd prompt, return the server_tokens containing both text tokens and media chunks
server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector<raw_buffer> files);
/**
* break the input "prompt" object into multiple prompt if needed, then tokenize them
* this supports these cases:
* - "prompt": "string"
* - "prompt": [12, 34, 56]
* - "prompt": [12, 34, "string", 56, 78]
* - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] }
* and multiple prompts (multi-tasks):
* - "prompt": ["string1", "string2"]
* - "prompt": ["string1", [12, 34, 56]]
* - "prompt": [[12, 34, 56], [78, 90, 12]]
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}]
*/
std::vector<server_tokens> tokenize_input_prompts(
const llama_vocab * vocab,
mtmd_context * mctx,
const json & json_prompt,
bool add_special,
bool parse_special);
//
// OAI utils
//
// used by /completions endpoint
json oaicompat_completion_params_parse(const json & body);
struct oaicompat_parser_options {
bool use_jinja;
bool prefill_assistant;
common_reasoning_format reasoning_format;
std::map<std::string,std::string> chat_template_kwargs;
common_chat_templates * tmpls;
bool allow_image;
bool allow_audio;
bool enable_thinking = true;
};
// used by /chat/completions endpoint
json oaicompat_chat_params_parse(
json & body, /* openai api json semantics */
const oaicompat_parser_options & opt,
std::vector<raw_buffer> & out_files);
// TODO: move it to server-task.cpp
json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false);
// TODO: move it to server-task.cpp
json format_response_rerank(
const json & request,
const json & ranks,
bool is_tei_format,
std::vector<std::string> & texts,
int top_n);
//
// other utils
//
std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx);
std::string safe_json_to_str(const json & data);
std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens);
// format incomplete utf-8 multibyte character for output
std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token);
// format server-sent event (SSE), return the formatted string to send
// note: if data is a json array, it will be sent as multiple events, one per item
std::string format_sse(const json & data);
bool is_valid_utf8(const std::string & str);
//
// formatting output responses
// TODO: move these to server-task.cpp
//
llama_tokens format_prompt_infill(
const llama_vocab * vocab,
const json & input_prefix,
const json & input_suffix,
const json & input_extra,
const int n_batch,
const int n_predict,
const int n_ctx,
const bool spm_infill,
const llama_tokens & tokens_prompt);
// format rerank task: [BOS]query[EOS][SEP]doc[EOS].
server_tokens format_prompt_rerank(
const struct llama_model * model,
const struct llama_vocab * vocab,
mtmd_context * mctx,
const std::string & query,
const std::string & doc);
+387
View File
@@ -0,0 +1,387 @@
#include "common.h"
#include "server-http.h"
#include "server-common.h"
#include <cpp-httplib/httplib.h>
#include <functional>
#include <string>
#include <thread>
// auto generated files (see README.md for details)
#include "index.html.gz.hpp"
#include "loading.html.hpp"
//
// HTTP implementation using cpp-httplib
//
class server_http_context::Impl {
public:
std::unique_ptr<httplib::Server> srv;
};
server_http_context::server_http_context()
: pimpl(std::make_unique<server_http_context::Impl>())
{}
server_http_context::~server_http_context() = default;
static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
// skip GH copilot requests when using default port
if (req.path == "/v1/health") {
return;
}
// reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch
SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
SRV_DBG("request: %s\n", req.body.c_str());
SRV_DBG("response: %s\n", res.body.c_str());
}
bool server_http_context::init(const common_params & params) {
path_prefix = params.api_prefix;
port = params.port;
hostname = params.hostname;
auto & srv = pimpl->srv;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
srv.reset(
new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str())
);
} else {
LOG_INF("Running without SSL\n");
srv.reset(new httplib::Server());
}
#else
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
LOG_ERR("Server is built without SSL support\n");
return false;
}
srv.reset(new httplib::Server());
#endif
srv->set_default_headers({{"Server", "llama.cpp"}});
srv->set_logger(log_server_request);
srv->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) {
// this is fail-safe; exceptions should already handled by `ex_wrapper`
std::string message;
try {
std::rethrow_exception(ep);
} catch (const std::exception & e) {
message = e.what();
} catch (...) {
message = "Unknown Exception";
}
res.status = 500;
res.set_content(message, "text/plain");
LOG_ERR("got exception: %s\n", message.c_str());
});
srv->set_error_handler([](const httplib::Request &, httplib::Response & res) {
if (res.status == 404) {
res.set_content(
safe_json_to_str(json {
{"error", {
{"message", "File Not Found"},
{"type", "not_found_error"},
{"code", 404}
}}
}),
"application/json; charset=utf-8"
);
}
// for other error codes, we skip processing here because it's already done by res->error()
});
// set timeouts and change hostname and port
srv->set_read_timeout (params.timeout_read);
srv->set_write_timeout(params.timeout_write);
if (params.api_keys.size() == 1) {
auto key = params.api_keys[0];
std::string substr = key.substr(std::max((int)(key.length() - 4), 0));
LOG_INF("%s: api_keys: ****%s\n", __func__, substr.c_str());
} else if (params.api_keys.size() > 1) {
LOG_INF("%s: api_keys: %zu keys loaded\n", __func__, params.api_keys.size());
}
//
// Middlewares
//
auto middleware_validate_api_key = [api_keys = params.api_keys](const httplib::Request & req, httplib::Response & res) {
static const std::unordered_set<std::string> public_endpoints = {
"/health",
"/v1/health",
"/models",
"/v1/models",
"/api/tags"
};
// If API key is not set, skip validation
if (api_keys.empty()) {
return true;
}
// If path is public or is static file, skip validation
if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") {
return true;
}
// Check for API key in the header
auto auth_header = req.get_header_value("Authorization");
std::string prefix = "Bearer ";
if (auth_header.substr(0, prefix.size()) == prefix) {
std::string received_api_key = auth_header.substr(prefix.size());
if (std::find(api_keys.begin(), api_keys.end(), received_api_key) != api_keys.end()) {
return true; // API key is valid
}
}
// API key is invalid or not provided
res.status = 401;
res.set_content(
safe_json_to_str(json {
{"error", {
{"message", "Invalid API Key"},
{"type", "authentication_error"},
{"code", 401}
}}
}),
"application/json; charset=utf-8"
);
LOG_WRN("Unauthorized: Invalid API Key\n");
return false;
};
auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) {
bool ready = is_ready.load();
if (!ready) {
auto tmp = string_split<std::string>(req.path, '.');
if (req.path == "/" || tmp.back() == "html") {
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
res.status = 503;
} else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") {
// allow the models endpoint to be accessed during loading
return true;
} else {
res.status = 503;
res.set_content(
safe_json_to_str(json {
{"error", {
{"message", "Loading model"},
{"type", "unavailable_error"},
{"code", 503}
}}
}),
"application/json; charset=utf-8"
);
}
return false;
}
return true;
};
// register server middlewares
srv->set_pre_routing_handler([middleware_validate_api_key, middleware_server_state](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
// If this is OPTIONS request, skip validation because browsers don't include Authorization header
if (req.method == "OPTIONS") {
res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Allow-Methods", "GET, POST");
res.set_header("Access-Control-Allow-Headers", "*");
res.set_content("", "text/html"); // blank response, no data
return httplib::Server::HandlerResponse::Handled; // skip further processing
}
if (!middleware_server_state(req, res)) {
return httplib::Server::HandlerResponse::Handled;
}
if (!middleware_validate_api_key(req, res)) {
return httplib::Server::HandlerResponse::Handled;
}
return httplib::Server::HandlerResponse::Unhandled;
});
int n_threads_http = params.n_threads_http;
if (n_threads_http < 1) {
// +2 threads for monitoring endpoints
n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
}
LOG_INF("%s: using %d threads for HTTP server\n", __func__, n_threads_http);
srv->new_task_queue = [n_threads_http] { return new httplib::ThreadPool(n_threads_http); };
//
// Web UI setup
//
if (!params.webui) {
LOG_INF("Web UI is disabled\n");
} else {
// register static assets routes
if (!params.public_path.empty()) {
// Set the base directory for serving static files
bool is_found = srv->set_mount_point(params.api_prefix + "/", params.public_path);
if (!is_found) {
LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
return 1;
}
} else {
// using embedded static index.html
srv->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) {
if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
res.set_content("Error: gzip is not supported by this browser", "text/plain");
} else {
res.set_header("Content-Encoding", "gzip");
// COEP and COOP headers, required by pyodide (python interpreter)
res.set_header("Cross-Origin-Embedder-Policy", "require-corp");
res.set_header("Cross-Origin-Opener-Policy", "same-origin");
res.set_content(reinterpret_cast<const char*>(index_html_gz), index_html_gz_len, "text/html; charset=utf-8");
}
return false;
});
}
}
return true;
}
bool server_http_context::start() {
// Bind and listen
auto & srv = pimpl->srv;
bool was_bound = false;
bool is_sock = false;
if (string_ends_with(std::string(hostname), ".sock")) {
is_sock = true;
LOG_INF("%s: setting address family to AF_UNIX\n", __func__);
srv->set_address_family(AF_UNIX);
// bind_to_port requires a second arg, any value other than 0 should
// simply get ignored
was_bound = srv->bind_to_port(hostname, 8080);
} else {
LOG_INF("%s: binding port with default address family\n", __func__);
// bind HTTP listen port
if (port == 0) {
int bound_port = srv->bind_to_any_port(hostname);
was_bound = (bound_port >= 0);
if (was_bound) {
port = bound_port;
}
} else {
was_bound = srv->bind_to_port(hostname, port);
}
}
if (!was_bound) {
LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, hostname.c_str(), port);
return false;
}
// run the HTTP server in a thread
thread = std::thread([this]() { pimpl->srv->listen_after_bind(); });
srv->wait_until_ready();
listening_address = is_sock ? string_format("unix://%s", hostname.c_str())
: string_format("http://%s:%d", hostname.c_str(), port);
return true;
}
void server_http_context::stop() const {
if (pimpl->srv) {
pimpl->srv->stop();
}
}
static void set_headers(httplib::Response & res, const std::map<std::string, std::string> & headers) {
for (const auto & [key, value] : headers) {
res.set_header(key, value);
}
}
static std::map<std::string, std::string> get_params(const httplib::Request & req) {
std::map<std::string, std::string> params;
for (const auto & [key, value] : req.params) {
params[key] = value;
}
for (const auto & [key, value] : req.path_params) {
params[key] = value;
}
return params;
}
static std::map<std::string, std::string> get_headers(const httplib::Request & req) {
std::map<std::string, std::string> headers;
for (const auto & [key, value] : req.headers) {
headers[key] = value;
}
return headers;
}
static void process_handler_response(server_http_res_ptr & response, httplib::Response & res) {
if (response->is_stream()) {
res.status = response->status;
set_headers(res, response->headers);
std::string content_type = response->content_type;
// convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it
std::shared_ptr<server_http_res> r_ptr = std::move(response);
const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool {
std::string chunk;
bool has_next = response->next(chunk);
if (!chunk.empty()) {
// TODO: maybe handle sink.write unsuccessful? for now, we rely on is_connection_closed()
sink.write(chunk.data(), chunk.size());
SRV_DBG("http: streamed chunk: %s\n", chunk.c_str());
}
if (!has_next) {
sink.done();
SRV_DBG("%s", "http: stream ended\n");
}
return has_next;
};
const auto on_complete = [response = r_ptr](bool) mutable {
response.reset(); // trigger the destruction of the response object
};
res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete);
} else {
res.status = response->status;
set_headers(res, response->headers);
res.set_content(response->data, response->content_type);
}
}
void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const {
pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
server_http_res_ptr response = handler(server_http_req{
get_params(req),
get_headers(req),
req.path,
req.body,
req.is_connection_closed
});
process_handler_response(response, res);
});
}
void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const {
pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
server_http_res_ptr response = handler(server_http_req{
get_params(req),
get_headers(req),
req.path,
req.body,
req.is_connection_closed
});
process_handler_response(response, res);
});
}
+78
View File
@@ -0,0 +1,78 @@
#pragma once
#include <atomic>
#include <functional>
#include <map>
#include <string>
#include <thread>
struct common_params;
// generator-like API for HTTP response generation
// this object response with one of the 2 modes:
// 1) normal response: `data` contains the full response body
// 2) streaming response: each call to next(output) generates the next chunk
// when next(output) returns false, no more data after the current chunk
// note: some chunks can be empty, in which case no data is sent for that chunk
struct server_http_res {
std::string content_type = "application/json; charset=utf-8";
int status = 200;
std::string data;
std::map<std::string, std::string> headers;
// TODO: move this to a virtual function once we have proper polymorphism support
std::function<bool(std::string &)> next = nullptr;
bool is_stream() const {
return next != nullptr;
}
virtual ~server_http_res() = default;
};
// unique pointer, used by set_chunked_content_provider
// httplib requires the stream provider to be stored in heap
using server_http_res_ptr = std::unique_ptr<server_http_res>;
struct server_http_req {
std::map<std::string, std::string> params; // path_params + query_params
std::map<std::string, std::string> headers; // reserved for future use
std::string path; // reserved for future use
std::string body;
const std::function<bool()> & should_stop;
std::string get_param(const std::string & key, const std::string & def = "") const {
auto it = params.find(key);
if (it != params.end()) {
return it->second;
}
return def;
}
};
struct server_http_context {
class Impl;
std::unique_ptr<Impl> pimpl;
std::thread thread; // server thread
std::atomic<bool> is_ready = false;
std::string path_prefix;
std::string hostname;
int port;
server_http_context();
~server_http_context();
bool init(const common_params & params);
bool start();
void stop() const;
// note: the handler should never throw exceptions
using handler_t = std::function<server_http_res_ptr(const server_http_req & req)>;
void get(const std::string & path, const handler_t & handler) const;
void post(const std::string & path, const handler_t & handler) const;
// for debugging
std::string listening_address;
};
+268
View File
@@ -0,0 +1,268 @@
#include "server-task.h"
#include "server-queue.h"
#include "log.h"
#include <chrono>
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define RES_INF(fmt, ...) LOG_INF("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define RES_WRN(fmt, ...) LOG_WRN("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define RES_ERR(fmt, ...) LOG_ERR("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define RES_DBG(fmt, ...) LOG_DBG("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
//
// server_queue
//
int server_queue::post(server_task && task, bool front) {
std::unique_lock<std::mutex> lock(mutex_tasks);
GGML_ASSERT(task.id != -1);
// if this is cancel task make sure to clean up pending tasks
if (task.type == SERVER_TASK_TYPE_CANCEL) {
cleanup_pending_task(task.id_target);
}
const int task_id = task.id;
QUE_DBG("new task, id = %d, front = %d\n", task_id, front);
if (front) {
queue_tasks.push_front(std::move(task));
} else {
queue_tasks.push_back(std::move(task));
}
condition_tasks.notify_one();
return task_id;
}
int server_queue::post(std::vector<server_task> && tasks, bool front) {
std::unique_lock<std::mutex> lock(mutex_tasks);
for (auto & task : tasks) {
if (task.id == -1) {
task.id = id++;
}
// if this is cancel task make sure to clean up pending tasks
if (task.type == SERVER_TASK_TYPE_CANCEL) {
cleanup_pending_task(task.id_target);
}
QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
if (front) {
queue_tasks.push_front(std::move(task));
} else {
queue_tasks.push_back(std::move(task));
}
}
condition_tasks.notify_one();
return 0;
}
void server_queue::defer(server_task && task) {
std::unique_lock<std::mutex> lock(mutex_tasks);
QUE_DBG("defer task, id = %d\n", task.id);
queue_tasks_deferred.push_back(std::move(task));
condition_tasks.notify_one();
}
int server_queue::get_new_id() {
std::unique_lock<std::mutex> lock(mutex_tasks);
int new_id = id++;
return new_id;
}
void server_queue::on_new_task(std::function<void(server_task &&)> callback) {
callback_new_task = std::move(callback);
}
void server_queue::on_update_slots(std::function<void(void)> callback) {
callback_update_slots = std::move(callback);
}
void server_queue::pop_deferred_task() {
std::unique_lock<std::mutex> lock(mutex_tasks);
if (!queue_tasks_deferred.empty()) {
queue_tasks.emplace_front(std::move(queue_tasks_deferred.front()));
queue_tasks_deferred.pop_front();
}
condition_tasks.notify_one();
}
void server_queue::terminate() {
std::unique_lock<std::mutex> lock(mutex_tasks);
running = false;
condition_tasks.notify_all();
}
void server_queue::start_loop() {
running = true;
while (true) {
QUE_DBG("%s", "processing new tasks\n");
while (true) {
std::unique_lock<std::mutex> lock(mutex_tasks);
if (!running) {
QUE_DBG("%s", "terminate\n");
return;
}
if (queue_tasks.empty()) {
lock.unlock();
break;
}
server_task task = std::move(queue_tasks.front());
queue_tasks.pop_front();
lock.unlock();
QUE_DBG("processing task, id = %d\n", task.id);
callback_new_task(std::move(task));
}
// all tasks in the current loop is processed, slots data is now ready
QUE_DBG("%s", "update slots\n");
callback_update_slots();
QUE_DBG("%s", "waiting for new tasks\n");
{
std::unique_lock<std::mutex> lock(mutex_tasks);
if (!running) {
QUE_DBG("%s", "terminate\n");
return;
}
if (queue_tasks.empty()) {
condition_tasks.wait(lock, [&]{
return (!queue_tasks.empty() || !running);
});
}
}
}
}
void server_queue::cleanup_pending_task(int id_target) {
// no need lock because this is called exclusively by post()
auto rm_func = [id_target](const server_task & task) {
return task.id == id_target;
};
queue_tasks.erase(
std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func),
queue_tasks.end());
queue_tasks_deferred.erase(
std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
queue_tasks_deferred.end());
}
//
// server_response
//
void server_response::add_waiting_task_id(int id_task) {
RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
std::unique_lock<std::mutex> lock(mutex_results);
waiting_task_ids.insert(id_task);
}
void server_response::add_waiting_tasks(const std::vector<server_task> & tasks) {
std::unique_lock<std::mutex> lock(mutex_results);
for (const auto & task : tasks) {
RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size());
waiting_task_ids.insert(task.id);
}
}
void server_response::remove_waiting_task_id(int id_task) {
RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
std::unique_lock<std::mutex> lock(mutex_results);
waiting_task_ids.erase(id_task);
// make sure to clean up all pending results
queue_results.erase(
std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) {
return res->id == id_task;
}),
queue_results.end());
}
void server_response::remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
std::unique_lock<std::mutex> lock(mutex_results);
for (const auto & id_task : id_tasks) {
RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
waiting_task_ids.erase(id_task);
}
}
server_task_result_ptr server_response::recv(const std::unordered_set<int> & id_tasks) {
while (true) {
std::unique_lock<std::mutex> lock(mutex_results);
condition_results.wait(lock, [&]{
if (!running) {
RES_DBG("%s : queue result stop\n", __func__);
std::terminate(); // we cannot return here since the caller is HTTP code
}
return !queue_results.empty();
});
for (size_t i = 0; i < queue_results.size(); i++) {
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
server_task_result_ptr res = std::move(queue_results[i]);
queue_results.erase(queue_results.begin() + i);
return res;
}
}
}
// should never reach here
}
server_task_result_ptr server_response::recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) {
while (true) {
std::unique_lock<std::mutex> lock(mutex_results);
for (int i = 0; i < (int) queue_results.size(); i++) {
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
server_task_result_ptr res = std::move(queue_results[i]);
queue_results.erase(queue_results.begin() + i);
return res;
}
}
std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout));
if (!running) {
RES_DBG("%s : queue result stop\n", __func__);
std::terminate(); // we cannot return here since the caller is HTTP code
}
if (cr_res == std::cv_status::timeout) {
return nullptr;
}
}
// should never reach here
}
server_task_result_ptr server_response::recv(int id_task) {
std::unordered_set<int> id_tasks = {id_task};
return recv(id_tasks);
}
void server_response::send(server_task_result_ptr && result) {
RES_DBG("sending result for task id = %d\n", result->id);
std::unique_lock<std::mutex> lock(mutex_results);
for (const auto & id_task : waiting_task_ids) {
if (result->id == id_task) {
RES_DBG("task id = %d pushed to result queue\n", result->id);
queue_results.emplace_back(std::move(result));
condition_results.notify_all();
return;
}
}
}
void server_response::terminate() {
running = false;
condition_results.notify_all();
}
+110
View File
@@ -0,0 +1,110 @@
#pragma once
#include "server-task.h"
#include <condition_variable>
#include <deque>
#include <mutex>
#include <unordered_set>
struct server_queue {
private:
int id = 0;
bool running;
// queues
std::deque<server_task> queue_tasks;
std::deque<server_task> queue_tasks_deferred;
std::mutex mutex_tasks;
std::condition_variable condition_tasks;
// callback functions
std::function<void(server_task &&)> callback_new_task;
std::function<void(void)> callback_update_slots;
public:
// Add a new task to the end of the queue
int post(server_task && task, bool front = false);
// multi-task version of post()
int post(std::vector<server_task> && tasks, bool front = false);
// Add a new task, but defer until one slot is available
void defer(server_task && task);
// Get the next id for creating a new task
int get_new_id();
// Register function to process a new task
void on_new_task(std::function<void(server_task &&)> callback);
// Register the function to be called when all slots data is ready to be processed
void on_update_slots(std::function<void(void)> callback);
// Call when the state of one slot is changed, it will move one task from deferred to main queue
void pop_deferred_task();
// end the start_loop routine
void terminate();
/**
* Main loop consists of these steps:
* - Wait until a new task arrives
* - Process the task (i.e. maybe copy data into slot)
* - Check if multitask is finished
* - Update all slots
*/
void start_loop();
// for metrics
size_t queue_tasks_deferred_size() {
std::unique_lock<std::mutex> lock(mutex_tasks);
return queue_tasks_deferred.size();
}
private:
void cleanup_pending_task(int id_target);
};
struct server_response {
private:
bool running = true;
// for keeping track of all tasks waiting for the result
std::unordered_set<int> waiting_task_ids;
// the main result queue (using ptr for polymorphism)
std::vector<server_task_result_ptr> queue_results;
std::mutex mutex_results;
std::condition_variable condition_results;
public:
// add the id_task to the list of tasks waiting for response
void add_waiting_task_id(int id_task);
void add_waiting_tasks(const std::vector<server_task> & tasks);
// when the request is finished, we can remove task associated with it
void remove_waiting_task_id(int id_task);
// remove multiple tasks from waiting list
void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks);
// This function blocks the thread until there is a response for one of the id_tasks
server_task_result_ptr recv(const std::unordered_set<int> & id_tasks);
// same as recv(), but have timeout in seconds
// if timeout is reached, nullptr is returned
server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout);
// single-task version of recv()
server_task_result_ptr recv(int id_task);
// Send a new result to a waiting id_task
void send(server_task_result_ptr && result);
// terminate the waiting loop
void terminate();
};
File diff suppressed because it is too large Load Diff
+453
View File
@@ -0,0 +1,453 @@
#pragma once
#include "common.h"
#include "llama.h"
#include <string>
#include <unordered_set>
#include <list>
// TODO: prevent including the whole server-common.h as we only use server_tokens
#include "server-common.h"
using json = nlohmann::ordered_json;
enum server_task_type {
SERVER_TASK_TYPE_COMPLETION,
SERVER_TASK_TYPE_EMBEDDING,
SERVER_TASK_TYPE_RERANK,
SERVER_TASK_TYPE_INFILL,
SERVER_TASK_TYPE_CANCEL,
SERVER_TASK_TYPE_NEXT_RESPONSE,
SERVER_TASK_TYPE_METRICS,
SERVER_TASK_TYPE_SLOT_SAVE,
SERVER_TASK_TYPE_SLOT_RESTORE,
SERVER_TASK_TYPE_SLOT_ERASE,
SERVER_TASK_TYPE_SET_LORA,
};
// TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common
enum oaicompat_type {
OAICOMPAT_TYPE_NONE,
OAICOMPAT_TYPE_CHAT,
OAICOMPAT_TYPE_COMPLETION,
OAICOMPAT_TYPE_EMBEDDING,
};
enum stop_type {
STOP_TYPE_NONE,
STOP_TYPE_EOS,
STOP_TYPE_WORD,
STOP_TYPE_LIMIT,
};
struct task_params {
bool stream = true;
bool include_usage = false;
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
bool return_tokens = false;
bool return_progress = false;
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict
int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
std::vector<common_adapter_lora_info> lora;
std::vector<std::string> antiprompt;
std::vector<std::string> response_fields;
bool timings_per_token = false;
bool post_sampling_probs = false;
struct common_params_sampling sampling;
struct common_params_speculative speculative;
// OAI-compat fields
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_syntax oaicompat_chat_syntax;
// Embeddings
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) const;
json to_json(bool only_metrics = false) const;
};
struct server_task {
int id = -1; // to be filled by server_queue
int index = -1; // used when there are multiple prompts (batch request)
// used by SERVER_TASK_TYPE_CANCEL
int id_target = -1;
int id_slot = -1;
// used by SERVER_TASK_TYPE_INFERENCE
task_params params;
server_tokens tokens;
server_task_type type;
// used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
struct slot_action {
int slot_id;
std::string filename;
std::string filepath;
};
slot_action slot_action;
// used by SERVER_TASK_TYPE_METRICS
bool metrics_reset_bucket = false;
// used by SERVER_TASK_TYPE_SET_LORA
std::vector<common_adapter_lora_info> set_lora;
server_task() = default;
server_task(server_task_type type) : type(type) {}
int32_t n_tokens() const {
return tokens.size();
}
static task_params params_from_json_cmpl(
const llama_context * ctx,
const common_params & params_base,
const json & data);
// utility function
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
std::unordered_set<int> ids(tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
ids.insert(tasks[i].id);
}
return ids;
}
};
struct result_timings {
int32_t cache_n = -1;
int32_t prompt_n = -1;
double prompt_ms;
double prompt_per_token_ms;
double prompt_per_second;
int32_t predicted_n = -1;
double predicted_ms;
double predicted_per_token_ms;
double predicted_per_second;
// Optional speculative metrics - only included when > 0
int32_t draft_n = 0;
int32_t draft_n_accepted = 0;
json to_json() const;
};
struct result_prompt_progress {
int32_t total = 0;
int32_t cache = 0;
int32_t processed = 0;
int64_t time_ms = 0;
json to_json() const;
};
struct server_task_result {
int id = -1;
int id_slot = -1;
virtual bool is_error() {
// only used by server_task_result_error
return false;
}
virtual bool is_stop() {
// only used by server_task_result_cmpl_*
return true;
}
virtual int get_index() {
return -1;
}
virtual json to_json() = 0;
virtual ~server_task_result() = default;
};
// using shared_ptr for polymorphism of server_task_result
using server_task_result_ptr = std::unique_ptr<server_task_result>;
struct completion_token_output {
llama_token tok;
float prob;
std::string text_to_send;
struct prob_info {
llama_token tok;
std::string txt;
float prob;
};
std::vector<prob_info> probs;
json to_json(bool post_sampling_probs) const;
static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs);
static float logarithm(float x);
static std::vector<unsigned char> str_to_bytes(const std::string & str);
};
struct server_task_result_cmpl_final : server_task_result {
int index = 0;
std::string content;
llama_tokens tokens;
bool stream;
bool include_usage;
result_timings timings;
std::string prompt;
bool truncated;
int32_t n_decoded;
int32_t n_prompt_tokens;
int32_t n_tokens_cached;
bool has_new_line;
std::string stopping_word;
stop_type stop = STOP_TYPE_NONE;
bool post_sampling_probs;
std::vector<completion_token_output> probs_output;
std::vector<std::string> response_fields;
task_params generation_params;
// OAI-compat fields
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_msg oaicompat_msg;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
virtual int get_index() override {
return index;
}
virtual bool is_stop() override {
return true; // in stream mode, final responses are considered stop
}
virtual json to_json() override;
json to_json_non_oaicompat();
json to_json_oaicompat();
json to_json_oaicompat_chat();
json to_json_oaicompat_chat_stream();
};
struct server_task_result_cmpl_partial : server_task_result {
int index = 0;
std::string content;
llama_tokens tokens;
int32_t n_decoded;
int32_t n_prompt_tokens;
bool post_sampling_probs;
bool is_progress = false;
completion_token_output prob_output;
result_timings timings;
result_prompt_progress progress;
// OAI-compat fields
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
virtual int get_index() override {
return index;
}
virtual bool is_stop() override {
return false; // in stream mode, partial responses are not considered stop
}
virtual json to_json() override;
json to_json_non_oaicompat();
json to_json_oaicompat();
json to_json_oaicompat_chat();
};
struct server_task_result_embd : server_task_result {
int index = 0;
std::vector<std::vector<float>> embedding;
int32_t n_tokens;
// OAI-compat fields
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
virtual int get_index() override {
return index;
}
virtual json to_json() override;
json to_json_non_oaicompat();
json to_json_oaicompat();
};
struct server_task_result_rerank : server_task_result {
int index = 0;
float score = -1e6;
int32_t n_tokens;
virtual int get_index() override {
return index;
}
virtual json to_json() override;
};
struct server_task_result_error : server_task_result {
int index = 0;
error_type err_type = ERROR_TYPE_SERVER;
std::string err_msg;
// for ERROR_TYPE_EXCEED_CONTEXT_SIZE
int32_t n_prompt_tokens = 0;
int32_t n_ctx = 0;
virtual bool is_error() override {
return true;
}
virtual json to_json() override;
};
struct server_task_result_metrics : server_task_result {
int n_idle_slots;
int n_processing_slots;
int n_tasks_deferred;
int64_t t_start;
// TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
uint64_t n_prompt_tokens_processed_total = 0;
uint64_t t_prompt_processing_total = 0;
uint64_t n_tokens_predicted_total = 0;
uint64_t t_tokens_generation_total = 0;
uint64_t n_tokens_max = 0;
uint64_t n_prompt_tokens_processed = 0;
uint64_t t_prompt_processing = 0;
uint64_t n_tokens_predicted = 0;
uint64_t t_tokens_generation = 0;
uint64_t n_decode_total = 0;
uint64_t n_busy_slots_total = 0;
// while we can also use std::vector<server_slot> this requires copying the slot object which can be quite messy
// therefore, we use json to temporarily store the slot.to_json() result
json slots_data = json::array();
virtual json to_json() override;
};
struct server_task_result_slot_save_load : server_task_result {
std::string filename;
bool is_save; // true = save, false = load
size_t n_tokens;
size_t n_bytes;
double t_ms;
virtual json to_json() override;
};
struct server_task_result_slot_erase : server_task_result {
size_t n_erased;
virtual json to_json() override;
};
struct server_task_result_apply_lora : server_task_result {
virtual json to_json() override;
};
struct server_prompt_checkpoint {
llama_pos pos_min;
llama_pos pos_max;
std::vector<uint8_t> data;
size_t size() const {
return data.size();
}
};
struct server_prompt {
server_tokens tokens;
std::vector<uint8_t> data;
std::list<server_prompt_checkpoint> checkpoints;
size_t size() const {
size_t res = data.size();
for (const auto & checkpoint : checkpoints) {
res += checkpoint.size();
}
return res;
}
int n_tokens() const {
return tokens.size();
}
};
struct server_prompt_cache {
server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) {
this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib);
this->limit_tokens = limit_tokens;
}
std::list<server_prompt> states;
// in bytes, 0 = no limit
size_t limit_size = 0;
// in tokens, 0 = no limit
size_t limit_tokens = 0;
size_t size() const;
size_t n_tokens() const;
server_prompt * alloc(const server_prompt & prompt, size_t state_size);
bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot);
void update();
};
+800 -2796
View File
File diff suppressed because it is too large Load Diff
@@ -72,12 +72,6 @@
}
}
function handleScroll() {
if (isOpen) {
updateMenuPosition();
}
}
async function handleSelect(value: string | undefined) {
if (!value) return;
@@ -259,7 +253,7 @@
}
</script>
<svelte:window onresize={handleResize} onscroll={handleScroll} />
<svelte:window onresize={handleResize} />
<svelte:document onpointerdown={handlePointerDown} onkeydown={handleKeydown} />
@@ -2,6 +2,7 @@
import { getDeletionInfo } from '$lib/stores/chat.svelte';
import { copyToClipboard } from '$lib/utils/copy';
import { isIMEComposing } from '$lib/utils/is-ime-composing';
import type { ApiChatCompletionToolCall } from '$lib/types/api';
import ChatMessageAssistant from './ChatMessageAssistant.svelte';
import ChatMessageUser from './ChatMessageUser.svelte';
@@ -54,6 +55,29 @@
return null;
});
let toolCallContent = $derived.by((): ApiChatCompletionToolCall[] | string | null => {
if (message.role === 'assistant') {
const trimmedToolCalls = message.toolCalls?.trim();
if (!trimmedToolCalls) {
return null;
}
try {
const parsed = JSON.parse(trimmedToolCalls);
if (Array.isArray(parsed)) {
return parsed as ApiChatCompletionToolCall[];
}
} catch {
// Harmony-only path: fall back to the raw string so issues surface visibly.
}
return trimmedToolCalls;
}
return null;
});
function handleCancelEdit() {
isEditing = false;
editedContent = message.content;
@@ -171,5 +195,6 @@
{showDeleteDialog}
{siblingInfo}
{thinkingContent}
{toolCallContent}
/>
{/if}
@@ -11,7 +11,8 @@
Gauge,
Clock,
WholeWord,
ChartNoAxesColumn
ChartNoAxesColumn,
Wrench
} from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import { Checkbox } from '$lib/components/ui/checkbox';
@@ -21,6 +22,7 @@
import { config } from '$lib/stores/settings.svelte';
import { modelName as serverModelName } from '$lib/stores/server.svelte';
import { copyToClipboard } from '$lib/utils/copy';
import type { ApiChatCompletionToolCall } from '$lib/types/api';
interface Props {
class?: string;
@@ -51,6 +53,7 @@
siblingInfo?: ChatMessageSiblingInfo | null;
textareaElement?: HTMLTextAreaElement;
thinkingContent: string | null;
toolCallContent: ApiChatCompletionToolCall[] | string | null;
}
let {
@@ -76,9 +79,15 @@
shouldBranchAfterEdit = false,
siblingInfo = null,
textareaElement = $bindable(),
thinkingContent
thinkingContent,
toolCallContent = null
}: Props = $props();
const toolCalls = $derived(
Array.isArray(toolCallContent) ? (toolCallContent as ApiChatCompletionToolCall[]) : null
);
const fallbackToolCalls = $derived(typeof toolCallContent === 'string' ? toolCallContent : null);
const processingState = useProcessingState();
let currentConfig = $derived(config());
let serverModel = $derived(serverModelName());
@@ -97,6 +106,58 @@
void copyToClipboard(model ?? '');
}
function formatToolCallBadge(toolCall: ApiChatCompletionToolCall, index: number) {
const callNumber = index + 1;
const functionName = toolCall.function?.name?.trim();
const label = functionName || `Call #${callNumber}`;
const payload: Record<string, unknown> = {};
const id = toolCall.id?.trim();
if (id) {
payload.id = id;
}
const type = toolCall.type?.trim();
if (type) {
payload.type = type;
}
if (toolCall.function) {
const fnPayload: Record<string, unknown> = {};
const name = toolCall.function.name?.trim();
if (name) {
fnPayload.name = name;
}
const rawArguments = toolCall.function.arguments?.trim();
if (rawArguments) {
try {
fnPayload.arguments = JSON.parse(rawArguments);
} catch {
fnPayload.arguments = rawArguments;
}
}
if (Object.keys(fnPayload).length > 0) {
payload.function = fnPayload;
}
}
const formattedPayload = JSON.stringify(payload, null, 2);
return {
label,
tooltip: formattedPayload,
copyValue: formattedPayload
};
}
function handleCopyToolCall(payload: string) {
void copyToClipboard(payload, 'Tool call copied to clipboard');
}
</script>
<div
@@ -189,6 +250,47 @@
</span>
{/if}
{#if config().showToolCalls}
{#if (toolCalls && toolCalls.length > 0) || fallbackToolCalls}
<span class="inline-flex flex-wrap items-center gap-2 text-xs text-muted-foreground">
<span class="inline-flex items-center gap-1">
<Wrench class="h-3.5 w-3.5" />
<span>Tool calls:</span>
</span>
{#if toolCalls && toolCalls.length > 0}
{#each toolCalls as toolCall, index (toolCall.id ?? `${index}`)}
{@const badge = formatToolCallBadge(toolCall, index)}
<button
type="button"
class="tool-call-badge inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
title={badge.tooltip}
aria-label={`Copy tool call ${badge.label}`}
onclick={() => handleCopyToolCall(badge.copyValue)}
>
{badge.label}
<Copy class="ml-1 h-3 w-3" />
</button>
{/each}
{:else if fallbackToolCalls}
<button
type="button"
class="tool-call-badge tool-call-badge--fallback inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
title={fallbackToolCalls}
aria-label="Copy tool call payload"
onclick={() => handleCopyToolCall(fallbackToolCalls)}
>
{fallbackToolCalls}
<Copy class="ml-1 h-3 w-3" />
</button>
{/if}
</span>
{/if}
{/if}
{#if currentConfig.showMessageStats && message.timings && message.timings.predicted_n && message.timings.predicted_ms}
{@const tokensPerSecond = (message.timings.predicted_n / message.timings.predicted_ms) * 1000}
<span class="inline-flex items-center gap-2 text-xs text-muted-foreground">
@@ -287,4 +389,17 @@
white-space: pre-wrap;
word-break: break-word;
}
.tool-call-badge {
max-width: 12rem;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
.tool-call-badge--fallback {
max-width: 20rem;
white-space: normal;
word-break: break-word;
}
</style>
@@ -76,10 +76,10 @@
});
</script>
<div class="chat-processing-info-container" class:visible={showSlotsInfo}>
<div class="chat-processing-info-container pointer-events-none" class:visible={showSlotsInfo}>
<div class="chat-processing-info-content">
{#each processingDetails as detail (detail)}
<span class="chat-processing-info-detail">{detail}</span>
<span class="chat-processing-info-detail pointer-events-auto">{detail}</span>
{/each}
</div>
</div>
@@ -92,7 +92,6 @@
padding: 1.5rem 1rem;
opacity: 0;
transform: translateY(50%);
pointer-events: none;
transition:
opacity 300ms ease-out,
transform 300ms ease-out;
@@ -100,7 +99,6 @@
.chat-processing-info-container.visible {
opacity: 1;
pointer-events: auto;
transform: translateY(0);
}
@@ -226,6 +226,11 @@
label: 'Enable model selector',
type: 'checkbox'
},
{
key: 'showToolCalls',
label: 'Show tool call labels',
type: 'checkbox'
},
{
key: 'disableReasoningFormat',
label: 'Show raw LLM output',
@@ -6,6 +6,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
theme: 'system',
showTokensPerSecond: false,
showThoughtInProgress: false,
showToolCalls: false,
disableReasoningFormat: false,
keepStatsVisible: false,
showMessageStats: true,
@@ -80,6 +81,8 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
custom: 'Custom JSON parameters to send to the API. Must be valid JSON format.',
showTokensPerSecond: 'Display generation speed in tokens per second during streaming.',
showThoughtInProgress: 'Expand thought process by default when generating messages.',
showToolCalls:
'Display tool call labels and payloads from Harmony-compatible delta.tool_calls data below assistant messages.',
disableReasoningFormat:
'Show raw LLM output without backend parsing and frontend Markdown rendering to inspect streaming across different models.',
keepStatsVisible: 'Keep processing statistics visible after generation finishes.',
+161 -7
View File
@@ -1,6 +1,25 @@
import { config } from '$lib/stores/settings.svelte';
import { selectedModelName } from '$lib/stores/models.svelte';
import { slotsService } from './slots';
import type {
ApiChatCompletionRequest,
ApiChatCompletionResponse,
ApiChatCompletionStreamChunk,
ApiChatCompletionToolCall,
ApiChatCompletionToolCallDelta,
ApiChatMessageData
} from '$lib/types/api';
import type {
DatabaseMessage,
DatabaseMessageExtra,
DatabaseMessageExtraAudioFile,
DatabaseMessageExtraImageFile,
DatabaseMessageExtraLegacyContext,
DatabaseMessageExtraPdfFile,
DatabaseMessageExtraTextFile
} from '$lib/types/database';
import type { ChatMessagePromptProgress, ChatMessageTimings } from '$lib/types/chat';
import type { SettingsChatServiceOptions } from '$lib/types/settings';
/**
* ChatService - Low-level API communication layer for llama.cpp server interactions
*
@@ -53,6 +72,7 @@ export class ChatService {
onComplete,
onError,
onReasoningChunk,
onToolCallChunk,
onModel,
onFirstValidChunk,
// Generation parameters
@@ -201,6 +221,7 @@ export class ChatService {
onComplete,
onError,
onReasoningChunk,
onToolCallChunk,
onModel,
onFirstValidChunk,
conversationId,
@@ -208,7 +229,13 @@ export class ChatService {
);
return;
} else {
return this.handleNonStreamResponse(response, onComplete, onError, onModel);
return this.handleNonStreamResponse(
response,
onComplete,
onError,
onToolCallChunk,
onModel
);
}
} catch (error) {
if (error instanceof Error && error.name === 'AbortError') {
@@ -264,10 +291,12 @@ export class ChatService {
onComplete?: (
response: string,
reasoningContent?: string,
timings?: ChatMessageTimings
timings?: ChatMessageTimings,
toolCalls?: string
) => void,
onError?: (error: Error) => void,
onReasoningChunk?: (chunk: string) => void,
onToolCallChunk?: (chunk: string) => void,
onModel?: (model: string) => void,
onFirstValidChunk?: () => void,
conversationId?: string,
@@ -282,11 +311,53 @@ export class ChatService {
const decoder = new TextDecoder();
let aggregatedContent = '';
let fullReasoningContent = '';
let aggregatedToolCalls: ApiChatCompletionToolCall[] = [];
let hasReceivedData = false;
let lastTimings: ChatMessageTimings | undefined;
let streamFinished = false;
let modelEmitted = false;
let firstValidChunkEmitted = false;
let toolCallIndexOffset = 0;
let hasOpenToolCallBatch = false;
const finalizeOpenToolCallBatch = () => {
if (!hasOpenToolCallBatch) {
return;
}
toolCallIndexOffset = aggregatedToolCalls.length;
hasOpenToolCallBatch = false;
};
const processToolCallDelta = (toolCalls?: ApiChatCompletionToolCallDelta[]) => {
if (!toolCalls || toolCalls.length === 0) {
return;
}
aggregatedToolCalls = this.mergeToolCallDeltas(
aggregatedToolCalls,
toolCalls,
toolCallIndexOffset
);
if (aggregatedToolCalls.length === 0) {
return;
}
hasOpenToolCallBatch = true;
const serializedToolCalls = JSON.stringify(aggregatedToolCalls);
if (!serializedToolCalls) {
return;
}
hasReceivedData = true;
if (!abortSignal?.aborted) {
onToolCallChunk?.(serializedToolCalls);
}
};
try {
let chunk = '';
@@ -325,6 +396,7 @@ export class ChatService {
const content = parsed.choices[0]?.delta?.content;
const reasoningContent = parsed.choices[0]?.delta?.reasoning_content;
const toolCalls = parsed.choices[0]?.delta?.tool_calls;
const timings = parsed.timings;
const promptProgress = parsed.prompt_progress;
@@ -342,6 +414,7 @@ export class ChatService {
}
if (content) {
finalizeOpenToolCallBatch();
hasReceivedData = true;
aggregatedContent += content;
if (!abortSignal?.aborted) {
@@ -350,12 +423,15 @@ export class ChatService {
}
if (reasoningContent) {
finalizeOpenToolCallBatch();
hasReceivedData = true;
fullReasoningContent += reasoningContent;
if (!abortSignal?.aborted) {
onReasoningChunk?.(reasoningContent);
}
}
processToolCallDelta(toolCalls);
} catch (e) {
console.error('Error parsing JSON chunk:', e);
}
@@ -368,12 +444,26 @@ export class ChatService {
if (abortSignal?.aborted) return;
if (streamFinished) {
if (!hasReceivedData && aggregatedContent.length === 0) {
finalizeOpenToolCallBatch();
if (
!hasReceivedData &&
aggregatedContent.length === 0 &&
aggregatedToolCalls.length === 0
) {
const noResponseError = new Error('No response received from server. Please try again.');
throw noResponseError;
}
onComplete?.(aggregatedContent, fullReasoningContent || undefined, lastTimings);
const finalToolCalls =
aggregatedToolCalls.length > 0 ? JSON.stringify(aggregatedToolCalls) : undefined;
onComplete?.(
aggregatedContent,
fullReasoningContent || undefined,
lastTimings,
finalToolCalls
);
}
} catch (error) {
const err = error instanceof Error ? error : new Error('Stream error');
@@ -386,6 +476,54 @@ export class ChatService {
}
}
private mergeToolCallDeltas(
existing: ApiChatCompletionToolCall[],
deltas: ApiChatCompletionToolCallDelta[],
indexOffset = 0
): ApiChatCompletionToolCall[] {
const result = existing.map((call) => ({
...call,
function: call.function ? { ...call.function } : undefined
}));
for (const delta of deltas) {
const index =
typeof delta.index === 'number' && delta.index >= 0
? delta.index + indexOffset
: result.length;
while (result.length <= index) {
result.push({ function: undefined });
}
const target = result[index]!;
if (delta.id) {
target.id = delta.id;
}
if (delta.type) {
target.type = delta.type;
}
if (delta.function) {
const fn = target.function ? { ...target.function } : {};
if (delta.function.name) {
fn.name = delta.function.name;
}
if (delta.function.arguments) {
fn.arguments = (fn.arguments ?? '') + delta.function.arguments;
}
target.function = fn;
}
}
return result;
}
/**
* Handles non-streaming response from the chat completion API.
* Parses the JSON response and extracts the generated content.
@@ -401,9 +539,11 @@ export class ChatService {
onComplete?: (
response: string,
reasoningContent?: string,
timings?: ChatMessageTimings
timings?: ChatMessageTimings,
toolCalls?: string
) => void,
onError?: (error: Error) => void,
onToolCallChunk?: (chunk: string) => void,
onModel?: (model: string) => void
): Promise<string> {
try {
@@ -423,17 +563,31 @@ export class ChatService {
const content = data.choices[0]?.message?.content || '';
const reasoningContent = data.choices[0]?.message?.reasoning_content;
const toolCalls = data.choices[0]?.message?.tool_calls;
if (reasoningContent) {
console.log('Full reasoning content:', reasoningContent);
}
if (!content.trim()) {
let serializedToolCalls: string | undefined;
if (toolCalls && toolCalls.length > 0) {
const mergedToolCalls = this.mergeToolCallDeltas([], toolCalls);
if (mergedToolCalls.length > 0) {
serializedToolCalls = JSON.stringify(mergedToolCalls);
if (serializedToolCalls) {
onToolCallChunk?.(serializedToolCalls);
}
}
}
if (!content.trim() && !serializedToolCalls) {
const noResponseError = new Error('No response received from server. Please try again.');
throw noResponseError;
}
onComplete?.(content, reasoningContent);
onComplete?.(content, reasoningContent, undefined, serializedToolCalls);
return content;
} catch (error) {
@@ -205,6 +205,7 @@ class ChatStore {
type,
timestamp: Date.now(),
thinking: '',
toolCalls: '',
children: [],
extra: extras
},
@@ -360,6 +361,7 @@ class ChatStore {
): Promise<void> {
let streamedContent = '';
let streamedReasoningContent = '';
let streamedToolCallContent = '';
let resolvedModel: string | null = null;
let modelPersisted = false;
@@ -468,6 +470,20 @@ class ChatStore {
this.updateMessageAtIndex(messageIndex, { thinking: streamedReasoningContent });
},
onToolCallChunk: (toolCallChunk: string) => {
const chunk = toolCallChunk.trim();
if (!chunk) {
return;
}
streamedToolCallContent = chunk;
const messageIndex = this.findMessageIndex(assistantMessage.id);
this.updateMessageAtIndex(messageIndex, { toolCalls: streamedToolCallContent });
},
onModel: (modelName: string) => {
recordModel(modelName);
},
@@ -475,18 +491,21 @@ class ChatStore {
onComplete: async (
finalContent?: string,
reasoningContent?: string,
timings?: ChatMessageTimings
timings?: ChatMessageTimings,
toolCallContent?: string
) => {
slotsService.stopStreaming();
const updateData: {
content: string;
thinking: string;
toolCalls: string;
timings?: ChatMessageTimings;
model?: string;
} = {
content: finalContent || streamedContent,
thinking: reasoningContent || streamedReasoningContent,
toolCalls: toolCallContent || streamedToolCallContent,
timings: timings
};
@@ -499,7 +518,11 @@ class ChatStore {
const messageIndex = this.findMessageIndex(assistantMessage.id);
const localUpdateData: { timings?: ChatMessageTimings; model?: string } = {
const localUpdateData: {
timings?: ChatMessageTimings;
model?: string;
toolCalls?: string;
} = {
timings: timings
};
@@ -507,6 +530,10 @@ class ChatStore {
localUpdateData.model = updateData.model;
}
if (updateData.toolCalls !== undefined) {
localUpdateData.toolCalls = updateData.toolCalls;
}
this.updateMessageAtIndex(messageIndex, localUpdateData);
await DatabaseStore.updateCurrentNode(assistantMessage.convId, assistantMessage.id);
@@ -620,6 +647,7 @@ class ChatStore {
content: '',
timestamp: Date.now(),
thinking: '',
toolCalls: '',
children: [],
model: null
},
@@ -1443,6 +1471,7 @@ class ChatStore {
role: messageToEdit.role,
content: newContent,
thinking: messageToEdit.thinking || '',
toolCalls: messageToEdit.toolCalls || '',
children: [],
model: messageToEdit.model // Preserve original model info when branching
},
@@ -1518,6 +1547,7 @@ class ChatStore {
role: messageToEdit.role,
content: newContent,
thinking: messageToEdit.thinking || '',
toolCalls: messageToEdit.toolCalls || '',
children: [],
extra: messageToEdit.extra ? JSON.parse(JSON.stringify(messageToEdit.extra)) : undefined,
model: messageToEdit.model // Preserve original model info when branching
@@ -1589,6 +1619,7 @@ class ChatStore {
role: 'assistant',
content: '',
thinking: '',
toolCalls: '',
children: [],
model: null
},
@@ -1647,6 +1678,7 @@ class ChatStore {
role: 'assistant',
content: '',
thinking: '',
toolCalls: '',
children: [],
model: null
},
@@ -114,6 +114,7 @@ export class DatabaseStore {
...message,
id: uuid(),
parent: parentId,
toolCalls: message.toolCalls ?? '',
children: []
};
@@ -154,6 +155,7 @@ export class DatabaseStore {
content: '',
parent: null,
thinking: '',
toolCalls: '',
children: []
};
+19
View File
@@ -183,6 +183,23 @@ export interface ApiChatCompletionRequest {
samplers?: string[];
// Custom parameters (JSON string)
custom?: Record<string, unknown>;
timings_per_token?: boolean;
}
export interface ApiChatCompletionToolCallFunctionDelta {
name?: string;
arguments?: string;
}
export interface ApiChatCompletionToolCallDelta {
index?: number;
id?: string;
type?: string;
function?: ApiChatCompletionToolCallFunctionDelta;
}
export interface ApiChatCompletionToolCall extends ApiChatCompletionToolCallDelta {
function?: ApiChatCompletionToolCallFunctionDelta & { arguments?: string };
}
export interface ApiChatCompletionStreamChunk {
@@ -195,6 +212,7 @@ export interface ApiChatCompletionStreamChunk {
content?: string;
reasoning_content?: string;
model?: string;
tool_calls?: ApiChatCompletionToolCallDelta[];
};
}>;
timings?: {
@@ -216,6 +234,7 @@ export interface ApiChatCompletionResponse {
content: string;
reasoning_content?: string;
model?: string;
tool_calls?: ApiChatCompletionToolCallDelta[];
};
}>;
}
+1
View File
@@ -60,6 +60,7 @@ export interface DatabaseMessage {
content: string;
parent: string;
thinking: string;
toolCalls?: string;
children: string[];
extra?: DatabaseMessageExtra[];
timings?: ChatMessageTimings;
+8 -1
View File
@@ -38,12 +38,19 @@ export interface SettingsChatServiceOptions {
samplers?: string | string[];
// Custom parameters
custom?: string;
timings_per_token?: boolean;
// Callbacks
onChunk?: (chunk: string) => void;
onReasoningChunk?: (chunk: string) => void;
onToolCallChunk?: (chunk: string) => void;
onModel?: (model: string) => void;
onFirstValidChunk?: () => void;
onComplete?: (response: string, reasoningContent?: string, timings?: ChatMessageTimings) => void;
onComplete?: (
response: string,
reasoningContent?: string,
timings?: ChatMessageTimings,
toolCalls?: string
) => void;
onError?: (error: Error) => void;
}