mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-30 17:47:40 +02:00
Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 96fe9badfc | |||
| bde188d60f | |||
| 9d0229967a | |||
| c4c10bfb86 | |||
| 817d743cc1 | |||
| bd4ef13476 | |||
| 87a2084c45 | |||
| 3659aa28e9 | |||
| 2a73f81f8a | |||
| 7dba049b07 | |||
| 83c1171529 | |||
| 0d1324856f | |||
| a67ef0f47f | |||
| ef75a89fdb | |||
| d8b5cdc4fe |
+22
-22
@@ -1602,33 +1602,33 @@ jobs:
|
||||
run: |
|
||||
bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-x64-amd-vulkan:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
# ggml-ci-x64-amd-vulkan:
|
||||
# runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# id: checkout
|
||||
# uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
vulkaninfo --summary
|
||||
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
# - name: Test
|
||||
# id: ggml-ci
|
||||
# run: |
|
||||
# vulkaninfo --summary
|
||||
# GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-x64-amd-rocm:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
# ggml-ci-x64-amd-rocm:
|
||||
# runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# id: checkout
|
||||
# uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
amd-smi static
|
||||
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
# - name: Test
|
||||
# id: ggml-ci
|
||||
# run: |
|
||||
# amd-smi static
|
||||
# GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-mac-metal:
|
||||
runs-on: [self-hosted, macOS, ARM64]
|
||||
|
||||
+6
-5
@@ -72,6 +72,12 @@ if (MSVC)
|
||||
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/bigobj>")
|
||||
endif()
|
||||
|
||||
if (LLAMA_STANDALONE)
|
||||
# enable parallel builds for msbuild
|
||||
list(APPEND CMAKE_VS_GLOBALS UseMultiToolTask=true)
|
||||
list(APPEND CMAKE_VS_GLOBALS EnforceProcessCountAcrossBuilds=true)
|
||||
endif()
|
||||
|
||||
if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
|
||||
set(LLAMA_TOOLS_INSTALL_DEFAULT OFF)
|
||||
else()
|
||||
@@ -193,11 +199,6 @@ if (NOT TARGET ggml AND NOT LLAMA_USE_SYSTEM_GGML)
|
||||
# ... otherwise assume ggml is added by a parent CMakeLists.txt
|
||||
endif()
|
||||
|
||||
if (MINGW)
|
||||
# Target Windows 8 for PrefetchVirtualMemory
|
||||
add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
|
||||
endif()
|
||||
|
||||
#
|
||||
# build the library
|
||||
#
|
||||
|
||||
+5
-21
@@ -39,26 +39,10 @@ if(Git_FOUND)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(MSVC)
|
||||
set(BUILD_COMPILER "${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}")
|
||||
if (CMAKE_VS_PLATFORM_NAME)
|
||||
set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME})
|
||||
else()
|
||||
set(BUILD_TARGET "${CMAKE_SYSTEM_NAME} ${CMAKE_SYSTEM_PROCESSOR}")
|
||||
endif()
|
||||
else()
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_C_COMPILER} --version
|
||||
OUTPUT_VARIABLE OUT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
string(REGEX REPLACE " *\n.*" "" OUT "${OUT}")
|
||||
set(BUILD_COMPILER ${OUT})
|
||||
set(BUILD_COMPILER "${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}")
|
||||
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_C_COMPILER} -dumpmachine
|
||||
OUTPUT_VARIABLE OUT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
set(BUILD_TARGET ${OUT})
|
||||
if(CMAKE_VS_PLATFORM_NAME)
|
||||
set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME})
|
||||
else()
|
||||
set(BUILD_TARGET "${CMAKE_SYSTEM_NAME} ${CMAKE_SYSTEM_PROCESSOR}")
|
||||
endif()
|
||||
|
||||
+1
-1
@@ -427,7 +427,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
|
||||
// model is required (except for server)
|
||||
// TODO @ngxson : maybe show a list of available models in CLI in this case
|
||||
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER) {
|
||||
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage) {
|
||||
throw std::invalid_argument("error: --model is required\n");
|
||||
}
|
||||
|
||||
|
||||
+20
-2
@@ -786,11 +786,29 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
||||
#include <iostream>
|
||||
|
||||
|
||||
#ifdef _WIN32
|
||||
static std::wstring utf8_to_wstring(const std::string & str) {
|
||||
if (str.empty()) {
|
||||
return std::wstring();
|
||||
}
|
||||
|
||||
int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0);
|
||||
|
||||
if (size <= 0) {
|
||||
return std::wstring();
|
||||
}
|
||||
|
||||
std::wstring wstr(size, 0);
|
||||
MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size);
|
||||
|
||||
return wstr;
|
||||
}
|
||||
#endif
|
||||
|
||||
// returns true if successful, false otherwise
|
||||
bool fs_create_directory_with_parents(const std::string & path) {
|
||||
#ifdef _WIN32
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
||||
std::wstring wpath = converter.from_bytes(path);
|
||||
std::wstring wpath = utf8_to_wstring(path);
|
||||
|
||||
// if the path already exists, check whether it's a directory
|
||||
const DWORD attributes = GetFileAttributesW(wpath.c_str());
|
||||
|
||||
@@ -12,6 +12,10 @@
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#if defined(_WIN32) && !defined(_WIN32_WINNT)
|
||||
#define _WIN32_WINNT 0x0A00
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
#define DIRECTORY_SEPARATOR '\\'
|
||||
#else
|
||||
|
||||
+16
-4
@@ -2341,9 +2341,18 @@ class LlamaModel(TextModel):
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
self.gguf_writer.add_add_eos_token(False)
|
||||
|
||||
template_dir = Path(__file__).parent / "models/templates/"
|
||||
local_template_file_path = self.dir_model / "chat_template.jinja"
|
||||
|
||||
if self.is_mistral_format and local_template_file_path.is_file():
|
||||
# Ministral-3 and other new Mistral models come with chat templates.
|
||||
# ref: https://huggingface.co/mistralai/Ministral-3-14B-Instruct-2512/tree/main
|
||||
logger.info("Using an existing Mistral local chat template.")
|
||||
|
||||
with open(local_template_file_path, "r", encoding="utf-8") as f:
|
||||
template = f.read()
|
||||
elif not self.is_mistral_format or not self.disable_mistral_community_chat_template:
|
||||
template_dir = Path(__file__).parent / "models/templates/"
|
||||
|
||||
if not self.is_mistral_format or not self.disable_mistral_community_chat_template:
|
||||
# Log only for Mistral format that the official tokenization and detokenization is via `mistral-common`.
|
||||
if self.is_mistral_format:
|
||||
logger.info(
|
||||
@@ -2351,9 +2360,12 @@ class LlamaModel(TextModel):
|
||||
"Mistral recommends to use `mistral-common` to perform tokenization and detokenization."
|
||||
)
|
||||
template = MistralModel.get_community_chat_template(vocab, template_dir, self.is_mistral_format)
|
||||
self.gguf_writer.add_chat_template(template)
|
||||
else:
|
||||
logger.info("Not using a Mistral community chat template. Ensure to perform the tokenization and detokenization via `mistral-common`.")
|
||||
logger.info("Not using a Mistral local or community chat template. Ensure to perform the tokenization and detokenization via `mistral-common`.")
|
||||
template = None
|
||||
|
||||
if template is not None:
|
||||
self.gguf_writer.add_chat_template(template)
|
||||
|
||||
def set_vocab(self):
|
||||
if self.is_mistral_format:
|
||||
|
||||
@@ -18,6 +18,7 @@ cd llama.cpp
|
||||
cmake -S . -B build
|
||||
cmake --build build
|
||||
cmake --install build --prefix inst
|
||||
```
|
||||
|
||||
### Build simple-cmake-pkg
|
||||
|
||||
|
||||
@@ -175,11 +175,6 @@ option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requi
|
||||
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
|
||||
set(GGML_CPU_POWERPC_CPUTYPE "" CACHE STRING "ggml: CPU type for PowerPC")
|
||||
|
||||
|
||||
if (MINGW)
|
||||
set(GGML_WIN_VER "0xA00" CACHE STRING "ggml: Windows version")
|
||||
endif()
|
||||
|
||||
# ggml core
|
||||
set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
|
||||
option(GGML_CPU "ggml: enable CPU backend" ON)
|
||||
|
||||
@@ -204,6 +204,10 @@
|
||||
# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||
#endif
|
||||
|
||||
#if defined(_WIN32) && !defined(_WIN32_WINNT)
|
||||
# define _WIN32_WINNT 0x0A00
|
||||
#endif
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
@@ -127,10 +127,6 @@ if (NOT MSVC)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (MINGW)
|
||||
add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
|
||||
endif()
|
||||
|
||||
#
|
||||
# POSIX conformance
|
||||
#
|
||||
|
||||
@@ -505,7 +505,6 @@ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
constexpr int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
@@ -645,7 +644,6 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
||||
constexpr int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
|
||||
@@ -463,6 +463,53 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
template<typename T, int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) {
|
||||
const int lane_id = threadIdx.x % width;
|
||||
#pragma unroll
|
||||
for (int offset = 1; offset < width; offset <<= 1) {
|
||||
const T t = __shfl_up_sync(0xffffffff, x, offset, width);
|
||||
if (lane_id >= offset) {
|
||||
x += t;
|
||||
}
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) {
|
||||
const int lane_id = threadIdx.x % width;
|
||||
#pragma unroll
|
||||
for (int offset = 1; offset < width; offset <<= 1) {
|
||||
const float t_x = __shfl_up_sync(0xffffffff, a.x, offset, width);
|
||||
const float t_y = __shfl_up_sync(0xffffffff, a.y, offset, width);
|
||||
if (lane_id >= offset) {
|
||||
a.x += t_x;
|
||||
a.y += t_y;
|
||||
}
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) {
|
||||
#ifdef FP16_AVAILABLE
|
||||
const int lane_id = threadIdx.x % width;
|
||||
#pragma unroll
|
||||
for (int offset = 1; offset < width; offset <<= 1) {
|
||||
const half2 t = __shfl_up_sync(0xffffffff, a, offset, width);
|
||||
if (lane_id >= offset) {
|
||||
a = __hadd2(a, t);
|
||||
}
|
||||
}
|
||||
return a;
|
||||
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
return a;
|
||||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
|
||||
#ifdef FP16_AVAILABLE
|
||||
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
#include <algorithm>
|
||||
#include "cumsum.cuh"
|
||||
#include "convert.cuh"
|
||||
#include "ggml-cuda/common.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
# include <cub/device/device_scan.cuh>
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
|
||||
template<typename T, int BLOCK_SIZE>
|
||||
static __global__ void cumsum_cub_kernel(
|
||||
const T * __restrict__ src,
|
||||
T * __restrict__ dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t s1, const int64_t s2, const int64_t s3) {
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
using BlockScan = cub::BlockScan<T, BLOCK_SIZE>;
|
||||
|
||||
__shared__ typename BlockScan::TempStorage temp_storage;
|
||||
__shared__ T block_carry; // carry from previous tile
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const int64_t i1 = blockIdx.x;
|
||||
const int64_t i2 = blockIdx.y;
|
||||
const int64_t i3 = blockIdx.z;
|
||||
|
||||
if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
|
||||
return;
|
||||
}
|
||||
|
||||
const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
|
||||
T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
|
||||
|
||||
if (tid == 0) {
|
||||
block_carry = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int64_t start = 0; start < ne00; start += BLOCK_SIZE) {
|
||||
int64_t idx = start + tid;
|
||||
T x = (idx < ne00) ? src_row[idx] : T(0);
|
||||
|
||||
T inclusive;
|
||||
T block_total;
|
||||
BlockScan(temp_storage).InclusiveSum(x, inclusive, block_total);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
T final_val = inclusive + block_carry;
|
||||
|
||||
// store result
|
||||
if (idx < ne00) {
|
||||
dst_row[idx] = final_val;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (tid == 0) {
|
||||
block_carry += block_total;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
}
|
||||
|
||||
// Fallback kernel implementation (original)
|
||||
template<typename T>
|
||||
static __global__ void cumsum_kernel(
|
||||
const T * src, T * dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t s00, const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t s0, const int64_t s1, const int64_t s2, const int64_t s3) {
|
||||
|
||||
GGML_UNUSED_VARS(s00, s0);
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
const int lane = tid % warp_size;
|
||||
const int warp = tid / warp_size;
|
||||
const int warps_per_block = blockDim.x / warp_size;
|
||||
|
||||
extern __shared__ float smem[];
|
||||
float * s_vals = smem;
|
||||
float * s_warp_sums = smem + blockDim.x;
|
||||
float * s_carry = smem + blockDim.x + warps_per_block;
|
||||
float * s_chunk_total = s_carry + 1;
|
||||
|
||||
// Initialize carry
|
||||
if (tid == 0) {
|
||||
*s_carry = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const int64_t i3 = blockIdx.z;
|
||||
const int64_t i2 = blockIdx.y;
|
||||
const int64_t i1 = blockIdx.x;
|
||||
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
|
||||
T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
|
||||
|
||||
for (int64_t start = 0; start < ne00; start += blockDim.x) {
|
||||
int64_t idx = start + tid;
|
||||
float val = (idx < ne00) ? ggml_cuda_cast<float, T>(src_row[idx]) : 0.0f;
|
||||
|
||||
// 1. Warp inclusive scan
|
||||
val = warp_prefix_inclusive_sum<T, warp_size>(val);
|
||||
s_vals[tid] = val;
|
||||
|
||||
// Store warp total
|
||||
if (lane == warp_size - 1) {
|
||||
s_warp_sums[warp] = val;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// 2. Exclusive scan of warp sums (warp 0 only)
|
||||
if (warp == 0) {
|
||||
float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
|
||||
float inc = warp_prefix_inclusive_sum<T, warp_size>(w);
|
||||
if (tid < warps_per_block) {
|
||||
s_warp_sums[tid] = inc - w; // exclusive sum
|
||||
}
|
||||
if (tid == warps_per_block - 1) {
|
||||
*s_chunk_total = inc; // total sum of this chunk
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float carry = *s_carry;
|
||||
float final_val = s_vals[tid] + s_warp_sums[warp] + carry;
|
||||
if (idx < ne00) {
|
||||
dst_row[idx] = ggml_cuda_cast<T, float>(final_val);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Update carry for next chunk
|
||||
if (tid == 0) {
|
||||
*s_carry += *s_chunk_total;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void cumsum_cuda(
|
||||
const T * src, T * dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
|
||||
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
|
||||
cudaStream_t stream) {
|
||||
|
||||
const size_t type_size = sizeof(T);
|
||||
bool use_cub = false;
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
// Check if we can use CUB (data must be contiguous along innermost dimension)
|
||||
const bool is_contiguous = (nb00 == type_size) && (nb0 == type_size);
|
||||
|
||||
if (is_contiguous) {
|
||||
use_cub = true;
|
||||
}
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
dim3 grid_dims(ne01, ne02, ne03);
|
||||
const auto &info = ggml_cuda_info().devices[ggml_cuda_get_device()];
|
||||
const int warp_size = info.warp_size;
|
||||
const int num_warps = (ne00 + warp_size - 1) / warp_size;
|
||||
int block_size = num_warps * warp_size;
|
||||
block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE);
|
||||
dim3 block_dims(block_size, 1, 1);
|
||||
const int warps_per_block = block_size / warp_size;
|
||||
const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);
|
||||
|
||||
if (use_cub) {
|
||||
cumsum_cub_kernel<T, CUDA_CUMSUM_BLOCK_SIZE><<<grid_dims, CUDA_CUMSUM_BLOCK_SIZE, 0, stream>>>(
|
||||
src, dst,
|
||||
ne00, ne01, ne02, ne03,
|
||||
nb01 / type_size, nb02 / type_size, nb03 / type_size,
|
||||
nb1 / type_size, nb2 / type_size, nb3 / type_size
|
||||
);
|
||||
} else {
|
||||
cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>(
|
||||
src, dst,
|
||||
ne00, ne01, ne02, ne03,
|
||||
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
|
||||
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == dst->type);
|
||||
switch(src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
cumsum_cuda(
|
||||
(const float *)src0->data, (float *)dst->data,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
|
||||
stream
|
||||
);
|
||||
} break;
|
||||
// We do not support those on CPU for now anyway, so comment them out because they cause errors on some CI platforms
|
||||
/*case GGML_TYPE_F16:
|
||||
{
|
||||
cumsum_cuda(
|
||||
(const half *)src0->data, (half *)dst->data,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
|
||||
stream
|
||||
);
|
||||
} break;
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
cumsum_cuda(
|
||||
(const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
|
||||
stream
|
||||
);
|
||||
} break;*/
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_CUMSUM_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -54,6 +54,8 @@
|
||||
#include "ggml-cuda/set-rows.cuh"
|
||||
#include "ggml-cuda/pad_reflect_1d.cuh"
|
||||
#include "ggml-cuda/solve_tri.cuh"
|
||||
#include "ggml-cuda/tri.cuh"
|
||||
#include "ggml-cuda/cumsum.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <algorithm>
|
||||
@@ -2701,6 +2703,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
ggml_cuda_cross_entropy_loss(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CUMSUM:
|
||||
ggml_cuda_op_cumsum(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_TRI:
|
||||
ggml_cuda_op_tri(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
ggml_cuda_op_rwkv_wkv6(ctx, dst);
|
||||
break;
|
||||
@@ -4609,6 +4617,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
case GGML_OP_CUMSUM:
|
||||
case GGML_OP_TRI:
|
||||
return true;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
#include "common.cuh"
|
||||
#include "convert.cuh"
|
||||
#include "tri.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
template<typename T, bool prefix_keep, int add_to_split>
|
||||
static __global__ void tri_kernel(
|
||||
const T * src, T * dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
|
||||
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) {
|
||||
const int64_t i3 = blockIdx.z;
|
||||
const int64_t i2 = blockIdx.y;
|
||||
const int64_t i1 = blockIdx.x;
|
||||
const int64_t split_point = i1 + add_to_split;
|
||||
|
||||
GGML_UNUSED_VARS(nb00, nb0);
|
||||
|
||||
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const T * src_row = src + i1*nb01 + i2*nb02 + i3*nb03;
|
||||
T * dst_row = dst + i1*nb1 + i2*nb2 + i3*nb3;
|
||||
|
||||
if constexpr (prefix_keep) {
|
||||
for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
|
||||
dst_row[i0] = src_row[i0];
|
||||
}
|
||||
for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
|
||||
dst_row[i0] = ggml_cuda_cast<T, float>(0.0f);
|
||||
}
|
||||
} else {
|
||||
for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
|
||||
dst_row[i0] = ggml_cuda_cast<T, float>(0.0f);
|
||||
}
|
||||
for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
|
||||
dst_row[i0] = src_row[i0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void tri_cuda(
|
||||
const T * src, T * dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
|
||||
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
|
||||
const ggml_tri_type ttype,
|
||||
cudaStream_t stream) {
|
||||
|
||||
dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1);
|
||||
dim3 grid_dims(ne01, ne02, ne03);
|
||||
const size_t type_size = sizeof(T);
|
||||
|
||||
const int add_to_split = (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) ? 1 : 0;
|
||||
const bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG);
|
||||
|
||||
if (prefix_keep) {
|
||||
if (add_to_split == 0) {
|
||||
tri_kernel<T, true, 0><<<grid_dims, block_dims, 0, stream>>>(
|
||||
src, dst,
|
||||
ne00, ne01, ne02, ne03,
|
||||
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
|
||||
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
|
||||
);
|
||||
} else { // only 0 and 1 supported
|
||||
tri_kernel<T, true, 1><<<grid_dims, block_dims, 0, stream>>>(
|
||||
src, dst,
|
||||
ne00, ne01, ne02, ne03,
|
||||
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
|
||||
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
|
||||
);
|
||||
}
|
||||
} else {
|
||||
if (add_to_split == 0) {
|
||||
tri_kernel<T, false, 0><<<grid_dims, block_dims, 0, stream>>>(
|
||||
src, dst,
|
||||
ne00, ne01, ne02, ne03,
|
||||
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
|
||||
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
|
||||
);
|
||||
} else {
|
||||
tri_kernel<T, false, 1><<<grid_dims, block_dims, 0, stream>>>(
|
||||
src, dst,
|
||||
ne00, ne01, ne02, ne03,
|
||||
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
|
||||
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const ggml_tri_type ttype = static_cast<ggml_tri_type>(ggml_get_op_params_i32(dst, 0));
|
||||
|
||||
GGML_ASSERT(src0->type == dst->type);
|
||||
|
||||
switch(src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
tri_cuda(
|
||||
(const float *)src0->data, (float *)dst->data,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
|
||||
ttype, stream
|
||||
);
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
tri_cuda(
|
||||
(const half *)src0->data, (half *)dst->data,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
|
||||
ttype, stream
|
||||
);
|
||||
} break;
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
tri_cuda(
|
||||
(const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
|
||||
ttype, stream
|
||||
);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_TRI_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
File diff suppressed because it is too large
Load Diff
@@ -35,20 +35,6 @@ typedef struct ggml_metal_pipeline * ggml_metal_pipeline_t;
|
||||
ggml_metal_pipeline_t ggml_metal_pipeline_init(void);
|
||||
void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg);
|
||||
int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0);
|
||||
int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1);
|
||||
int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem);
|
||||
size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
// a collection of pipelines
|
||||
typedef struct ggml_metal_pipelines * ggml_metal_pipelines_t;
|
||||
|
||||
@@ -58,6 +44,19 @@ void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls);
|
||||
void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline);
|
||||
ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name);
|
||||
|
||||
struct ggml_metal_pipeline_with_params {
|
||||
ggml_metal_pipeline_t pipeline;
|
||||
|
||||
int nsg;
|
||||
|
||||
int nr0;
|
||||
int nr1;
|
||||
|
||||
size_t smem;
|
||||
};
|
||||
|
||||
int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline);
|
||||
|
||||
//
|
||||
// MTLCommandBuffer wrapper
|
||||
//
|
||||
@@ -76,7 +75,7 @@ void ggml_metal_encoder_free(ggml_metal_encoder_t encoder);
|
||||
void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name);
|
||||
void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder);
|
||||
|
||||
void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline);
|
||||
void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline);
|
||||
|
||||
void ggml_metal_encoder_set_bytes (ggml_metal_encoder_t encoder, void * data, size_t size, int idx);
|
||||
void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx);
|
||||
@@ -100,66 +99,67 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
|
||||
|
||||
void ggml_metal_library_free(ggml_metal_library_t lib);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name);
|
||||
ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
bool has_mask,
|
||||
int32_t ncpsg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
int32_t nqptg,
|
||||
int32_t ncpsg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
bool has_mask,
|
||||
@@ -169,7 +169,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
bool has_kvpad,
|
||||
int32_t nsg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
bool has_mask,
|
||||
@@ -180,7 +180,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
int32_t nsg,
|
||||
int32_t nwg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
int32_t dv,
|
||||
|
||||
@@ -75,14 +75,6 @@ void ggml_metal_cv_set_bool(ggml_metal_cv_t cv, bool value, int32_t idx) {
|
||||
|
||||
struct ggml_metal_pipeline {
|
||||
id<MTLComputePipelineState> obj;
|
||||
|
||||
// suggested dispatch sizes
|
||||
int nsg;
|
||||
|
||||
int nr0;
|
||||
int nr1;
|
||||
|
||||
size_t smem;
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_pipeline_init(void) {
|
||||
@@ -90,10 +82,6 @@ ggml_metal_pipeline_t ggml_metal_pipeline_init(void) {
|
||||
|
||||
*res = (struct ggml_metal_pipeline) {
|
||||
/*.obj =*/ nil,
|
||||
/*.nsg =*/ 0,
|
||||
/*.nr0 =*/ 0,
|
||||
/*.nr1 =*/ 0,
|
||||
/*.smem =*/ 0,
|
||||
};
|
||||
|
||||
return res;
|
||||
@@ -105,40 +93,8 @@ void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline) {
|
||||
free(pipeline);
|
||||
}
|
||||
|
||||
void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg) {
|
||||
pipeline->nsg = nsg;
|
||||
}
|
||||
|
||||
int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline) {
|
||||
return pipeline->nsg;
|
||||
}
|
||||
|
||||
void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0) {
|
||||
pipeline->nr0 = nr0;
|
||||
}
|
||||
|
||||
int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline) {
|
||||
return pipeline->nr0;
|
||||
}
|
||||
|
||||
void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1) {
|
||||
pipeline->nr1 = nr1;
|
||||
}
|
||||
|
||||
int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline) {
|
||||
return pipeline->nr1;
|
||||
}
|
||||
|
||||
void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem) {
|
||||
pipeline->smem = smem;
|
||||
}
|
||||
|
||||
size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline) {
|
||||
return pipeline->smem;
|
||||
}
|
||||
|
||||
int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline) {
|
||||
return pipeline->obj.maxTotalThreadsPerThreadgroup;
|
||||
int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline) {
|
||||
return pipeline.pipeline->obj.maxTotalThreadsPerThreadgroup;
|
||||
}
|
||||
|
||||
struct ggml_metal_library {
|
||||
@@ -389,28 +345,42 @@ void ggml_metal_library_free(ggml_metal_library_t lib) {
|
||||
free(lib);
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {
|
||||
[lib->lock lock];
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name);
|
||||
struct ggml_metal_pipeline_with_params res = {
|
||||
/*.pipeline =*/ nil,
|
||||
/*.nr0 =*/ 0,
|
||||
/*.nr1 =*/ 0,
|
||||
/*.nsg =*/ 0,
|
||||
/*.smem =*/ 0,
|
||||
};
|
||||
|
||||
res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
|
||||
|
||||
[lib->lock unlock];
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
|
||||
struct ggml_metal_pipeline_with_params res = {
|
||||
/*.pipeline =*/ nil,
|
||||
/*.nr0 =*/ 0,
|
||||
/*.nr1 =*/ 0,
|
||||
/*.nsg =*/ 0,
|
||||
/*.smem =*/ 0,
|
||||
};
|
||||
|
||||
[lib->lock lock];
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name);
|
||||
if (res) {
|
||||
res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
|
||||
if (res.pipeline) {
|
||||
[lib->lock unlock];
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_pipeline_init();
|
||||
|
||||
@autoreleasepool {
|
||||
NSError * error = nil;
|
||||
|
||||
@@ -432,26 +402,43 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
|
||||
GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]);
|
||||
}
|
||||
|
||||
return nil;
|
||||
return res;
|
||||
}
|
||||
|
||||
res->obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error];
|
||||
id<MTLComputePipelineState> obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error];
|
||||
|
||||
[mtl_function release];
|
||||
|
||||
GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) res->obj,
|
||||
(int) res->obj.maxTotalThreadsPerThreadgroup,
|
||||
(int) res->obj.threadExecutionWidth);
|
||||
if (!obj) {
|
||||
[lib->lock unlock];
|
||||
|
||||
GGML_LOG_ERROR("%s: failed to create pipeline state: base = '%s', name = '%s'\n", __func__, base, name);
|
||||
if (error) {
|
||||
GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name,
|
||||
(void *) obj,
|
||||
(int) obj.maxTotalThreadsPerThreadgroup,
|
||||
(int) obj.threadExecutionWidth);
|
||||
|
||||
if (obj.maxTotalThreadsPerThreadgroup == 0 || obj.threadExecutionWidth == 0) {
|
||||
[obj release];
|
||||
|
||||
if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) {
|
||||
[lib->lock unlock];
|
||||
|
||||
GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name);
|
||||
|
||||
return nil;
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipelines_add(lib->pipelines, name, res);
|
||||
res.pipeline = ggml_metal_pipeline_init();
|
||||
res.pipeline->obj = obj;
|
||||
|
||||
ggml_metal_pipelines_add(lib->pipelines, name, res.pipeline);
|
||||
}
|
||||
|
||||
[lib->lock unlock];
|
||||
@@ -496,8 +483,8 @@ void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder) {
|
||||
[encoder->obj popDebugGroup];
|
||||
}
|
||||
|
||||
void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline) {
|
||||
[encoder->obj setComputePipelineState:pipeline->obj];
|
||||
void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline) {
|
||||
[encoder->obj setComputePipelineState:pipeline.pipeline->obj];
|
||||
}
|
||||
|
||||
void ggml_metal_encoder_set_bytes(ggml_metal_encoder_t encoder, void * data, size_t size, int idx) {
|
||||
@@ -622,8 +609,8 @@ ggml_metal_device_t ggml_metal_device_init(void) {
|
||||
GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
|
||||
dev->props.has_tensor = false;
|
||||
} else {
|
||||
ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
|
||||
if (!ppl) {
|
||||
struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
|
||||
if (!ppl.pipeline) {
|
||||
GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
|
||||
dev->props.has_tensor = false;
|
||||
}
|
||||
@@ -672,8 +659,8 @@ ggml_metal_device_t ggml_metal_device_init(void) {
|
||||
GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
|
||||
dev->props.has_bfloat = false;
|
||||
} else {
|
||||
ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
|
||||
if (!ppl) {
|
||||
struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
|
||||
if (!ppl.pipeline) {
|
||||
GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
|
||||
dev->props.has_bfloat = false;
|
||||
}
|
||||
@@ -831,6 +818,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
default:
|
||||
return false;
|
||||
@@ -863,6 +852,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_FILL:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
return true;
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
@@ -880,6 +870,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SUM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_TRI:
|
||||
return ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_CUMSUM:
|
||||
case GGML_OP_MEAN:
|
||||
|
||||
@@ -182,6 +182,10 @@ typedef struct {
|
||||
float bias;
|
||||
} ggml_metal_kargs_scale;
|
||||
|
||||
typedef struct {
|
||||
float val;
|
||||
} ggml_metal_kargs_fill;
|
||||
|
||||
typedef struct {
|
||||
float min;
|
||||
float max;
|
||||
@@ -831,6 +835,25 @@ typedef struct {
|
||||
float slope;
|
||||
} ggml_metal_kargs_leaky_relu;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_tri;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
|
||||
@@ -286,6 +286,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_scale(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_FILL:
|
||||
{
|
||||
n_fuse = ggml_metal_op_fill(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_CLAMP:
|
||||
{
|
||||
n_fuse = ggml_metal_op_clamp(ctx, idx);
|
||||
@@ -414,6 +418,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_TRI:
|
||||
{
|
||||
n_fuse = ggml_metal_op_tri(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
|
||||
@@ -524,7 +532,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
|
||||
/*.dim =*/ dim,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -550,7 +558,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
|
||||
|
||||
ggml_metal_kargs_repeat args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -616,7 +624,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
// TODO: make a simpler cpy_bytes kernel
|
||||
|
||||
//const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
||||
|
||||
ggml_metal_kargs_cpy args = {
|
||||
/*.nk0 =*/ ne00,
|
||||
@@ -679,7 +687,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
/*.o1 =*/ { 0 },
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -721,7 +729,42 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
|
||||
n /= 4;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const float val = ggml_get_op_params_f32(op, 0);
|
||||
|
||||
ggml_metal_kargs_fill args = {
|
||||
/*.val =*/ val
|
||||
};
|
||||
|
||||
int64_t n = ggml_nelements(op);
|
||||
|
||||
if (n % 4 == 0) {
|
||||
n /= 4;
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -760,7 +803,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
|
||||
n /= 4;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -789,7 +832,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
||||
n /= 4;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
|
||||
@@ -817,7 +860,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_glu(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_glu(lib, op);
|
||||
|
||||
const int32_t swp = ggml_get_op_params_i32(op, 1);
|
||||
const float alpha = ggml_get_op_params_f32(op, 2);
|
||||
@@ -870,7 +913,7 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
|
||||
/*.np =*/ n,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
@@ -925,7 +968,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
@@ -936,7 +979,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
nth = std::min(nth, ne00);
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -963,7 +1006,7 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
|
||||
auto pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
|
||||
|
||||
int nth = 1;
|
||||
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
|
||||
@@ -1060,7 +1103,7 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
|
||||
{
|
||||
ggml_metal_pipeline_t pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
|
||||
auto pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
|
||||
|
||||
ggml_metal_kargs_cumsum_add args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -1106,7 +1149,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
|
||||
|
||||
ggml_metal_kargs_get_rows args = {
|
||||
/*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
|
||||
@@ -1151,7 +1194,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
|
||||
|
||||
const int32_t nk0 = ne0/ggml_blck_size(op->type);
|
||||
|
||||
@@ -1252,7 +1295,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
||||
/*.n_head_log2 =*/ n_head_log2,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
@@ -1266,7 +1309,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
}
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
||||
@@ -1322,7 +1365,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb2 =*/ nb2,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
||||
@@ -1409,11 +1452,11 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb0 =*/ nb0,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
|
||||
|
||||
GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
const size_t sms = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -1426,7 +1469,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 8);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
|
||||
|
||||
@@ -1449,7 +1492,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
|
||||
const int64_t C = op->ne[0];
|
||||
const int64_t H = op->src[0]->ne[1];
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
|
||||
|
||||
int ida = 0;
|
||||
|
||||
@@ -1485,7 +1528,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
||||
|
||||
GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
|
||||
|
||||
@@ -1592,7 +1635,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
|
||||
/* .np = */ np
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
|
||||
const int ntg = (np + nth - 1) / nth;
|
||||
@@ -1701,7 +1744,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ABORT("unsupported ne11");
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
|
||||
|
||||
ggml_metal_kargs_mul_mv_ext args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -1748,7 +1791,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
// default: break;
|
||||
//}
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
|
||||
|
||||
ggml_metal_kargs_mul_mm args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -1773,18 +1816,18 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);
|
||||
} else {
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
|
||||
|
||||
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
|
||||
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
|
||||
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
|
||||
const int nr0 = pipeline.nr0;
|
||||
const int nr1 = pipeline.nr1;
|
||||
const int nsg = pipeline.nsg;
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_kargs_mul_mv args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -1915,9 +1958,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
nb21,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
@@ -1938,7 +1981,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
|
||||
{
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
|
||||
|
||||
ggml_metal_kargs_mul_mm_id args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -1967,20 +2010,20 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_set_buffer (enc, bid_ids, 4);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 5);
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
|
||||
}
|
||||
} else {
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
|
||||
|
||||
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
|
||||
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
|
||||
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
|
||||
const int nr0 = pipeline.nr0;
|
||||
const int nr1 = pipeline.nr1;
|
||||
const int nsg = pipeline.nsg;
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_kargs_mul_mv_id args = {
|
||||
/*.nei0 =*/ ne20,
|
||||
@@ -2064,7 +2107,7 @@ int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb21 =*/ nb21,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -2308,7 +2351,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb33 =*/nb33,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
||||
auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||
@@ -2339,7 +2382,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb33 =*/ nb33,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
|
||||
auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||
@@ -2424,7 +2467,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.logit_softcap =*/ logit_softcap,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -2476,7 +2519,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb33 =*/nb33,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
||||
auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||
@@ -2578,7 +2621,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.logit_softcap =*/ logit_softcap,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
|
||||
|
||||
GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
@@ -2630,7 +2673,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
nrows,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
|
||||
auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||
@@ -2762,7 +2805,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
||||
// the offsets of src1 and all fused buffers are relative to the start of the src1 buffer
|
||||
bid_src1.offs = 0;
|
||||
|
||||
ggml_metal_pipeline_t pipeline = nullptr;
|
||||
struct ggml_metal_pipeline_with_params pipeline;
|
||||
|
||||
if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
@@ -2835,7 +2878,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||
/*.eps =*/ eps,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
|
||||
|
||||
while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nth *= 2;
|
||||
@@ -2844,7 +2887,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
nth = std::min(nth, ne00/4);
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
const int64_t nrows = ggml_nrows(op->src[0]);
|
||||
|
||||
@@ -2887,7 +2930,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
||||
/*.eps =*/ eps,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
//while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
@@ -2897,7 +2940,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
||||
//nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
//nth = std::min(nth, ne00/4);
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -3022,7 +3065,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
@@ -3033,7 +3076,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
nth = std::min(nth, args.ne00_t);
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -3127,7 +3170,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
|
||||
/* src2 =*/ op->src[2] != nullptr,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -3199,7 +3242,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
|
||||
/*.KHW =*/ KH * KW,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
|
||||
|
||||
GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
@@ -3270,7 +3313,7 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
|
||||
/*.d1 =*/ d1,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);
|
||||
|
||||
int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
|
||||
nth = std::min(nth, 256);
|
||||
@@ -3325,7 +3368,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb1 =*/ nb1,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -3377,7 +3420,7 @@ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb2 =*/ nb2,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -3433,7 +3476,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
||||
/*.sf3 =*/ sf3
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
||||
|
||||
@@ -3477,7 +3520,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb3 =*/ nb3
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
|
||||
|
||||
const int nth = std::min(1024, ne0);
|
||||
|
||||
@@ -3523,7 +3566,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
|
||||
/*.p1 =*/ ((const int32_t *)(op->op_params))[1]
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
|
||||
|
||||
const int nth = std::min(1024, ne0);
|
||||
|
||||
@@ -3560,7 +3603,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
const int nth = std::min(1024, ne0);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -3591,7 +3634,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
|
||||
/*.max_period =*/ max_period,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
|
||||
|
||||
const int nth = std::max(1, std::min(1024, dim/2));
|
||||
|
||||
@@ -3621,7 +3664,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb01 = */ nb01,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argmax(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_argmax(lib, op);
|
||||
|
||||
const int64_t nrows = ggml_nrows(op->src[0]);
|
||||
|
||||
@@ -3630,7 +3673,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -3657,7 +3700,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
|
||||
|
||||
// bitonic sort requires the number of elements to be power of 2
|
||||
int nth = 1;
|
||||
@@ -3706,7 +3749,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
|
||||
auto pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
|
||||
|
||||
int len = nth;
|
||||
|
||||
@@ -3764,7 +3807,7 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
|
||||
|
||||
// bitonic sort requires the number of elements to be power of 2
|
||||
int nth = 1;
|
||||
@@ -3818,7 +3861,7 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
|
||||
auto pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
|
||||
|
||||
int len = args.top_k;
|
||||
|
||||
@@ -3881,7 +3924,7 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
||||
/*.slope =*/ slope
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
int64_t n = ggml_nelements(op);
|
||||
|
||||
@@ -3899,6 +3942,57 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_kargs_tri args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_tri(lib, op);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
nth = std::min(nth, ne00);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
@@ -3910,7 +4004,7 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
|
||||
|
||||
const int64_t np = ggml_nelements(op->src[0]);
|
||||
ggml_metal_kargs_opt_step_adamw args = {
|
||||
@@ -3946,7 +4040,7 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
|
||||
|
||||
const int64_t np = ggml_nelements(op->src[0]);
|
||||
ggml_metal_kargs_opt_step_sgd args = {
|
||||
|
||||
@@ -47,6 +47,7 @@ int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_fill (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
|
||||
@@ -83,6 +84,7 @@ int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
|
||||
@@ -1249,6 +1249,22 @@ kernel void kernel_scale_f32_4(
|
||||
dst[tpig] = src0[tpig] * args.scale + args.bias;
|
||||
}
|
||||
|
||||
kernel void kernel_fill_f32(
|
||||
constant ggml_metal_kargs_fill & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = args.val;
|
||||
}
|
||||
|
||||
kernel void kernel_fill_f32_4(
|
||||
constant ggml_metal_kargs_fill & args,
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = args.val;
|
||||
}
|
||||
|
||||
kernel void kernel_clamp_f32(
|
||||
constant ggml_metal_kargs_clamp & args,
|
||||
device const float * src0,
|
||||
@@ -1595,6 +1611,36 @@ kernel void kernel_exp_f32_4(
|
||||
dst[tpig] = exp(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_softplus_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
|
||||
}
|
||||
|
||||
kernel void kernel_softplus_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float4 & x = src0[tpig];
|
||||
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
|
||||
}
|
||||
|
||||
kernel void kernel_expm1_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = exp(src0[tpig]) - 1.0f;
|
||||
}
|
||||
|
||||
kernel void kernel_expm1_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = exp(src0[tpig]) - 1.0f;
|
||||
}
|
||||
|
||||
kernel void kernel_reglu_f32(
|
||||
constant ggml_metal_kargs_glu & args,
|
||||
device const char * src0,
|
||||
@@ -1943,6 +1989,75 @@ typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
|
||||
|
||||
template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
|
||||
|
||||
|
||||
template<uint32_t ttype>
|
||||
bool _ggml_vec_tri_cmp(const int i, const int r);
|
||||
|
||||
template<>
|
||||
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER */ 3>(const int i, const int r) {
|
||||
return i < r;
|
||||
}
|
||||
|
||||
template<>
|
||||
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER_DIAG */ 2>(const int i, const int r) {
|
||||
return i <= r;
|
||||
}
|
||||
|
||||
template<>
|
||||
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER */ 1>(const int i, const int r) {
|
||||
return i > r;
|
||||
}
|
||||
|
||||
template<>
|
||||
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER_DIAG */ 0>(const int i, const int r) {
|
||||
return i >= r;
|
||||
}
|
||||
|
||||
template<typename T, int ttype>
|
||||
kernel void kernel_tri(
|
||||
constant ggml_metal_kargs_tri & args,
|
||||
device const char * src0,
|
||||
device const char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i3 = tgpig.z;
|
||||
const int i2 = tgpig.y;
|
||||
const int i1 = tgpig.x;
|
||||
|
||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
// Each thread is a single element of the row if ne00 < max threads per
|
||||
// threadgroup, so this will loop once for each index that this thread is
|
||||
// responsible for
|
||||
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
// Use the comparison as a mask for branchless
|
||||
dst_row[i0] = static_cast<T>(_ggml_vec_tri_cmp<ttype>(i0, i1)) * src_row[i0];
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_tri<float, 0>) kernel_tri_t;
|
||||
|
||||
template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri<float, 0>;
|
||||
template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri<float, 1>;
|
||||
template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri<float, 2>;
|
||||
template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri<float, 3>;
|
||||
template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri<half, 0>;
|
||||
template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri<half, 1>;
|
||||
template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri<half, 2>;
|
||||
template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri<half, 3>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri<bfloat, 0>;
|
||||
template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri<bfloat, 1>;
|
||||
template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri<bfloat, 2>;
|
||||
template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri<bfloat, 3>;
|
||||
#endif
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max(
|
||||
constant ggml_metal_kargs_soft_max & args,
|
||||
|
||||
+5
-7
@@ -726,21 +726,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
// sanity checks for models that have attention layers
|
||||
if (qs.n_attention_wv != 0 && !is_clip_model)
|
||||
{
|
||||
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
|
||||
// attention layers have a non-zero number of kv heads
|
||||
int32_t n_layer_attn = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
|
||||
int32_t n_layer_all = model.hparams.n_layer;
|
||||
if (llama_model_has_encoder(&model)) {
|
||||
// now n_layer_attn is the number of attention layers in the encoder
|
||||
// now n_layer_all is the number of attention layers in the encoder
|
||||
// for each decoder block, there are 2 attention layers
|
||||
n_layer_attn += 2 * model.hparams.dec_n_layer;
|
||||
n_layer_all += 2 * model.hparams.dec_n_layer;
|
||||
}
|
||||
|
||||
// note: for linear-attention models (such as Qwen3 Next) this is the number of linear layers
|
||||
const int32_t n_layer_recr = std::count(model.hparams.recurrent_layer_arr.begin(), model.hparams.recurrent_layer_arr.end(), true);
|
||||
|
||||
LLAMA_LOG_INFO("%s: n_layer_attn = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_attn, n_layer_recr, pruned_attention_w);
|
||||
LLAMA_LOG_INFO("%s: n_layer_all = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_all, n_layer_recr, pruned_attention_w);
|
||||
|
||||
GGML_ASSERT((qs.n_attention_wv == n_layer_attn - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected");
|
||||
GGML_ASSERT((qs.n_attention_wv == n_layer_all - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected");
|
||||
}
|
||||
|
||||
size_t total_size_org = 0;
|
||||
|
||||
@@ -7725,6 +7725,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 10, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 127, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 4 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 255, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 256, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 511, 5, 4, 3 }));
|
||||
@@ -7954,6 +7955,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));
|
||||
|
||||
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 }));
|
||||
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));
|
||||
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 4 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 16, 5, 4 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 20000, 10, 4, 1 }));
|
||||
|
||||
for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
|
||||
for (ggml_type type_a : all_types) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
|
||||
@@ -2,11 +2,6 @@ set(TARGET llama-server)
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
if (MINGW)
|
||||
# fix: https://github.com/ggml-org/llama.cpp/actions/runs/9651004652/job/26617901362?pr=8006
|
||||
add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
|
||||
endif()
|
||||
|
||||
if (NOT LLAMA_HTTPLIB)
|
||||
message(FATAL_ERROR "LLAMA_HTTPLIB is OFF, cannot build llama-server. Hint: to skip building server, set -DLLAMA_BUILD_SERVER=OFF")
|
||||
endif()
|
||||
|
||||
@@ -101,8 +101,6 @@ struct server_slot {
|
||||
std::string generated_text;
|
||||
llama_tokens generated_tokens;
|
||||
|
||||
common_chat_msg chat_msg;
|
||||
|
||||
std::vector<completion_token_output> generated_token_probs;
|
||||
|
||||
bool has_next_token = true;
|
||||
@@ -153,9 +151,6 @@ struct server_slot {
|
||||
|
||||
llama_token sampled;
|
||||
|
||||
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
std::vector<std::string> generated_tool_call_ids;
|
||||
|
||||
// stats
|
||||
size_t n_sent_text = 0; // number of sent text character
|
||||
|
||||
@@ -183,13 +178,10 @@ struct server_slot {
|
||||
stop = STOP_TYPE_NONE;
|
||||
stopping_word = "";
|
||||
n_sent_text = 0;
|
||||
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
|
||||
generated_tokens.clear();
|
||||
generated_token_probs.clear();
|
||||
chat_msg = {};
|
||||
json_schema = json();
|
||||
generated_tool_call_ids.clear();
|
||||
|
||||
// clear speculative decoding stats
|
||||
n_draft_total = 0;
|
||||
@@ -302,23 +294,6 @@ struct server_slot {
|
||||
return timings;
|
||||
}
|
||||
|
||||
const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
|
||||
GGML_ASSERT(task);
|
||||
|
||||
auto previous_msg = chat_msg;
|
||||
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
|
||||
auto new_msg = common_chat_parse(
|
||||
generated_text,
|
||||
/* is_partial= */ stop != STOP_TYPE_EOS,
|
||||
task->params.oaicompat_chat_syntax);
|
||||
if (!new_msg.empty()) {
|
||||
new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
|
||||
chat_msg = new_msg;
|
||||
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
|
||||
}
|
||||
return chat_msg;
|
||||
}
|
||||
|
||||
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
|
||||
GGML_ASSERT(task);
|
||||
|
||||
@@ -1284,8 +1259,6 @@ struct server_context_impl {
|
||||
} else {
|
||||
res->content = tkn.text_to_send;
|
||||
res->tokens = { tkn.tok };
|
||||
|
||||
slot.update_chat_msg(res->oaicompat_msg_diffs);
|
||||
}
|
||||
|
||||
res->n_decoded = slot.n_decoded;
|
||||
@@ -1317,8 +1290,14 @@ struct server_context_impl {
|
||||
res->id_slot = slot.id;
|
||||
|
||||
res->index = slot.task->index;
|
||||
res->content = slot.generated_text;
|
||||
res->tokens = std::move(slot.generated_tokens);
|
||||
// in stream mode, content and tokens are already in last partial chunk
|
||||
if (slot.task->params.stream) {
|
||||
res->content = "";
|
||||
res->tokens = llama_tokens{};
|
||||
} else {
|
||||
res->content = std::move(slot.generated_text);
|
||||
res->tokens = std::move(slot.generated_tokens);
|
||||
}
|
||||
res->timings = slot.get_timings();
|
||||
res->prompt = slot.task->tokens.detokenize(ctx, true);
|
||||
res->response_fields = std::move(slot.task->params.response_fields);
|
||||
@@ -1338,7 +1317,6 @@ struct server_context_impl {
|
||||
res->res_type = slot.task->params.res_type;
|
||||
res->oaicompat_model = slot.task->params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
|
||||
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
|
||||
|
||||
// populate res.probs_output
|
||||
if (slot.task->params.sampling.n_probs > 0) {
|
||||
@@ -2596,6 +2574,9 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
try {
|
||||
std::vector<server_task> tasks;
|
||||
|
||||
// tracking generation state and partial tool calls
|
||||
std::vector<task_result_state> states;
|
||||
|
||||
const auto & prompt = data.at("prompt");
|
||||
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
||||
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
||||
@@ -2611,6 +2592,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
||||
}
|
||||
tasks.reserve(inputs.size());
|
||||
states.reserve(inputs.size());
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
server_task task = server_task(type);
|
||||
|
||||
@@ -2628,10 +2610,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
task.params.res_type = res_type;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
task.params.oaicompat_model = ctx_server.model_name;
|
||||
states.push_back(task.params.oaicompat_chat_syntax);
|
||||
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
rd.set_states(std::move(states));
|
||||
rd.post_tasks(std::move(tasks));
|
||||
} catch (const std::exception & e) {
|
||||
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||
@@ -2657,7 +2641,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
// if single request, return single object instead of array
|
||||
res->ok(arr.size() == 1 ? arr[0] : arr);
|
||||
}
|
||||
|
||||
} else {
|
||||
// in streaming mode, the first error must be treated as non-stream response
|
||||
// this is to match the OAI API behavior
|
||||
@@ -2676,76 +2659,92 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
}
|
||||
|
||||
// next responses are streamed
|
||||
// to be sent immediately
|
||||
json first_result_json = first_result->to_json();
|
||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||
res->data = format_anthropic_sse(first_result->to_json());
|
||||
res->data = format_anthropic_sse(first_result_json);
|
||||
} else {
|
||||
res->data = format_oai_sse(first_result->to_json()); // to be sent immediately
|
||||
res->data = format_oai_sse(first_result_json);
|
||||
}
|
||||
res->status = 200;
|
||||
res->content_type = "text/event-stream";
|
||||
res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool {
|
||||
if (should_stop()) {
|
||||
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
||||
return false; // should_stop condition met
|
||||
}
|
||||
|
||||
if (!res_this->data.empty()) {
|
||||
// flush the first chunk
|
||||
output = std::move(res_this->data);
|
||||
res_this->data.clear();
|
||||
return true;
|
||||
}
|
||||
|
||||
server_response_reader & rd = res_this->rd;
|
||||
|
||||
// check if there is more data
|
||||
if (!rd.has_next()) {
|
||||
static auto format_error = [](task_response_type res_type, const json & res_json) {
|
||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||
// Anthropic doesn't send [DONE], message_stop was already sent
|
||||
output = "";
|
||||
} else if (res_type != TASK_RESPONSE_TYPE_NONE) {
|
||||
output = "data: [DONE]\n\n";
|
||||
} else {
|
||||
output = "";
|
||||
}
|
||||
SRV_DBG("%s", "all results received, terminating stream\n");
|
||||
return false; // no more data, terminate
|
||||
}
|
||||
|
||||
// receive subsequent results
|
||||
auto result = rd.next(should_stop);
|
||||
if (result == nullptr) {
|
||||
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
||||
return false; // should_stop condition met
|
||||
}
|
||||
|
||||
// send the results
|
||||
json res_json = result->to_json();
|
||||
if (result->is_error()) {
|
||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||
output = format_anthropic_sse({
|
||||
return format_anthropic_sse({
|
||||
{"event", "error"},
|
||||
{"data", res_json},
|
||||
});
|
||||
} else {
|
||||
output = format_oai_sse(json {{ "error", res_json }});
|
||||
return format_oai_sse(json {{ "error", res_json }});
|
||||
}
|
||||
SRV_DBG("%s", "error received during streaming, terminating stream\n");
|
||||
return false; // terminate on error
|
||||
} else {
|
||||
GGML_ASSERT(
|
||||
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|
||||
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
||||
);
|
||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||
output = format_anthropic_sse(res_json);
|
||||
} else {
|
||||
output = format_oai_sse(res_json);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// has next data, continue
|
||||
return true;
|
||||
try {
|
||||
if (should_stop()) {
|
||||
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
||||
return false; // should_stop condition met
|
||||
}
|
||||
|
||||
if (!res_this->data.empty()) {
|
||||
// flush the first chunk
|
||||
output = std::move(res_this->data);
|
||||
res_this->data.clear();
|
||||
return true;
|
||||
}
|
||||
|
||||
server_response_reader & rd = res_this->rd;
|
||||
|
||||
// check if there is more data
|
||||
if (!rd.has_next()) {
|
||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||
// Anthropic doesn't send [DONE], message_stop was already sent
|
||||
output = "";
|
||||
} else if (res_type != TASK_RESPONSE_TYPE_NONE) {
|
||||
output = "data: [DONE]\n\n";
|
||||
} else {
|
||||
output = "";
|
||||
}
|
||||
SRV_DBG("%s", "all results received, terminating stream\n");
|
||||
return false; // no more data, terminate
|
||||
}
|
||||
|
||||
// receive subsequent results
|
||||
auto result = rd.next(should_stop);
|
||||
if (result == nullptr) {
|
||||
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
||||
return false; // should_stop condition met
|
||||
}
|
||||
|
||||
// send the results
|
||||
if (result->is_error()) {
|
||||
json res_json = result->to_json();
|
||||
output = format_error(res_type, res_json);
|
||||
SRV_DBG("%s", "error received during streaming, terminating stream\n");
|
||||
return false; // terminate on error
|
||||
} else {
|
||||
GGML_ASSERT(
|
||||
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|
||||
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
||||
);
|
||||
json res_json = result->to_json();
|
||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||
output = format_anthropic_sse(res_json);
|
||||
} else {
|
||||
output = format_oai_sse(res_json);
|
||||
}
|
||||
}
|
||||
|
||||
// has next data, continue
|
||||
return true;
|
||||
|
||||
} catch (const std::exception & e) {
|
||||
json error_json = format_error_response(e.what(), ERROR_TYPE_SERVER);
|
||||
output = format_error(res_type, error_json);
|
||||
|
||||
// terminate on exception
|
||||
return false;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -900,6 +900,7 @@ static bool should_strip_proxy_header(const std::string & header_name) {
|
||||
// Headers that get duplicated when router forwards child responses
|
||||
if (header_name == "server" ||
|
||||
header_name == "transfer-encoding" ||
|
||||
header_name == "content-length" || // quick fix for https://github.com/ggml-org/llama.cpp/issues/17710
|
||||
header_name == "keep-alive") {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -271,6 +271,10 @@ void server_response::terminate() {
|
||||
// server_response_reader
|
||||
//
|
||||
|
||||
void server_response_reader::set_states(std::vector<task_result_state> && states) {
|
||||
this->states = std::move(states);
|
||||
}
|
||||
|
||||
void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
|
||||
id_tasks = server_task::get_list_id(tasks);
|
||||
queue_results.add_waiting_tasks(tasks);
|
||||
@@ -298,6 +302,12 @@ server_task_result_ptr server_response_reader::next(const std::function<bool()>
|
||||
SRV_DBG("%s", "received error result, stopping further processing\n");
|
||||
return result;
|
||||
}
|
||||
if (!states.empty()) {
|
||||
// update the generation state if needed
|
||||
size_t idx = result->get_index();
|
||||
GGML_ASSERT(idx < states.size());
|
||||
result->update(states[idx]);
|
||||
}
|
||||
if (result->is_stop()) {
|
||||
received_count++;
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
#include <mutex>
|
||||
#include <unordered_set>
|
||||
|
||||
// struct for managing server tasks
|
||||
// in most cases, use server_response_reader to post new tasks and retrieve results
|
||||
struct server_queue {
|
||||
private:
|
||||
int id = 0;
|
||||
@@ -67,6 +69,8 @@ private:
|
||||
void cleanup_pending_task(int id_target);
|
||||
};
|
||||
|
||||
// struct for managing server responses
|
||||
// in most cases, use server_response_reader to retrieve results
|
||||
struct server_response {
|
||||
private:
|
||||
bool running = true;
|
||||
@@ -120,6 +124,10 @@ struct server_response_reader {
|
||||
bool cancelled = false;
|
||||
int polling_interval_seconds;
|
||||
|
||||
// tracking generation state and partial tool calls
|
||||
// only used by streaming completions
|
||||
std::vector<task_result_state> states;
|
||||
|
||||
// should_stop function will be called each polling_interval_seconds
|
||||
server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
|
||||
: queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
|
||||
@@ -127,6 +135,7 @@ struct server_response_reader {
|
||||
stop();
|
||||
}
|
||||
|
||||
void set_states(std::vector<task_result_state> && states);
|
||||
void post_tasks(std::vector<server_task> && tasks);
|
||||
bool has_next() const;
|
||||
|
||||
|
||||
@@ -565,6 +565,7 @@ std::vector<unsigned char> completion_token_output::str_to_bytes(const std::stri
|
||||
// server_task_result_cmpl_final
|
||||
//
|
||||
json server_task_result_cmpl_final::to_json() {
|
||||
GGML_ASSERT(is_updated && "update() must be called before to_json()");
|
||||
switch (res_type) {
|
||||
case TASK_RESPONSE_TYPE_NONE:
|
||||
return to_json_non_oaicompat();
|
||||
@@ -582,8 +583,8 @@ json server_task_result_cmpl_final::to_json() {
|
||||
json server_task_result_cmpl_final::to_json_non_oaicompat() {
|
||||
json res = json {
|
||||
{"index", index},
|
||||
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
||||
{"tokens", stream ? llama_tokens {} : tokens},
|
||||
{"content", content},
|
||||
{"tokens", tokens},
|
||||
{"id_slot", id_slot},
|
||||
{"stop", true},
|
||||
{"model", oaicompat_model},
|
||||
@@ -619,7 +620,7 @@ json server_task_result_cmpl_final::to_json_oaicompat() {
|
||||
json res = json {
|
||||
{"choices", json::array({
|
||||
json{
|
||||
{"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
||||
{"text", content},
|
||||
{"index", index},
|
||||
{"logprobs", logprobs},
|
||||
{"finish_reason", finish_reason},
|
||||
@@ -700,6 +701,25 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat() {
|
||||
return res;
|
||||
}
|
||||
|
||||
common_chat_msg task_result_state::update_chat_msg(
|
||||
const std::string & text_added,
|
||||
bool is_partial,
|
||||
std::vector<common_chat_msg_diff> & diffs) {
|
||||
generated_text += text_added;
|
||||
auto msg_prv_copy = chat_msg;
|
||||
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
|
||||
auto new_msg = common_chat_parse(
|
||||
generated_text,
|
||||
is_partial,
|
||||
oaicompat_chat_syntax);
|
||||
if (!new_msg.empty()) {
|
||||
new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
|
||||
chat_msg = new_msg;
|
||||
diffs = common_chat_msg_diff::compute_diffs(msg_prv_copy, new_msg.empty() ? msg_prv_copy : new_msg);
|
||||
}
|
||||
return chat_msg;
|
||||
}
|
||||
|
||||
json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
|
||||
std::time_t t = std::time(0);
|
||||
std::string finish_reason = "length";
|
||||
@@ -956,6 +976,7 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() {
|
||||
// server_task_result_cmpl_partial
|
||||
//
|
||||
json server_task_result_cmpl_partial::to_json() {
|
||||
GGML_ASSERT(is_updated && "update() must be called before to_json()");
|
||||
switch (res_type) {
|
||||
case TASK_RESPONSE_TYPE_NONE:
|
||||
return to_json_non_oaicompat();
|
||||
|
||||
@@ -161,6 +161,25 @@ struct result_prompt_progress {
|
||||
json to_json() const;
|
||||
};
|
||||
|
||||
// struct for tracking the state of a task (e.g., for streaming)
|
||||
struct task_result_state {
|
||||
// tracking diffs for partial tool calls
|
||||
std::vector<common_chat_msg_diff> diffs;
|
||||
common_chat_syntax oaicompat_chat_syntax;
|
||||
common_chat_msg chat_msg;
|
||||
std::string generated_text; // append new chunks of generated text here
|
||||
std::vector<std::string> generated_tool_call_ids;
|
||||
|
||||
task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
|
||||
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}
|
||||
|
||||
// parse partial tool calls and update the internal state
|
||||
common_chat_msg update_chat_msg(
|
||||
const std::string & text_added,
|
||||
bool is_partial,
|
||||
std::vector<common_chat_msg_diff> & diffs);
|
||||
};
|
||||
|
||||
struct server_task_result {
|
||||
int id = -1;
|
||||
int id_slot = -1;
|
||||
@@ -175,6 +194,9 @@ struct server_task_result {
|
||||
virtual int get_index() {
|
||||
return -1;
|
||||
}
|
||||
virtual void update(task_result_state &) {
|
||||
// only used by server_task_result_cmpl_*
|
||||
}
|
||||
virtual json to_json() = 0;
|
||||
virtual ~server_task_result() = default;
|
||||
};
|
||||
@@ -233,9 +255,10 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_msg oaicompat_msg;
|
||||
common_chat_msg oaicompat_msg; // to be populated by update()
|
||||
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
|
||||
bool is_updated = false;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
@@ -247,6 +270,11 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
|
||||
virtual json to_json() override;
|
||||
|
||||
virtual void update(task_result_state & state) override {
|
||||
is_updated = true;
|
||||
oaicompat_msg = state.update_chat_msg(content, false, oaicompat_msg_diffs);
|
||||
}
|
||||
|
||||
json to_json_non_oaicompat();
|
||||
|
||||
json to_json_oaicompat();
|
||||
@@ -280,7 +308,8 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
|
||||
bool is_updated = false;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
@@ -292,6 +321,11 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||
|
||||
virtual json to_json() override;
|
||||
|
||||
virtual void update(task_result_state & state) override {
|
||||
is_updated = true;
|
||||
state.update_chat_msg(content, true, oaicompat_msg_diffs);
|
||||
}
|
||||
|
||||
json to_json_non_oaicompat();
|
||||
|
||||
json to_json_oaicompat();
|
||||
|
||||
Reference in New Issue
Block a user