mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-30 17:47:40 +02:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8303e8b0fb | |||
| 7ad0779f5d | |||
| f777a73e18 | |||
| af7747c95a | |||
| a28e0d5eb1 | |||
| 36c258ee92 | |||
| f3e64859ed | |||
| 5fa07c2f93 | |||
| 335eb04a91 | |||
| cf756d6e0a | |||
| d70908421f | |||
| de8b5a3624 | |||
| 51f311e057 | |||
| 586d5fe6eb | |||
| ecc8e3aeff | |||
| 0b3863ff95 |
@@ -173,7 +173,15 @@ jobs:
|
||||
name: llama-bin-macos-x64.zip
|
||||
|
||||
ubuntu-cpu-cmake:
|
||||
runs-on: ubuntu-22.04
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- build: 'x64'
|
||||
os: ubuntu-22.04
|
||||
- build: 'arm64'
|
||||
os: ubuntu-22.04-arm
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -239,14 +247,14 @@ jobs:
|
||||
run: |
|
||||
cp LICENSE ./build/bin/
|
||||
cp examples/run/linenoise.cpp/LICENSE ./build/bin/LICENSE.linenoise.cpp
|
||||
zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.zip ./build/bin/*
|
||||
zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip ./build/bin/*
|
||||
|
||||
- name: Upload artifacts
|
||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.zip
|
||||
name: llama-bin-ubuntu-x64.zip
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip
|
||||
name: llama-bin-ubuntu-${{ matrix.build }}.zip
|
||||
|
||||
ubuntu-latest-cmake-sanitizer:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
- Verify that the perplexity and the performance are not affected negatively by your changes (use `llama-perplexity` and `llama-bench`)
|
||||
- If you modified the `ggml` source, run the `test-backend-ops` tool to check whether different backend implementations of the `ggml` operators produce consistent results (this requires access to at least two different `ggml` backends)
|
||||
- If you modified a `ggml` operator or added a new one, add the corresponding test cases to `test-backend-ops`
|
||||
- Create separate PRs for each feature or fix. Avoid combining unrelated changes in a single PR
|
||||
- Consider allowing write access to your branch for faster reviews, as reviewers can push commits directly
|
||||
- If your PR becomes stale, don't hesitate to ping the maintainers in the comments
|
||||
|
||||
|
||||
@@ -680,6 +680,10 @@ ifdef GGML_CUDA_CCBIN
|
||||
MK_NVCCFLAGS += -ccbin $(GGML_CUDA_CCBIN)
|
||||
endif # GGML_CUDA_CCBIN
|
||||
|
||||
ifdef GGML_CUDA_NO_FA
|
||||
MK_NVCCFLAGS += -DGGML_CUDA_NO_FA
|
||||
endif # GGML_CUDA_NO_FA
|
||||
|
||||
ifdef GGML_CUDA_FA_ALL_QUANTS
|
||||
MK_NVCCFLAGS += -DGGML_CUDA_FA_ALL_QUANTS
|
||||
endif # GGML_CUDA_FA_ALL_QUANTS
|
||||
@@ -800,6 +804,10 @@ ifdef GGML_CUDA_NO_PEER_COPY
|
||||
HIPFLAGS += -DGGML_CUDA_NO_PEER_COPY
|
||||
endif # GGML_CUDA_NO_PEER_COPY
|
||||
|
||||
ifdef GGML_CUDA_NO_FA
|
||||
HIPFLAGS += -DGGML_CUDA_NO_FA
|
||||
endif # GGML_CUDA_NO_FA
|
||||
|
||||
OBJ_GGML_EXT += ggml/src/ggml-cuda/ggml-cuda.o
|
||||
OBJ_GGML_EXT += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu))
|
||||
OBJ_GGML_EXT += $(OBJ_CUDA_TMPL)
|
||||
@@ -847,7 +855,7 @@ ifdef GGML_MUSA
|
||||
CXX := $(MUSA_PATH)/bin/clang++
|
||||
MCC := $(CCACHE) $(MUSA_PATH)/bin/mcc
|
||||
|
||||
MUSAFLAGS = -x musa -mtgpu
|
||||
MUSAFLAGS = -fsigned-char -x musa -mtgpu
|
||||
MUSAFLAGS += $(foreach arch,$(subst ;, ,$(MUSA_ARCHITECTURES)),--cuda-gpu-arch=mp_$(arch))
|
||||
|
||||
ifdef GGML_CUDA_FORCE_MMQ
|
||||
@@ -876,6 +884,10 @@ ifdef GGML_CUDA_NO_PEER_COPY
|
||||
MUSAFLAGS += -DGGML_CUDA_NO_PEER_COPY
|
||||
endif # GGML_CUDA_NO_PEER_COPY
|
||||
|
||||
ifdef GGML_CUDA_NO_FA
|
||||
MUSAFLAGS += -DGGML_CUDA_NO_FA
|
||||
endif # GGML_CUDA_NO_FA
|
||||
|
||||
ifdef GGML_CUDA_FA_ALL_QUANTS
|
||||
MUSAFLAGS += -DGGML_CUDA_FA_ALL_QUANTS
|
||||
endif # GGML_CUDA_FA_ALL_QUANTS
|
||||
|
||||
@@ -206,6 +206,14 @@ This provides GPU acceleration using the MUSA cores of your Moore Threads MTT GP
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
For static build:
|
||||
|
||||
```bash
|
||||
cmake -B build -DGGML_MUSA=ON \
|
||||
-DBUILD_SHARED_LIBS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
The environment variable [`MUSA_VISIBLE_DEVICES`](https://docs.mthreads.com/musa-sdk/musa-sdk-doc-online/programming_guide/Z%E9%99%84%E5%BD%95/) can be used to specify which GPU(s) will be used.
|
||||
|
||||
The environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` can be used to enable unified memory in Linux. This allows swapping to system RAM instead of crashing when the GPU VRAM is exhausted.
|
||||
|
||||
@@ -124,15 +124,26 @@ struct ContentView: View {
|
||||
}
|
||||
}
|
||||
}.sheet(isPresented: $showingHelp) { // Sheet for help modal
|
||||
VStack(alignment: .leading) {
|
||||
NavigationView {
|
||||
VStack(alignment: .leading) {
|
||||
Text("1. Make sure the model is in GGUF Format")
|
||||
.padding()
|
||||
Text("2. Copy the download link of the quantized model")
|
||||
.padding()
|
||||
VStack(alignment: .leading) {
|
||||
Text("1. Make sure the model is in GGUF Format")
|
||||
.padding()
|
||||
Text("2. Copy the download link of the quantized model")
|
||||
.padding()
|
||||
}
|
||||
Spacer()
|
||||
}
|
||||
Spacer()
|
||||
}
|
||||
.navigationTitle("Help")
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
.toolbar {
|
||||
ToolbarItem(placement: .navigationBarTrailing) {
|
||||
Button("Done") {
|
||||
showingHelp = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1729,11 +1729,11 @@ void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) {
|
||||
}
|
||||
}
|
||||
|
||||
static void build_clip_img_from_data(const stbi_uc * data, int nx, int ny, clip_image_u8 * img) {
|
||||
void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img) {
|
||||
img->nx = nx;
|
||||
img->ny = ny;
|
||||
img->buf.resize(3 * nx * ny);
|
||||
memcpy(img->buf.data(), data, img->buf.size());
|
||||
memcpy(img->buf.data(), rgb_pixels, img->buf.size());
|
||||
}
|
||||
|
||||
bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
|
||||
@@ -1743,7 +1743,7 @@ bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
|
||||
LOG_ERR("%s: failed to load image '%s'\n", __func__, fname);
|
||||
return false;
|
||||
}
|
||||
build_clip_img_from_data(data, nx, ny, img);
|
||||
clip_build_img_from_pixels(data, nx, ny, img);
|
||||
stbi_image_free(data);
|
||||
return true;
|
||||
}
|
||||
@@ -1755,7 +1755,7 @@ bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length
|
||||
LOG_ERR("%s: failed to decode image bytes\n", __func__);
|
||||
return false;
|
||||
}
|
||||
build_clip_img_from_data(data, nx, ny, img);
|
||||
clip_build_img_from_pixels(data, nx, ny, img);
|
||||
stbi_image_free(data);
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -73,6 +73,9 @@ CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
|
||||
CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch);
|
||||
CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
|
||||
|
||||
/** build image from pixels decoded by other libraries instead of stb_image.h for better performance. The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes */
|
||||
CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img);
|
||||
|
||||
CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
|
||||
|
||||
/** interpret bytes as an image file with length bytes_length, and use the result to populate img */
|
||||
|
||||
+47
-47
@@ -323,25 +323,17 @@ class File {
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string read_all(const std::string & filename){
|
||||
open(filename, "r");
|
||||
lock();
|
||||
if (!file) {
|
||||
printe("Error opening file '%s': %s", filename.c_str(), strerror(errno));
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string to_string() {
|
||||
fseek(file, 0, SEEK_END);
|
||||
size_t size = ftell(file);
|
||||
const size_t size = ftell(file);
|
||||
fseek(file, 0, SEEK_SET);
|
||||
|
||||
std::string out;
|
||||
out.resize(size);
|
||||
size_t read_size = fread(&out[0], 1, size, file);
|
||||
const size_t read_size = fread(&out[0], 1, size, file);
|
||||
if (read_size != size) {
|
||||
printe("Error reading file '%s': %s", filename.c_str(), strerror(errno));
|
||||
return "";
|
||||
printe("Error reading file: %s", strerror(errno));
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@@ -985,7 +977,8 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
|
||||
}
|
||||
|
||||
static int read_user_input(std::string & user_input) {
|
||||
static const char * prompt_prefix = "> ";
|
||||
static const char * prompt_prefix_env = std::getenv("LLAMA_PROMPT_PREFIX");
|
||||
static const char * prompt_prefix = prompt_prefix_env ? prompt_prefix_env : "> ";
|
||||
#ifdef WIN32
|
||||
printf("\r" LOG_CLR_TO_EOL LOG_COL_DEFAULT "%s", prompt_prefix);
|
||||
|
||||
@@ -1098,59 +1091,66 @@ static int get_user_input(std::string & user_input, const std::string & user) {
|
||||
|
||||
// Reads a chat template file to be used
|
||||
static std::string read_chat_template_file(const std::string & chat_template_file) {
|
||||
if(chat_template_file.empty()){
|
||||
return "";
|
||||
}
|
||||
|
||||
File file;
|
||||
std::string chat_template = "";
|
||||
chat_template = file.read_all(chat_template_file);
|
||||
if(chat_template.empty()){
|
||||
if (!file.open(chat_template_file, "r")) {
|
||||
printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno));
|
||||
return "";
|
||||
}
|
||||
return chat_template;
|
||||
|
||||
return file.to_string();
|
||||
}
|
||||
|
||||
static int process_user_message(const Opt & opt, const std::string & user_input, LlamaData & llama_data,
|
||||
const common_chat_templates_ptr & chat_templates, int & prev_len,
|
||||
const bool stdout_a_terminal) {
|
||||
add_message("user", opt.user.empty() ? user_input : opt.user, llama_data);
|
||||
int new_len;
|
||||
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, opt.use_jinja) < 0) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
|
||||
std::string response;
|
||||
if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (!opt.user.empty()) {
|
||||
return 2;
|
||||
}
|
||||
|
||||
add_message("assistant", response, llama_data);
|
||||
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, opt.use_jinja) < 0) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Main chat loop function
|
||||
static int chat_loop(LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja) {
|
||||
static int chat_loop(LlamaData & llama_data, const Opt & opt) {
|
||||
int prev_len = 0;
|
||||
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
|
||||
|
||||
std::string chat_template = "";
|
||||
if(!chat_template_file.empty()){
|
||||
chat_template = read_chat_template_file(chat_template_file);
|
||||
std::string chat_template;
|
||||
if (!opt.chat_template_file.empty()) {
|
||||
chat_template = read_chat_template_file(opt.chat_template_file);
|
||||
}
|
||||
auto chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template.empty() ? nullptr : chat_template);
|
||||
|
||||
common_chat_templates_ptr chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template);
|
||||
static const bool stdout_a_terminal = is_stdout_a_terminal();
|
||||
while (true) {
|
||||
// Get user input
|
||||
std::string user_input;
|
||||
if (get_user_input(user_input, user) == 1) {
|
||||
if (get_user_input(user_input, opt.user) == 1) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
add_message("user", user.empty() ? user_input : user, llama_data);
|
||||
int new_len;
|
||||
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) {
|
||||
const int ret = process_user_message(opt, user_input, llama_data, chat_templates, prev_len, stdout_a_terminal);
|
||||
if (ret == 1) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
|
||||
std::string response;
|
||||
if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (!user.empty()) {
|
||||
} else if (ret == 2) {
|
||||
break;
|
||||
}
|
||||
|
||||
add_message("assistant", response, llama_data);
|
||||
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
@@ -1208,7 +1208,7 @@ int main(int argc, const char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (chat_loop(llama_data, opt.user, opt.chat_template_file, opt.use_jinja)) {
|
||||
if (chat_loop(llama_data, opt)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
|
||||
// increase max payload length to allow use of larger context size
|
||||
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
|
||||
// disable Nagle's algorithm
|
||||
#define CPPHTTPLIB_TCP_NODELAY true
|
||||
#include "httplib.h"
|
||||
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
|
||||
@@ -122,6 +122,7 @@ endif()
|
||||
option(GGML_LASX "ggml: enable lasx" ON)
|
||||
option(GGML_LSX "ggml: enable lsx" ON)
|
||||
option(GGML_RVV "ggml: enable rvv" ON)
|
||||
option(GGML_VXE "ggml: enable vxe" ON)
|
||||
|
||||
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
|
||||
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
|
||||
@@ -151,6 +152,7 @@ set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
|
||||
"ggml: max. batch size for using peer access")
|
||||
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
|
||||
option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF)
|
||||
option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON)
|
||||
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
|
||||
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
|
||||
|
||||
|
||||
@@ -99,6 +99,7 @@ extern "C" {
|
||||
// other
|
||||
GGML_BACKEND_API int ggml_cpu_has_riscv_v (void);
|
||||
GGML_BACKEND_API int ggml_cpu_has_vsx (void);
|
||||
GGML_BACKEND_API int ggml_cpu_has_vxe (void);
|
||||
GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void);
|
||||
GGML_BACKEND_API int ggml_cpu_has_llamafile (void);
|
||||
|
||||
|
||||
@@ -310,6 +310,27 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
if (GGML_RVV)
|
||||
list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d)
|
||||
endif()
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
|
||||
message(STATUS "s390x detected")
|
||||
file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
|
||||
string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
|
||||
|
||||
# TODO: Separation to determine activation of VX/VXE/VXE2
|
||||
if (${S390X_M} MATCHES "8561|8562")
|
||||
message(STATUS "z15 target")
|
||||
list(APPEND ARCH_FLAGS -march=z15 -mtune=z15)
|
||||
elseif (${S390X_M} MATCHES "3931")
|
||||
message(STATUS "z16 target")
|
||||
list(APPEND ARCH_FLAGS -march=z16 -mtune=z16)
|
||||
else()
|
||||
message(STATUS "Unknown target")
|
||||
message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.")
|
||||
list(APPEND ARCH_FLAGS -march=native -mtune=native)
|
||||
endif()
|
||||
|
||||
if (GGML_VXE)
|
||||
list(APPEND ARCH_FLAGS -mvx -mzvector)
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "Unknown architecture")
|
||||
endif()
|
||||
|
||||
@@ -59,6 +59,15 @@ struct ggml_compute_params {
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(__s390x__) && defined(__VEC__)
|
||||
#ifndef __VXE__
|
||||
#define __VXE__
|
||||
#endif
|
||||
#ifndef __VXE2__
|
||||
#define __VXE2__
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
#include <arm_sve.h>
|
||||
#include <sys/prctl.h>
|
||||
@@ -359,6 +368,148 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(__VXE__) || defined(__VXE2__)
|
||||
#include <vecintrin.h>
|
||||
|
||||
#define vec_neg(a) (-(a)) // Vector Negate
|
||||
#define vec_add(a, b) ((a) + (b)) // Vector Add
|
||||
#define vec_sub(a, b) ((a) - (b)) // Vector Subtract
|
||||
#define vec_mul(a, b) ((a) * (b)) // Vector Multiply
|
||||
#define vec_div(a, b) ((a) / (b)) // Vector Divide
|
||||
#define vec_sl(a, b) ((a) << (b)) // Vector Shift Left
|
||||
#define vec_sra(a, b) ((a) >> (b)) // Vector Shift Right
|
||||
#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebraic
|
||||
#define vec_slo(a, b) vec_slb(a, (b) << 64) // Vector Shift Left by Octet
|
||||
#define vec_sro(a, b) vec_srb(a, (b) << 64) // Vector Shift Right by Octet
|
||||
|
||||
#ifndef vec_and
|
||||
#define vec_and(a, b) ((a) & (b)) // Vector AND
|
||||
#endif
|
||||
|
||||
#ifndef vec_or
|
||||
#define vec_or(a, b) ((a) | (b)) // Vector OR
|
||||
#endif
|
||||
|
||||
#ifndef vec_xor
|
||||
#define vec_xor(a, b) ((a) ^ (b)) // Vector XOR
|
||||
#endif
|
||||
|
||||
typedef signed char char8x16_t __attribute__((vector_size(16)));
|
||||
typedef unsigned char uchar8x16_t __attribute__((vector_size(16)));
|
||||
|
||||
typedef int8_t int8x16_t __attribute__((vector_size(16)));
|
||||
typedef int16_t int16x8_t __attribute__((vector_size(16)));
|
||||
typedef int32_t int32x4_t __attribute__((vector_size(16)));
|
||||
|
||||
typedef uint8_t uint8x16_t __attribute__((vector_size(16)));
|
||||
typedef uint16_t uint16x8_t __attribute__((vector_size(16)));
|
||||
typedef uint32_t uint32x4_t __attribute__((vector_size(16)));
|
||||
|
||||
typedef float float32x4_t __attribute__((vector_size(16)));
|
||||
typedef double double64x2_t __attribute((vector_size(16)));
|
||||
|
||||
typedef signed long long long64x2_t __attribute((vector_size(16)));
|
||||
typedef unsigned long long ulong64x2_t __attribute__((vector_size(16)));
|
||||
|
||||
typedef struct ggml_uint8x16x2_t {
|
||||
uint8x16_t val[2];
|
||||
} ggml_uint8x16x2_t;
|
||||
|
||||
inline static ggml_uint8x16x2_t ggml_vec_xl_u8x2(const uint8_t * ptr) {
|
||||
ggml_uint8x16x2_t res;
|
||||
|
||||
res.val[0] = vec_xl( 0, ptr);
|
||||
res.val[1] = vec_xl(16, ptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
typedef struct ggml_uint8x16x4_t {
|
||||
uint8x16_t val[4];
|
||||
} ggml_uint8x16x4_t;
|
||||
|
||||
inline static ggml_uint8x16x4_t ggml_vec_xl_u8x4(const uint8_t * ptr) {
|
||||
ggml_uint8x16x4_t res;
|
||||
|
||||
res.val[0] = vec_xl( 0, ptr);
|
||||
res.val[1] = vec_xl(16, ptr);
|
||||
res.val[2] = vec_xl(32, ptr);
|
||||
res.val[3] = vec_xl(48, ptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
typedef struct ggml_int8x16x4_t {
|
||||
int8x16_t val[4];
|
||||
} ggml_int8x16x4_t;
|
||||
|
||||
inline static ggml_int8x16x4_t ggml_vec_xl_s8x4(const int8_t * ptr) {
|
||||
ggml_int8x16x4_t res;
|
||||
|
||||
res.val[0] = vec_xl( 0, ptr);
|
||||
res.val[1] = vec_xl(16, ptr);
|
||||
res.val[2] = vec_xl(32, ptr);
|
||||
res.val[3] = vec_xl(48, ptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
typedef struct ggml_int16x8x2_t {
|
||||
int16x8_t val[2];
|
||||
} ggml_int16x8x2_t;
|
||||
|
||||
inline static ggml_int16x8x2_t ggml_vec_xl_s16x2(const int16_t * ptr) {
|
||||
ggml_int16x8x2_t res;
|
||||
|
||||
res.val[0] = vec_xl( 0, ptr);
|
||||
res.val[1] = vec_xl(16, ptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
/*
|
||||
! WARNING: Very slow. Use vec_perm if possible. Refer to iq4_xs
|
||||
! or iq4_nl for example implementation.
|
||||
*/
|
||||
inline static int8x16_t ggml_vec_tbl(int8x16_t a, uint8x16_t b) {
|
||||
int8x16_t res;
|
||||
|
||||
res[ 0] = a[b[ 0]];
|
||||
res[ 1] = a[b[ 1]];
|
||||
res[ 2] = a[b[ 2]];
|
||||
res[ 3] = a[b[ 3]];
|
||||
res[ 4] = a[b[ 4]];
|
||||
res[ 5] = a[b[ 5]];
|
||||
res[ 6] = a[b[ 6]];
|
||||
res[ 7] = a[b[ 7]];
|
||||
res[ 8] = a[b[ 8]];
|
||||
res[ 9] = a[b[ 9]];
|
||||
res[10] = a[b[10]];
|
||||
res[11] = a[b[11]];
|
||||
res[12] = a[b[12]];
|
||||
res[13] = a[b[13]];
|
||||
res[14] = a[b[14]];
|
||||
res[15] = a[b[15]];
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) {
|
||||
const uchar8x16_t v_maske = { 0, 1, 4, 5, 8, 9, 12, 13,
|
||||
16, 17, 20, 21, 24, 25, 28, 29 };
|
||||
|
||||
const int16x8_t v_abo = vec_pack((int32x4_t)a, (int32x4_t)b);
|
||||
const int16x8_t v_abe = vec_perm(a, b, v_maske);
|
||||
return v_abo + v_abe;
|
||||
}
|
||||
|
||||
inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
||||
const int16x8_t p = vec_mule(a, b) + vec_mulo(a, b);
|
||||
return acc + (vec_unpackh(p) + vec_unpackl(p));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#if defined(__loongarch_asx)
|
||||
/* float type data load instructions */
|
||||
static __m128 __lsx_vreplfr2vr_s(const float val) {
|
||||
|
||||
@@ -1011,6 +1011,38 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
|
||||
__lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
|
||||
|
||||
}
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
__vector float srcv [8];
|
||||
__vector float asrcv[8];
|
||||
__vector float amaxv[8];
|
||||
|
||||
for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
|
||||
for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
|
||||
for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
|
||||
for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
|
||||
for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
|
||||
|
||||
const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
|
||||
vec_extract(amaxv[0], 1)),
|
||||
MAX(vec_extract(amaxv[0], 2),
|
||||
vec_extract(amaxv[0], 3)));
|
||||
|
||||
const float d = amax / ((1 << 7) - 1);
|
||||
const float id = d ? 1.0f / d : 0.0f;
|
||||
|
||||
y[i].d = GGML_FP32_TO_FP16(d);
|
||||
|
||||
for (int j = 0; j < 8; j++) {
|
||||
const __vector float v = vec_mul(srcv[j], vec_splats(id));
|
||||
const __vector int32_t vi = vec_signed(v);
|
||||
|
||||
y[i].qs[4*j + 0] = vec_extract(vi, 0);
|
||||
y[i].qs[4*j + 1] = vec_extract(vi, 1);
|
||||
y[i].qs[4*j + 2] = vec_extract(vi, 2);
|
||||
y[i].qs[4*j + 3] = vec_extract(vi, 3);
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(nb);
|
||||
// scalar
|
||||
@@ -1337,6 +1369,44 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
|
||||
__lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0);
|
||||
__lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
|
||||
}
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
__vector float srcv [8];
|
||||
__vector float asrcv[8];
|
||||
__vector float amaxv[8];
|
||||
|
||||
for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
|
||||
for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
|
||||
for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
|
||||
for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
|
||||
for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
|
||||
|
||||
const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
|
||||
vec_extract(amaxv[0], 1)),
|
||||
MAX(vec_extract(amaxv[0], 2),
|
||||
vec_extract(amaxv[0], 3)));
|
||||
|
||||
const float d = amax / ((1 << 7) - 1);
|
||||
const float id = d ? 1.0f / d : 0.0f;
|
||||
|
||||
y[i].d = GGML_FP32_TO_FP16(d);
|
||||
|
||||
__vector int32_t acc = vec_splats(0);
|
||||
|
||||
for (int j = 0; j < 8; j++) {
|
||||
const __vector float v = vec_mul(srcv[j], vec_splats(id));
|
||||
const __vector int32_t vi = vec_signed(v);
|
||||
|
||||
y[i].qs[4*j + 0] = vec_extract(vi, 0);
|
||||
y[i].qs[4*j + 1] = vec_extract(vi, 1);
|
||||
y[i].qs[4*j + 2] = vec_extract(vi, 2);
|
||||
y[i].qs[4*j + 3] = vec_extract(vi, 3);
|
||||
|
||||
acc = vec_add(acc, vi);
|
||||
}
|
||||
|
||||
y[i].s = GGML_FP32_TO_FP16(d * (acc[0] + acc[1] + acc[2] + acc[3]));
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(nb);
|
||||
// scalar
|
||||
@@ -2488,6 +2558,37 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||
}
|
||||
|
||||
sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
__vector float acc = vec_splats(0.0f);
|
||||
|
||||
const __vector uint8_t v_m = vec_splats((const uint8_t)0x0F);
|
||||
const __vector int8_t v_s = vec_splats( (const int8_t)0x08);
|
||||
|
||||
for (; ib < nb; ++ib) {
|
||||
const __vector uint8_t v_x = vec_xl(0, x[ib].qs);
|
||||
const __vector int8_t v_xl = (const __vector int8_t)(v_x & v_m);
|
||||
const __vector int8_t v_xh = (const __vector int8_t)(v_x >> 4);
|
||||
|
||||
const __vector int8_t v_xls = vec_sub(v_xl, v_s);
|
||||
const __vector int8_t v_xhs = vec_sub(v_xh, v_s);
|
||||
|
||||
const __vector int8_t v_yl = vec_xl(0 , y[ib].qs);
|
||||
const __vector int8_t v_yh = vec_xl(QK8_0/2, y[ib].qs);
|
||||
|
||||
const __vector int16_t v_xylso = vec_mulo(v_xls, v_yl);
|
||||
const __vector int16_t v_xylse = vec_mule(v_xls, v_yl);
|
||||
const __vector int16_t v_xyhso = vec_mulo(v_xhs, v_yh);
|
||||
const __vector int16_t v_xyhse = vec_mule(v_xhs, v_yh);
|
||||
|
||||
__vector int16_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_);
|
||||
|
||||
const __vector float v_xy = vec_float(vec_unpackh(v_xy_));
|
||||
const __vector float v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
|
||||
|
||||
acc = vec_madd(v_xy, v_d, acc);
|
||||
}
|
||||
|
||||
sumf = acc[0] + acc[1] + acc[2] + acc[3];
|
||||
#endif
|
||||
for (; ib < nb; ++ib) {
|
||||
int sumi0 = 0;
|
||||
@@ -2781,6 +2882,35 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
||||
}
|
||||
|
||||
sumf = hsum_float_8(acc) + summs;
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
float summs = 0;
|
||||
float32x4_t acc = vec_splats(0.0f);
|
||||
|
||||
const uint8x16_t v_m = vec_splat_u8(0x0F);
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (; ib < nb; ++ib) {
|
||||
__builtin_prefetch(x[ib].qs, 0, 1);
|
||||
__builtin_prefetch(y[ib].qs, 0, 1);
|
||||
|
||||
summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
|
||||
|
||||
const uint8x16_t v_x = vec_xl(0, x[ib].qs);
|
||||
const int8x16_t v_xl = (const int8x16_t)(v_x & v_m);
|
||||
const int8x16_t v_xh = (const int8x16_t)(v_x >> 4);
|
||||
|
||||
const int8x16_t v_yl = vec_xl(0 , y[ib].qs);
|
||||
const int8x16_t v_yh = vec_xl(QK8_1/2, y[ib].qs);
|
||||
|
||||
const int32x4_t v_xy_ = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
|
||||
const float32x4_t v_xy = vec_float(v_xy_);
|
||||
|
||||
const float32x4_t v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
|
||||
|
||||
acc = vec_madd(v_xy, v_d, acc);
|
||||
}
|
||||
|
||||
sumf = acc[0] + acc[1] + acc[2] + acc[3] + summs;
|
||||
#endif
|
||||
for (; ib < nb; ++ib) {
|
||||
int sumi0 = 0;
|
||||
@@ -3915,6 +4045,27 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||
}
|
||||
|
||||
sumf = hsum_float_8(acc);
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
__vector float acc = vec_splats(0.0f);
|
||||
|
||||
#pragma GCC unroll 8
|
||||
for (; ib < nb; ++ib) {
|
||||
__builtin_prefetch(x[ib].qs, 0, 1);
|
||||
__builtin_prefetch(y[ib].qs, 0, 1);
|
||||
|
||||
const int8x16_t v_xl = vec_xl(0 , x[ib].qs);
|
||||
const int8x16_t v_xh = vec_xl(QK8_0/2, x[ib].qs);
|
||||
const int8x16_t v_yl = vec_xl(0 , y[ib].qs);
|
||||
const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs);
|
||||
|
||||
const int32x4_t v_xy_ = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
|
||||
const float32x4_t v_xy = vec_float(v_xy_);
|
||||
const float32x4_t v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
|
||||
|
||||
acc = vec_madd(v_xy, v_d, acc);
|
||||
}
|
||||
|
||||
sumf = acc[0] + acc[1] + acc[2] + acc[3];
|
||||
#endif
|
||||
for (; ib < nb; ++ib) {
|
||||
int sumi = 0;
|
||||
@@ -6797,6 +6948,77 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||
|
||||
|
||||
*s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
const uint8x16_t v_lm = vec_splat_u8(0x0F);
|
||||
const int32x4_t v_z = vec_splat_s32(0);
|
||||
|
||||
uint8x16_t v_x[2];
|
||||
int8x16_t v_xl[2];
|
||||
int8x16_t v_y[2];
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
||||
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
||||
|
||||
const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
|
||||
const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
|
||||
const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh);
|
||||
|
||||
memcpy(utmp, x[i].scales, 12);
|
||||
|
||||
uint32x4_t v_mins8 = { 0 };
|
||||
v_mins8 = vec_insert(utmp[1] & kmask1, v_mins8, 0);
|
||||
v_mins8 = vec_insert(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), v_mins8, 1);
|
||||
|
||||
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
||||
utmp[0] &= kmask1;
|
||||
|
||||
const int16x8_t v_minsh = (int16x8_t)vec_unpackh((uint8x16_t)v_mins8);
|
||||
|
||||
const int32x4_t v_minso = vec_mulo(v_ysums, v_minsh);
|
||||
const int32x4_t v_minse = vec_mule(v_ysums, v_minsh);
|
||||
const int32x4_t v_mins = v_minso + v_minse;
|
||||
sumf -= dmin * (v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]);
|
||||
|
||||
const uint8_t * scales = (const uint8_t *)utmp;
|
||||
const uint8_t * restrict x0 = x[i].qs;
|
||||
const int8_t * restrict y0 = y[i].qs;
|
||||
|
||||
int32_t sumi1 = 0;
|
||||
int32_t sumi2 = 0;
|
||||
|
||||
for (int j = 0; j < QK_K/64; ++j) {
|
||||
v_x[0] = vec_xl(0 , x0);
|
||||
v_x[1] = vec_xl(16, x0);
|
||||
x0 += 32;
|
||||
|
||||
v_y[0] = vec_xl(0 , y0);
|
||||
v_y[1] = vec_xl(16, y0);
|
||||
y0 += 32;
|
||||
|
||||
v_xl[0] = (int8x16_t)vec_and(v_x[0], v_lm);
|
||||
v_xl[1] = (int8x16_t)vec_and(v_x[1], v_lm);
|
||||
|
||||
const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);
|
||||
sumi1 += (p1[0] + p1[1] + p1[2] + p1[3]) * scales[2*j+0];
|
||||
|
||||
v_y[0] = vec_xl(0 , y0);
|
||||
v_y[1] = vec_xl(16, y0);
|
||||
y0 += 32;
|
||||
|
||||
v_xl[0] = (int8x16_t)vec_sr(v_x[0], 4);
|
||||
v_xl[1] = (int8x16_t)vec_sr(v_x[1], 4);
|
||||
|
||||
const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);
|
||||
sumi2 += (p2[0] + p2[1] + p2[2] + p2[3]) * scales[2*j+1];
|
||||
}
|
||||
|
||||
sumf += d * (sumi1 + sumi2);
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
#else
|
||||
|
||||
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
||||
@@ -7526,7 +7748,94 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||
acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4));
|
||||
|
||||
*s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
const uint8x16_t v_lm = vec_splat_u8(0x0F);
|
||||
const uint8x16_t v_1m = vec_splat_u8(0x01);
|
||||
const uint8x16_t v_2m = vec_splat_u8(0x02);
|
||||
|
||||
const int32x4_t v_z = vec_splat_s32(0);
|
||||
|
||||
const uchar8x16_t v_minsm = {
|
||||
0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
|
||||
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF
|
||||
};
|
||||
|
||||
int8x16_t q5b[4];
|
||||
uint8x16_t q5h[4];
|
||||
|
||||
uint8x16_t v_xl[2];
|
||||
uint8x16_t v_xh[2];
|
||||
int8x16_t v_y[4];
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
||||
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
||||
|
||||
const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
|
||||
const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
|
||||
const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh);
|
||||
|
||||
memcpy(utmp, x[i].scales, 12);
|
||||
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux = utmp[1] & kmask1;
|
||||
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
||||
utmp[2] = uaux;
|
||||
utmp[0] &= kmask1;
|
||||
|
||||
const uint8x16_t v_mins16 = vec_xl(0, (const uint8_t *)utmp);
|
||||
const uint8x16_t v_mins8 = vec_perm(v_mins16, v_mins16, v_minsm);
|
||||
const int16x8_t v_minsh = (int16x8_t)vec_unpackh(v_mins8);
|
||||
|
||||
const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh);
|
||||
const int32x4_t v_minshe = vec_mule(v_ysums, v_minsh);
|
||||
const int32x4_t v_mins = vec_add(v_minsho, v_minshe);
|
||||
const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3];
|
||||
|
||||
const uint8_t * scales = (const uint8_t *)utmp;
|
||||
const uint8_t * restrict x0l = x[i].qs;
|
||||
const uint8_t * restrict x0h = x[i].qh;
|
||||
const int8_t * restrict y0 = y[i].qs;
|
||||
|
||||
v_xh[0] = vec_xl(0 , x0h);
|
||||
v_xh[1] = vec_xl(16, x0h);
|
||||
|
||||
int32_t sumi = 0;
|
||||
for (int j = 0; j < QK_K/64; ++j) {
|
||||
v_xl[0] = vec_xl(0 , x0l);
|
||||
v_xl[1] = vec_xl(16, x0l);
|
||||
x0l += 32;
|
||||
|
||||
v_y[0] = vec_xl(0 , y0);
|
||||
v_y[1] = vec_xl(16, y0);
|
||||
v_y[2] = vec_xl(32, y0);
|
||||
v_y[3] = vec_xl(48, y0);
|
||||
y0 += 64;
|
||||
|
||||
q5h[0] = vec_sl(vec_and(v_1m, v_xh[0]), 4);
|
||||
q5h[1] = vec_sl(vec_and(v_1m, v_xh[1]), 4);
|
||||
q5h[2] = vec_sl(vec_and(v_2m, v_xh[0]), 3);
|
||||
q5h[3] = vec_sl(vec_and(v_2m, v_xh[1]), 3);
|
||||
v_xh[0] = vec_sr(v_xh[0], 2);
|
||||
v_xh[1] = vec_sr(v_xh[1], 2);
|
||||
|
||||
q5b[0] = (int8x16_t)vec_or(vec_and(v_xl[0], v_lm), q5h[0]);
|
||||
q5b[1] = (int8x16_t)vec_or(vec_and(v_xl[1], v_lm), q5h[1]);
|
||||
q5b[2] = (int8x16_t)vec_or(vec_sr(v_xl[0], 4), q5h[2]);
|
||||
q5b[3] = (int8x16_t)vec_or(vec_sr(v_xl[1], 4), q5h[3]);
|
||||
|
||||
int32x4_t sumi0 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[0], v_y[0]), q5b[1], v_y[1]);
|
||||
int32x4_t sumi1 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[2], v_y[2]), q5b[3], v_y[3]);
|
||||
|
||||
sumi += (sumi0[0] + sumi0[1] + sumi0[2] + sumi0[3]) * *scales++;
|
||||
sumi += (sumi1[0] + sumi1[1] + sumi1[2] + sumi1[3]) * *scales++;
|
||||
}
|
||||
|
||||
sumf += d * sumi - dmin * mins;
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
#else
|
||||
|
||||
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
||||
@@ -8243,7 +8552,130 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc);
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
float sum = 0;
|
||||
|
||||
// Lower 4-bit and upper 2-bit masks
|
||||
const uint8x16_t v_lm = vec_splat_u8(0x0F);
|
||||
const uint8x16_t v_um = vec_splat_u8(0x03);
|
||||
|
||||
const int32x4_t v_z = vec_splat_s32(0);
|
||||
|
||||
int8x16_t q6b[4];
|
||||
uint8x16_t q6h[4];
|
||||
|
||||
uint8x16_t v_xl[4];
|
||||
uint8x16_t v_xh[2];
|
||||
int8x16_t v_y[4];
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float d_all = GGML_FP16_TO_FP32(x[i].d);
|
||||
|
||||
const uint8_t * restrict x0l = x[i].ql;
|
||||
const uint8_t * restrict x0h = x[i].qh;
|
||||
const int8_t * restrict y0 = y[i].qs;
|
||||
|
||||
const int8_t * restrict scale = x[i].scales;
|
||||
|
||||
const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
|
||||
const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
|
||||
|
||||
const int8x16_t v_scale = vec_xl(0, scale);
|
||||
const int16x8_t v_scalel = vec_unpackh(v_scale);
|
||||
const int16x8_t v_scaleh = vec_unpackl(v_scale);
|
||||
|
||||
const int32x4_t v_minslo = vec_mulo(v_ysumsl, v_scalel);
|
||||
const int32x4_t v_minsle = vec_mule(v_ysumsl, v_scalel);
|
||||
const int32x4_t v_minsho = vec_mulo(v_ysumsh, v_scaleh);
|
||||
const int32x4_t v_minshe = vec_mule(v_ysumsh, v_scaleh);
|
||||
const int32x4_t v_mins = v_minslo + v_minsle + v_minsho + v_minshe;
|
||||
|
||||
const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3];
|
||||
|
||||
int32_t isum = 0;
|
||||
for (int j = 0; j < QK_K/128; ++j) {
|
||||
// Load model upper 2 bits
|
||||
v_xh[0] = vec_xl(0 , x0h);
|
||||
v_xh[1] = vec_xl(16, x0h);
|
||||
x0h += 32;
|
||||
|
||||
// Load model lower 4 bits
|
||||
v_xl[0] = vec_xl(0 , x0l);
|
||||
v_xl[1] = vec_xl(16, x0l);
|
||||
v_xl[2] = vec_xl(32, x0l);
|
||||
v_xl[3] = vec_xl(48, x0l);
|
||||
x0l += 64;
|
||||
|
||||
// Load activation quants
|
||||
v_y[0] = vec_xl(0 , y0);
|
||||
v_y[1] = vec_xl(16, y0);
|
||||
v_y[2] = vec_xl(32, y0);
|
||||
v_y[3] = vec_xl(48, y0);
|
||||
y0 += 64;
|
||||
|
||||
q6h[0] = vec_sl(vec_and(v_um, v_xh[0]), 4);
|
||||
q6h[1] = vec_sl(vec_and(v_um, v_xh[1]), 4);
|
||||
uint8x16_t shifted = vec_sr(v_xh[0], 2);
|
||||
q6h[2] = vec_sl(vec_and(v_um, shifted), 4);
|
||||
shifted = vec_sr(v_xh[1], 2);
|
||||
q6h[3] = vec_sl(vec_and(v_um, shifted), 4);
|
||||
|
||||
q6b[0] = (int8x16_t)(vec_or(vec_and(v_xl[0], v_lm), q6h[0]));
|
||||
q6b[1] = (int8x16_t)(vec_or(vec_and(v_xl[1], v_lm), q6h[1]));
|
||||
q6b[2] = (int8x16_t)(vec_or(vec_and(v_xl[2], v_lm), q6h[2]));
|
||||
q6b[3] = (int8x16_t)(vec_or(vec_and(v_xl[3], v_lm), q6h[3]));
|
||||
|
||||
int32x4_t summs0 = ggml_vec_dot(v_z, q6b[0], v_y[0]);
|
||||
int32x4_t summs1 = ggml_vec_dot(v_z, q6b[1], v_y[1]);
|
||||
int32x4_t summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]);
|
||||
int32x4_t summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]);
|
||||
|
||||
isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] +
|
||||
(summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] +
|
||||
(summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] +
|
||||
(summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3];
|
||||
|
||||
scale += 4;
|
||||
|
||||
|
||||
// Load activation quants
|
||||
v_y[0] = vec_xl(0 , y0);
|
||||
v_y[1] = vec_xl(16, y0);
|
||||
v_y[2] = vec_xl(32, y0);
|
||||
v_y[3] = vec_xl(48, y0);
|
||||
y0 += 64;
|
||||
|
||||
shifted = vec_sr(v_xh[0], 4);
|
||||
q6h[0] = vec_sl(vec_and(v_um, shifted), 4);
|
||||
shifted = vec_sr(v_xh[1], 4);
|
||||
q6h[1] = vec_sl(vec_and(v_um, shifted), 4);
|
||||
shifted = vec_sr(v_xh[0], 6);
|
||||
q6h[2] = vec_sl(vec_and(v_um, shifted), 4);
|
||||
shifted = vec_sr(v_xh[1], 6);
|
||||
q6h[3] = vec_sl(vec_and(v_um, shifted), 4);
|
||||
|
||||
q6b[0] = (int8x16_t)(vec_or(vec_sr(v_xl[0], 4), q6h[0]));
|
||||
q6b[1] = (int8x16_t)(vec_or(vec_sr(v_xl[1], 4), q6h[1]));
|
||||
q6b[2] = (int8x16_t)(vec_or(vec_sr(v_xl[2], 4), q6h[2]));
|
||||
q6b[3] = (int8x16_t)(vec_or(vec_sr(v_xl[3], 4), q6h[3]));
|
||||
|
||||
summs0 = ggml_vec_dot(v_z, q6b[0], v_y[0]);
|
||||
summs1 = ggml_vec_dot(v_z, q6b[1], v_y[1]);
|
||||
summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]);
|
||||
summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]);
|
||||
|
||||
isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] +
|
||||
(summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] +
|
||||
(summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] +
|
||||
(summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3];
|
||||
|
||||
scale += 4;
|
||||
}
|
||||
|
||||
sum += d_all * y[i].d * (isum - 32 * mins);
|
||||
}
|
||||
|
||||
*s = sum;
|
||||
#else
|
||||
|
||||
int8_t aux8[QK_K];
|
||||
@@ -8604,7 +9036,57 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
|
||||
}
|
||||
|
||||
*s = 0.125f * hsum_float_8(accumf);
|
||||
|
||||
//#elif defined(__VXE__) || defined(__VXE2__)
|
||||
// const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
||||
//
|
||||
// uint32_t aux32[4];
|
||||
// const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||
//
|
||||
// float sumf = 0;
|
||||
//
|
||||
// for (int i = 0; i < nb; ++i) {
|
||||
// const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||
// const uint16_t * restrict q2 = x[i].qs;
|
||||
// const int8_t * restrict q8 = y[i].qs;
|
||||
//
|
||||
// float sumf1 = 0, sumf2 = 0;
|
||||
//
|
||||
// for (int ib32 = 0; ib32 < QK_K/32; ib += 2) {
|
||||
// int8x16_t q8b0 = vec_xl( 0, q8);
|
||||
// int8x16_t qb81 = vec_xl(16, q8);
|
||||
// int8x16_t q8b2 = vec_xl(32, q8);
|
||||
// int8x16_t q8b3 = vec_xl(48, q8);
|
||||
// q8 += 64;
|
||||
//
|
||||
// memcpy(aux32, q2, 4 * sizeof(uint32_t));
|
||||
// q2 += 8;
|
||||
//
|
||||
// int8x16_t q2u0 = { *(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1]) };
|
||||
// int8x16_t q2u1 = { *(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3]) };
|
||||
// int8x16_t q2u2 = { *(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9]) };
|
||||
// int8x16_t q2u3 = { *(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11]) };
|
||||
//
|
||||
// int8x16_t q2s0 = { *(const int64_t *)(signs64 + ((aux32[1] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 7) & 127)) };
|
||||
// int8x16_t q2s1 = { *(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127)) };
|
||||
// int8x16_t q2s2 = { *(const int64_t *)(signs64 + ((aux32[3] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 7) & 127)) };
|
||||
// int8x16_t q2s3 = { *(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127)) };
|
||||
//
|
||||
// q2u0 = vec_mul(q2u0, q2s0);
|
||||
// q2u1 = vec_mul(q2u1, q2s1);
|
||||
// q2u2 = vec_mul(q2u2, q2s2);
|
||||
// q2u3 = vec_mul(q2u3, q2s3);
|
||||
//
|
||||
// const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(vec_splat_s32(0), q2u0, q8b0), q2u1, q8b1);
|
||||
// const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(vec_splat_s32(0), q2u2, q8b2), q2u3, q8b3);
|
||||
//
|
||||
// sumf1 += (p1[0] + p1[1] + p1[2] + p1[3]) * (0.5f + (aux32[1] >> 28));
|
||||
// sumf2 += (p2[0] + p2[1] + p2[2] + p2[3]) * (0.5f + (aux32[3] >> 28));
|
||||
// }
|
||||
//
|
||||
// sumf += d * (sumf1 + sumf2);
|
||||
// }
|
||||
//
|
||||
// *s = 0.25f * sumf;
|
||||
#else
|
||||
|
||||
uint32_t aux32[2];
|
||||
@@ -11365,6 +11847,27 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||
|
||||
sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2));
|
||||
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
const int8x16_t v_k = vec_xl(0, kvalues_iq4nl);
|
||||
const uint8x16_t v_m = vec_splat_u8(0x0F);
|
||||
|
||||
for (; ib < nb; ++ib) {
|
||||
const block_iq4_nl * restrict x0 = &x[ib];
|
||||
const block_q8_0 * restrict y0 = &y[ib];
|
||||
|
||||
const uint8x16_t v_x = vec_xl(0, x0->qs);
|
||||
int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);
|
||||
int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);
|
||||
|
||||
v_xl = vec_perm(v_k, v_k, (uchar8x16_t)v_xl);
|
||||
v_xh = vec_perm(v_k, v_k, (uchar8x16_t)v_xh);
|
||||
|
||||
const int8x16_t v_yl = vec_xl(0 , y0->qs);
|
||||
const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs);
|
||||
const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
|
||||
|
||||
sumf += GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d) * (v_xy[0] + v_xy[1] + v_xy[2] + v_xy[3]);
|
||||
}
|
||||
#endif
|
||||
for (; ib < nb; ++ib) {
|
||||
const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d);
|
||||
@@ -11643,6 +12146,56 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
|
||||
}
|
||||
|
||||
*s = hsum_float_8(accum);
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
const int8x16_t v_k = vec_xl(0, kvalues_iq4nl);
|
||||
const uint8x16_t v_m = vec_splat_u8(0x0F);
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
for (int ibl = 0; ibl < nb; ++ibl) {
|
||||
const uint8_t * restrict q4 = x[ibl].qs;
|
||||
const int8_t * restrict q8 = y[ibl].qs;
|
||||
|
||||
uint16_t h = x[ibl].scales_h;
|
||||
|
||||
int sumi1 = 0, sumi2 = 0;
|
||||
for (int ib = 0; ib < QK_K/64; ++ib) {
|
||||
const uint8x16_t v_x0 = vec_xl(0 , q4);
|
||||
const uint8x16_t v_x1 = vec_xl(QK4_NL/2, q4);
|
||||
q4 += 32;
|
||||
|
||||
int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);
|
||||
int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);
|
||||
int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);
|
||||
int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);
|
||||
|
||||
v_x0l = vec_perm(v_k, v_k, (uchar8x16_t)v_x0l);
|
||||
v_x0h = vec_perm(v_k, v_k, (uchar8x16_t)v_x0h);
|
||||
v_x1l = vec_perm(v_k, v_k, (uchar8x16_t)v_x1l);
|
||||
v_x1h = vec_perm(v_k, v_k, (uchar8x16_t)v_x1h);
|
||||
|
||||
const int8x16_t v_y0 = vec_xl( 0, q8);
|
||||
const int8x16_t v_y1 = vec_xl(16, q8);
|
||||
const int8x16_t v_y2 = vec_xl(32, q8);
|
||||
const int8x16_t v_y3 = vec_xl(48, q8);
|
||||
q8 += 64;
|
||||
|
||||
int32x4_t vsumi0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0l, v_y0), v_x0h, v_y1);
|
||||
int32x4_t vsumi1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1l, v_y2), v_x1h, v_y3);
|
||||
|
||||
int ls1 = ((x[ibl].scales_l[ib] & 0xF) | ((h << 4) & 0x30)) - 32;
|
||||
int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32;
|
||||
|
||||
h >>= 4;
|
||||
|
||||
sumi1 += (vsumi0[0] + vsumi0[1] + vsumi0[2] + vsumi0[3]) * ls1;
|
||||
sumi2 += (vsumi1[0] + vsumi1[1] + vsumi1[2] + vsumi1[3]) * ls2;
|
||||
}
|
||||
|
||||
sumf += GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2);
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
|
||||
#else
|
||||
float sumf = 0;
|
||||
|
||||
@@ -237,6 +237,8 @@ typedef pthread_t ggml_thread_t;
|
||||
#else
|
||||
#if defined(__POWER9_VECTOR__)
|
||||
#define CACHE_LINE_SIZE 128
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
#define CACHE_LINE_SIZE 256
|
||||
#else
|
||||
#define CACHE_LINE_SIZE 64
|
||||
#endif
|
||||
@@ -1211,6 +1213,87 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
|
||||
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
|
||||
#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
|
||||
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
|
||||
#define GGML_SIMD
|
||||
|
||||
// F32 s390x
|
||||
|
||||
#define GGML_F32_STEP 32
|
||||
#define GGML_F32_EPR 4
|
||||
|
||||
#define GGML_F32x4 __vector float
|
||||
#define GGML_F32x4_ZERO vec_splats(0.0f)
|
||||
#define GGML_F32x4_SET1 vec_splats
|
||||
#define GGML_F32x4_LOAD(p) vec_xl(0, p)
|
||||
#define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p)
|
||||
#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
|
||||
#define GGML_F32x4_ADD vec_add
|
||||
#define GGML_F32x4_MUL vec_mul
|
||||
#define GGML_F32x4_REDUCE(res, x) \
|
||||
{ \
|
||||
int offset = GGML_F32_ARR >> 1; \
|
||||
for (int i = 0; i < offset; ++i) { \
|
||||
x[i] = vec_add(x[i], x[offset + i]); \
|
||||
} \
|
||||
offset >>= 1; \
|
||||
for (int i = 0; i < offset; ++i) { \
|
||||
x[i] = vec_add(x[i], x[offset + i]); \
|
||||
} \
|
||||
offset >>= 1; \
|
||||
for (int i = 0; i < offset; ++i) { \
|
||||
x[i] = vec_add(x[i], x[offset + i]); \
|
||||
} \
|
||||
res = vec_extract(x[0], 0) + \
|
||||
vec_extract(x[0], 1) + \
|
||||
vec_extract(x[0], 2) + \
|
||||
vec_extract(x[0], 3); \
|
||||
}
|
||||
|
||||
#define GGML_F32_VEC GGML_F32x4
|
||||
#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
|
||||
#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
|
||||
#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
|
||||
#define GGML_F32_VEC_STORE GGML_F32x4_STORE
|
||||
#define GGML_F32_VEC_FMA GGML_F32x4_FMA
|
||||
#define GGML_F32_VEC_ADD GGML_F32x4_ADD
|
||||
#define GGML_F32_VEC_MUL GGML_F32x4_MUL
|
||||
#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
|
||||
|
||||
// F16 s390x
|
||||
#define GGML_F16_STEP GGML_F32_STEP
|
||||
#define GGML_F16_EPR GGML_F32_EPR
|
||||
|
||||
static inline __vector float __lzs_f16cx4_load(const ggml_fp16_t * x) {
|
||||
float tmp[4];
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
tmp[i] = GGML_FP16_TO_FP32(x[i]);
|
||||
}
|
||||
|
||||
return vec_xl(0, tmp);
|
||||
}
|
||||
|
||||
static inline void __lzs_f16cx4_store(ggml_fp16_t * x, __vector float y) {
|
||||
float arr[4];
|
||||
|
||||
vec_xst(y, 0, arr);
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
x[i] = GGML_FP32_TO_FP16(arr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
#define GGML_F16_VEC GGML_F32x4
|
||||
#define GGML_F16_VEC_ZERO GGML_F32x4_ZERO
|
||||
#define GGML_F16_VEC_SET1 GGML_F32x4_SET1
|
||||
#define GGML_F16_VEC_LOAD(p, i) __lzs_f16cx4_load(p)
|
||||
#define GGML_F16_VEC_STORE(p, r, i) __lzs_f16cx4_store(p, r[i])
|
||||
#define GGML_F16_VEC_FMA GGML_F32x4_FMA
|
||||
#define GGML_F16_VEC_ADD GGML_F32x4_ADD
|
||||
#define GGML_F16_VEC_MUL GGML_F32x4_MUL
|
||||
#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
|
||||
|
||||
#endif
|
||||
|
||||
// GGML_F32_ARR / GGML_F16_ARR
|
||||
@@ -14419,6 +14502,14 @@ int ggml_cpu_has_vsx(void) {
|
||||
#endif
|
||||
}
|
||||
|
||||
int ggml_cpu_has_vxe(void) {
|
||||
#if defined(__VXE__) || defined(__VXE2__)
|
||||
return 1;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
int ggml_cpu_has_neon(void) {
|
||||
#if defined(__ARM_ARCH) && defined(__ARM_NEON)
|
||||
return ggml_arm_arch_features.has_neon;
|
||||
|
||||
@@ -557,6 +557,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
|
||||
if (ggml_cpu_has_vsx()) {
|
||||
features.push_back({ "VSX", "1" });
|
||||
}
|
||||
if (ggml_cpu_has_vxe()) {
|
||||
features.push_back({ "VXE", "1" });
|
||||
}
|
||||
if (ggml_cpu_has_wasm_simd()) {
|
||||
features.push_back({ "WASM_SIMD", "1" });
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ if (CUDAToolkit_FOUND)
|
||||
|
||||
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
|
||||
# native == GPUs available at build time
|
||||
# 52 == Maxwell, lowest CUDA 12 standard
|
||||
# 50 == Maxwell, lowest CUDA 12 standard
|
||||
# 60 == P100, FP16 CUDA intrinsics
|
||||
# 61 == Pascal, __dp4a instruction (per-byte integer dot product)
|
||||
# 70 == V100, FP16 tensor cores
|
||||
@@ -17,7 +17,7 @@ if (CUDAToolkit_FOUND)
|
||||
elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
|
||||
set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;80")
|
||||
else()
|
||||
set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75;80")
|
||||
set(CMAKE_CUDA_ARCHITECTURES "50;61;70;75;80")
|
||||
endif()
|
||||
endif()
|
||||
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
||||
@@ -69,6 +69,10 @@ if (CUDAToolkit_FOUND)
|
||||
add_compile_definitions(GGML_CUDA_NO_VMM)
|
||||
endif()
|
||||
|
||||
if (NOT GGML_CUDA_FA)
|
||||
add_compile_definitions(GGML_CUDA_NO_FA)
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
|
||||
add_compile_definitions(GGML_CUDA_F16)
|
||||
endif()
|
||||
|
||||
@@ -204,9 +204,9 @@ typedef float2 dfloat2;
|
||||
#define CP_ASYNC_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
|
||||
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
#define FLASH_ATTN_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
|
||||
static bool fp16_available(const int cc) {
|
||||
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
|
||||
@@ -411,13 +411,13 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
||||
|
||||
#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
|
||||
return __dp4a(a, b, c);
|
||||
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
|
||||
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
|
||||
const int8_t * a8 = (const int8_t *) &a;
|
||||
const int8_t * b8 = (const int8_t *) &b;
|
||||
return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
|
||||
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, co
|
||||
} else
|
||||
#endif // CUDART_VERSION >= 11040
|
||||
{
|
||||
asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;"
|
||||
asm volatile("cp.async.cg.shared.global [%0], [%1], 16;"
|
||||
: : "r"(dst), "l"(src));
|
||||
}
|
||||
#else
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "cpy.cuh"
|
||||
#include "dequantize.cuh"
|
||||
|
||||
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
||||
|
||||
@@ -82,13 +83,14 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
|
||||
const block_q8_0 * xi = (const block_q8_0 *) cxi;
|
||||
float * dsti = (float *) cdsti;
|
||||
float * cdstf = (float *)(cdsti);
|
||||
|
||||
const float d = (float)xi->d;
|
||||
|
||||
for (int j = 0; j < QK8_0; j++) {
|
||||
dsti[j] = xi->qs[j] * d;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < QK8_0; j += 2) {
|
||||
dfloat2 dq;
|
||||
dequantize_q8_0(cxi, 0, j, dq);
|
||||
*(cdstf + j) = dq.x;
|
||||
*(cdstf + j + 1) = dq.y;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -225,6 +227,18 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
|
||||
memcpy(dsti->qh, &qh, sizeof(qh));
|
||||
}
|
||||
|
||||
template<dequantize_kernel_t dequant, int qk>
|
||||
static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
|
||||
float * cdstf = (float *)(cdsti);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < qk/2; j++) {
|
||||
dfloat2 dq;
|
||||
dequant(cxi, 0, j, dq);
|
||||
*(cdstf + j) = dq.x;
|
||||
*(cdstf + j + qk/2) = dq.y;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
|
||||
if (x <= val[0]) return 0;
|
||||
@@ -387,6 +401,19 @@ static void ggml_cpy_f32_q4_0_cuda(
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q4_0_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02,
|
||||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q4_1_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
@@ -398,6 +425,19 @@ static void ggml_cpy_f32_q4_1_cuda(
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q4_1_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02,
|
||||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q5_0_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
@@ -409,6 +449,19 @@ static void ggml_cpy_f32_q5_0_cuda(
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q5_0_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02,
|
||||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q5_1_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
@@ -420,6 +473,19 @@ static void ggml_cpy_f32_q5_1_cuda(
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q5_1_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02,
|
||||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_iq4_nl_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
@@ -488,14 +554,25 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
@@ -524,14 +601,22 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
|
||||
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
|
||||
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
|
||||
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
|
||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
|
||||
@@ -123,13 +123,13 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
|
||||
|
||||
if (nbytes_shared <= smpbo) {
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||
if (!shared_memory_limit_raised[id]) {
|
||||
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
|
||||
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
|
||||
shared_memory_limit_raised[id] = true;
|
||||
}
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
||||
} else {
|
||||
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
||||
@@ -175,13 +175,13 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
|
||||
if (nbytes_shared <= smpbo) {
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||
if (!shared_memory_limit_raised[id]) {
|
||||
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
|
||||
shared_memory_limit_raised[id] = true;
|
||||
}
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
||||
} else {
|
||||
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
||||
|
||||
@@ -516,27 +516,25 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
||||
nullptr;
|
||||
}
|
||||
|
||||
// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional.
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wpass-failed"
|
||||
#endif // __clang__
|
||||
|
||||
template<int D, int ncols, int KQ_stride> // D == head size
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void flash_attn_stream_k_fixup(
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
||||
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
|
||||
|
||||
const int iter_k = ne11 / KQ_stride;
|
||||
const int iter_j = (ne01 + (ncols - 1)) / ncols;
|
||||
constexpr int ncols = ncols1*ncols2;
|
||||
|
||||
const int bidx0 = blockIdx.x;
|
||||
const int j = blockIdx.y;
|
||||
const int c = blockIdx.z;
|
||||
const int jc = j*ncols2 + c;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x;
|
||||
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x;
|
||||
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
|
||||
|
||||
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
|
||||
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
|
||||
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
||||
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
||||
@@ -548,22 +546,22 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
const int channel = kbc0 / (iter_k*iter_j);
|
||||
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
|
||||
|
||||
dst += jt*ncols*ne02*D + channel*D;
|
||||
if (jt*ncols1 + j >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
|
||||
|
||||
// Load the partial result that needs a fixup:
|
||||
float dst_val[ncols] = {0.0f};
|
||||
float max_val[ncols] = {0.0f};
|
||||
float rowsum[ncols] = {0.0f};
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (jt*ncols + j >= ne01) {
|
||||
break;
|
||||
}
|
||||
dst_val[j] = dst[j*ne02*D + threadIdx.x];
|
||||
float dst_val = 0.0f;
|
||||
float max_val = 0.0f;
|
||||
float rowsum = 0.0f;
|
||||
{
|
||||
dst_val = *dst;
|
||||
|
||||
const float2 tmp = dst_fixup[bidx0*ncols + j];
|
||||
max_val[j] = tmp.x;
|
||||
rowsum[j] = tmp.y;
|
||||
const float2 tmp = dst_fixup[bidx0*ncols + jc];
|
||||
max_val = tmp.x;
|
||||
rowsum = tmp.y;
|
||||
}
|
||||
|
||||
// Iterate over previous blocks and compute the combined results.
|
||||
@@ -571,36 +569,30 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
int bidx = bidx0 - 1;
|
||||
int kbc_stop = kbc0;
|
||||
while(true) {
|
||||
const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x;
|
||||
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
if (kbc == kbc_stop) { // Did not have any data.
|
||||
bidx--;
|
||||
kbc_stop = kbc;
|
||||
continue;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (jt*ncols + j >= ne01) {
|
||||
break;
|
||||
}
|
||||
const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x];
|
||||
const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
|
||||
|
||||
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j];
|
||||
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc];
|
||||
|
||||
// Scale the current and new value accumulators depending on the max. values.
|
||||
const float max_val_new = fmaxf(max_val[j], tmp.x);
|
||||
// Scale the current and new value accumulators depending on the max. values.
|
||||
const float max_val_new = fmaxf(max_val, tmp.x);
|
||||
|
||||
const float diff_val = max_val[j] - max_val_new;
|
||||
const float diff_add = tmp.x - max_val_new;
|
||||
const float diff_val = max_val - max_val_new;
|
||||
const float diff_add = tmp.x - max_val_new;
|
||||
|
||||
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
|
||||
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
|
||||
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
|
||||
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
|
||||
|
||||
dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
|
||||
rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y;
|
||||
dst_val = scale_val*dst_val + scale_add*dst_add;
|
||||
rowsum = scale_val*rowsum + scale_add*tmp.y;
|
||||
|
||||
max_val[j] = max_val_new;
|
||||
}
|
||||
max_val = max_val_new;
|
||||
|
||||
// If this block started in a previous tile we are done and don't need to combine additional partial results.
|
||||
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
|
||||
@@ -611,19 +603,9 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
}
|
||||
|
||||
// Write back final result:
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (jt*ncols + j >= ne01) {
|
||||
return;
|
||||
}
|
||||
dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j];
|
||||
}
|
||||
*dst = dst_val / rowsum;
|
||||
}
|
||||
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#endif // __clang__
|
||||
|
||||
template<int D, int parallel_blocks> // D == head size
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
@@ -690,11 +672,13 @@ static void on_no_fattn_vec_case(const int D) {
|
||||
}
|
||||
|
||||
// parallel_blocks == 0 is stream-k decomposition
|
||||
template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
|
||||
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
|
||||
void launch_fattn(
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
|
||||
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
|
||||
) {
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
@@ -763,25 +747,26 @@ void launch_fattn(
|
||||
nb23 = nb23*bs*sizeof(half)/ts;
|
||||
}
|
||||
|
||||
const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block);
|
||||
const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3];
|
||||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
||||
|
||||
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
||||
dim3 blocks_num;
|
||||
if (parallel_blocks == 0) {
|
||||
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
||||
const int tiles_nwaves = (ntiles_total + 2*nsm - 1) / (2*nsm);
|
||||
const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves);
|
||||
const int max_blocks = 2*nsm;
|
||||
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
|
||||
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
|
||||
|
||||
const int nblocks_stream_k = 2*nsm;
|
||||
const int nblocks_stream_k = max_blocks;
|
||||
|
||||
const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE;
|
||||
const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
|
||||
|
||||
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
|
||||
blocks_num.y = 1;
|
||||
blocks_num.z = 1;
|
||||
|
||||
dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float));
|
||||
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
|
||||
} else {
|
||||
blocks_num.x = parallel_blocks*ntiles_x;
|
||||
blocks_num.y = Q->ne[2];
|
||||
@@ -793,7 +778,6 @@ void launch_fattn(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
@@ -832,9 +816,9 @@ void launch_fattn(
|
||||
if constexpr (parallel_blocks == 0) {
|
||||
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 blocks_num_combine = blocks_num;
|
||||
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
|
||||
|
||||
flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
|
||||
flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -44,12 +44,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifdef FP16_AVAILABLE
|
||||
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
#ifdef FP16_MMA_AVAILABLE
|
||||
@@ -290,7 +285,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FP16_AVAILABLE
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
}
|
||||
|
||||
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
|
||||
@@ -302,14 +297,14 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||
|
||||
@@ -44,10 +44,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
#ifdef FP16_MMA_AVAILABLE
|
||||
@@ -285,6 +282,9 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
||||
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
|
||||
@@ -296,14 +296,14 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||
|
||||
@@ -41,12 +41,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifdef FP16_AVAILABLE
|
||||
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
@@ -300,7 +295,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FP16_AVAILABLE
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
}
|
||||
|
||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||
@@ -310,7 +305,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
|
||||
constexpr bool need_f16_K = D != 128;
|
||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
|
||||
}
|
||||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
|
||||
@@ -41,10 +41,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
@@ -281,6 +278,9 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||
@@ -290,7 +290,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
|
||||
constexpr bool need_f16_K = D != 128;
|
||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
|
||||
}
|
||||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
|
||||
@@ -51,7 +51,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
@@ -425,7 +425,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
}
|
||||
|
||||
constexpr int get_max_power_of_2(int x) {
|
||||
@@ -478,7 +478,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
}
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
||||
return;
|
||||
}
|
||||
if (2*blocks_num_pb1 < 2*nsm) {
|
||||
@@ -493,7 +493,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
}
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
||||
return;
|
||||
}
|
||||
constexpr int parallel_blocks = 1;
|
||||
@@ -507,7 +507,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
}
|
||||
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
+56
-17
@@ -8,28 +8,50 @@
|
||||
#include "fattn-wmma-f16.cuh"
|
||||
#include "fattn.cuh"
|
||||
|
||||
template <int cols_per_block>
|
||||
template <int D, int ncols2>
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
if (Q->ne[1] <= 8/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<D, 8/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 16/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<D, 16/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<D, 32/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<D, 64/ncols2, ncols2>(ctx, dst);
|
||||
}
|
||||
|
||||
template <int ncols2>
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case< 64, cols_per_block>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case< 80, cols_per_block>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case< 96, cols_per_block>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<112, cols_per_block>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<128, cols_per_block>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<256, cols_per_block>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -38,24 +60,35 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context
|
||||
}
|
||||
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
if (Q->ne[1] <= 8) {
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
const float use_gqa_opt = mask && max_bias == 0.0f;
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
|
||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 16) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<16>(ctx, dst);
|
||||
if (use_gqa_opt && gqa_ratio == 4) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<32>(ctx, dst);
|
||||
if (use_gqa_opt && gqa_ratio == 2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<64>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst);
|
||||
}
|
||||
|
||||
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
||||
@@ -209,8 +242,11 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
@@ -252,7 +288,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
|
||||
K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
|
||||
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0 && !mma_fast_for_bs1) {
|
||||
if (prec == GGML_PREC_DEFAULT) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
return;
|
||||
|
||||
@@ -261,6 +261,12 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n",
|
||||
id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
|
||||
device_vmm ? "yes" : "no", prop.warpSize);
|
||||
#elif defined(GGML_USE_MUSA)
|
||||
// TODO: refine the .cc to reflect MUSA's actual CC capabilities
|
||||
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
|
||||
info.devices[id].cc = 100*prop.major + 10*prop.minor;
|
||||
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
|
||||
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
|
||||
#else
|
||||
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
|
||||
info.devices[id].cc = 100*prop.major + 10*prop.minor;
|
||||
@@ -1782,9 +1788,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
||||
}
|
||||
}
|
||||
#else
|
||||
#ifdef GGML_USE_MUSA
|
||||
GGML_ASSERT(false);
|
||||
#else // !GGML_USE_MUSA
|
||||
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
||||
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
||||
// use cublasGemmStridedBatchedEx
|
||||
@@ -1827,7 +1830,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
||||
cu_compute_type,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
#endif // GGML_USE_MUSA
|
||||
#endif
|
||||
|
||||
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||
@@ -3073,15 +3075,27 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
|
||||
return true;
|
||||
}
|
||||
@@ -3189,7 +3203,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_FLASH_ATTN_EXT: {
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
return false;
|
||||
#endif
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -73,6 +73,8 @@ namespace ggml_cuda_mma {
|
||||
return threadIdx.x / 4;
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return (l / 2) * 8 + threadIdx.x / 4;
|
||||
} else if constexpr (I == 16 && J == 16) {
|
||||
return ((l / 2) % 2) * 8 + threadIdx.x / 4;
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
}
|
||||
@@ -85,6 +87,8 @@ namespace ggml_cuda_mma {
|
||||
return 4 * l + threadIdx.x % 4;
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return 2 * (threadIdx.x % 4) + l % 2;
|
||||
} else if constexpr (I == 16 && J == 16) {
|
||||
return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
}
|
||||
@@ -289,6 +293,42 @@ namespace ggml_cuda_mma {
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
const int * Axi = (const int *) A.x;
|
||||
const int * Bxi = (const int *) B.x;
|
||||
int * Dxi = (int *) D.x;
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
|
||||
: "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
|
||||
#else
|
||||
// On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
||||
: "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
||||
: "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#else
|
||||
GGML_UNUSED(D);
|
||||
GGML_UNUSED(A);
|
||||
GGML_UNUSED(B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
@@ -316,4 +356,39 @@ namespace ggml_cuda_mma {
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
const int * Axi = (const int *) A.x;
|
||||
const int * Bxi = (const int *) B.x;
|
||||
int * Dxi = (int *) D.x;
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
|
||||
#else
|
||||
// On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#else
|
||||
GGML_UNUSED(D);
|
||||
GGML_UNUSED(A);
|
||||
GGML_UNUSED(B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 16);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 16);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 16);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 16);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 16);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 16);
|
||||
@@ -1,10 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 32);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 32);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 32);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 32);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 32);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 32);
|
||||
@@ -1,10 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 64);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 64);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 64);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 64);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 64);
|
||||
@@ -1,10 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 1, 8);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 16, 1);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 16, 2);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 16, 4);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 2, 4);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 2, 8);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 32, 1);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 32, 2);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 4, 2);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 4, 4);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 4, 8);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 64, 1);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8, 1);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8, 2);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8, 4);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8, 8);
|
||||
@@ -18,7 +18,7 @@ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_f
|
||||
|
||||
"""
|
||||
|
||||
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {cols_per_block});\n"
|
||||
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {ncols1}, {ncols2});\n"
|
||||
|
||||
TYPES_MMQ = [
|
||||
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
|
||||
@@ -57,12 +57,18 @@ for vkq_size in [16, 32]:
|
||||
with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
|
||||
|
||||
for cols_per_block in [8, 16, 32, 64]:
|
||||
with open(f"fattn-mma-f16-instance-cpb{cols_per_block}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_MMA_START)
|
||||
for ncols in [8, 16, 32, 64, 128]:
|
||||
for ncols2 in [1, 2, 4, 8]:
|
||||
ncols1 = ncols // ncols2
|
||||
if ncols == 128:
|
||||
continue # Too much register pressure.
|
||||
with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_MMA_START)
|
||||
|
||||
for head_size in [64, 80, 96, 112, 128, 256]:
|
||||
f.write(SOURCE_FATTN_MMA_CASE.format(cols_per_block=cols_per_block, head_size=head_size))
|
||||
for head_size in [64, 80, 96, 112, 128, 256]:
|
||||
if ncols == 128 and head_size == 256:
|
||||
continue # Needs too much shared memory.
|
||||
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size=head_size))
|
||||
|
||||
for type in TYPES_MMQ:
|
||||
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
|
||||
|
||||
@@ -107,6 +107,10 @@ if (GGML_HIP_NO_VMM)
|
||||
add_compile_definitions(GGML_HIP_NO_VMM)
|
||||
endif()
|
||||
|
||||
if (NOT GGML_CUDA_FA)
|
||||
add_compile_definitions(GGML_CUDA_NO_FA)
|
||||
endif()
|
||||
|
||||
if (CXX_IS_HIPCC)
|
||||
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
|
||||
target_link_libraries(ggml-hip PRIVATE hip::device)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
#include <arm_sve.h>
|
||||
#endif // __ARM_FEATURE_SVE
|
||||
|
||||
#if defined(__ARM_NEON) && !defined(__CUDACC__)
|
||||
#if defined(__ARM_NEON) && !defined(__CUDACC__) && !defined(__MUSACC__)
|
||||
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
|
||||
//
|
||||
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
|
||||
|
||||
@@ -49,7 +49,7 @@ if (MUSAToolkit_FOUND)
|
||||
|
||||
set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX)
|
||||
foreach(SOURCE ${GGML_SOURCES_MUSA})
|
||||
set(COMPILE_FLAGS "-x musa -mtgpu")
|
||||
set(COMPILE_FLAGS "-fsigned-char -x musa -mtgpu")
|
||||
foreach(ARCH ${MUSA_ARCHITECTURES})
|
||||
set(COMPILE_FLAGS "${COMPILE_FLAGS} --cuda-gpu-arch=mp_${ARCH}")
|
||||
endforeach()
|
||||
@@ -83,6 +83,10 @@ if (MUSAToolkit_FOUND)
|
||||
add_compile_definitions(GGML_CUDA_NO_VMM)
|
||||
endif()
|
||||
|
||||
if (NOT GGML_CUDA_FA)
|
||||
add_compile_definitions(GGML_CUDA_NO_FA)
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
|
||||
add_compile_definitions(GGML_CUDA_F16)
|
||||
endif()
|
||||
|
||||
@@ -35,7 +35,7 @@
|
||||
void* ggml_sycl_host_malloc(size_t size);
|
||||
void ggml_sycl_host_free(void* ptr);
|
||||
|
||||
static int g_ggml_sycl_debug = 0;
|
||||
extern int g_ggml_sycl_debug;
|
||||
#define GGML_SYCL_DEBUG(...) \
|
||||
do { \
|
||||
if (g_ggml_sycl_debug) \
|
||||
|
||||
@@ -41,6 +41,7 @@
|
||||
#include "ggml-sycl/gemm.hpp"
|
||||
|
||||
static bool g_sycl_loaded = false;
|
||||
int g_ggml_sycl_debug = 0;
|
||||
|
||||
static ggml_sycl_device_info ggml_sycl_init() {
|
||||
ggml_sycl_device_info info = {};
|
||||
@@ -157,8 +158,8 @@ static void ggml_check_sycl() try {
|
||||
static bool initialized = false;
|
||||
|
||||
if (!initialized) {
|
||||
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
|
||||
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
|
||||
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
|
||||
GGML_LOG_INFO("GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
|
||||
#if defined(GGML_SYCL_FORCE_MMQ)
|
||||
GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ: yes\n");
|
||||
|
||||
@@ -249,13 +249,16 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) {
|
||||
const sycl::half * src1_dd = static_cast<sycl::half *>(dst->src[1]->data);
|
||||
GGML_SYCL_DEBUG("%s: F16 mask\n", __func__);
|
||||
soft_max_f32_sycl<sycl::half>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias,
|
||||
main_stream, ctx.device);
|
||||
} else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) {
|
||||
const float * src1_dd = static_cast<const float *>(dst->src[1]->data);
|
||||
GGML_SYCL_DEBUG("%s: F32 mask\n", __func__);
|
||||
soft_max_f32_sycl<float>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
|
||||
} else {
|
||||
/* mask unavailable */
|
||||
GGML_SYCL_DEBUG("%s: No mask\n", __func__);
|
||||
soft_max_f32_sycl<float>(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,7 +240,11 @@ void ggml_log_callback_default(enum ggml_log_level level, const char * text, voi
|
||||
|
||||
|
||||
void * ggml_aligned_malloc(size_t size) {
|
||||
#if defined(__s390x__)
|
||||
const int alignment = 256;
|
||||
#else
|
||||
const int alignment = 64;
|
||||
#endif
|
||||
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
return _aligned_malloc(size, alignment);
|
||||
|
||||
@@ -1424,6 +1424,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
throw std::runtime_error(format("missing tensor info mapping for %s", tn.str().c_str()));
|
||||
}
|
||||
|
||||
// skip unused tensors
|
||||
if (info.op == GGML_OP_NONE) {
|
||||
LLAMA_LOG_WARN("model has unused tensor %s -- ignoring\n", tn.str().c_str());
|
||||
ml.n_created++;
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// tensors with "bias" suffix are always used with GGML_OP_ADD
|
||||
ggml_op op;
|
||||
bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0;
|
||||
|
||||
+20
-15
@@ -3119,6 +3119,7 @@ struct test_leaky_relu : public test_case {
|
||||
struct test_flash_attn_ext : public test_case {
|
||||
const int64_t hs; // head size
|
||||
const int64_t nh; // num heads
|
||||
const int64_t nr; // repeat in Q, tests for grouped-query attention
|
||||
const int64_t kv; // kv size
|
||||
const int64_t nb; // batch size
|
||||
|
||||
@@ -3131,7 +3132,7 @@ struct test_flash_attn_ext : public test_case {
|
||||
std::array<int32_t, 4> permute;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR9(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
|
||||
return VARS_TO_STR10(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
@@ -3142,13 +3143,13 @@ struct test_flash_attn_ext : public test_case {
|
||||
GGML_UNUSED(t);
|
||||
// Just counting matmul costs:
|
||||
// Q*K^T is nb x hs x kv, P*V is nb x kv x hs, per head
|
||||
return 2 * 2 * nh * nb * hs * kv;
|
||||
return 2 * 2 * nh*nr * nb * hs * kv;
|
||||
}
|
||||
|
||||
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8,
|
||||
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
|
||||
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16,
|
||||
std::array<int32_t, 4> permute = {0, 1, 2, 3})
|
||||
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
|
||||
: hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
|
||||
@@ -3166,13 +3167,13 @@ struct test_flash_attn_ext : public test_case {
|
||||
return t;
|
||||
};
|
||||
|
||||
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh, 1);
|
||||
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh*nr, 1);
|
||||
ggml_set_name(q, "q");
|
||||
|
||||
ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
|
||||
ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
|
||||
ggml_set_name(k, "k");
|
||||
|
||||
ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
|
||||
ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
|
||||
ggml_set_name(v, "v");
|
||||
|
||||
ggml_tensor * m = nullptr;
|
||||
@@ -4278,14 +4279,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
if (!mask && max_bias > 0.0f) continue;
|
||||
for (float logit_softcap : {0.0f, 10.0f}) {
|
||||
if (hs != 128 && logit_softcap != 0.0f) continue;
|
||||
for (int nh : { 32, }) {
|
||||
for (int kv : { 512, 1024, }) {
|
||||
for (int nb : { 1, 3, 32, 35, }) {
|
||||
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
|
||||
// run fewer test cases permuted
|
||||
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
|
||||
for (int nh : { 4, }) {
|
||||
for (int nr : { 1, 4, 16 }) {
|
||||
if (nr == 16 && hs != 128) continue;
|
||||
for (int kv : { 512, 1024, }) {
|
||||
if (nr != 1 && kv != 512) continue;
|
||||
for (int nb : { 1, 3, 32, 35, }) {
|
||||
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV));
|
||||
// run fewer test cases permuted
|
||||
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user