mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-07-01 10:07:44 +02:00
Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 24d2ee0527 | |||
| 541bf37622 | |||
| d969e933e1 | |||
| 7f5ee54968 | |||
| 66199c9f03 | |||
| c99909dd0b | |||
| cb8f4fa3f8 | |||
| 54910bd4f3 | |||
| ecd99d6a9a | |||
| 137435ff15 | |||
| 24350fdf9b | |||
| 49a7564ac1 | |||
| 4d828bd1ab | |||
| 36a7a6589c | |||
| feefb92836 | |||
| ec88c3ceea | |||
| 2afcdb9777 | |||
| 319146247e | |||
| 66d65ec29b | |||
| 05728db18e | |||
| 4720819d45 | |||
| d979f2b176 | |||
| ecbcb7ea9d | |||
| 3e6ab244ad |
+1
-1
@@ -108,7 +108,7 @@ Building through oneAPI compilers will make avx_vnni instruction set available f
|
||||
- Using oneAPI docker image:
|
||||
If you do not want to source the environment vars and install oneAPI manually, you can also build the code using intel docker container: [oneAPI-basekit](https://hub.docker.com/r/intel/oneapi-basekit). Then, you can use the commands given above.
|
||||
|
||||
Check [Optimizing and Running LLaMA2 on Intel® CPU](https://www.intel.com/content/www/us/en/content-details/791610/optimizing-and-running-llama2-on-intel-cpu.html) for more information.
|
||||
Check [Optimizing and Running LLaMA2 on Intel® CPU](https://builders.intel.com/solutionslibrary/optimizing-and-running-llama2-on-intel-cpu) for more information.
|
||||
|
||||
### Other BLAS libraries
|
||||
|
||||
|
||||
+1
-1
@@ -24,7 +24,7 @@ Legend:
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
|
||||
+32
-32
@@ -9535,38 +9535,38 @@
|
||||
"WebGPU: WebGPU","ROPE","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=40,n_ctx=512,fs=1.424500,ef=0.746500,af=1.424500,ff=1,v=0,inplace=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","ROPE","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=24,n_ctx=512,fs=1.424500,ef=0.746500,af=1.424500,ff=0,v=0,inplace=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","ROPE","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=24,n_ctx=512,fs=1.424500,ef=0.746500,af=1.424500,ff=1,v=0,inplace=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","ARGSORT","type=f32,ne=[3,1,1,1],order=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","ARGSORT","type=f32,ne=[4,1,1,1],order=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","ARGSORT","type=f32,ne=[7,1,1,1],order=0","support","1","yes","WebGPU"
|
||||
|
||||
|
Can't render this file because it is too large.
|
@@ -5,6 +5,7 @@
|
||||
#include "sampling.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <clocale>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@@ -16,6 +17,8 @@ static void print_usage(int, char ** argv) {
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_params params;
|
||||
|
||||
params.prompt = "Hello my name is";
|
||||
|
||||
@@ -5,14 +5,16 @@
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cinttypes>
|
||||
#include <climits>
|
||||
#include <clocale>
|
||||
#include <cstdarg>
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <cassert>
|
||||
#include <climits>
|
||||
#include <cstring>
|
||||
#include <cstdarg>
|
||||
#include <cinttypes>
|
||||
#include <ctime>
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
#include <sstream>
|
||||
@@ -874,6 +876,8 @@ static std::string basename(const std::string &path) {
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_init();
|
||||
|
||||
struct train_params params = get_default_train_params();
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
// Warns users that this filename was deprecated, and provides a link for more information.
|
||||
|
||||
#include <clocale>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
// Main
|
||||
int main(int argc, char** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
std::string filename = "main";
|
||||
if (argc >= 1) {
|
||||
filename = argv[0];
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <limits.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <clocale>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
@@ -538,6 +539,8 @@ static std::string format_input_text(const std::string & prompt, const std::stri
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
ggml_time_init();
|
||||
|
||||
common_params params;
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <clocale>
|
||||
#include <ctime>
|
||||
#include <algorithm>
|
||||
|
||||
@@ -94,6 +95,8 @@ static void print_raw_embeddings(const float * emb,
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_EMBEDDING)) {
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include "llama-cpp.h"
|
||||
|
||||
#include <clocale>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@@ -29,6 +31,8 @@ static bool run(llama_context * ctx, const common_params & params) {
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
base_callback_data cb_data;
|
||||
|
||||
common_params params;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
|
||||
#include <clocale>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
@@ -100,6 +101,8 @@ static void write_help(std::ostringstream & ss, const md_file & md) {
|
||||
}
|
||||
|
||||
int main(int, char **) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
for (const auto & md : md_files) {
|
||||
std::ifstream infile(md.fname);
|
||||
if (!infile.is_open()) {
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
#include "ggml.h"
|
||||
#include "gguf.h"
|
||||
|
||||
#include <cstdlib> /* abort() */
|
||||
#include <algorithm>
|
||||
#include <clocale>
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include <algorithm>
|
||||
#include <cstdlib> /* abort() */
|
||||
#include <cstring>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
@@ -626,6 +627,8 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
|
||||
}
|
||||
|
||||
int main(int argc, const char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
hash_params params;
|
||||
manifest_check_params manifest_check;
|
||||
hash_params_parse(argc, argv, params);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "ggml.h"
|
||||
#include "gguf.h"
|
||||
|
||||
#include <clocale>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
@@ -240,6 +241,8 @@ static bool gguf_ex_read_1(const std::string & fname, bool check_data) {
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
if (argc < 3) {
|
||||
printf("usage: %s data.gguf r|w [n]\n", argv[0]);
|
||||
printf("r: read data.gguf file\n");
|
||||
|
||||
@@ -4,10 +4,11 @@
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <clocale>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
struct ngram_data {
|
||||
bool active = false;
|
||||
@@ -38,6 +39,8 @@ struct ngram_container {
|
||||
};
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
|
||||
|
||||
@@ -3,10 +3,13 @@
|
||||
#include "ngram-cache.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <clocale>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
int main(int argc, char ** argv){
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) {
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "common.h"
|
||||
#include "ngram-cache.h"
|
||||
|
||||
#include <clocale>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <fstream>
|
||||
@@ -17,6 +18,8 @@ static void print_usage(char* argv0) {
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv){
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
if (argc < 3) {
|
||||
print_usage(argv[0]);
|
||||
exit(1);
|
||||
|
||||
@@ -5,14 +5,17 @@
|
||||
#include "llama.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <cinttypes>
|
||||
#include <clocale>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cinttypes>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
int main(int argc, char ** argv){
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) {
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <clocale>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <fstream>
|
||||
@@ -13,6 +14,8 @@
|
||||
#include <vector>
|
||||
|
||||
int main(int argc, char ** argv){
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) {
|
||||
|
||||
@@ -7,12 +7,13 @@
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <clocale>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <ctime>
|
||||
#include <algorithm>
|
||||
|
||||
// trim whitespace from the beginning and end of a string
|
||||
static std::string trim(const std::string & str) {
|
||||
@@ -153,6 +154,8 @@ static std::vector<std::string> split_string(const std::string& input, char deli
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
srand(1234);
|
||||
|
||||
common_params params;
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <clocale>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
@@ -16,6 +17,8 @@ static void print_usage(int, char ** argv) {
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_params params;
|
||||
|
||||
params.n_junk = 250;
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <clocale>
|
||||
#include <fstream>
|
||||
#include <iostream> // TODO: remove me
|
||||
|
||||
@@ -112,6 +113,8 @@ static void batch_process(llama_context * ctx, llama_batch & batch, float * outp
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_RETRIEVAL, print_usage)) {
|
||||
|
||||
@@ -2,11 +2,14 @@
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <clocale>
|
||||
#include <vector>
|
||||
#include <cstdio>
|
||||
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_params params;
|
||||
|
||||
params.prompt = "The quick brown fox";
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "llama.h"
|
||||
#include <clocale>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
@@ -12,6 +13,8 @@ static void print_usage(int, char ** argv) {
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
std::string model_path;
|
||||
int ngl = 99;
|
||||
int n_ctx = 2048;
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "llama.h"
|
||||
#include <clocale>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
@@ -11,6 +12,8 @@ static void print_usage(int, char ** argv) {
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
// path to the model gguf file
|
||||
std::string model_path;
|
||||
// prompt to generate text from
|
||||
|
||||
@@ -5,12 +5,15 @@
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <clocale>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <clocale>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <random>
|
||||
@@ -30,6 +31,8 @@ struct seq_draft {
|
||||
};
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_params params;
|
||||
|
||||
// needed to get candidate probs even for temp <= 0.0
|
||||
|
||||
@@ -6,8 +6,10 @@
|
||||
|
||||
|
||||
#include "ggml-sycl.h"
|
||||
#include <clocale>
|
||||
|
||||
int main() {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
ggml_backend_sycl_print_sycl_devices();
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <clocale>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
@@ -14,6 +15,8 @@
|
||||
#endif
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_params params;
|
||||
params.escape = false;
|
||||
|
||||
|
||||
@@ -566,9 +566,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
|
||||
# Fetch KleidiAI sources:
|
||||
include(FetchContent)
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.16.0")
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.22.0")
|
||||
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
||||
set(KLEIDIAI_ARCHIVE_MD5 "0a9e9008adb6031f9e8cf70dff4a3321")
|
||||
set(KLEIDIAI_ARCHIVE_MD5 "54049037570ab0ee0a0d126b2ba5ece1")
|
||||
|
||||
if (POLICY CMP0135)
|
||||
cmake_policy(SET CMP0135 NEW)
|
||||
@@ -608,6 +608,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
|
||||
|
||||
set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}")
|
||||
@@ -648,7 +649,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
|
||||
if (NOT SME_ENABLED MATCHES -1)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S
|
||||
@@ -656,10 +656,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa_asm.S
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_f16pmrx2_f32_neon.c
|
||||
${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S)
|
||||
set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
|
||||
set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2+sme2+fp16")
|
||||
endif()
|
||||
|
||||
if (NOT SVE_ENABLED MATCHES -1)
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
|
||||
#if defined(GGML_USE_OPENMP)
|
||||
#include <omp.h>
|
||||
#else
|
||||
#include <thread>
|
||||
#endif
|
||||
|
||||
#define TILE_M 16
|
||||
@@ -56,18 +58,40 @@ inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
inline void parallel_for(int n, const func_t& f) {
|
||||
inline void parallel_for(int n, const func_t & f) {
|
||||
if (n <= 0) {
|
||||
return;
|
||||
}
|
||||
#if defined(GGML_USE_OPENMP)
|
||||
#pragma omp parallel
|
||||
{
|
||||
int nth = omp_get_num_threads();
|
||||
int ith = omp_get_thread_num();
|
||||
int tbegin, tend;
|
||||
balance211(n, nth, ith, tbegin, tend);
|
||||
f(tbegin, tend);
|
||||
}
|
||||
#pragma omp parallel
|
||||
{
|
||||
int nth = omp_get_num_threads();
|
||||
int ith = omp_get_thread_num();
|
||||
int tbegin, tend;
|
||||
balance211(n, nth, ith, tbegin, tend);
|
||||
f(tbegin, tend);
|
||||
}
|
||||
#else
|
||||
f(0, n);
|
||||
int nth = std::thread::hardware_concurrency();
|
||||
if (nth <= 1) {
|
||||
f(0, n);
|
||||
return;
|
||||
}
|
||||
if (nth > n) {
|
||||
nth = n;
|
||||
}
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(nth);
|
||||
for (int ith = 0; ith < nth; ++ith) {
|
||||
threads.emplace_back([&f, n, ith, nth] {
|
||||
int tbegin, tend;
|
||||
balance211(n, nth, ith, tbegin, tend);
|
||||
f(tbegin, tend);
|
||||
});
|
||||
}
|
||||
for (auto & t : threads) {
|
||||
t.join();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -181,11 +181,11 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs);
|
||||
|
||||
const int16x8_t v_xylso = vec_mulo(v_xls, v_yl);
|
||||
const int16x8_t v_xylse = vec_mule(v_xls, v_yl);
|
||||
const int16x8_t v_xyl = vec_meadd(v_xls, v_yl, v_xylso);
|
||||
const int16x8_t v_xyhso = vec_mulo(v_xhs, v_yh);
|
||||
const int16x8_t v_xyhse = vec_mule(v_xhs, v_yh);
|
||||
const int16x8_t v_xyh = vec_meadd(v_xhs, v_yh, v_xyhso);
|
||||
|
||||
int16x8_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_);
|
||||
int16x8_t v_xy_ = v_xyl + v_xyh; v_xy_ += vec_reve(v_xy_);
|
||||
|
||||
const float32x4_t v_xy = vec_float(vec_unpackh(v_xy_));
|
||||
const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));
|
||||
@@ -890,8 +890,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
const int16x8_t v_minsh = (int16x8_t)vec_unpackh((uint8x16_t)v_mins8);
|
||||
|
||||
const int32x4_t v_minso = vec_mulo(v_ysums, v_minsh);
|
||||
const int32x4_t v_minse = vec_mule(v_ysums, v_minsh);
|
||||
const int32x4_t v_mins = v_minso + v_minse;
|
||||
const int32x4_t v_mins = vec_meadd(v_ysums, v_minsh, v_minso);
|
||||
sumf -= dmin * (v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]);
|
||||
|
||||
const uint8_t * scales = (const uint8_t *)utmp;
|
||||
@@ -1004,8 +1003,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
const int16x8_t v_minsh = (int16x8_t)vec_unpackh(v_mins8);
|
||||
|
||||
const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh);
|
||||
const int32x4_t v_minshe = vec_mule(v_ysums, v_minsh);
|
||||
const int32x4_t v_mins = vec_add(v_minsho, v_minshe);
|
||||
const int32x4_t v_mins = vec_meadd(v_ysums, v_minsh, v_minsho);
|
||||
const int32_t mins = vec_hsum_i32x4(v_mins);
|
||||
|
||||
const uint8_t * scales = (const uint8_t *)utmp;
|
||||
@@ -1110,10 +1108,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
const int16x8_t v_scaleh = vec_unpackl(v_scale);
|
||||
|
||||
const int32x4_t v_minslo = vec_mulo(v_ysumsl, v_scalel);
|
||||
const int32x4_t v_minsle = vec_mule(v_ysumsl, v_scalel);
|
||||
const int32x4_t v_minsl = vec_meadd(v_ysumsl, v_scalel, v_minslo);
|
||||
const int32x4_t v_minsho = vec_mulo(v_ysumsh, v_scaleh);
|
||||
const int32x4_t v_minshe = vec_mule(v_ysumsh, v_scaleh);
|
||||
const int32x4_t v_mins = v_minslo + v_minsle + v_minsho + v_minshe;
|
||||
const int32x4_t v_minsh = vec_meadd(v_ysumsh, v_scaleh, v_minsho);
|
||||
const int32x4_t v_mins = vec_add(v_minsl, v_minsh);
|
||||
|
||||
const int32_t mins = vec_hsum_i32x4(v_mins);
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
||||
// SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
|
||||
#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
|
||||
#include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
|
||||
@@ -20,6 +19,7 @@
|
||||
#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.h"
|
||||
#include "kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa.h"
|
||||
|
||||
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
|
||||
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
|
||||
@@ -31,6 +31,7 @@
|
||||
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
|
||||
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
|
||||
#include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
|
||||
#include "kai_lhs_pack_f16pmrx2_f32_neon.h"
|
||||
|
||||
#include "kai_common.h"
|
||||
|
||||
@@ -309,24 +310,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
{
|
||||
/* SME GEMM */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa>,
|
||||
},
|
||||
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32_neon>,
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_f16pmrx2_f32_neon,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_pack_f16pmrx2_f32_neon>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_pack_f16pmrx2_f32_neon>,
|
||||
/* .pack_func_ex = */ &lhs_pack_void_fn10<kai_run_lhs_pack_f16pmrx2_f32_neon>,
|
||||
},
|
||||
/* SME GEMV */
|
||||
/* .kern_info = */ {
|
||||
|
||||
@@ -16,27 +16,27 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i01 = blockIdx.y;
|
||||
for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) {
|
||||
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
|
||||
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
|
||||
const int64_t i02 = dm.y;
|
||||
const int64_t i03 = dm.x;
|
||||
|
||||
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
|
||||
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
|
||||
const int64_t i02 = dm.y;
|
||||
const int64_t i03 = dm.x;
|
||||
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
|
||||
|
||||
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
|
||||
const int64_t ib = ibx0 + i00/qk; // block index
|
||||
const int64_t iqs = (i00%qk)/qr; // quant index
|
||||
const int64_t iybs = i00 - i00%qk; // y block start index
|
||||
const int64_t y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
const int64_t ib = ibx0 + i00/qk; // block index
|
||||
const int64_t iqs = (i00%qk)/qr; // quant index
|
||||
const int64_t iybs = i00 - i00%qk; // y block start index
|
||||
const int64_t y_offset = qr == 1 ? 1 : qk/2;
|
||||
// dequantize
|
||||
float2 v;
|
||||
dequantize_kernel(vx, ib, iqs, v);
|
||||
|
||||
// dequantize
|
||||
float2 v;
|
||||
dequantize_kernel(vx, ib, iqs, v);
|
||||
|
||||
const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs;
|
||||
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
|
||||
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
|
||||
const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs;
|
||||
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
|
||||
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -492,7 +492,7 @@ static void dequantize_block_cuda(const void * vx, dst_t * y,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
|
||||
const int64_t ne0203 = ne02*ne03;
|
||||
const uint3 ne02_fdv = init_fastdiv_values(ne02);
|
||||
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, (int)std::min(ne0203, (int64_t)65535));
|
||||
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535));
|
||||
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
|
||||
(vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
|
||||
}
|
||||
@@ -628,18 +628,18 @@ static __global__ void convert_unary(
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i01 = blockIdx.y;
|
||||
|
||||
const src_t * x = (const src_t *) vx;
|
||||
|
||||
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
|
||||
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
|
||||
const int64_t i02 = dm.y;
|
||||
const int64_t i03 = dm.x;
|
||||
for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) {
|
||||
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
|
||||
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
|
||||
const int64_t i02 = dm.y;
|
||||
const int64_t i03 = dm.x;
|
||||
|
||||
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
|
||||
const int64_t iy = (i0203*ne01 + i01)*ne00 + i00;
|
||||
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
|
||||
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
|
||||
const int64_t iy = (i0203*ne01 + i01)*ne00 + i00;
|
||||
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -649,7 +649,7 @@ static void convert_unary_cuda(const void * vx, dst_t * y,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
|
||||
const int64_t ne0203 = ne02*ne03;
|
||||
const uint3 ne02_fdv = init_fastdiv_values(ne02);
|
||||
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, (int)std::min(ne0203, (int64_t)65535));
|
||||
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535));
|
||||
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
|
||||
(vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
|
||||
}
|
||||
|
||||
@@ -111,6 +111,44 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) {
|
||||
// Conservative configs for CDNA (MI100+): 64KB LDS, wavefront64, nstages=1 (no cp.async).
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 32, 32, 32, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 256, 2, 64, 40, 40, 40, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 256, 2, 64, 48, 48, 48, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2, 64, 56, 56, 56, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2, 64, 64, 64, 64, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 256, 2, 32, 128, 128, 128, 1, true);
|
||||
|
||||
// Fallback for unsupported DKQ values (e.g. 576). Must return non-zero values to satisfy
|
||||
// compile-time static_asserts even though the kernel guard prevents runtime execution.
|
||||
// nthreads=256 gives nwarps=4 (warp_size=64) or 8 (warp_size=32), nbatch_fa=128 satisfies np*16 divisibility.
|
||||
return fattn_mma_config(256, 1, 128, 4, 4, 4, 1, false);
|
||||
}
|
||||
|
||||
static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
|
||||
if (ampere_mma_available(cc)) {
|
||||
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
||||
@@ -118,6 +156,9 @@ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, c
|
||||
if (turing_mma_available(cc)) {
|
||||
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
|
||||
}
|
||||
if (amd_mfma_available(cc)) {
|
||||
return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
|
||||
}
|
||||
if (amd_wmma_available(cc)) {
|
||||
return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
|
||||
}
|
||||
@@ -130,6 +171,8 @@ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(cons
|
||||
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
||||
#elif defined(TURING_MMA_AVAILABLE)
|
||||
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
|
||||
#elif defined(VOLTA_MMA_AVAILABLE)
|
||||
return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
@@ -205,15 +248,15 @@ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ,
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_cols_per_thread() {
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
return 1; // RDNA has a single column.
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
return 1; // AMD has a single column per thread.
|
||||
#else
|
||||
return 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
static __host__ int get_cols_per_warp(const int cc) {
|
||||
if (turing_mma_available(cc) || amd_wmma_available(cc)) {
|
||||
if (turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc)) {
|
||||
return 16;
|
||||
} else {
|
||||
// Volta
|
||||
@@ -241,6 +284,7 @@ static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, c
|
||||
template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
|
||||
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
|
||||
if constexpr (use_cp_async) {
|
||||
@@ -252,10 +296,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
|
||||
|
||||
auto load = [&] __device__ (auto n) {
|
||||
const int stride_k = WARP_SIZE >> n;
|
||||
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
||||
const int stride_k = warp_size >> n;
|
||||
const int k0_start = stride_k == warp_size ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
||||
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
|
||||
const int stride_i = WARP_SIZE / stride_k;
|
||||
const int stride_i = warp_size / stride_k;
|
||||
|
||||
if (k0_start == k0_stop) {
|
||||
return;
|
||||
@@ -263,7 +307,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
|
||||
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
||||
const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
||||
|
||||
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
|
||||
break;
|
||||
@@ -271,7 +315,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
||||
|
||||
cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
|
||||
}
|
||||
@@ -287,10 +331,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
} else {
|
||||
// TODO use ggml_cuda_memcpy_1
|
||||
auto load = [&] __device__ (const int n) {
|
||||
const int stride_k = WARP_SIZE >> n;
|
||||
const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
|
||||
const int stride_k = warp_size >> n;
|
||||
const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k);
|
||||
const int k0_stop = D2 - D2 % (1*stride_k);
|
||||
const int stride_i = WARP_SIZE / stride_k;
|
||||
const int stride_i = warp_size / stride_k;
|
||||
|
||||
if (k0_start == k0_stop) {
|
||||
return;
|
||||
@@ -298,7 +342,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
|
||||
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
||||
const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
||||
|
||||
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
|
||||
break;
|
||||
@@ -306,7 +350,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
||||
|
||||
tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f);
|
||||
}
|
||||
@@ -324,18 +368,19 @@ template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_chec
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
||||
const half * const __restrict__ mask_h, half * const __restrict__ tile_mask,
|
||||
const int stride_mask, const int i_sup, const int j0, const uint3 ne01) {
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
if constexpr (use_cp_async) {
|
||||
static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa");
|
||||
static_assert(nbatch_fa <= 8*warp_size && nbatch_fa % 8 == 0, "bad nbatch_fa");
|
||||
static_assert(!oob_check, "OOB check incompatible with cp_async");
|
||||
constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
|
||||
constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
|
||||
constexpr int cols_per_warp = 8*warp_size/nbatch_fa;
|
||||
constexpr int stride_j = nwarps * cols_per_warp;
|
||||
|
||||
const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
|
||||
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
|
||||
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
|
||||
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
|
||||
const int j_vram = fastmodulo(j0 + j_sram, ne01);
|
||||
|
||||
if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
|
||||
@@ -357,25 +402,25 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
|
||||
}
|
||||
}
|
||||
} else if constexpr (nbatch_fa < 2*WARP_SIZE) {
|
||||
constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
|
||||
} else if constexpr (nbatch_fa < 2*warp_size) {
|
||||
constexpr int cols_per_warp = 2*warp_size/nbatch_fa;
|
||||
constexpr int stride_j = nwarps * cols_per_warp;
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
|
||||
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
|
||||
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
|
||||
const int j_vram = fastmodulo(j0 + j_sram, ne01);
|
||||
|
||||
if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
|
||||
break;
|
||||
}
|
||||
|
||||
const int i = threadIdx.x % (WARP_SIZE/cols_per_warp);
|
||||
const int i = threadIdx.x % (warp_size/cols_per_warp);
|
||||
|
||||
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
|
||||
}
|
||||
@@ -390,7 +435,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) {
|
||||
const int i = i0 + 2*threadIdx.x;
|
||||
|
||||
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
|
||||
@@ -428,7 +473,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
const int jt,
|
||||
const int kb0,
|
||||
const int k_VKQ_sup) {
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int cols_per_warp = T_B_KQ::I;
|
||||
constexpr int cols_per_thread = get_cols_per_thread();
|
||||
@@ -447,7 +493,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
const int k_VKQ_0 = kb0 * nbatch_fa;
|
||||
#if defined(TURING_MMA_AVAILABLE)
|
||||
T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
|
||||
#else // Volta
|
||||
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
|
||||
@@ -500,13 +546,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
|
||||
} else {
|
||||
// Wide version of KQ_C is column-major
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// RDNA matrix C is column-major.
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
// AMD matrix C is column-major.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
|
||||
#else
|
||||
// swap A and B for CUDA.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -526,13 +572,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
} else {
|
||||
// Wide version of KQ_C is column-major
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// RDNA matrix C is column-major.
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
// AMD matrix C is column-major.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
#else
|
||||
// swap A and B for CUDA.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -585,12 +631,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
||||
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
constexpr int KQ_idx = 0;
|
||||
#else
|
||||
// Turing + Volta:
|
||||
const int KQ_idx = l % 2;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
|
||||
}
|
||||
}
|
||||
@@ -601,7 +647,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset >= 4; offset >>= 1) {
|
||||
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
|
||||
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -611,12 +657,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
||||
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
constexpr int KQ_idx = 0;
|
||||
#else
|
||||
// Turing + Volta:
|
||||
const int KQ_idx = l % 2;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]);
|
||||
KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
|
||||
} else {
|
||||
@@ -649,12 +695,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
||||
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
constexpr int KQ_idx = 0;
|
||||
#else
|
||||
// Turing + Volta:
|
||||
const int KQ_idx = (l/2) % 2;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
|
||||
}
|
||||
}
|
||||
@@ -666,6 +712,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
// Values per KQ column are spread across 4 threads:
|
||||
constexpr int offset_first = 2;
|
||||
constexpr int offset_last = 1;
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
// MFMA: 4 threads per Q column (threadIdx.x % 16 == col, spaced by 16).
|
||||
constexpr int offset_first = 32;
|
||||
constexpr int offset_last = 16;
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
// Values per KQ column are spread across 2 threads:
|
||||
constexpr int offset_first = 16;
|
||||
@@ -677,7 +727,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#endif // defined(TURING_MMA_AVAILABLE)
|
||||
#pragma unroll
|
||||
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
|
||||
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
|
||||
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -687,12 +737,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
||||
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
constexpr int KQ_idx = 0;
|
||||
#else
|
||||
// Turing + Volta:
|
||||
const int KQ_idx = (l/2) % 2;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]);
|
||||
KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
|
||||
} else {
|
||||
@@ -739,7 +789,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
const half2 KQ_max_scale_h2 = make_half2(
|
||||
KQ_max_scale[0], KQ_max_scale[0]);
|
||||
#pragma unroll
|
||||
@@ -818,7 +868,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;
|
||||
|
||||
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
|
||||
@@ -830,24 +880,38 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
|
||||
#if defined(LDMATRIX_TRANS_AVAILABLE)
|
||||
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
// MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg].
|
||||
// Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T.
|
||||
// Load with transposed addressing: 4 strided half loads.
|
||||
{
|
||||
const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2;
|
||||
const half * xs0_h = (const half *) xs0;
|
||||
const int stride_h = stride_tile_V * 2; // stride in half units
|
||||
half * A_h = (half *) A.x;
|
||||
#pragma unroll
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16];
|
||||
}
|
||||
}
|
||||
#else
|
||||
// TODO: Try to transpose tile_V when loading gmem to smem.
|
||||
// Use mma to transpose T_A_VKQ for RDNA.
|
||||
T_A_VKQ A_trans;
|
||||
load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
||||
mma(A, A_trans, A_identity);
|
||||
#endif // defined(TURING_MMA_AVAILABLE)
|
||||
#endif // defined(LDMATRIX_TRANS_AVAILABLE)
|
||||
if constexpr (T_B_KQ::I == 8) {
|
||||
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
|
||||
} else {
|
||||
// Wide version of VKQ_C is column-major.
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// RDNA matrix C is column-major.
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
// AMD matrix C is column-major.
|
||||
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
|
||||
#else
|
||||
// swap A and B for CUDA.
|
||||
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -866,7 +930,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
|
||||
}
|
||||
}
|
||||
#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
|
||||
if constexpr (nstages <= 1) {
|
||||
__syncthreads(); // Only needed if tile_K == tile_V.
|
||||
@@ -879,7 +943,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
tile_Q, tile_K, tile_V, tile_mask,
|
||||
Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
#if defined(TURING_MMA_AVAILABLE)
|
||||
@@ -899,7 +963,7 @@ template<> struct mma_tile_sizes<8> {
|
||||
using T_B_VKQ = tile< 8, 8, half2>; // column-major
|
||||
using T_C_VKQ = tile<16, 4, half2>; // row-major
|
||||
};
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
template<int ncols> struct mma_tile_sizes {
|
||||
using T_A_KQ = tile<16, 8, half2>; // row-major
|
||||
using T_B_KQ = tile<16, 8, half2>; // column-major
|
||||
@@ -944,9 +1008,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const int zt_gqa,
|
||||
const int kb0_start,
|
||||
const int kb0_stop) {
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
using T_A_KQ = typename mma_tile_sizes<ncols>::T_A_KQ;
|
||||
using T_B_KQ = typename mma_tile_sizes<ncols>::T_B_KQ;
|
||||
@@ -986,7 +1051,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
|
||||
#if defined(TURING_MMA_AVAILABLE)
|
||||
T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
|
||||
#else // Volta
|
||||
T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
|
||||
@@ -1004,10 +1069,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
// The loading is done with decreasing granularity for D for better memory bandwidth.
|
||||
const half2 scale_h2 = make_half2(scale, scale);
|
||||
#pragma unroll
|
||||
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
||||
const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
|
||||
for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
|
||||
const int k0_start = stride_k == warp_size ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
|
||||
const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k);
|
||||
const int stride_jc = WARP_SIZE / stride_k;
|
||||
const int stride_jc = warp_size / stride_k;
|
||||
|
||||
if (k0_start == k0_stop) {
|
||||
continue;
|
||||
@@ -1015,7 +1080,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
|
||||
#pragma unroll
|
||||
for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
|
||||
const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
||||
const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
||||
|
||||
if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
|
||||
break;
|
||||
@@ -1027,7 +1092,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
|
||||
#pragma unroll
|
||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
||||
|
||||
const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
|
||||
tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);
|
||||
@@ -1035,7 +1100,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
||||
|
||||
tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);
|
||||
}
|
||||
@@ -1127,6 +1192,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
// The partial sums are spread across 8/4 threads.
|
||||
constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
|
||||
constexpr int offset_last = cols_per_warp == 8 ? 4 : 1;
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
// The partial sums are spread across 4 threads (wavefront64, 16 cols).
|
||||
constexpr int offset_first = 32;
|
||||
constexpr int offset_last = 16;
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
// The partial sums are spread across 2 threads.
|
||||
constexpr int offset_first = 16;
|
||||
@@ -1140,7 +1209,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
#pragma unroll
|
||||
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
|
||||
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
|
||||
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, warp_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1189,7 +1258,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
||||
@@ -1249,7 +1318,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
|
||||
const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
|
||||
const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0);
|
||||
const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]);
|
||||
const bool thread_should_write = threadIdx.x / 16 < cols_per_thread;
|
||||
@@ -1283,14 +1352,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
// Warps with threadIdx.y % np != 0 must NOT return early.
|
||||
// All threads must return simultaneously to avoid race conditions with work on the next tile.
|
||||
|
||||
constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
|
||||
constexpr int nmeta = np*cols_per_warp >= warp_size ? np*cols_per_warp/warp_size : 1;
|
||||
|
||||
const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
|
||||
const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
|
||||
float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
|
||||
float2 meta[nmeta];
|
||||
#pragma unroll
|
||||
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
||||
meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2];
|
||||
meta[imeta] = meta_ptr[imeta * warp_size * tile_stride/2];
|
||||
}
|
||||
|
||||
float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
|
||||
@@ -1300,8 +1369,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
}
|
||||
#pragma unroll
|
||||
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
|
||||
if (offset < WARP_SIZE) {
|
||||
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
|
||||
if (offset < warp_size) {
|
||||
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, warp_size));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1318,8 +1387,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
}
|
||||
#pragma unroll
|
||||
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
|
||||
if (offset < WARP_SIZE) {
|
||||
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
|
||||
if (offset < warp_size) {
|
||||
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, warp_size);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1328,19 +1397,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
// Write back combined meta data:
|
||||
#pragma unroll
|
||||
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
||||
if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
|
||||
if (np*cols_per_warp >= warp_size || threadIdx.x < np*cols_per_warp) {
|
||||
// Combined KQ max scale + rowsum.
|
||||
meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
|
||||
meta_ptr[imeta * warp_size * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
|
||||
}
|
||||
}
|
||||
|
||||
// Combined KQ max + rowsum.
|
||||
static_assert(cols_per_warp <= WARP_SIZE);
|
||||
if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
|
||||
static_assert(cols_per_warp <= warp_size);
|
||||
if (needs_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
|
||||
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
||||
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
||||
}
|
||||
if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
|
||||
if (is_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
|
||||
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
||||
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
||||
}
|
||||
@@ -1388,10 +1457,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));
|
||||
|
||||
#pragma unroll
|
||||
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
||||
const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
|
||||
for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
|
||||
const int k0_start = stride_k == warp_size ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
|
||||
const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k);
|
||||
const int stride_jc = WARP_SIZE / stride_k;
|
||||
const int stride_jc = warp_size / stride_k;
|
||||
|
||||
if (k0_start == k0_stop) {
|
||||
continue;
|
||||
@@ -1399,7 +1468,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
|
||||
#pragma unroll
|
||||
for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
|
||||
const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
||||
const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
||||
|
||||
if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
|
||||
break;
|
||||
@@ -1417,7 +1486,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;
|
||||
#pragma unroll
|
||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
||||
|
||||
float2 dstk_val = make_float2(0.0f, 0.0f);
|
||||
#pragma unroll
|
||||
@@ -1453,7 +1522,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
|
||||
jt, kb0_start, kb0_stop);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view>
|
||||
@@ -1480,7 +1549,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
|
||||
@@ -1508,10 +1577,18 @@ static __global__ void flash_attn_ext_f16(
|
||||
}
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
if (DKQ != 64 && DKQ != 80 && DKQ != 96 && DKQ != 112 && DKQ != 128) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#endif // defined(AMD_MFMA_AVAILABLE)
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
||||
constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
|
||||
constexpr int nwarps = nthreads / WARP_SIZE;
|
||||
constexpr int nwarps = nthreads / warp_size;
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
|
||||
@@ -1624,7 +1701,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
ne31, ne32, ne33,
|
||||
nb31, nb32, nb33);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
|
||||
}
|
||||
|
||||
template <int DKQ, int DV, int ncols1, int ncols2>
|
||||
@@ -1644,7 +1721,8 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc);
|
||||
|
||||
const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
const int warp_size_host = ggml_cuda_info().devices[ctx.device].warp_size;
|
||||
const int nwarps = nthreads / warp_size_host;
|
||||
|
||||
constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu
|
||||
|
||||
@@ -1694,7 +1772,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
}
|
||||
|
||||
launch_fattn<DV, ncols1, ncols2>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true);
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true, warp_size_host);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -440,6 +440,18 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
}
|
||||
|
||||
// Use MFMA flash attention for CDNA (MI100+):
|
||||
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) {
|
||||
const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1);
|
||||
// MMA vs tile crossover benchmarked on MI300X @ d32768:
|
||||
// hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%)
|
||||
// hsk=128 (gqa=4): MMA wins at eff >= 128 (+4%)
|
||||
if (eff_nq >= (GGML_CUDA_CC_IS_CDNA1(cc) && Q->ne[0] == 64 ? 64 : 128)) {
|
||||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
}
|
||||
// Fall through to tile kernel for small effective batch sizes.
|
||||
}
|
||||
|
||||
// If there are no tensor cores available, use the generic tile kernel:
|
||||
if (can_use_vector_kernel) {
|
||||
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
|
||||
@@ -668,7 +668,7 @@ namespace ggml_cuda_mma {
|
||||
|
||||
return ret;
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
template <int I, int J>
|
||||
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
||||
tile<I, J/2, half2> ret;
|
||||
@@ -964,6 +964,34 @@ namespace ggml_cuda_mma {
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(RDNA4)
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
// MFMA: FP16 input, FP32 accumulate, convert back to half2.
|
||||
using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
|
||||
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
||||
|
||||
// Convert existing half2 accumulator to float for MFMA:
|
||||
floatx4_t acc_f32;
|
||||
{
|
||||
const halfx4_t acc_h = reinterpret_cast<const halfx4_t&>(D.x[0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
acc_f32[i] = (float)acc_h[i];
|
||||
}
|
||||
}
|
||||
|
||||
const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]);
|
||||
const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]);
|
||||
acc_f32 = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_f32, 0, 0, 0);
|
||||
|
||||
// Convert back to half2:
|
||||
{
|
||||
halfx4_t result_h;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
result_h[i] = (_Float16)acc_f32[i];
|
||||
}
|
||||
reinterpret_cast<halfx4_t&>(D.x[0]) = result_h;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
|
||||
@@ -108,6 +108,8 @@ set(GGML_OPENCL_KERNELS
|
||||
mul_mm_q8_0_f32_l4_lm
|
||||
mul_mm_q6_k_f32_l4_lm
|
||||
mul_mm_q8_0_f32_8x4
|
||||
gemv_noshuffle_q4_1_f32
|
||||
gemm_noshuffle_q4_1_f32
|
||||
gemv_noshuffle_general_q8_0_f32
|
||||
mul
|
||||
norm
|
||||
|
||||
@@ -531,6 +531,8 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
|
||||
cl_kernel kernel_convert_block_q4_0_noshuffle;
|
||||
cl_kernel kernel_restore_block_q4_0_noshuffle;
|
||||
cl_kernel kernel_convert_block_q4_1_noshuffle;
|
||||
cl_kernel kernel_restore_block_q4_1_noshuffle;
|
||||
cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
|
||||
cl_kernel kernel_mul_mv_q4_1_f32;
|
||||
@@ -683,7 +685,9 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_transpose_32;
|
||||
cl_kernel kernel_transpose_32_16;
|
||||
cl_kernel kernel_transpose_16;
|
||||
cl_kernel kernel_transpose_8_buf;
|
||||
cl_kernel kernel_transpose_16_buf;
|
||||
cl_kernel kernel_transpose_32_buf;
|
||||
cl_kernel kernel_transpose_16_4x1;
|
||||
|
||||
// Gemm and Gemv related programs, kernels, etc
|
||||
@@ -699,6 +703,8 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096;
|
||||
cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096;
|
||||
cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096;
|
||||
cl_kernel kernel_gemv_noshuffle_q4_1_f32;
|
||||
cl_kernel kernel_gemm_noshuffle_q4_1_f32;
|
||||
cl_kernel kernel_mul_mm_q8_0_f32_8x4;
|
||||
cl_kernel CL_mul_mat_vec_q8_0_f32;
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
@@ -893,6 +899,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0_noshuffle", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q4_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1_noshuffle", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q4_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_noshuffle", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err));
|
||||
@@ -2258,7 +2266,9 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32_16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_transpose_32 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_transpose_8_buf = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_8_buf", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_transpose_16_buf = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16_buf", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_transpose_32_buf = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32_buf", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_transpose_16_4x1 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16_4x1", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
@@ -2378,6 +2388,45 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// gemm_noshuffle_q4_1_f32
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "gemm_noshuffle_q4_1_f32.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("gemm_noshuffle_q4_1_f32.cl");
|
||||
#endif
|
||||
cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q4_1_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q4_1_f32", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// gemv_noshuffle_q4_1_f32
|
||||
{
|
||||
std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std +
|
||||
" -cl-mad-enable ";
|
||||
if (backend_ctx->has_vector_subgroup_broadcast) {
|
||||
CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT ";
|
||||
}
|
||||
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "gemv_noshuffle_q4_1_f32.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("gemv_noshuffle_q4_1_f32.cl");
|
||||
#endif
|
||||
|
||||
cl_program prog = build_program_from_source(
|
||||
backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_1_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_1_f32", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mm_q8_0_f32_8x4
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -2413,7 +2462,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
cl_program prog = build_program_from_source(
|
||||
backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->CL_mul_mat_vec_q8_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle", &err), err));
|
||||
CL_CHECK((backend_ctx->CL_mul_mat_vec_q8_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q8_0_f32", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
@@ -2923,6 +2972,82 @@ static void ggml_cl2_free(ggml_backend_t backend) {
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
static void transpose_2d(
|
||||
ggml_backend_opencl_context * backend_ctx,
|
||||
cl_kernel kernel,
|
||||
cl_mem src, cl_mem dst, size_t size,
|
||||
cl_int stride, cl_int rows,
|
||||
bool blocking = true
|
||||
) {
|
||||
static ggml_cl_buffer buf;
|
||||
|
||||
cl_event evt;
|
||||
cl_int err;
|
||||
|
||||
buf.allocate(backend_ctx->context, size);
|
||||
|
||||
cl_mem trans;
|
||||
cl_buffer_region region;
|
||||
|
||||
region.origin = 0;
|
||||
region.size = size;
|
||||
CL_CHECK((trans = clCreateSubBuffer(
|
||||
buf.buffer, CL_MEM_READ_WRITE,
|
||||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err));
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &src));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &trans));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_int), &stride));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &rows));
|
||||
|
||||
size_t local_size[3] = {64, 1, 1};
|
||||
size_t global_size[3] = {(size_t)stride, (size_t)rows, 1};;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 3, NULL,
|
||||
global_size, local_size, 0, NULL, NULL));
|
||||
|
||||
if (blocking) {
|
||||
CL_CHECK(clEnqueueCopyBuffer(backend_ctx->queue, trans, dst, 0, 0, size, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
CL_CHECK(clReleaseEvent(evt));
|
||||
} else {
|
||||
CL_CHECK(clEnqueueCopyBuffer(backend_ctx->queue, trans, dst, 0, 0, size, 0, NULL, NULL));
|
||||
}
|
||||
|
||||
CL_CHECK(clReleaseMemObject(trans));
|
||||
}
|
||||
|
||||
static void transpose_2d_as_8b(
|
||||
ggml_backend_opencl_context * backend_ctx,
|
||||
cl_mem src, cl_mem dst, size_t size,
|
||||
cl_int stride, cl_int rows,
|
||||
bool blocking = true
|
||||
) {
|
||||
transpose_2d(backend_ctx, backend_ctx->kernel_transpose_8_buf,
|
||||
src, dst, size, stride, rows, blocking);
|
||||
}
|
||||
|
||||
static void transpose_2d_as_16b(
|
||||
ggml_backend_opencl_context * backend_ctx,
|
||||
cl_mem src, cl_mem dst, size_t size,
|
||||
cl_int stride, cl_int rows,
|
||||
bool blocking = true
|
||||
) {
|
||||
transpose_2d(backend_ctx, backend_ctx->kernel_transpose_16_buf,
|
||||
src, dst, size, stride, rows, blocking);
|
||||
}
|
||||
|
||||
static void transpose_2d_as_32b(
|
||||
ggml_backend_opencl_context * backend_ctx,
|
||||
cl_mem src, cl_mem dst, size_t size,
|
||||
cl_int stride, cl_int rows,
|
||||
bool blocking = true
|
||||
) {
|
||||
transpose_2d(backend_ctx, backend_ctx->kernel_transpose_32_buf,
|
||||
src, dst, size, stride, rows, blocking);
|
||||
}
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Tensor extra management
|
||||
//------------------------------------------------------------------------------
|
||||
@@ -4271,7 +4396,15 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1;
|
||||
|
||||
if (use_adreno_kernels(backend_ctx, tensor)) {
|
||||
kernel = backend_ctx->kernel_convert_block_q4_1_noshuffle;
|
||||
}
|
||||
#else
|
||||
cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1;
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d));
|
||||
@@ -4287,6 +4420,22 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||
|
||||
tensor->extra = extra;
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
if (use_adreno_kernels(backend_ctx, tensor)) {
|
||||
|
||||
int M = tensor->ne[1];
|
||||
int K = tensor->ne[0];
|
||||
|
||||
GGML_ASSERT(K % 32 == 0);
|
||||
|
||||
// Transpose q as ushort
|
||||
transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M);
|
||||
// Transpose d as ushort
|
||||
transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M);
|
||||
// Transpose m as ushort
|
||||
transpose_2d_as_16b(backend_ctx, extra->m, extra->m, size_m, K/32, M);
|
||||
}
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
return;
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_MXFP4) {
|
||||
@@ -4795,6 +4944,53 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
||||
if (tensor->type == GGML_TYPE_Q4_1) {
|
||||
ggml_tensor_extra_cl_q4_1 * extra = (ggml_tensor_extra_cl_q4_1 *)tensor->extra;
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
if (use_adreno_kernels(backend_ctx, tensor)) {
|
||||
static ggml_cl_buffer buf_trans_q;
|
||||
static ggml_cl_buffer buf_trans_m;
|
||||
static ggml_cl_buffer buf_trans_d;
|
||||
static ggml_cl_buffer buf_unpacked;
|
||||
|
||||
cl_int M = tensor->ne[1];
|
||||
cl_int K = tensor->ne[0];
|
||||
|
||||
GGML_ASSERT(K % ggml_blck_size(tensor->type) == 0);
|
||||
|
||||
size_t size_q = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*ggml_blck_size(tensor->type)/2;
|
||||
size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t);
|
||||
size_t size_m = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t);
|
||||
GGML_ASSERT(size_d + size_q + size_m == ggml_nbytes(tensor) && "Incorrect tensor size");
|
||||
|
||||
buf_trans_q.allocate(backend_ctx->context, size_q);
|
||||
buf_trans_m.allocate(backend_ctx->context, size_m);
|
||||
buf_trans_d.allocate(backend_ctx->context, size_d);
|
||||
buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor));
|
||||
|
||||
// transpose q, d, m back
|
||||
transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4);
|
||||
transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32);
|
||||
transpose_2d_as_16b(backend_ctx, extra->m, buf_trans_m.buffer, size_m, M, K/32);
|
||||
|
||||
cl_uchar mask_0F = 0x0F;
|
||||
cl_uchar mask_F0 = 0xF0;
|
||||
|
||||
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
|
||||
size_t local_work_size[] = {1, 1, 1};
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1_noshuffle;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d.buffer));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_m.buffer));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_unpacked.buffer));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_0F));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_F0));
|
||||
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
|
||||
CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL));
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
cl_int err;
|
||||
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
|
||||
ggml_nbytes(tensor), NULL, &err);
|
||||
@@ -4886,8 +5082,8 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
||||
|
||||
int ne00 = tensor->ne[0];
|
||||
int ne01 = tensor->ne[1];
|
||||
GGML_ASSERT(tensor->ne[2] == 1); // ???
|
||||
GGML_ASSERT(tensor->ne[3] == 1); // ???
|
||||
GGML_ASSERT(tensor->ne[2] == 1);
|
||||
GGML_ASSERT(tensor->ne[3] == 1);
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d));
|
||||
@@ -8371,6 +8567,180 @@ static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_ten
|
||||
CL_CHECK(clReleaseMemObject(D_sub_buffer));
|
||||
}
|
||||
|
||||
static void ggml_cl_mul_mat_q4_1_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
GGML_ASSERT(src1);
|
||||
GGML_ASSERT(src1->extra);
|
||||
GGML_ASSERT(dst);
|
||||
GGML_ASSERT(dst->extra);
|
||||
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra;
|
||||
|
||||
cl_ulong offset1 = extra1->offset + src1->view_offs;
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
|
||||
const int ne00 = src0->ne[0];
|
||||
const int ne01 = src0->ne[1];
|
||||
|
||||
const int ne1 = dst->ne[1];
|
||||
|
||||
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
||||
|
||||
cl_context context = backend_ctx->context;
|
||||
cl_kernel kernel;
|
||||
|
||||
cl_int err;
|
||||
cl_image_format img_fmt;
|
||||
cl_image_desc img_desc;
|
||||
cl_buffer_region region;
|
||||
|
||||
int M = ne01;
|
||||
int N = ne1;
|
||||
int K = ne00;
|
||||
|
||||
if (ne1 == 1) {
|
||||
cl_mem q_img = nullptr;
|
||||
cl_mem b_sub_buf = nullptr;
|
||||
cl_mem b_img = nullptr;
|
||||
|
||||
// image for q
|
||||
img_fmt = { CL_R, CL_UNSIGNED_INT32};
|
||||
memset(&img_desc, 0, sizeof(img_desc));
|
||||
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc.image_width = M * K / 2 / 4;
|
||||
img_desc.buffer = extra0_q4_1->q;
|
||||
CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
|
||||
|
||||
// subbuffer for activations
|
||||
region.origin = offset1;
|
||||
region.size = K * N * sizeof(float);
|
||||
CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err));
|
||||
|
||||
// image for activations
|
||||
img_fmt = {CL_RGBA, CL_FLOAT};
|
||||
memset(&img_desc, 0, sizeof(img_desc));
|
||||
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc.image_width = K * N / 4;
|
||||
img_desc.buffer = b_sub_buf;
|
||||
CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
|
||||
|
||||
kernel = backend_ctx->kernel_gemv_noshuffle_q4_1_f32;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &b_img));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne01));
|
||||
|
||||
size_t local_work_size[3] = {64, 4, 1};
|
||||
size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
|
||||
CL_CHECK(clReleaseMemObject(q_img));
|
||||
CL_CHECK(clReleaseMemObject(b_sub_buf));
|
||||
CL_CHECK(clReleaseMemObject(b_img));
|
||||
} else {
|
||||
cl_mem b_sub_buf = nullptr;
|
||||
cl_mem b_sub_buf_trans = nullptr;
|
||||
cl_mem b_img = nullptr;
|
||||
cl_mem b_img_trans = nullptr;
|
||||
|
||||
// subbuffer for activations
|
||||
region.origin = offset1;
|
||||
region.size = K * N * sizeof(float);
|
||||
CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err));
|
||||
|
||||
// image for activations
|
||||
img_fmt = {CL_RGBA, CL_FLOAT};
|
||||
memset(&img_desc, 0, sizeof(img_desc));
|
||||
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc.image_width = K * N / 4;
|
||||
img_desc.buffer = b_sub_buf;
|
||||
CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
|
||||
|
||||
// pad N to multiple of 8
|
||||
int extra_elements = N % 8;
|
||||
int padding = 0;
|
||||
if (extra_elements > 0){
|
||||
padding = 8 - extra_elements;
|
||||
}
|
||||
|
||||
// subbuffer for transposed activations
|
||||
region.origin = 0;
|
||||
region.size = K * (N + padding) * sizeof(float)/2;
|
||||
backend_ctx->prealloc_act_trans.allocate(context, region.size);
|
||||
CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err));
|
||||
|
||||
// image for transposed activations
|
||||
img_fmt = {CL_RGBA, CL_HALF_FLOAT};
|
||||
memset(&img_desc, 0, sizeof(img_desc));
|
||||
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc.image_width = K * (N + padding) / 4;
|
||||
img_desc.buffer = b_sub_buf_trans;
|
||||
CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err));
|
||||
|
||||
// transpose activations
|
||||
int height_B = N/4;
|
||||
if (height_B == 0) {
|
||||
height_B = 1;
|
||||
}
|
||||
int width_B = K/4;
|
||||
int padded_height_B = (N + padding)/4;
|
||||
|
||||
kernel = backend_ctx->kernel_transpose_32_16;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B));
|
||||
|
||||
size_t local_work_size_t[2] = { 1, 16 };
|
||||
size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B };
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst);
|
||||
|
||||
// gemm
|
||||
kernel = backend_ctx->kernel_gemm_noshuffle_q4_1_f32;
|
||||
int padded_N = N + padding;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &b_img_trans));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &padded_N));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &ne1));
|
||||
|
||||
size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1};
|
||||
size_t local_work_size[3] = {1, 128, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
|
||||
CL_CHECK(clReleaseMemObject(b_sub_buf));
|
||||
CL_CHECK(clReleaseMemObject(b_sub_buf_trans));
|
||||
CL_CHECK(clReleaseMemObject(b_img));
|
||||
CL_CHECK(clReleaseMemObject(b_img_trans));
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(backend);
|
||||
GGML_UNUSED(src0);
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
GGML_ASSERT(src0);
|
||||
@@ -8736,6 +9106,16 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
int padding;
|
||||
// <--------------------------------------------> //
|
||||
|
||||
// NOTE: Kernels using image1d_buffer_t (e.g., src0_q) would normally require
|
||||
// a limit check, but q4_0 / q4_1 tensors are very unlikely to exceed that
|
||||
// limit, so the check is omitted.
|
||||
|
||||
// q4_1 x fp32
|
||||
if (src0t == GGML_TYPE_Q4_1 && src1t == GGML_TYPE_F32) {
|
||||
ggml_cl_mul_mat_q4_1_f32_adreno(backend, src0, src1, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
// q8_0 x fp32
|
||||
if (src0t == GGML_TYPE_Q8_0 && src1t == GGML_TYPE_F32 &&
|
||||
enable_adreno_trans_weight(backend_ctx, src0)) {
|
||||
|
||||
@@ -199,6 +199,58 @@ kernel void kernel_restore_block_q4_1(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_convert_block_q4_1_noshuffle(
|
||||
global struct block_q4_1 * src0,
|
||||
global uchar * dst_q,
|
||||
global half * dst_d,
|
||||
global half * dst_m
|
||||
) {
|
||||
global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0);
|
||||
global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0);
|
||||
global half * d = (global half *) dst_d + get_global_id(0);
|
||||
global half * m = (global half *) dst_m + get_global_id(0);
|
||||
|
||||
*d = b->d;
|
||||
*m = b->m;
|
||||
for (int i = 0; i < QK4_1/4; ++i) {
|
||||
uchar x0 = b->qs[2*i + 0];
|
||||
uchar x1 = b->qs[2*i + 1];
|
||||
|
||||
q[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
|
||||
q[i + QK4_1/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
|
||||
|
||||
#ifdef ADRENO_GPU
|
||||
if (get_global_id(0) == 65536*4096) {
|
||||
printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_q4_1_noshuffle(
|
||||
global uchar * src_q,
|
||||
global half * src_d,
|
||||
global half * src_m,
|
||||
global struct block_q4_1 * dst,
|
||||
uchar mask_0F,
|
||||
uchar mask_F0
|
||||
) {
|
||||
global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0);
|
||||
global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0);
|
||||
global half * d = (global half *) src_d + get_global_id(0);
|
||||
global half * m = (global half *) src_m + get_global_id(0);
|
||||
|
||||
b->d = *d;
|
||||
b->m = *m;
|
||||
for (int i = 0; i < QK4_1/4; ++i) {
|
||||
uchar x0 = q[i + 0 ] ;
|
||||
uchar x1 = q[i + QK4_1/4];
|
||||
|
||||
b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4));
|
||||
b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0));
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_mxfp4
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
|
||||
#ifdef cl_qcom_reqd_sub_group_size
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_128
|
||||
#endif
|
||||
|
||||
kernel void kernel_gemm_noshuffle_q4_1_f32(
|
||||
global const ushort * src0_q,
|
||||
global const half * src0_d,
|
||||
global const half * src0_m,
|
||||
read_only image1d_buffer_t src1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
int n_no_padding
|
||||
) {
|
||||
dst = (global float *)((global char *)dst + offsetd);
|
||||
|
||||
int m_4 = m >> 2;
|
||||
int n_4 = n >> 2;
|
||||
|
||||
int gy = get_global_id(0);
|
||||
int gx = get_global_id(1);
|
||||
int gx_2 = gx << 2;
|
||||
|
||||
half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0;
|
||||
half8 B;
|
||||
half4 dequantized_weights;
|
||||
|
||||
global const ushort* weight_ptr = src0_q + gx_2;
|
||||
global const half* scale_ptr = src0_d + gx_2;
|
||||
global const half* min_ptr = src0_m + gx_2;
|
||||
|
||||
for(int i = 0; i < k; i += 4) {
|
||||
B.s0123 = read_imageh(src1, gy*2 + (i)*(n_4));
|
||||
B.s4567 = read_imageh(src1, gy*2 + (i)*(n_4)+1);
|
||||
|
||||
ushort4 bits4 = vload4(0, weight_ptr + (i/4)*(m));
|
||||
|
||||
half4 scale = vload4(0, scale_ptr + (i/32)*(m));
|
||||
half4 minv = vload4(0, min_ptr + (i/32)*(m));
|
||||
|
||||
// j=0
|
||||
dequantized_weights.s0 = (bits4.s0 & (0x000F)) * scale.s0 + minv.s0;
|
||||
dequantized_weights.s1 = (bits4.s1 & (0x000F)) * scale.s1 + minv.s1;
|
||||
dequantized_weights.s2 = (bits4.s2 & (0x000F)) * scale.s2 + minv.s2;
|
||||
dequantized_weights.s3 = (bits4.s3 & (0x000F)) * scale.s3 + minv.s3;
|
||||
c0 += B * dequantized_weights.s0;
|
||||
c1 += B * dequantized_weights.s1;
|
||||
c2 += B * dequantized_weights.s2;
|
||||
c3 += B * dequantized_weights.s3;
|
||||
|
||||
// j=1
|
||||
B.s0123 = read_imageh(src1, gy*2 + (i+1)*(n_4));
|
||||
B.s4567 = read_imageh(src1, gy*2 + (i+1)*(n_4)+1);
|
||||
dequantized_weights.s0 = ((bits4.s0 & (0x00F0)) >> 4) * scale.s0 + minv.s0;
|
||||
dequantized_weights.s1 = ((bits4.s1 & (0x00F0)) >> 4) * scale.s1 + minv.s1;
|
||||
dequantized_weights.s2 = ((bits4.s2 & (0x00F0)) >> 4) * scale.s2 + minv.s2;
|
||||
dequantized_weights.s3 = ((bits4.s3 & (0x00F0)) >> 4) * scale.s3 + minv.s3;
|
||||
c0 += B * dequantized_weights.s0;
|
||||
c1 += B * dequantized_weights.s1;
|
||||
c2 += B * dequantized_weights.s2;
|
||||
c3 += B * dequantized_weights.s3;
|
||||
|
||||
// j=2
|
||||
B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4));
|
||||
B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1);
|
||||
dequantized_weights.s0 = ((bits4.s0 & (0x0F00)) >> 8) * scale.s0 + minv.s0;
|
||||
dequantized_weights.s1 = ((bits4.s1 & (0x0F00)) >> 8) * scale.s1 + minv.s1;
|
||||
dequantized_weights.s2 = ((bits4.s2 & (0x0F00)) >> 8) * scale.s2 + minv.s2;
|
||||
dequantized_weights.s3 = ((bits4.s3 & (0x0F00)) >> 8) * scale.s3 + minv.s3;
|
||||
c0 += B * dequantized_weights.s0;
|
||||
c1 += B * dequantized_weights.s1;
|
||||
c2 += B * dequantized_weights.s2;
|
||||
c3 += B * dequantized_weights.s3;
|
||||
|
||||
// j=3
|
||||
B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4));
|
||||
B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1);
|
||||
dequantized_weights.s0 = ((bits4.s0 & (0xF000)) >> 12) * scale.s0 + minv.s0;
|
||||
dequantized_weights.s1 = ((bits4.s1 & (0xF000)) >> 12) * scale.s1 + minv.s1;
|
||||
dequantized_weights.s2 = ((bits4.s2 & (0xF000)) >> 12) * scale.s2 + minv.s2;
|
||||
dequantized_weights.s3 = ((bits4.s3 & (0xF000)) >> 12) * scale.s3 + minv.s3;
|
||||
c0 += B * dequantized_weights.s0;
|
||||
c1 += B * dequantized_weights.s1;
|
||||
c2 += B * dequantized_weights.s2;
|
||||
c3 += B * dequantized_weights.s3;
|
||||
}
|
||||
|
||||
int idx = (gy<<3)*m + (gx<<2);
|
||||
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
|
||||
}
|
||||
}
|
||||
@@ -121,7 +121,7 @@
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
__kernel void kernel_gemv_noshuffle(
|
||||
__kernel void kernel_gemv_noshuffle_q8_0_f32(
|
||||
__read_only image1d_buffer_t src0_q, // quantized A
|
||||
global half * src0_d, // A scales
|
||||
__read_only image1d_buffer_t src1, // B
|
||||
|
||||
@@ -0,0 +1,283 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
|
||||
#ifdef cl_qcom_reqd_sub_group_size
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#endif
|
||||
|
||||
#define QK4_0 32
|
||||
#define NSUBGROUPS 4
|
||||
#define SUBGROUP_SIZE 64
|
||||
|
||||
#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, minv, y) \
|
||||
float shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s0, 0); \
|
||||
total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 0); \
|
||||
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 0); \
|
||||
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 0); \
|
||||
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s4, 0); \
|
||||
total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 0); \
|
||||
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 0); \
|
||||
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 0); \
|
||||
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s0, 1); \
|
||||
total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 1); \
|
||||
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 1); \
|
||||
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 1); \
|
||||
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s4, 1); \
|
||||
total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 1); \
|
||||
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 1); \
|
||||
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 1); \
|
||||
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \
|
||||
|
||||
|
||||
#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, minv, y) \
|
||||
shared_y = sub_group_broadcast(y.s0, 2); \
|
||||
total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 2); \
|
||||
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 2); \
|
||||
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 2); \
|
||||
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s4, 2); \
|
||||
total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 2); \
|
||||
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 2); \
|
||||
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 2); \
|
||||
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s0, 3); \
|
||||
total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 3); \
|
||||
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 3); \
|
||||
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 3); \
|
||||
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s4, 3); \
|
||||
total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 3); \
|
||||
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 3); \
|
||||
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 3); \
|
||||
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \
|
||||
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \
|
||||
|
||||
|
||||
#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, minv, y) \
|
||||
float8 shared_y; \
|
||||
shared_y = sub_group_broadcast(y, 0); \
|
||||
total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 + minv.s0) * shared_y.s0; \
|
||||
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s1; \
|
||||
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s2; \
|
||||
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \
|
||||
total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 + minv.s0) * shared_y.s4; \
|
||||
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s5; \
|
||||
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s6; \
|
||||
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \
|
||||
total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 + minv.s1) * shared_y.s0; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s1; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s2; \
|
||||
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \
|
||||
total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 + minv.s1) * shared_y.s4; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s5; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s6; \
|
||||
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \
|
||||
shared_y = sub_group_broadcast(y, 1); \
|
||||
total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 + minv.s0) * shared_y.s0; \
|
||||
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s1; \
|
||||
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s2; \
|
||||
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \
|
||||
total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 + minv.s0) * shared_y.s4; \
|
||||
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s5; \
|
||||
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s6; \
|
||||
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \
|
||||
total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 + minv.s1) * shared_y.s0; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s1; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s2; \
|
||||
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \
|
||||
total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 + minv.s1) * shared_y.s4; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s5; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s6; \
|
||||
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \
|
||||
|
||||
|
||||
#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, minv, y) \
|
||||
shared_y = sub_group_broadcast(y, 2); \
|
||||
total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 + minv.s0) * shared_y.s0; \
|
||||
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s1; \
|
||||
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s2; \
|
||||
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \
|
||||
total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 + minv.s0) * shared_y.s4; \
|
||||
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s5; \
|
||||
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s6; \
|
||||
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \
|
||||
total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 + minv.s1) * shared_y.s0; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s1; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s2; \
|
||||
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \
|
||||
total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 + minv.s1) * shared_y.s4; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s5; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s6; \
|
||||
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \
|
||||
shared_y = sub_group_broadcast(y, 3); \
|
||||
total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 + minv.s0) * shared_y.s0; \
|
||||
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s1; \
|
||||
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s2; \
|
||||
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \
|
||||
total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 + minv.s0) * shared_y.s4; \
|
||||
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s5; \
|
||||
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s6; \
|
||||
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \
|
||||
total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 + minv.s1) * shared_y.s0; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s1; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s2; \
|
||||
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \
|
||||
total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 + minv.s1) * shared_y.s4; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s5; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s6; \
|
||||
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \
|
||||
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_gemv_noshuffle_q4_1_f32(
|
||||
read_only image1d_buffer_t src0_q,
|
||||
global half2 * src0_d,
|
||||
global half2 * src0_m,
|
||||
read_only image1d_buffer_t src1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01)
|
||||
{
|
||||
uint groupId = get_local_id(1);
|
||||
uint gid = get_global_id(0);
|
||||
ushort slid = get_sub_group_local_id();
|
||||
|
||||
uint K = ne00;
|
||||
uint M = ne01;
|
||||
|
||||
uint LINE_STRIDE_A = M / 2;
|
||||
uint BLOCK_STRIDE_A = NSUBGROUPS * M;
|
||||
|
||||
private uint4 regA;
|
||||
private half2 regS;
|
||||
private half2 regM;
|
||||
private float8 regB;
|
||||
|
||||
private float2 totalSum = (float2)(0.0f);
|
||||
|
||||
// loop along K in block granularity, skip 4 blocks every iter
|
||||
for (uint k = groupId; k < (K / QK4_0); k += NSUBGROUPS) {
|
||||
regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows
|
||||
regM = src0_m[gid + k * LINE_STRIDE_A]; // each fiber loads min of two rows
|
||||
// first 4 fibers in each wave load 8 B values to its private scope
|
||||
if (slid < 4) {
|
||||
regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));
|
||||
regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));
|
||||
}
|
||||
|
||||
// load half weights for two blocks in consecutive rows
|
||||
regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
|
||||
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
|
||||
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
|
||||
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
|
||||
#ifdef VECTOR_SUB_GROUP_BROADCAT
|
||||
dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regM, regB);
|
||||
#else
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regM, regB);
|
||||
#endif // VECTOR_SUB_GROUP_BROADCAT
|
||||
|
||||
regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
|
||||
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
|
||||
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
|
||||
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
|
||||
#ifdef VECTOR_SUB_GROUP_BROADCAT
|
||||
dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regM, regB);
|
||||
#else
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regM, regB);
|
||||
#endif // VECTOR_SUB_GROUP_BROADCAT
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #wave=4
|
||||
local float2 reduceLM[SUBGROUP_SIZE * 3];
|
||||
if (groupId == 1) {
|
||||
reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum;
|
||||
}
|
||||
if (groupId == 2) {
|
||||
reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum;
|
||||
}
|
||||
if (groupId == 3) {
|
||||
reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if (groupId == 0) {
|
||||
totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid];
|
||||
}
|
||||
if (groupId == 0) {
|
||||
totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid];
|
||||
}
|
||||
if (groupId == 0) {
|
||||
totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid];
|
||||
}
|
||||
|
||||
// 2 outputs per fiber in wave 0
|
||||
if (groupId == 0) {
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
vstore2(totalSum, 0, &(dst[gid * 2]));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -44,6 +44,19 @@ kernel void kernel_transpose_16_4x1(
|
||||
write_imageh(output, i * rows + j, (half4)(temp0, temp1, temp2, temp3));
|
||||
}
|
||||
|
||||
// Transpose treating each element as 8-bit using buffer
|
||||
kernel void kernel_transpose_8_buf(
|
||||
global const uchar * input,
|
||||
global uchar * output,
|
||||
const int ldi,
|
||||
const int ldo
|
||||
) {
|
||||
const int x = get_global_id(0);
|
||||
const int y = get_global_id(1);
|
||||
|
||||
output[x*ldo + y] = input[y*ldi + x];
|
||||
}
|
||||
|
||||
// Transpose treating each element as 16-bit using buffer
|
||||
kernel void kernel_transpose_16_buf(
|
||||
global const ushort * input,
|
||||
@@ -57,6 +70,19 @@ kernel void kernel_transpose_16_buf(
|
||||
output[x*ldo + y] = input[y*ldi + x];
|
||||
}
|
||||
|
||||
// Transpose treating each element as 32-bit using buffer
|
||||
kernel void kernel_transpose_32_buf(
|
||||
global const uint * input,
|
||||
global uint * output,
|
||||
const int ldi,
|
||||
const int ldo
|
||||
) {
|
||||
const int x = get_global_id(0);
|
||||
const int y = get_global_id(1);
|
||||
|
||||
output[x*ldo + y] = input[y*ldi + x];
|
||||
}
|
||||
|
||||
// 32-bit transpose, loading/storing a 4x4 tile of elements
|
||||
kernel void kernel_transpose_32(
|
||||
__read_only image1d_buffer_t input,
|
||||
|
||||
@@ -590,6 +590,7 @@ struct vk_device_struct {
|
||||
vk_queue transfer_queue;
|
||||
bool single_queue;
|
||||
bool support_async;
|
||||
bool async_use_transfer_queue;
|
||||
uint32_t subgroup_size;
|
||||
uint32_t subgroup_size_log2;
|
||||
uint32_t shader_core_count;
|
||||
@@ -1858,6 +1859,10 @@ struct ggml_backend_vk_context {
|
||||
|
||||
vk_context_ref compute_ctx;
|
||||
|
||||
vk_context_ref transfer_ctx;
|
||||
vk_semaphore transfer_semaphore;
|
||||
uint64_t transfer_semaphore_last_submitted {};
|
||||
|
||||
std::vector<vk_context_ref> tensor_ctxs;
|
||||
|
||||
std::vector<vk::DescriptorPool> descriptor_pools;
|
||||
@@ -1866,6 +1871,7 @@ struct ggml_backend_vk_context {
|
||||
uint32_t pipeline_descriptor_set_requirements {};
|
||||
|
||||
vk_command_pool compute_cmd_pool;
|
||||
vk_command_pool transfer_cmd_pool;
|
||||
|
||||
// number of additional consecutive nodes that are being fused with the
|
||||
// node currently being processed
|
||||
@@ -5391,13 +5397,19 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
|
||||
ggml_vk_load_shaders(device);
|
||||
|
||||
const bool prefers_transfer_queue = device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN;
|
||||
|
||||
if (!device->single_queue) {
|
||||
const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
|
||||
ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true);
|
||||
|
||||
device->async_use_transfer_queue = prefers_transfer_queue || (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr);
|
||||
} else {
|
||||
// TODO: Use pointer or reference to avoid copy
|
||||
device->transfer_queue.copyFrom(device->compute_queue);
|
||||
device->transfer_queue.cmd_pool.init(device, &device->transfer_queue);
|
||||
|
||||
device->async_use_transfer_queue = false;
|
||||
}
|
||||
|
||||
device->buffer_type = {
|
||||
@@ -5871,6 +5883,15 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
|
||||
ctx->almost_ready_fence = ctx->device->device.createFence({});
|
||||
|
||||
ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue);
|
||||
if (ctx->device->async_use_transfer_queue) {
|
||||
vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
|
||||
vk::SemaphoreCreateInfo ci{};
|
||||
ci.setPNext(&tci);
|
||||
ctx->transfer_semaphore.s = ctx->device->device.createSemaphore(ci);
|
||||
ctx->transfer_semaphore.value = 0;
|
||||
|
||||
ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue);
|
||||
}
|
||||
|
||||
if (vk_perf_logger_enabled) {
|
||||
ctx->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger());
|
||||
@@ -6419,6 +6440,47 @@ static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) {
|
||||
subctx->s = subctx->seqs[subctx->seqs.size() - 1].data();
|
||||
}
|
||||
|
||||
static vk_context ggml_vk_get_compute_ctx(ggml_backend_vk_context * ctx) {
|
||||
if (!ctx->compute_ctx.expired()) {
|
||||
return ctx->compute_ctx.lock();
|
||||
}
|
||||
|
||||
vk_context result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
|
||||
ctx->compute_ctx = result;
|
||||
ggml_vk_ctx_begin(ctx->device, result);
|
||||
|
||||
if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) {
|
||||
result->s->wait_semaphores.push_back(ctx->transfer_semaphore);
|
||||
ctx->transfer_semaphore_last_submitted = ctx->transfer_semaphore.value;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Submit any pending transfer queue work and signal the transfer semaphore.
|
||||
// The next compute context created via ggml_vk_get_compute_ctx will wait on this semaphore.
|
||||
// Returns true if work was submitted.
|
||||
static bool ggml_vk_submit_transfer_ctx(ggml_backend_vk_context * ctx) {
|
||||
if (!ctx->device->async_use_transfer_queue || ctx->transfer_ctx.expired()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
vk_context cpy_ctx = ctx->transfer_ctx.lock();
|
||||
ggml_vk_ctx_end(cpy_ctx);
|
||||
|
||||
for (auto& cpy : cpy_ctx->in_memcpys) {
|
||||
memcpy(cpy.dst, cpy.src, cpy.n);
|
||||
}
|
||||
|
||||
ctx->transfer_semaphore.value++;
|
||||
cpy_ctx->seqs.back().back().signal_semaphores.push_back(ctx->transfer_semaphore);
|
||||
|
||||
ggml_vk_submit(cpy_ctx, {});
|
||||
ctx->transfer_ctx.reset();
|
||||
return true;
|
||||
}
|
||||
|
||||
static size_t ggml_vk_align_size(size_t width, size_t align) {
|
||||
VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")");
|
||||
return CEIL_DIV(width, align) * align;
|
||||
@@ -7512,6 +7574,18 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
|
||||
return false;
|
||||
}
|
||||
|
||||
if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) {
|
||||
// Intel Windows proprietary driver tuning
|
||||
switch (src0_type) {
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
return false;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
switch (src0_type) {
|
||||
// From tests on A770 Linux, may need more tuning
|
||||
case GGML_TYPE_Q4_0:
|
||||
@@ -12529,15 +12603,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
}
|
||||
}
|
||||
|
||||
vk_context compute_ctx;
|
||||
|
||||
if (ctx->compute_ctx.expired()) {
|
||||
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->compute_ctx = compute_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
||||
} else {
|
||||
compute_ctx = ctx->compute_ctx.lock();
|
||||
}
|
||||
vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
||||
|
||||
{
|
||||
// This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers
|
||||
@@ -13055,6 +13121,9 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
|
||||
ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
|
||||
|
||||
ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
|
||||
if (ctx->device->async_use_transfer_queue) {
|
||||
ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) {
|
||||
ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s });
|
||||
@@ -13116,6 +13185,11 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
|
||||
ctx->descriptor_sets.clear();
|
||||
|
||||
ctx->compute_cmd_pool.destroy(ctx->device->device);
|
||||
if (ctx->device->async_use_transfer_queue) {
|
||||
ctx->device->device.destroySemaphore(ctx->transfer_semaphore.s);
|
||||
|
||||
ctx->transfer_cmd_pool.destroy(ctx->device->device);
|
||||
}
|
||||
if (vk_perf_logger_enabled) {
|
||||
ctx->perf_logger->print_timings(true);
|
||||
}
|
||||
@@ -13387,34 +13461,38 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
|
||||
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
|
||||
|
||||
vk_context compute_ctx;
|
||||
vk_context cpy_ctx;
|
||||
|
||||
if (ctx->compute_ctx.expired()) {
|
||||
// Initialize new transfer context
|
||||
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->compute_ctx = compute_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
||||
if (ctx->device->async_use_transfer_queue) {
|
||||
if (ctx->transfer_ctx.expired()) {
|
||||
// Initialize new transfer context
|
||||
cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool);
|
||||
ctx->transfer_ctx = cpy_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, cpy_ctx);
|
||||
} else {
|
||||
cpy_ctx = ctx->transfer_ctx.lock();
|
||||
}
|
||||
} else {
|
||||
compute_ctx = ctx->compute_ctx.lock();
|
||||
cpy_ctx = ggml_vk_get_compute_ctx(ctx);
|
||||
}
|
||||
|
||||
vk_buffer buf = buf_ctx->dev_buffer;
|
||||
|
||||
auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
|
||||
|
||||
bool ret = ggml_vk_buffer_write_async(compute_ctx, buf, dst_offset, data, size);
|
||||
bool ret = ggml_vk_buffer_write_async(cpy_ctx, buf, dst_offset, data, size);
|
||||
|
||||
if (!ret) {
|
||||
ggml_vk_ensure_sync_staging_buffer(ctx, size);
|
||||
ggml_vk_sync_buffers(nullptr, compute_ctx);
|
||||
ggml_vk_sync_buffers(nullptr, cpy_ctx);
|
||||
|
||||
vk::BufferCopy buffer_cpy;
|
||||
buffer_cpy.srcOffset = 0;
|
||||
buffer_cpy.dstOffset = dst_offset;
|
||||
buffer_cpy.size = size;
|
||||
|
||||
compute_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
|
||||
deferred_memcpy(ctx->sync_staging->ptr, data, size, &compute_ctx->in_memcpys);
|
||||
cpy_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
|
||||
deferred_memcpy(ctx->sync_staging->ptr, data, size, &cpy_ctx->in_memcpys);
|
||||
ggml_vk_synchronize(ctx);
|
||||
}
|
||||
}
|
||||
@@ -13426,16 +13504,7 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_
|
||||
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
|
||||
|
||||
vk_context compute_ctx;
|
||||
|
||||
if (ctx->compute_ctx.expired()) {
|
||||
// Initialize new transfer context
|
||||
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->compute_ctx = compute_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
||||
} else {
|
||||
compute_ctx = ctx->compute_ctx.lock();
|
||||
}
|
||||
vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
||||
|
||||
vk_buffer buf = buf_ctx->dev_buffer;
|
||||
|
||||
@@ -13458,31 +13527,60 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()");
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||
if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) {
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend_dst->context;
|
||||
|
||||
if (dst->buffer->buft != ggml_backend_vk_get_default_buffer_type(backend_dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
|
||||
|
||||
if (ggml_backend_buffer_is_vk(src->buffer)) {
|
||||
ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
|
||||
vk_context compute_ctx;
|
||||
|
||||
if (ctx->compute_ctx.expired()) {
|
||||
// Initialize new transfer context
|
||||
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->compute_ctx = compute_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
||||
} else {
|
||||
compute_ctx = ctx->compute_ctx.lock();
|
||||
// Async copy only works within the same device
|
||||
if (src_buf_ctx->dev_buffer->device != dst_buf->device) {
|
||||
return false;
|
||||
}
|
||||
|
||||
vk_buffer src_buf = src_buf_ctx->dev_buffer;
|
||||
vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
|
||||
vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
||||
|
||||
ggml_vk_buffer_copy_async(compute_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
|
||||
ggml_vk_buffer_copy_async(compute_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs,
|
||||
src_buf_ctx->dev_buffer, vk_tensor_offset(src) + src->view_offs,
|
||||
ggml_nbytes(src));
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ggml_backend_buffer_is_host(src->buffer)) {
|
||||
vk_buffer pinned_buf = nullptr;
|
||||
size_t pinned_offset = 0;
|
||||
ggml_vk_host_get(ctx->device, src->data, pinned_buf, pinned_offset);
|
||||
if (pinned_buf == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
vk_context cpy_ctx;
|
||||
if (ctx->device->async_use_transfer_queue) {
|
||||
if (ctx->transfer_ctx.expired()) {
|
||||
cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool);
|
||||
ctx->transfer_ctx = cpy_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, cpy_ctx);
|
||||
} else {
|
||||
cpy_ctx = ctx->transfer_ctx.lock();
|
||||
}
|
||||
} else {
|
||||
cpy_ctx = ggml_vk_get_compute_ctx(ctx);
|
||||
}
|
||||
|
||||
return ggml_vk_buffer_write_async(cpy_ctx, dst_buf,
|
||||
vk_tensor_offset(dst) + dst->view_offs,
|
||||
src->data, ggml_nbytes(src));
|
||||
}
|
||||
|
||||
GGML_UNUSED(backend_src);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -13491,6 +13589,10 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) {
|
||||
|
||||
bool do_transfer = !ctx->compute_ctx.expired();
|
||||
|
||||
if (ggml_vk_submit_transfer_ctx(ctx)) {
|
||||
ctx->submit_pending = true;
|
||||
}
|
||||
|
||||
vk_context compute_ctx;
|
||||
if (do_transfer) {
|
||||
compute_ctx = ctx->compute_ctx.lock();
|
||||
@@ -13506,7 +13608,22 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) {
|
||||
}
|
||||
|
||||
if (ctx->submit_pending) {
|
||||
{
|
||||
if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) {
|
||||
vk::TimelineSemaphoreSubmitInfo tl_info{
|
||||
1, &ctx->transfer_semaphore.value,
|
||||
0, nullptr,
|
||||
};
|
||||
vk::PipelineStageFlags stage = ctx->device->transfer_queue.stage_flags;
|
||||
vk::SubmitInfo si{
|
||||
1, &ctx->transfer_semaphore.s, &stage,
|
||||
0, nullptr,
|
||||
0, nullptr,
|
||||
};
|
||||
si.setPNext(&tl_info);
|
||||
std::lock_guard<std::mutex> guard(queue_mutex);
|
||||
ctx->device->compute_queue.queue.submit({ si }, ctx->fence);
|
||||
ctx->transfer_semaphore_last_submitted = ctx->transfer_semaphore.value;
|
||||
} else {
|
||||
std::lock_guard<std::mutex> guard(queue_mutex);
|
||||
ctx->device->compute_queue.queue.submit({}, ctx->fence);
|
||||
}
|
||||
@@ -13972,6 +14089,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
bool first_node_in_batch = true; // true if next node will be first node in a batch
|
||||
int submit_node_idx = 0; // index to first node in a batch
|
||||
|
||||
ggml_vk_submit_transfer_ctx(ctx);
|
||||
|
||||
vk_context compute_ctx;
|
||||
if (vk_perf_logger_enabled) {
|
||||
// allocate/resize the query pool
|
||||
@@ -13997,9 +14116,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
std::fill(ctx->query_node_idx.begin(), ctx->query_node_idx.end(), 0);
|
||||
|
||||
GGML_ASSERT(ctx->compute_ctx.expired());
|
||||
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->compute_ctx = compute_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
||||
compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
||||
ctx->query_idx = 0;
|
||||
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
|
||||
}
|
||||
@@ -14009,13 +14126,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
|
||||
if (ctx->prealloc_size_add_rms_partials) {
|
||||
ggml_vk_preallocate_buffers(ctx, nullptr);
|
||||
if (ctx->compute_ctx.expired()) {
|
||||
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->compute_ctx = compute_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
||||
} else {
|
||||
compute_ctx = ctx->compute_ctx.lock();
|
||||
}
|
||||
compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
||||
// initialize partial sums to zero.
|
||||
ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials);
|
||||
ggml_vk_sync_buffers(ctx, compute_ctx);
|
||||
@@ -14238,13 +14349,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit);
|
||||
|
||||
if (vk_perf_logger_enabled && enqueued) {
|
||||
if (ctx->compute_ctx.expired()) {
|
||||
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->compute_ctx = compute_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
||||
} else {
|
||||
compute_ctx = ctx->compute_ctx.lock();
|
||||
}
|
||||
compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
||||
if (!vk_perf_logger_concurrent) {
|
||||
// track a single node/fusion for the current query
|
||||
ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i];
|
||||
@@ -14579,16 +14684,9 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||
vk_event *vkev = (vk_event *)event->context;
|
||||
|
||||
vk_context compute_ctx;
|
||||
ggml_vk_submit_transfer_ctx(ctx);
|
||||
|
||||
if (ctx->compute_ctx.expired()) {
|
||||
// Initialize new transfer context
|
||||
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->compute_ctx = compute_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
||||
} else {
|
||||
compute_ctx = ctx->compute_ctx.lock();
|
||||
}
|
||||
vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
||||
|
||||
// the backend interface doesn't have an explicit reset, so reset it here
|
||||
// before we record the command to set it
|
||||
@@ -14609,16 +14707,7 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||
vk_event *vkev = (vk_event *)event->context;
|
||||
|
||||
vk_context compute_ctx;
|
||||
|
||||
if (ctx->compute_ctx.expired()) {
|
||||
// Initialize new transfer context
|
||||
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->compute_ctx = compute_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
||||
} else {
|
||||
compute_ctx = ctx->compute_ctx.lock();
|
||||
}
|
||||
vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
||||
|
||||
ggml_vk_wait_events(compute_ctx, {vkev->event});
|
||||
ggml_vk_ctx_end(compute_ctx);
|
||||
@@ -14631,7 +14720,7 @@ static ggml_backend_i ggml_backend_vk_interface = {
|
||||
/* .free = */ ggml_backend_vk_free,
|
||||
/* .set_tensor_async = */ ggml_backend_vk_set_tensor_async,
|
||||
/* .get_tensor_async = */ ggml_backend_vk_get_tensor_async,
|
||||
/* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async,
|
||||
/* .cpy_tensor_async = */ ggml_backend_vk_cpy_tensor_async,
|
||||
/* .synchronize = */ ggml_backend_vk_synchronize,
|
||||
/* .graph_plan_create = */ NULL,
|
||||
/* .graph_plan_free = */ NULL,
|
||||
@@ -15367,11 +15456,25 @@ static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_ba
|
||||
return buft_ctx->device->idx == ctx->device;
|
||||
}
|
||||
|
||||
static int64_t ggml_vk_get_op_batch_size(const ggml_tensor * op) {
|
||||
switch (op->op) {
|
||||
case GGML_OP_GET_ROWS:
|
||||
return 0;
|
||||
case GGML_OP_MUL_MAT:
|
||||
return op->ne[1];
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK:
|
||||
return op->ne[2];
|
||||
default:
|
||||
return ggml_nrows(op);
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||
ggml_backend_vk_device_context * dev_ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||
|
||||
return (op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS) ||
|
||||
(op->ne[2] >= dev_ctx->op_offload_min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
|
||||
return ggml_vk_get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size;
|
||||
}
|
||||
|
||||
static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t dev) {
|
||||
|
||||
@@ -68,6 +68,7 @@ struct ggml_webgpu_shader_lib_context {
|
||||
size_t wg_mem_limit_bytes = 0;
|
||||
bool inplace = false;
|
||||
bool overlap = false;
|
||||
bool src_overlap = false;
|
||||
bool supports_subgroup_matrix = false;
|
||||
uint32_t sg_mat_m = 0;
|
||||
uint32_t sg_mat_n = 0;
|
||||
@@ -172,6 +173,22 @@ struct ggml_webgpu_scale_pipeline_key_hash {
|
||||
}
|
||||
};
|
||||
|
||||
/** Concat **/
|
||||
|
||||
struct ggml_webgpu_concat_pipeline_key {
|
||||
int type;
|
||||
|
||||
bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { return type == other.type; }
|
||||
};
|
||||
|
||||
struct ggml_webgpu_concat_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** Binary **/
|
||||
|
||||
struct ggml_webgpu_binary_pipeline_key {
|
||||
@@ -179,9 +196,10 @@ struct ggml_webgpu_binary_pipeline_key {
|
||||
int op;
|
||||
bool inplace;
|
||||
bool overlap;
|
||||
bool src_overlap;
|
||||
|
||||
bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
|
||||
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap;
|
||||
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -192,6 +210,7 @@ struct ggml_webgpu_binary_pipeline_key_hash {
|
||||
ggml_webgpu_hash_combine(seed, key.op);
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
ggml_webgpu_hash_combine(seed, key.overlap);
|
||||
ggml_webgpu_hash_combine(seed, key.src_overlap);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
@@ -400,6 +419,8 @@ class ggml_webgpu_shader_lib {
|
||||
pad_pipelines; // circular/non-circular
|
||||
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
|
||||
binary_pipelines; // type/op/inplace/overlap
|
||||
std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
|
||||
concat_pipelines; // type
|
||||
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
||||
flash_attn_pipelines;
|
||||
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
|
||||
@@ -1044,6 +1065,7 @@ class ggml_webgpu_shader_lib {
|
||||
.op = context.dst->op,
|
||||
.inplace = context.inplace,
|
||||
.overlap = context.overlap,
|
||||
.src_overlap = context.src_overlap,
|
||||
};
|
||||
|
||||
auto it = binary_pipelines.find(key);
|
||||
@@ -1076,6 +1098,9 @@ class ggml_webgpu_shader_lib {
|
||||
} else if (key.overlap) {
|
||||
defines.push_back("OVERLAP");
|
||||
variant += "_overlap";
|
||||
} else if (key.src_overlap) {
|
||||
defines.push_back("SRC_OVERLAP");
|
||||
variant += "_src_overlap";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
@@ -1089,6 +1114,43 @@ class ggml_webgpu_shader_lib {
|
||||
return binary_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_concat_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
};
|
||||
|
||||
auto it = concat_pipelines.find(key);
|
||||
if (it != concat_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "concat";
|
||||
|
||||
switch (key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("TYPE_F32");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_I32:
|
||||
defines.push_back("TYPE_I32");
|
||||
variant += "_i32";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for concat shader");
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_concat, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
concat_pipelines[key] = pipeline;
|
||||
return concat_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
const bool has_mask = context.src3 != nullptr;
|
||||
const bool has_sinks = context.src4 != nullptr;
|
||||
|
||||
@@ -31,6 +31,13 @@
|
||||
#define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
|
||||
#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))
|
||||
|
||||
// Return a rectangular grid of workgroups with minimal over-provisioned workgroups.
|
||||
// Assumes that the total number of workgroups does not exceed max_per_dim^2.
|
||||
static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim, uint32_t & wg_x, uint32_t & wg_y) {
|
||||
wg_y = std::max(1u, CEIL_DIV(total_wg, max_per_dim));
|
||||
wg_x = CEIL_DIV(total_wg, wg_y);
|
||||
}
|
||||
|
||||
#ifdef GGML_WEBGPU_DEBUG
|
||||
# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
|
||||
# define WEBGPU_DEBUG_BUF_ELEMS 512
|
||||
@@ -69,8 +76,8 @@
|
||||
|
||||
/* Constants */
|
||||
|
||||
#define WEBGPU_NUM_PARAM_BUFS 16u
|
||||
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u
|
||||
#define WEBGPU_NUM_PARAM_BUFS 48u
|
||||
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16u
|
||||
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
|
||||
// Maximum number of in-flight submissions per-thread, to avoid exhausting the
|
||||
// parameter buffer pool
|
||||
@@ -116,11 +123,6 @@ struct webgpu_pool_bufs {
|
||||
wgpu::Buffer dev_buf;
|
||||
};
|
||||
|
||||
// The futures to wait on for a single queue submission
|
||||
struct webgpu_submission_futures {
|
||||
std::vector<wgpu::FutureWaitInfo> futures;
|
||||
};
|
||||
|
||||
// Holds a pool of parameter buffers for WebGPU operations
|
||||
struct webgpu_buf_pool {
|
||||
std::vector<webgpu_pool_bufs> free;
|
||||
@@ -133,12 +135,28 @@ struct webgpu_buf_pool {
|
||||
// which can run on a different thread than the calling thread.
|
||||
std::mutex mutex;
|
||||
std::condition_variable cv;
|
||||
size_t cur_pool_size;
|
||||
size_t max_pool_size;
|
||||
wgpu::Device device;
|
||||
wgpu::BufferUsage host_buf_usage;
|
||||
wgpu::BufferUsage dev_buf_usage;
|
||||
size_t buf_size;
|
||||
bool should_grow;
|
||||
|
||||
void init(wgpu::Device device,
|
||||
int num_bufs,
|
||||
size_t buf_size,
|
||||
wgpu::BufferUsage dev_buf_usage,
|
||||
wgpu::BufferUsage host_buf_usage) {
|
||||
wgpu::BufferUsage host_buf_usage,
|
||||
bool should_grow = false,
|
||||
size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) {
|
||||
this->max_pool_size = max_pool_size;
|
||||
this->cur_pool_size = num_bufs;
|
||||
this->device = device;
|
||||
this->host_buf_usage = host_buf_usage;
|
||||
this->dev_buf_usage = dev_buf_usage;
|
||||
this->buf_size = buf_size;
|
||||
this->should_grow = should_grow;
|
||||
for (int i = 0; i < num_bufs; i++) {
|
||||
wgpu::Buffer host_buf;
|
||||
wgpu::Buffer dev_buf;
|
||||
@@ -150,6 +168,25 @@ struct webgpu_buf_pool {
|
||||
|
||||
webgpu_pool_bufs alloc_bufs() {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
if (!free.empty()) {
|
||||
webgpu_pool_bufs bufs = free.back();
|
||||
free.pop_back();
|
||||
return bufs;
|
||||
}
|
||||
|
||||
// Try growing the pool if no free buffers
|
||||
if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
|
||||
cur_pool_size++;
|
||||
wgpu::Buffer host_buf;
|
||||
wgpu::Buffer dev_buf;
|
||||
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
|
||||
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
|
||||
|
||||
if (!(host_buf && dev_buf)) {
|
||||
GGML_ABORT("webgpu_buf_pool: failed to allocate buffers");
|
||||
}
|
||||
return webgpu_pool_bufs{ host_buf, dev_buf };
|
||||
}
|
||||
cv.wait(lock, [this] { return !free.empty(); });
|
||||
webgpu_pool_bufs bufs = free.back();
|
||||
free.pop_back();
|
||||
@@ -243,6 +280,7 @@ struct webgpu_gpu_profile_buf_pool {
|
||||
#endif
|
||||
|
||||
struct webgpu_command {
|
||||
uint32_t num_kernels;
|
||||
wgpu::CommandBuffer commands;
|
||||
std::vector<webgpu_pool_bufs> params_bufs;
|
||||
std::optional<webgpu_pool_bufs> set_rows_error_bufs;
|
||||
@@ -280,7 +318,6 @@ struct webgpu_global_context_struct {
|
||||
|
||||
webgpu_buf_pool memset_buf_pool;
|
||||
std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
|
||||
std::atomic_uint inflight_threads = 0;
|
||||
|
||||
#ifdef GGML_WEBGPU_CPU_PROFILE
|
||||
// Profiling: labeled CPU time in ms (total)
|
||||
@@ -421,30 +458,60 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
||||
/** End WebGPU object initializations */
|
||||
|
||||
/** WebGPU Actions */
|
||||
static void erase_completed(std::vector<wgpu::FutureWaitInfo> & futures) {
|
||||
futures.erase(std::remove_if(futures.begin(), futures.end(),
|
||||
[](const wgpu::FutureWaitInfo & info) { return info.completed; }),
|
||||
futures.end());
|
||||
}
|
||||
|
||||
// Wait for the queue to finish processing all submitted work
|
||||
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
|
||||
std::vector<webgpu_submission_futures> & futures,
|
||||
bool block = true) {
|
||||
// If we have too many in-flight submissions, wait on the oldest one first. If
|
||||
// there are many threads, inflight_max may be 0, meaning that we must wait on
|
||||
// all futures.
|
||||
uint64_t timeout_ms = block ? UINT64_MAX : 0;
|
||||
uint32_t inflight_threads = ctx->inflight_threads;
|
||||
uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u);
|
||||
while (futures.size() >= inflight_max && futures.size() > 0) {
|
||||
ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX);
|
||||
futures.erase(futures.begin());
|
||||
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
|
||||
std::vector<wgpu::FutureWaitInfo> & futures,
|
||||
bool block = true) {
|
||||
// If we have too many in-flight submissions, wait on the oldest one first.
|
||||
if (futures.empty()) {
|
||||
return;
|
||||
}
|
||||
size_t i = 0;
|
||||
while (i < futures.size()) {
|
||||
auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms);
|
||||
uint64_t timeout_ms = block ? UINT64_MAX : 0;
|
||||
while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
|
||||
auto waitStatus = ctx->instance.WaitAny(1, &futures[0], UINT64_MAX);
|
||||
if (waitStatus == wgpu::WaitStatus::Error) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
|
||||
}
|
||||
if (futures[0].completed) {
|
||||
futures.erase(futures.begin());
|
||||
}
|
||||
}
|
||||
|
||||
if (futures.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (block) {
|
||||
while (!futures.empty()) {
|
||||
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
|
||||
switch (waitStatus) {
|
||||
case wgpu::WaitStatus::Success:
|
||||
// WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
|
||||
erase_completed(futures);
|
||||
break;
|
||||
case wgpu::WaitStatus::Error:
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
|
||||
break;
|
||||
default:
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Poll once and return
|
||||
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
|
||||
switch (waitStatus) {
|
||||
case wgpu::WaitStatus::Success:
|
||||
futures.erase(futures.begin() + i);
|
||||
// WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
|
||||
erase_completed(futures);
|
||||
break;
|
||||
case wgpu::WaitStatus::TimedOut:
|
||||
i++;
|
||||
break;
|
||||
case wgpu::WaitStatus::Error:
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
|
||||
@@ -487,10 +554,11 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
|
||||
}
|
||||
#endif
|
||||
|
||||
static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_context ctx,
|
||||
std::vector<webgpu_command> commands,
|
||||
webgpu_buf_pool & param_buf_pool,
|
||||
webgpu_buf_pool * set_rows_error_buf_pool = nullptr) {
|
||||
static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
|
||||
webgpu_global_context ctx,
|
||||
std::vector<webgpu_command> commands,
|
||||
webgpu_buf_pool & param_buf_pool,
|
||||
webgpu_buf_pool * set_rows_error_buf_pool = nullptr) {
|
||||
std::vector<wgpu::CommandBuffer> command_buffers;
|
||||
std::vector<webgpu_pool_bufs> params_bufs;
|
||||
std::vector<webgpu_pool_bufs> set_rows_error_bufs;
|
||||
@@ -562,7 +630,7 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_contex
|
||||
futures.push_back({ f });
|
||||
}
|
||||
#endif
|
||||
return { futures };
|
||||
return futures;
|
||||
}
|
||||
|
||||
static webgpu_command ggml_backend_webgpu_build_multi(
|
||||
@@ -651,6 +719,7 @@ static webgpu_command ggml_backend_webgpu_build_multi(
|
||||
result.commands = commands;
|
||||
result.params_bufs = params_bufs_list;
|
||||
result.set_rows_error_bufs = set_rows_error_bufs;
|
||||
result.num_kernels = pipelines.size();
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
result.timestamp_query_bufs = ts_bufs;
|
||||
// TODO: handle multiple pipeline names
|
||||
@@ -688,8 +757,7 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
|
||||
|
||||
webgpu_command command =
|
||||
ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
|
||||
std::vector<webgpu_submission_futures> futures = { ggml_backend_webgpu_submit(ctx, { command },
|
||||
ctx->memset_buf_pool) };
|
||||
auto futures = ggml_backend_webgpu_submit(ctx, { command }, ctx->memset_buf_pool);
|
||||
ggml_backend_webgpu_wait(ctx, futures);
|
||||
}
|
||||
|
||||
@@ -788,6 +856,7 @@ static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
|
||||
struct binary_overlap_flags {
|
||||
bool inplace; // src0 == dst
|
||||
bool overlap; // src1 == dst
|
||||
bool src_overlap;
|
||||
};
|
||||
|
||||
static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,
|
||||
@@ -796,6 +865,7 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0
|
||||
binary_overlap_flags flags = {};
|
||||
flags.inplace = ggml_webgpu_tensor_equal(src0, dst);
|
||||
flags.overlap = ggml_webgpu_tensor_overlap(src1, dst);
|
||||
flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1);
|
||||
|
||||
return flags;
|
||||
}
|
||||
@@ -1112,8 +1182,9 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
};
|
||||
|
||||
// Calculate workgroup dimensions
|
||||
uint32_t wg_x = 1;
|
||||
uint32_t wg_y = 1;
|
||||
uint32_t wg_x = 1;
|
||||
uint32_t wg_y = 1;
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
|
||||
if (use_fast && is_vec) {
|
||||
auto decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
|
||||
@@ -1121,9 +1192,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
uint32_t batches = dst->ne[2] * dst->ne[3];
|
||||
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
|
||||
uint32_t total_wg = output_groups * batches;
|
||||
// TODO: split large sizes into multiple batches to avoid way over-provisioning workgroups
|
||||
wg_x = std::min(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
|
||||
wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
} else if (use_fast) {
|
||||
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
@@ -1142,12 +1211,14 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
|
||||
wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
|
||||
}
|
||||
wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
|
||||
uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3];
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
|
||||
} else { // legacy
|
||||
auto decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
uint32_t wg_size = decisions->wg_size;
|
||||
wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
|
||||
wg_y = 1;
|
||||
uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
}
|
||||
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
|
||||
@@ -1353,6 +1424,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
.inplace = flags.inplace,
|
||||
.overlap = flags.overlap,
|
||||
.src_overlap = flags.src_overlap,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
|
||||
@@ -1361,11 +1433,28 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
|
||||
size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0);
|
||||
size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1);
|
||||
|
||||
uint32_t offset_merged_src0 = 0;
|
||||
uint32_t offset_merged_src1 = 0;
|
||||
if (flags.src_overlap) {
|
||||
size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
|
||||
offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
|
||||
offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
|
||||
}
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
ne,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
offset_merged_src0,
|
||||
offset_merged_src1,
|
||||
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
@@ -1381,31 +1470,111 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries;
|
||||
|
||||
entries.push_back({
|
||||
.binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0),
|
||||
});
|
||||
|
||||
entries.push_back({
|
||||
.binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1),
|
||||
});
|
||||
|
||||
if (!flags.inplace && !flags.overlap) {
|
||||
entries.push_back({ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
if (flags.src_overlap) {
|
||||
size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
|
||||
size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0),
|
||||
src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1));
|
||||
entries.push_back({
|
||||
.binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = merged_offset,
|
||||
.size = merged_end - merged_offset,
|
||||
});
|
||||
entries.push_back({
|
||||
.binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst),
|
||||
});
|
||||
} else {
|
||||
entries.push_back({
|
||||
.binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = src0_webgpu_tensor_align_offset,
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0),
|
||||
});
|
||||
entries.push_back({
|
||||
.binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = src1_webgpu_tensor_align_offset,
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1),
|
||||
});
|
||||
if (!flags.inplace && !flags.overlap) {
|
||||
entries.push_back({
|
||||
.binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
uint32_t dim = (uint32_t) dst->op_params[0];
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
ne,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
(uint32_t) dst->ne[0],
|
||||
(uint32_t) dst->ne[1],
|
||||
(uint32_t) dst->ne[2],
|
||||
(uint32_t) dst->ne[3],
|
||||
dim,
|
||||
(uint32_t)src0->ne[dim]
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{
|
||||
.binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0)
|
||||
},
|
||||
{
|
||||
.binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1)
|
||||
},
|
||||
{
|
||||
.binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst)
|
||||
}
|
||||
};
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src0,
|
||||
.src1 = src1,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
int inplace = ggml_webgpu_tensor_equal(src, dst);
|
||||
|
||||
@@ -1990,6 +2159,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
return ggml_webgpu_binary_op(ctx, src0, src1, node);
|
||||
case GGML_OP_CONCAT:
|
||||
return ggml_webgpu_concat(ctx, src0, src1, node);
|
||||
case GGML_OP_RMS_NORM:
|
||||
return ggml_webgpu_rms_norm(ctx, src0, node);
|
||||
case GGML_OP_ROPE:
|
||||
@@ -2043,21 +2214,20 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
|
||||
|
||||
WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
|
||||
|
||||
ctx->global_ctx->inflight_threads++;
|
||||
|
||||
std::vector<webgpu_command> commands;
|
||||
std::vector<webgpu_submission_futures> futures;
|
||||
std::vector<webgpu_command> commands;
|
||||
std::vector<wgpu::FutureWaitInfo> futures;
|
||||
uint32_t num_batched_kernels = 0;
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
|
||||
commands.push_back(*cmd);
|
||||
num_batched_kernels += cmd.value().num_kernels;
|
||||
}
|
||||
// compute the batch size based on the number of inflight threads
|
||||
uint32_t inflight_threads = ctx->global_ctx->inflight_threads;
|
||||
uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)),
|
||||
WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
|
||||
if (commands.size() >= batch_size) {
|
||||
futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool,
|
||||
&ctx->set_rows_error_buf_pool));
|
||||
|
||||
if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
|
||||
num_batched_kernels = 0;
|
||||
std::vector<wgpu::FutureWaitInfo> compute_futures = ggml_backend_webgpu_submit(
|
||||
ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
|
||||
futures.insert(futures.end(), compute_futures.begin(), compute_futures.end());
|
||||
// Process events and check for completed submissions
|
||||
ctx->global_ctx->instance.ProcessEvents();
|
||||
ggml_backend_webgpu_wait(ctx->global_ctx, futures, false);
|
||||
@@ -2065,13 +2235,12 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
|
||||
}
|
||||
}
|
||||
if (!commands.empty()) {
|
||||
webgpu_submission_futures new_futures =
|
||||
auto new_futures =
|
||||
ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
|
||||
futures.push_back(new_futures);
|
||||
futures.insert(futures.end(), new_futures.begin(), new_futures.end());
|
||||
}
|
||||
|
||||
ggml_backend_webgpu_wait(ctx->global_ctx, futures);
|
||||
ctx->global_ctx->inflight_threads--;
|
||||
WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
@@ -2689,7 +2858,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
|
||||
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
|
||||
webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
|
||||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
|
||||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true);
|
||||
webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
|
||||
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
|
||||
@@ -2816,10 +2985,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
// TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE
|
||||
// see https://github.com/ggml-org/llama.cpp/pull/16857
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
|
||||
(src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
|
||||
(src1->type == op->type);
|
||||
break;
|
||||
case GGML_OP_CONCAT:
|
||||
supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
|
||||
break;
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
|
||||
@@ -7,6 +7,13 @@ struct Params {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
offset_merged_src0: u32,
|
||||
offset_merged_src1: u32,
|
||||
|
||||
stride_src0_0: u32,
|
||||
stride_src0_1: u32,
|
||||
stride_src0_2: u32,
|
||||
stride_src0_3: u32,
|
||||
|
||||
stride_src1_0: u32,
|
||||
stride_src1_1: u32,
|
||||
@@ -23,6 +30,21 @@ struct Params {
|
||||
b_ne3: u32,
|
||||
};
|
||||
|
||||
fn src0_index(_i: u32) -> u32 {
|
||||
var i = _i;
|
||||
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
|
||||
i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
|
||||
let a_i2 = i / (params.a_ne1 * params.a_ne0);
|
||||
i = i % (params.a_ne1 * params.a_ne0);
|
||||
let a_i1 = i / params.a_ne0;
|
||||
let a_i0 = i % params.a_ne0;
|
||||
|
||||
return a_i0 * params.stride_src0_0 +
|
||||
a_i1 * params.stride_src0_1 +
|
||||
a_i2 * params.stride_src0_2 +
|
||||
a_i3 * params.stride_src0_3;
|
||||
}
|
||||
|
||||
fn src1_index(_i: u32) -> u32 {
|
||||
var i = _i;
|
||||
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
|
||||
@@ -53,17 +75,22 @@ fn src1_index(_i: u32) -> u32 {
|
||||
#define DataType f16
|
||||
#endif
|
||||
|
||||
#ifdef SRC_OVERLAP
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> merged_src: array<DataType>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<DataType>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
#else
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<DataType>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1 : array<DataType>;
|
||||
|
||||
#ifdef INPLACE
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#elif defined(OVERLAP)
|
||||
#if defined(INPLACE) || defined(OVERLAP)
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
@@ -74,6 +101,7 @@ var<storage, read_write> dst: array<DataType>;
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
#endif
|
||||
#endif
|
||||
|
||||
fn op(a: DataType, b: DataType) -> DataType {
|
||||
#ifdef OP_ADD
|
||||
@@ -87,13 +115,17 @@ fn op(a: DataType, b: DataType) -> DataType {
|
||||
#endif
|
||||
}
|
||||
|
||||
fn update(dst_i: u32, src0_i: u32, src1_i: u32){
|
||||
fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
|
||||
#ifdef SRC_OVERLAP
|
||||
let result = op(merged_src[src0_i], merged_src[src1_i]);
|
||||
#else
|
||||
let result = op(src0[src0_i], src1[src1_i]);
|
||||
#endif
|
||||
|
||||
#ifdef INPLACE
|
||||
src0[dst_i] = result;
|
||||
src0[src0_i] = result;
|
||||
#elif defined(OVERLAP)
|
||||
src1[dst_i] = result;
|
||||
src1[src1_i] = result;
|
||||
#else
|
||||
dst[dst_i] = result;
|
||||
#endif
|
||||
@@ -102,6 +134,8 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32){
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x < params.ne) {
|
||||
update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
|
||||
let src0_i = params.offset_src0 + params.offset_merged_src0 + src0_index(gid.x);
|
||||
let src1_i = params.offset_src1 + params.offset_merged_src1 + src1_index(gid.x);
|
||||
update(params.offset_dst + gid.x, src0_i, src1_i);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
struct Params {
|
||||
ne: u32,
|
||||
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
stride_src0_0: u32,
|
||||
stride_src0_1: u32,
|
||||
stride_src0_2: u32,
|
||||
stride_src0_3: u32,
|
||||
|
||||
stride_src1_0: u32,
|
||||
stride_src1_1: u32,
|
||||
stride_src1_2: u32,
|
||||
stride_src1_3: u32,
|
||||
|
||||
ne0: u32,
|
||||
ne1: u32,
|
||||
ne2: u32,
|
||||
ne3: u32,
|
||||
|
||||
dim: u32,
|
||||
src0_nedim: u32
|
||||
};
|
||||
|
||||
#ifdef TYPE_F32
|
||||
#define DataType f32
|
||||
#endif
|
||||
#ifdef TYPE_I32
|
||||
#define DataType i32
|
||||
#endif
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<DataType>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1 : array<DataType>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<DataType>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
|
||||
if (gid.x < params.ne) {
|
||||
var i = gid.x;
|
||||
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
||||
i = i % (params.ne2 * params.ne1 * params.ne0);
|
||||
let i2 = i / (params.ne1 * params.ne0);
|
||||
i = i % (params.ne1 * params.ne0);
|
||||
let i1 = i / params.ne0;
|
||||
let i0 = i % params.ne0;
|
||||
|
||||
var ni = array<u32, 4>(i0, i1, i2, i3);
|
||||
|
||||
if (ni[params.dim] < params.src0_nedim) {
|
||||
let src_i = ni[0] * params.stride_src0_0 +
|
||||
ni[1] * params.stride_src0_1 +
|
||||
ni[2] * params.stride_src0_2 +
|
||||
ni[3] * params.stride_src0_3;
|
||||
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + src_i];
|
||||
} else {
|
||||
ni[params.dim] -= params.src0_nedim;
|
||||
let src_i = ni[0] * params.stride_src1_0 +
|
||||
ni[1] * params.stride_src1_1 +
|
||||
ni[2] * params.stride_src1_2 +
|
||||
ni[3] * params.stride_src1_3;
|
||||
dst[params.offset_dst + gid.x] = src1[params.offset_src1 + src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -679,19 +679,24 @@ struct MulMatParams {
|
||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||
|
||||
@compute @workgroup_size(256)
|
||||
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
fn main(@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>) {
|
||||
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
|
||||
let global_idx = wg_linear * 256u + local_id.x;
|
||||
|
||||
let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
|
||||
if (global_id.x >= total) {
|
||||
if (global_idx >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
let dst2_stride = params.m * params.n;
|
||||
let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
|
||||
|
||||
let dst3_idx = global_id.x / dst3_stride;
|
||||
let dst3_idx = global_idx / dst3_stride;
|
||||
let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
|
||||
let src13_idx = dst3_idx; // src1 is not broadcast
|
||||
let dst3_rem = global_id.x % dst3_stride;
|
||||
let dst3_rem = global_idx % dst3_stride;
|
||||
|
||||
let dst2_idx = dst3_rem / dst2_stride;
|
||||
let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
|
||||
|
||||
@@ -54,7 +54,8 @@ var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
|
||||
|
||||
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>) {
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>) {
|
||||
|
||||
let thread_id = local_id.x;
|
||||
let local_m = get_local_m(thread_id);
|
||||
@@ -64,9 +65,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M);
|
||||
let wg_per_matrix = wg_m_count * wg_n_count;
|
||||
|
||||
let batch_idx = wg_id.x / wg_per_matrix;
|
||||
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
|
||||
|
||||
let wg_in_batch = wg_id.x % wg_per_matrix;
|
||||
let batch_idx = wg_linear / wg_per_matrix;
|
||||
|
||||
let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
|
||||
if (batch_idx >= total_batches) {
|
||||
return;
|
||||
}
|
||||
|
||||
let wg_in_batch = wg_linear % wg_per_matrix;
|
||||
let wg_m = wg_in_batch % wg_m_count;
|
||||
let wg_n = wg_in_batch / wg_m_count;
|
||||
|
||||
|
||||
@@ -69,7 +69,8 @@ var<workgroup> shmem: array<f16, SHMEM_SIZE>;
|
||||
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(subgroup_id) subgroup_id: u32) {
|
||||
@builtin(subgroup_id) subgroup_id: u32,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>) {
|
||||
|
||||
let thread_id = local_id.x;
|
||||
let subgroup_m = subgroup_id % SUBGROUP_M;
|
||||
@@ -79,9 +80,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE;
|
||||
let wg_per_matrix = wg_m_count * wg_n_count;
|
||||
|
||||
let batch_idx = wg_id.x / wg_per_matrix;
|
||||
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
|
||||
|
||||
let wg_in_batch = wg_id.x % wg_per_matrix;
|
||||
let batch_idx = wg_linear / wg_per_matrix;
|
||||
|
||||
let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
|
||||
if (batch_idx >= total_batches) {
|
||||
return;
|
||||
}
|
||||
|
||||
let wg_in_batch = wg_linear % wg_per_matrix;
|
||||
let wg_m = wg_in_batch % wg_m_count;
|
||||
let wg_n = wg_in_batch / wg_m_count;
|
||||
|
||||
|
||||
+7
-9
@@ -1410,16 +1410,14 @@ static bool ggml_is_contiguous_n(const struct ggml_tensor * tensor, int n) {
|
||||
}
|
||||
next_nb *= tensor->ne[0]/ggml_blck_size(tensor->type);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
if (tensor->ne[i] != 1) {
|
||||
if (i > n) {
|
||||
if (tensor->nb[i] != next_nb) {
|
||||
return false;
|
||||
}
|
||||
next_nb *= tensor->ne[i];
|
||||
} else {
|
||||
// this dimension does not need to be contiguous
|
||||
next_nb = tensor->ne[i]*tensor->nb[i];
|
||||
if (i > n) {
|
||||
if (tensor->ne[i] != 1 && tensor->nb[i] != next_nb) {
|
||||
return false;
|
||||
}
|
||||
next_nb *= tensor->ne[i];
|
||||
} else {
|
||||
// this dimension does not need to be contiguous
|
||||
next_nb = tensor->ne[i]*tensor->nb[i];
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
||||
@@ -1,11 +1,43 @@
|
||||
#!/usr/bin/env bash
|
||||
#!/bin/sh
|
||||
# vim: set ts=4 sw=4 et:
|
||||
|
||||
wget https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
|
||||
unzip wikitext-2-raw-v1.zip
|
||||
ZIP="wikitext-2-raw-v1.zip"
|
||||
FILE="wikitext-2-raw/wiki.test.raw"
|
||||
URL="https://huggingface.co/datasets/ggml-org/ci/resolve/main/$ZIP"
|
||||
|
||||
echo "Usage:"
|
||||
echo ""
|
||||
echo " ./llama-perplexity -m model.gguf -f wikitext-2-raw/wiki.test.raw [other params]"
|
||||
echo ""
|
||||
die() {
|
||||
printf "%s\n" "$@" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
exit 0
|
||||
have_cmd() {
|
||||
for cmd; do
|
||||
command -v "$cmd" >/dev/null || return
|
||||
done
|
||||
}
|
||||
|
||||
dl() {
|
||||
[ -f "$2" ] && return
|
||||
if have_cmd wget; then
|
||||
wget "$1" -O "$2"
|
||||
elif have_cmd curl; then
|
||||
curl -L "$1" -o "$2"
|
||||
else
|
||||
die "Please install wget or curl"
|
||||
fi
|
||||
}
|
||||
|
||||
have_cmd unzip || die "Please install unzip"
|
||||
|
||||
if [ ! -f "$FILE" ]; then
|
||||
dl "$URL" "$ZIP" || exit
|
||||
unzip -o "$ZIP" || exit
|
||||
rm -f -- "$ZIP"
|
||||
fi
|
||||
|
||||
cat <<EOF
|
||||
Usage:
|
||||
|
||||
llama-perplexity -m model.gguf -f $FILE [other params]
|
||||
|
||||
EOF
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
HTTPLIB_VERSION = "refs/tags/v0.34.0"
|
||||
HTTPLIB_VERSION = "refs/tags/v0.35.0"
|
||||
|
||||
vendor = {
|
||||
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
|
||||
@@ -14,8 +14,8 @@ vendor = {
|
||||
"https://raw.githubusercontent.com/nothings/stb/refs/heads/master/stb_image.h": "vendor/stb/stb_image.h",
|
||||
|
||||
# not using latest tag to avoid this issue: https://github.com/ggml-org/llama.cpp/pull/17179#discussion_r2515877926
|
||||
# "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h",
|
||||
"https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h",
|
||||
# "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.24/miniaudio.h": "vendor/miniaudio/miniaudio.h",
|
||||
"https://github.com/mackron/miniaudio/raw/13d161bc8d856ad61ae46b798bbeffc0f49808e8/miniaudio.h": "vendor/miniaudio/miniaudio.h",
|
||||
|
||||
f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/httplib.h": "httplib.h",
|
||||
f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/split.py": "split.py",
|
||||
|
||||
+2
-2
@@ -100,9 +100,9 @@ std::string format(const char * fmt, ...) {
|
||||
|
||||
std::string llama_format_tensor_shape(const std::vector<int64_t> & ne) {
|
||||
char buf[256];
|
||||
snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0));
|
||||
snprintf(buf, sizeof(buf), "%6" PRId64, ne.at(0));
|
||||
for (size_t i = 1; i < ne.size(); i++) {
|
||||
snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i));
|
||||
snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %6" PRId64, ne.at(i));
|
||||
}
|
||||
return buf;
|
||||
}
|
||||
|
||||
@@ -257,6 +257,21 @@ set(LLAMA_TEST_NAME test-mtmd-c-api)
|
||||
llama_build_and_test(test-mtmd-c-api.c)
|
||||
target_link_libraries(${LLAMA_TEST_NAME} PRIVATE mtmd)
|
||||
|
||||
# GGUF model data fetcher library for tests that need real model metadata
|
||||
# Only compile when cpp-httplib has SSL support (CPPHTTPLIB_OPENSSL_SUPPORT)
|
||||
if (TARGET cpp-httplib)
|
||||
get_target_property(_cpp_httplib_defs cpp-httplib INTERFACE_COMPILE_DEFINITIONS)
|
||||
if (_cpp_httplib_defs MATCHES "CPPHTTPLIB_OPENSSL_SUPPORT")
|
||||
add_library(gguf-model-data STATIC gguf-model-data.cpp)
|
||||
target_link_libraries(gguf-model-data PRIVATE common cpp-httplib)
|
||||
target_include_directories(gguf-model-data PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
add_executable(test-gguf-model-data test-gguf-model-data.cpp)
|
||||
target_link_libraries(test-gguf-model-data PRIVATE gguf-model-data common)
|
||||
llama_test(test-gguf-model-data LABEL "model")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# dummy executable - not installed
|
||||
get_filename_component(TEST_TARGET test-c.c NAME_WE)
|
||||
add_executable(${TEST_TARGET} test-c.c)
|
||||
|
||||
@@ -0,0 +1,613 @@
|
||||
// GGUF binary parser adapted from the huggingface/gguf package.
|
||||
// Reference: https://github.com/huggingface/huggingface.js
|
||||
|
||||
#include "gguf-model-data.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "gguf.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
|
||||
#include "http.h"
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
// Equivalent of RangeView
|
||||
struct gguf_buf_reader {
|
||||
const char * data;
|
||||
size_t size;
|
||||
size_t pos;
|
||||
|
||||
gguf_buf_reader(const std::vector<char> & buf) : data(buf.data()), size(buf.size()), pos(0) {}
|
||||
|
||||
bool has_n_bytes(size_t n) const {
|
||||
return pos + n <= size;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool read_val(T & out) {
|
||||
if (!has_n_bytes(sizeof(T))) {
|
||||
return false;
|
||||
}
|
||||
memcpy(&out, data + pos, sizeof(T));
|
||||
pos += sizeof(T);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool read_str(std::string & out) {
|
||||
uint64_t len;
|
||||
if (!read_val(len)) {
|
||||
return false;
|
||||
}
|
||||
if (!has_n_bytes((size_t)len)) {
|
||||
return false;
|
||||
}
|
||||
out.assign(data + pos, (size_t)len);
|
||||
pos += (size_t)len;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool skip(size_t n) {
|
||||
if (!has_n_bytes(n)) {
|
||||
return false;
|
||||
}
|
||||
pos += n;
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
static size_t gguf_val_type_size(int32_t vtype) {
|
||||
switch (vtype) {
|
||||
case GGUF_TYPE_UINT8: return 1;
|
||||
case GGUF_TYPE_INT8: return 1;
|
||||
case GGUF_TYPE_UINT16: return 2;
|
||||
case GGUF_TYPE_INT16: return 2;
|
||||
case GGUF_TYPE_UINT32: return 4;
|
||||
case GGUF_TYPE_INT32: return 4;
|
||||
case GGUF_TYPE_FLOAT32: return 4;
|
||||
case GGUF_TYPE_BOOL: return 1;
|
||||
case GGUF_TYPE_UINT64: return 8;
|
||||
case GGUF_TYPE_INT64: return 8;
|
||||
case GGUF_TYPE_FLOAT64: return 8;
|
||||
default: return 0; // string/array handled separately
|
||||
}
|
||||
}
|
||||
|
||||
// Equivalent of readMetadataValue(), skips unused values rather than storing
|
||||
static bool gguf_skip_value(gguf_buf_reader & r, int32_t vtype) {
|
||||
if (vtype == GGUF_TYPE_STRING) {
|
||||
std::string tmp;
|
||||
return r.read_str(tmp);
|
||||
}
|
||||
if (vtype == GGUF_TYPE_ARRAY) {
|
||||
int32_t elem_type;
|
||||
uint64_t count;
|
||||
if (!r.read_val(elem_type)) {
|
||||
return false;
|
||||
}
|
||||
if (!r.read_val(count)) {
|
||||
return false;
|
||||
}
|
||||
if (elem_type == GGUF_TYPE_STRING) {
|
||||
for (uint64_t i = 0; i < count; i++) {
|
||||
std::string tmp;
|
||||
if (!r.read_str(tmp)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (elem_type == GGUF_TYPE_ARRAY) {
|
||||
// nested arrays - recurse
|
||||
for (uint64_t i = 0; i < count; i++) {
|
||||
if (!gguf_skip_value(r, GGUF_TYPE_ARRAY)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
size_t elem_sz = gguf_val_type_size(elem_type);
|
||||
if (elem_sz == 0) {
|
||||
return false;
|
||||
}
|
||||
return r.skip((size_t)count * elem_sz);
|
||||
}
|
||||
size_t sz = gguf_val_type_size(vtype);
|
||||
if (sz == 0) {
|
||||
return false;
|
||||
}
|
||||
return r.skip(sz);
|
||||
}
|
||||
|
||||
static bool gguf_read_uint32_val(gguf_buf_reader & r, int32_t vtype, uint32_t & out) {
|
||||
if (vtype == GGUF_TYPE_UINT8) {
|
||||
uint8_t v;
|
||||
if (!r.read_val(v)) {
|
||||
return false;
|
||||
}
|
||||
out = v;
|
||||
return true;
|
||||
}
|
||||
if (vtype == GGUF_TYPE_INT8) {
|
||||
int8_t v;
|
||||
if (!r.read_val(v)) {
|
||||
return false;
|
||||
}
|
||||
out = (uint32_t)v;
|
||||
return true;
|
||||
}
|
||||
if (vtype == GGUF_TYPE_UINT16) {
|
||||
uint16_t v;
|
||||
if (!r.read_val(v)) {
|
||||
return false;
|
||||
}
|
||||
out = v;
|
||||
return true;
|
||||
}
|
||||
if (vtype == GGUF_TYPE_INT16) {
|
||||
int16_t v;
|
||||
if (!r.read_val(v)) {
|
||||
return false;
|
||||
}
|
||||
out = (uint32_t)v;
|
||||
return true;
|
||||
}
|
||||
if (vtype == GGUF_TYPE_UINT32) {
|
||||
uint32_t v;
|
||||
if (!r.read_val(v)) {
|
||||
return false;
|
||||
}
|
||||
out = v;
|
||||
return true;
|
||||
}
|
||||
if (vtype == GGUF_TYPE_INT32) {
|
||||
int32_t v;
|
||||
if (!r.read_val(v)) {
|
||||
return false;
|
||||
}
|
||||
out = (uint32_t)v;
|
||||
return true;
|
||||
}
|
||||
if (vtype == GGUF_TYPE_UINT64) {
|
||||
uint64_t v;
|
||||
if (!r.read_val(v)) {
|
||||
return false;
|
||||
}
|
||||
out = (uint32_t)v;
|
||||
return true;
|
||||
}
|
||||
if (vtype == GGUF_TYPE_INT64) {
|
||||
int64_t v;
|
||||
if (!r.read_val(v)) {
|
||||
return false;
|
||||
}
|
||||
out = (uint32_t)v;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Follows the same header -> KV -> tensor parsing sequence as gguf() huggingface/gguf
|
||||
static std::optional<gguf_remote_model> gguf_parse_meta(const std::vector<char> & buf) {
|
||||
gguf_buf_reader r(buf);
|
||||
|
||||
// Header: magic(4) + version(4) + tensor_count(8) + kv_count(8) = 24 bytes minimum
|
||||
uint32_t magic_raw;
|
||||
if (!r.read_val(magic_raw)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (memcmp(&magic_raw, "GGUF", 4) != 0) {
|
||||
fprintf(stderr, "gguf_parse_meta: invalid magic\n");
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
uint32_t version;
|
||||
if (!r.read_val(version)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (version < 2 || version > 3) {
|
||||
fprintf(stderr, "gguf_parse_meta: unsupported version %u\n", version);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
int64_t tensor_count_raw;
|
||||
int64_t kv_count_raw;
|
||||
if (!r.read_val(tensor_count_raw)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (!r.read_val(kv_count_raw)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
uint64_t tensor_count = (uint64_t)tensor_count_raw;
|
||||
uint64_t kv_count = (uint64_t)kv_count_raw;
|
||||
|
||||
gguf_remote_model model;
|
||||
|
||||
std::string arch_prefix;
|
||||
|
||||
// Parse KV pairs
|
||||
for (uint64_t i = 0; i < kv_count; i++) {
|
||||
std::string key;
|
||||
if (!r.read_str(key)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
int32_t vtype;
|
||||
if (!r.read_val(vtype)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (key == "general.architecture" && vtype == GGUF_TYPE_STRING) {
|
||||
if (!r.read_str(model.architecture)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
arch_prefix = model.architecture + ".";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Extract split.count for proper handling of split files
|
||||
if (key == "split.count") {
|
||||
uint32_t v;
|
||||
if (!gguf_read_uint32_val(r, vtype, v)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
model.n_split = (uint16_t)v;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Extract split.tensors.count so we can verify we have all tensors
|
||||
if (key == "split.tensors.count") {
|
||||
uint32_t v;
|
||||
if (!gguf_read_uint32_val(r, vtype, v)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
model.n_split_tensors = v;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!arch_prefix.empty()) {
|
||||
uint32_t * target = nullptr;
|
||||
|
||||
if (key == arch_prefix + "embedding_length") { target = &model.n_embd; }
|
||||
else if (key == arch_prefix + "feed_forward_length") { target = &model.n_ff; }
|
||||
else if (key == arch_prefix + "block_count") { target = &model.n_layer; }
|
||||
else if (key == arch_prefix + "attention.head_count") { target = &model.n_head; }
|
||||
else if (key == arch_prefix + "attention.head_count_kv") { target = &model.n_head_kv; }
|
||||
else if (key == arch_prefix + "expert_count") { target = &model.n_expert; }
|
||||
else if (key == arch_prefix + "attention.key_length") { target = &model.n_embd_head_k; }
|
||||
else if (key == arch_prefix + "attention.value_length") { target = &model.n_embd_head_v; }
|
||||
|
||||
if (target) {
|
||||
if (!gguf_read_uint32_val(r, vtype, *target)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (!gguf_skip_value(r, vtype)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
// Parse tensor info entries
|
||||
model.tensors.reserve((size_t)tensor_count);
|
||||
for (uint64_t i = 0; i < tensor_count; i++) {
|
||||
gguf_remote_tensor t;
|
||||
|
||||
if (!r.read_str(t.name)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (!r.read_val(t.n_dims)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (t.n_dims > 4) {
|
||||
fprintf(stderr, "gguf_parse_meta: tensor '%s' has %u dims (max 4)\n", t.name.c_str(), t.n_dims);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
for (uint32_t d = 0; d < t.n_dims; d++) {
|
||||
if (!r.read_val(t.ne[d])) {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t type_raw;
|
||||
if (!r.read_val(type_raw)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
t.type = (ggml_type)type_raw;
|
||||
|
||||
uint64_t offset;
|
||||
if (!r.read_val(offset)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Infer n_vocab from token_embd.weight
|
||||
if (t.name == "token_embd.weight") {
|
||||
model.n_vocab = (uint32_t)t.ne[1];
|
||||
}
|
||||
|
||||
model.tensors.push_back(std::move(t));
|
||||
}
|
||||
|
||||
return model;
|
||||
}
|
||||
|
||||
// cache handling for local download
|
||||
static std::string get_default_cache_dir() {
|
||||
return fs_get_cache_directory() + "gguf-headers/";
|
||||
}
|
||||
|
||||
static std::string sanitize_for_path(const std::string & s) {
|
||||
std::string out = s;
|
||||
for (char & c : out) {
|
||||
if (c == '/' || c == '\\' || c == ':') {
|
||||
c = '_';
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
static bool read_file(const std::string & path, std::vector<char> & out) {
|
||||
std::ifstream f(path, std::ios::binary | std::ios::ate);
|
||||
if (!f.good()) {
|
||||
return false;
|
||||
}
|
||||
auto sz = f.tellg();
|
||||
if (sz <= 0) {
|
||||
return false;
|
||||
}
|
||||
out.resize((size_t)sz);
|
||||
f.seekg(0);
|
||||
f.read(out.data(), sz);
|
||||
return f.good();
|
||||
}
|
||||
|
||||
static bool write_file(const std::string & path, const std::vector<char> & data) {
|
||||
std::ofstream f(path, std::ios::binary | std::ios::trunc);
|
||||
if (!f.good()) {
|
||||
return false;
|
||||
}
|
||||
f.write(data.data(), (std::streamsize)data.size());
|
||||
return f.good();
|
||||
}
|
||||
|
||||
// HuggingFace file auto-detection and HTTP download
|
||||
static std::pair<long, std::vector<char>> gguf_http_get(
|
||||
const std::string & url,
|
||||
const httplib::Headers & headers = {},
|
||||
int timeout_sec = 60) {
|
||||
try {
|
||||
auto [cli, parts] = common_http_client(url);
|
||||
|
||||
if (timeout_sec > 0) {
|
||||
cli.set_read_timeout(timeout_sec, 0);
|
||||
cli.set_write_timeout(timeout_sec, 0);
|
||||
}
|
||||
cli.set_connection_timeout(30, 0);
|
||||
|
||||
std::vector<char> body;
|
||||
auto res = cli.Get(parts.path, headers,
|
||||
[&](const char * data, size_t len) {
|
||||
body.insert(body.end(), data, data + len);
|
||||
return true;
|
||||
}, nullptr);
|
||||
|
||||
if (!res) {
|
||||
fprintf(stderr, "gguf_fetch: HTTP request failed for %s (error %d)\n",
|
||||
url.c_str(), (int)res.error());
|
||||
return {-1, {}};
|
||||
}
|
||||
return {res->status, std::move(body)};
|
||||
} catch (const std::exception & e) {
|
||||
fprintf(stderr, "gguf_fetch: HTTP error: %s\n", e.what());
|
||||
return {-1, {}};
|
||||
}
|
||||
}
|
||||
|
||||
// Find the filename for given repo/quant.
|
||||
// For split models, returns the first shard (the one containing "00001-of-")
|
||||
// split_prefix is set to the portion before "-00001-of-XXXXX.gguf" when a split file is found
|
||||
static std::string detect_gguf_filename(const std::string & repo, const std::string & quant,
|
||||
std::string & split_prefix) {
|
||||
split_prefix.clear();
|
||||
std::string api_url = "https://huggingface.co/api/models/" + repo;
|
||||
|
||||
auto [code, body] = gguf_http_get(api_url, {}, 30);
|
||||
if (code != 200 || body.empty()) {
|
||||
fprintf(stderr, "gguf_fetch: failed to query HF API for %s (HTTP %ld)\n", repo.c_str(), code);
|
||||
return "";
|
||||
}
|
||||
|
||||
nlohmann::json j;
|
||||
try {
|
||||
j = nlohmann::json::parse(body.begin(), body.end());
|
||||
} catch (...) {
|
||||
fprintf(stderr, "gguf_fetch: failed to parse HF API response\n");
|
||||
return "";
|
||||
}
|
||||
|
||||
if (!j.contains("siblings") || !j["siblings"].is_array()) {
|
||||
fprintf(stderr, "gguf_fetch: unexpected HF API response format\n");
|
||||
return "";
|
||||
}
|
||||
|
||||
std::vector<std::string> matches;
|
||||
std::string quant_upper = quant;
|
||||
for (char & c : quant_upper) { c = (char)toupper(c); }
|
||||
|
||||
for (const auto & sibling : j["siblings"]) {
|
||||
if (!sibling.contains("rfilename")) { continue; }
|
||||
std::string fname = sibling["rfilename"].get<std::string>();
|
||||
if (fname.size() < 5 || fname.substr(fname.size() - 5) != ".gguf") {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string fname_upper = fname;
|
||||
for (char & c : fname_upper) { c = (char)toupper(c); }
|
||||
if (fname_upper.find(quant_upper) != std::string::npos) {
|
||||
matches.push_back(fname);
|
||||
}
|
||||
}
|
||||
|
||||
if (matches.empty()) {
|
||||
fprintf(stderr, "gguf_fetch: no .gguf files matching '%s' in %s\n", quant.c_str(), repo.c_str());
|
||||
return "";
|
||||
}
|
||||
|
||||
std::sort(matches.begin(), matches.end());
|
||||
|
||||
// Prefer non-split, non-supplementary file
|
||||
for (const auto & m : matches) {
|
||||
if (m.find("-of-") == std::string::npos && m.find("mmproj") == std::string::npos) {
|
||||
return m;
|
||||
}
|
||||
}
|
||||
|
||||
// Return the first shard (00001-of-) and extract the prefix
|
||||
for (const auto & m : matches) {
|
||||
auto pos = m.find("-00001-of-");
|
||||
if (pos != std::string::npos) {
|
||||
split_prefix = m.substr(0, pos);
|
||||
return m;
|
||||
}
|
||||
}
|
||||
|
||||
return matches[0];
|
||||
}
|
||||
|
||||
static std::optional<gguf_remote_model> fetch_and_parse(
|
||||
const std::string & repo,
|
||||
const std::string & filename,
|
||||
const std::string & cache_path) {
|
||||
std::string url = "https://huggingface.co/" + repo + "/resolve/main/" + filename;
|
||||
|
||||
// Progressive download inspired by RangeView.fetchChunk()
|
||||
// Start at 2MB, double each time, cap at 64MB
|
||||
size_t chunk_size = 2 * 1024 * 1024;
|
||||
const size_t max_chunk = 64 * 1024 * 1024;
|
||||
|
||||
while (chunk_size <= max_chunk) {
|
||||
fprintf(stderr, "gguf_fetch: downloading %zu bytes from %s\n", chunk_size, filename.c_str());
|
||||
|
||||
char range_buf[64];
|
||||
snprintf(range_buf, sizeof(range_buf), "bytes=0-%zu", chunk_size - 1);
|
||||
httplib::Headers headers = {{"Range", range_buf}};
|
||||
|
||||
auto [code, body] = gguf_http_get(url, headers, 120);
|
||||
if (code != 200 && code != 206) {
|
||||
fprintf(stderr, "gguf_fetch: HTTP %ld fetching %s\n", code, url.c_str());
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (body.empty()) {
|
||||
fprintf(stderr, "gguf_fetch: empty response\n");
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto result = gguf_parse_meta(body);
|
||||
if (result.has_value()) {
|
||||
write_file(cache_path, body);
|
||||
return result;
|
||||
}
|
||||
|
||||
if (code == 200) {
|
||||
fprintf(stderr, "gguf_fetch: server returned full response but metadata parse failed\n");
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Parse failed, try larger chunk
|
||||
chunk_size *= 2;
|
||||
}
|
||||
|
||||
fprintf(stderr, "gguf_fetch: metadata exceeds 64MB, giving up\n");
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Try cache first, then fetch and parse a single GGUF shard.
|
||||
static std::optional<gguf_remote_model> fetch_or_cached(
|
||||
const std::string & repo,
|
||||
const std::string & filename,
|
||||
const std::string & cdir,
|
||||
const std::string & repo_part) {
|
||||
std::string fname_part = sanitize_for_path(filename);
|
||||
std::string cache_path = cdir + "/" + repo_part + "--" + fname_part + ".partial";
|
||||
|
||||
{
|
||||
std::vector<char> cached;
|
||||
if (std::filesystem::exists(cache_path) && read_file(cache_path, cached)) {
|
||||
auto result = gguf_parse_meta(cached);
|
||||
if (result.has_value()) {
|
||||
fprintf(stderr, "gguf_fetch: loaded from cache: %s\n", cache_path.c_str());
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fs_create_directory_with_parents(cdir);
|
||||
return fetch_and_parse(repo, filename, cache_path);
|
||||
}
|
||||
|
||||
std::optional<gguf_remote_model> gguf_fetch_model_meta(
|
||||
const std::string & repo,
|
||||
const std::string & quant,
|
||||
const std::string & cache_dir) {
|
||||
std::string cdir = cache_dir.empty() ? get_default_cache_dir() : cache_dir;
|
||||
std::string repo_part = sanitize_for_path(repo);
|
||||
|
||||
std::string split_prefix;
|
||||
std::string filename = detect_gguf_filename(repo, quant, split_prefix);
|
||||
if (filename.empty()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part);
|
||||
if (!model_opt.has_value()) {
|
||||
fprintf(stderr, "gguf_fetch: failed to fetch %s\n", filename.c_str());
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto & model = model_opt.value();
|
||||
|
||||
// If the model is split across multiple files we need to fetch the remaining shards metadata
|
||||
if (model.n_split > 1) {
|
||||
if (split_prefix.empty()) {
|
||||
fprintf(stderr, "gguf_fetch: model reports %u splits but filename has no split pattern\n", model.n_split);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
|
||||
model.n_split, model.n_split - 1);
|
||||
|
||||
for (int i = 2; i <= model.n_split; i++) {
|
||||
char num_buf[6], total_buf[6];
|
||||
snprintf(num_buf, sizeof(num_buf), "%05d", i);
|
||||
snprintf(total_buf, sizeof(total_buf), "%05d", (int)model.n_split);
|
||||
std::string shard_name = split_prefix + "-" + num_buf + "-of-" + total_buf + ".gguf";
|
||||
|
||||
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part);
|
||||
if (!shard.has_value()) {
|
||||
fprintf(stderr, "gguf_fetch: failed to fetch shard %d: %s\n", i, shard_name.c_str());
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
model.tensors.insert(model.tensors.end(),
|
||||
std::make_move_iterator(shard->tensors.begin()),
|
||||
std::make_move_iterator(shard->tensors.end()));
|
||||
}
|
||||
|
||||
if (model.n_split_tensors > 0 && model.tensors.size() != model.n_split_tensors) {
|
||||
fprintf(stderr, "gguf_fetch: WARNING: expected %u tensors from split.tensors.count, got %zu\n",
|
||||
model.n_split_tensors, model.tensors.size());
|
||||
}
|
||||
}
|
||||
|
||||
return model_opt;
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
struct gguf_remote_tensor {
|
||||
std::string name;
|
||||
ggml_type type = GGML_TYPE_F32;
|
||||
int64_t ne[4] = {1, 1, 1, 1}; // dimensions, unused dims = 1
|
||||
uint32_t n_dims = 0;
|
||||
};
|
||||
|
||||
struct gguf_remote_model {
|
||||
// Selected KV metadata
|
||||
std::string architecture; // general.architecture
|
||||
uint32_t n_embd = 0; // <arch>.embedding_length
|
||||
uint32_t n_ff = 0; // <arch>.feed_forward_length
|
||||
uint32_t n_vocab = 0; // inferred from token_embd.weight ne[1]
|
||||
uint32_t n_layer = 0; // <arch>.block_count
|
||||
uint32_t n_head = 0; // <arch>.attention.head_count
|
||||
uint32_t n_head_kv = 0; // <arch>.attention.head_count_kv
|
||||
uint32_t n_expert = 0; // <arch>.expert_count (0 if absent)
|
||||
uint32_t n_embd_head_k = 0; // <arch>.attention.key_length
|
||||
uint32_t n_embd_head_v = 0; // <arch>.attention.value_length
|
||||
uint16_t n_split = 0; // split.count (0 = not split)
|
||||
uint32_t n_split_tensors = 0; // split.tensors.count (0 if not split)
|
||||
|
||||
std::vector<gguf_remote_tensor> tensors;
|
||||
};
|
||||
|
||||
// Fetch model metadata from HuggingFace with local caching.
|
||||
// repo: e.g., "ggml-org/Qwen3-32B-GGUF"
|
||||
// quant: e.g., "Q8_0" -- auto-detects filename (including first shard of split models)
|
||||
// Returns nullopt if download fails or network is unavailable.
|
||||
std::optional<gguf_remote_model> gguf_fetch_model_meta(
|
||||
const std::string & repo,
|
||||
const std::string & quant = "Q8_0",
|
||||
const std::string & cache_dir = ""); // empty = default
|
||||
@@ -2977,6 +2977,7 @@ struct test_bin_bcast : public test_case {
|
||||
const std::array<int, 4> nr;
|
||||
int nf; // number of fused ops, nf == 1 -> single op (no fusion)
|
||||
bool perm1; // permute src1?
|
||||
bool src_overlap; // src0 and src1 are overlapping views of the same buffer
|
||||
|
||||
bool run_whole_graph() override { return nf > 1; }
|
||||
|
||||
@@ -2992,8 +2993,8 @@ struct test_bin_bcast : public test_case {
|
||||
std::array<int64_t, 4> ne = {10, 10, 1, 1},
|
||||
std::array<int, 4> nr = {1, 2, 1, 1},
|
||||
int nf = 1,
|
||||
bool perm1 = false)
|
||||
: op(op), type(type), ne(ne), nr(nr), nf(nf), perm1(perm1) {}
|
||||
bool perm1 = false, bool src_overlap = false)
|
||||
: op(op), type(type), ne(ne), nr(nr), nf(nf), perm1(perm1), src_overlap(src_overlap) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
GGML_ASSERT(nf <= 16);
|
||||
@@ -3008,6 +3009,8 @@ struct test_bin_bcast : public test_case {
|
||||
|
||||
b[i] = ggml_new_tensor_4d(ctx, type, ne[p[0]], ne[p[1]], ne[p[2]], ne[p[3]]);
|
||||
b[i] = ggml_permute(ctx, b[i], p[0], p[1], p[2], p[3]);
|
||||
} else if (src_overlap) {
|
||||
b[i] = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], 2 * (ne[3] / 3), a->nb[1], a->nb[2], a->nb[3], (ne[3] / 3) * a->nb[3]);
|
||||
} else {
|
||||
b[i] = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
}
|
||||
@@ -3021,7 +3024,13 @@ struct test_bin_bcast : public test_case {
|
||||
ggml_set_param(b[0]);
|
||||
}
|
||||
|
||||
ggml_tensor * out = a;
|
||||
ggml_tensor *out;
|
||||
|
||||
if (src_overlap) {
|
||||
out = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], 2 * (ne[3] / 3), a->nb[1], a->nb[2], a->nb[3], 0);
|
||||
} else {
|
||||
out = a;
|
||||
}
|
||||
|
||||
for (int i = 0; i < nf; ++i) {
|
||||
out = op(ctx, out, b[i]);
|
||||
@@ -7527,9 +7536,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
}
|
||||
|
||||
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr, bool perm1 = false) {
|
||||
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr, bool perm1 = false, bool src_overlap = false) {
|
||||
for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
|
||||
test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr, 1, perm1));
|
||||
test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr, 1, perm1, src_overlap));
|
||||
}
|
||||
};
|
||||
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
||||
@@ -7549,6 +7558,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2}, perm1);
|
||||
}
|
||||
|
||||
// src_overlap
|
||||
add_test_bin_bcast(type, {10, 5, 4, 6}, {1, 1, 1, 1}, false, true);
|
||||
add_test_bin_bcast(type, {10, 5, 4, 5}, {1, 1, 1, 1}, false, true);
|
||||
add_test_bin_bcast(type, {1, 1, 120, 120}, {1, 1, 1, 1}, false, true);
|
||||
add_test_bin_bcast(type, {1, 1, 4, 320}, {1, 1, 1, 1}, false, true);
|
||||
|
||||
// test case for k_bin_bcast_unravel in CUDA backend
|
||||
add_test_bin_bcast(type, {1, 1, 65536, 1}, {256, 1, 1, 1});
|
||||
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
#include "gguf-model-data.h"
|
||||
|
||||
#include <cstdio>
|
||||
|
||||
#define TEST_ASSERT(cond, msg) \
|
||||
do { \
|
||||
if (!(cond)) { \
|
||||
fprintf(stderr, "FAIL: %s (line %d): %s\n", #cond, __LINE__, msg); \
|
||||
return 1; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
int main() {
|
||||
fprintf(stderr, "=== test-gguf-model-data ===\n");
|
||||
|
||||
// Fetch Qwen3-0.6B Q8_0 metadata
|
||||
auto result = gguf_fetch_model_meta("ggml-org/Qwen3-0.6B-GGUF", "Q8_0");
|
||||
|
||||
if (!result.has_value()) {
|
||||
fprintf(stderr, "SKIP: could not fetch model metadata (no network or HTTP disabled)\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
const auto & model = result.value();
|
||||
|
||||
fprintf(stderr, "Architecture: %s\n", model.architecture.c_str());
|
||||
fprintf(stderr, "n_embd: %u\n", model.n_embd);
|
||||
fprintf(stderr, "n_ff: %u\n", model.n_ff);
|
||||
fprintf(stderr, "n_vocab: %u\n", model.n_vocab);
|
||||
fprintf(stderr, "n_layer: %u\n", model.n_layer);
|
||||
fprintf(stderr, "n_head: %u\n", model.n_head);
|
||||
fprintf(stderr, "n_head_kv: %u\n", model.n_head_kv);
|
||||
fprintf(stderr, "n_expert: %u\n", model.n_expert);
|
||||
fprintf(stderr, "n_embd_head_k: %u\n", model.n_embd_head_k);
|
||||
fprintf(stderr, "n_embd_head_v: %u\n", model.n_embd_head_v);
|
||||
fprintf(stderr, "tensors: %zu\n", model.tensors.size());
|
||||
|
||||
// Verify architecture
|
||||
TEST_ASSERT(model.architecture == "qwen3", "expected architecture 'qwen3'");
|
||||
|
||||
// Verify key dimensions (Qwen3-0.6B)
|
||||
TEST_ASSERT(model.n_layer == 28, "expected n_layer == 28");
|
||||
TEST_ASSERT(model.n_embd == 1024, "expected n_embd == 1024");
|
||||
TEST_ASSERT(model.n_head == 16, "expected n_head == 16");
|
||||
TEST_ASSERT(model.n_head_kv == 8, "expected n_head_kv == 8");
|
||||
TEST_ASSERT(model.n_expert == 0, "expected n_expert == 0 (not MoE)");
|
||||
TEST_ASSERT(model.n_vocab == 151936, "expected n_vocab == 151936");
|
||||
|
||||
// Verify tensor count
|
||||
TEST_ASSERT(model.tensors.size() == 311, "expected tensor count == 311");
|
||||
|
||||
// Verify known tensor names exist
|
||||
bool found_attn_q = false;
|
||||
bool found_token_embd = false;
|
||||
bool found_output_norm = false;
|
||||
for (const auto & t : model.tensors) {
|
||||
if (t.name == "blk.0.attn_q.weight") {
|
||||
found_attn_q = true;
|
||||
}
|
||||
if (t.name == "token_embd.weight") {
|
||||
found_token_embd = true;
|
||||
}
|
||||
if (t.name == "output_norm.weight") {
|
||||
found_output_norm = true;
|
||||
}
|
||||
}
|
||||
TEST_ASSERT(found_attn_q, "expected tensor 'blk.0.attn_q.weight'");
|
||||
TEST_ASSERT(found_token_embd, "expected tensor 'token_embd.weight'");
|
||||
TEST_ASSERT(found_output_norm, "expected tensor 'output_norm.weight'");
|
||||
|
||||
// Verify token_embd.weight shape
|
||||
for (const auto & t : model.tensors) {
|
||||
if (t.name == "token_embd.weight") {
|
||||
TEST_ASSERT(t.ne[0] == 1024, "expected token_embd.weight ne[0] == 1024");
|
||||
TEST_ASSERT(t.n_dims == 2, "expected token_embd.weight to be 2D");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Test that second call uses cache (just call again, it should work)
|
||||
auto result2 = gguf_fetch_model_meta("ggml-org/Qwen3-0.6B-GGUF", "Q8_0");
|
||||
TEST_ASSERT(result2.has_value(), "cached fetch should succeed");
|
||||
TEST_ASSERT(result2->tensors.size() == model.tensors.size(), "cached result should match");
|
||||
|
||||
// Test a split MoE model without specifying quant (should default to Q8_0)
|
||||
auto result3 = gguf_fetch_model_meta("ggml-org/GLM-4.6V-GGUF");
|
||||
if (!result3.has_value()) {
|
||||
fprintf(stderr, "SKIP: could not fetch GLM-4.6V metadata (no network?)\n");
|
||||
return 0;
|
||||
}
|
||||
const auto & model3 = result3.value();
|
||||
|
||||
fprintf(stderr, "Architecture: %s\n", model3.architecture.c_str());
|
||||
fprintf(stderr, "n_embd: %u\n", model3.n_embd);
|
||||
fprintf(stderr, "n_ff: %u\n", model3.n_ff);
|
||||
fprintf(stderr, "n_vocab: %u\n", model3.n_vocab);
|
||||
fprintf(stderr, "n_layer: %u\n", model3.n_layer);
|
||||
fprintf(stderr, "n_head: %u\n", model3.n_head);
|
||||
fprintf(stderr, "n_head_kv: %u\n", model3.n_head_kv);
|
||||
fprintf(stderr, "n_expert: %u\n", model3.n_expert);
|
||||
fprintf(stderr, "n_embd_head_k: %u\n", model3.n_embd_head_k);
|
||||
fprintf(stderr, "n_embd_head_v: %u\n", model3.n_embd_head_v);
|
||||
fprintf(stderr, "tensors: %zu\n", model3.tensors.size());
|
||||
|
||||
// Verify architecture
|
||||
TEST_ASSERT(model3.architecture == "glm4moe", "expected architecture 'glm4moe'");
|
||||
|
||||
// Verify key dimensions (GLM-4.6V)
|
||||
TEST_ASSERT(model3.n_layer == 46, "expected n_layer == 46");
|
||||
TEST_ASSERT(model3.n_embd == 4096, "expected n_embd == 4096");
|
||||
TEST_ASSERT(model3.n_head == 96, "expected n_head == 96");
|
||||
TEST_ASSERT(model3.n_head_kv == 8, "expected n_head_kv == 8");
|
||||
TEST_ASSERT(model3.n_expert == 128, "expected n_expert == 128 (MoE)");
|
||||
TEST_ASSERT(model3.n_vocab == 151552, "expected n_vocab == 151552");
|
||||
|
||||
// Verify tensor count
|
||||
TEST_ASSERT(model3.tensors.size() == 780, "expected tensor count == 780");
|
||||
|
||||
fprintf(stderr, "=== ALL TESTS PASSED ===\n");
|
||||
return 0;
|
||||
}
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <clocale>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@@ -15,6 +16,8 @@ static void print_usage(int, char ** argv) {
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_BENCH, print_usage)) {
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "llama.h"
|
||||
#include "chat.h"
|
||||
|
||||
#include <clocale>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
@@ -84,6 +85,8 @@ static void sigint_handler(int signo) {
|
||||
#endif
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_params params;
|
||||
g_params = ¶ms;
|
||||
|
||||
@@ -376,7 +379,7 @@ int main(int argc, char ** argv) {
|
||||
// remove any "future" tokens that we might have inherited from the previous session
|
||||
if (session_tokens.size() > n_match) {
|
||||
if (!llama_memory_seq_rm(mem, -1, n_match, -1)) {
|
||||
LOG_WRN("%s: unable to resuse common prefix (for example, when the memory is recurrent)\n", __func__);
|
||||
LOG_WRN("%s: unable to reuse common prefix (for example, when the memory is recurrent)\n", __func__);
|
||||
llama_memory_clear(mem, true);
|
||||
session_tokens.clear();
|
||||
n_match = 0;
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
#include "pca.hpp"
|
||||
#include "mean.hpp"
|
||||
|
||||
#include <clocale>
|
||||
|
||||
#ifdef GGML_USE_CUDA
|
||||
#include "ggml-cuda.h"
|
||||
#endif
|
||||
@@ -392,6 +394,8 @@ static int prepare_entries(common_params & params, train_context & ctx_train) {
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_params params;
|
||||
|
||||
params.out_file = "control_vector.gguf";
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
|
||||
#include <clocale>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
@@ -411,6 +412,8 @@ static void print_usage(int, char ** argv) {
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_params params;
|
||||
|
||||
params.out_file = "ggml-lora-merged-f16.gguf";
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
#include <climits>
|
||||
#include <clocale>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <stdexcept>
|
||||
@@ -567,6 +568,8 @@ static void gguf_merge(const split_params & split_params) {
|
||||
}
|
||||
|
||||
int main(int argc, const char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
split_params params;
|
||||
split_params_parse(argc, argv, params);
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <clocale>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
@@ -1191,6 +1192,8 @@ static bool show_statistics(const common_params & params) {
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_params params;
|
||||
|
||||
params.out_file = "imatrix.gguf";
|
||||
|
||||
@@ -2034,8 +2034,9 @@ static std::unique_ptr<printer> create_printer(output_formats format) {
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
// try to set locale for unicode characters in markdown
|
||||
setlocale(LC_CTYPE, ".UTF-8");
|
||||
std::setlocale(LC_CTYPE, ".UTF-8");
|
||||
|
||||
#if !defined(NDEBUG)
|
||||
fprintf(stderr, "warning: asserts enabled, performance may be affected\n");
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
#include <clocale>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
std::string filename = "main";
|
||||
if (argc >= 1) {
|
||||
filename = argv[0];
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include <vector>
|
||||
#include <limits.h>
|
||||
#include <cinttypes>
|
||||
#include <clocale>
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
#include <signal.h>
|
||||
@@ -274,6 +275,8 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg) {
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
ggml_time_init();
|
||||
|
||||
common_params params;
|
||||
|
||||
@@ -3,10 +3,11 @@
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <clocale>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
@@ -2004,6 +2005,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_params params;
|
||||
|
||||
params.n_ctx = 512;
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
#include "llama.h"
|
||||
#include "gguf.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <clocale>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
@@ -485,6 +489,8 @@ static bool parse_layer_prune(const char * data, std::vector<int> & prune_layers
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
if (argc < 3) {
|
||||
usage(argv[0]);
|
||||
}
|
||||
|
||||
@@ -10,12 +10,15 @@
|
||||
# include <unistd.h>
|
||||
# include <sys/stat.h>
|
||||
#endif
|
||||
#include <string>
|
||||
#include <stdio.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <thread>
|
||||
#include <clocale>
|
||||
#include <codecvt>
|
||||
#include <filesystem>
|
||||
#include <regex>
|
||||
#include <stdio.h>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#if defined(__linux__)
|
||||
#include <sys/types.h>
|
||||
@@ -285,6 +288,8 @@ static std::vector<ggml_backend_dev_t> get_devices(const rpc_server_params & par
|
||||
}
|
||||
|
||||
int main(int argc, char * argv[]) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
ggml_backend_load_all();
|
||||
|
||||
rpc_server_params params;
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include "server-http.h"
|
||||
#include "server-task.h"
|
||||
#include "server-queue.h"
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "log.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <clocale>
|
||||
#include <exception>
|
||||
#include <signal.h>
|
||||
#include <thread> // for std::thread::hardware_concurrency
|
||||
@@ -67,6 +68,8 @@ static server_http_context::handler_t ex_wrapper(server_http_context::handler_t
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
// own arguments required by this example
|
||||
common_params params;
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
//#include "log.h" // TODO: start using log.h
|
||||
#include "llama.h"
|
||||
|
||||
#include <clocale>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
@@ -184,6 +185,8 @@ static void write_utf8_cstr_to_stdout(const char * str, bool & invalid_utf8) {
|
||||
}
|
||||
|
||||
int main(int raw_argc, char ** raw_argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
const std::vector<std::string> argv = ingest_args(raw_argc, raw_argv);
|
||||
const int argc = argv.size();
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <algorithm>
|
||||
#include <clocale>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <fstream>
|
||||
@@ -536,6 +537,8 @@ static std::string audio_data_from_speaker(json speaker, const outetts_version t
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_params params;
|
||||
|
||||
params.out_file = "output.wav";
|
||||
|
||||
Vendored
-1
@@ -171,7 +171,6 @@ endif()
|
||||
if (CPPHTTPLIB_OPENSSL_SUPPORT)
|
||||
target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_OPENSSL_SUPPORT) # used in server.cpp
|
||||
if (APPLE AND CMAKE_SYSTEM_NAME STREQUAL "Darwin")
|
||||
target_compile_definitions(${TARGET} PRIVATE CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
|
||||
find_library(CORE_FOUNDATION_FRAMEWORK CoreFoundation REQUIRED)
|
||||
find_library(SECURITY_FRAMEWORK Security REQUIRED)
|
||||
target_link_libraries(${TARGET} PUBLIC ${CORE_FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK})
|
||||
|
||||
Vendored
+112
-62
@@ -2571,10 +2571,46 @@ find_content_type(const std::string &path,
|
||||
}
|
||||
}
|
||||
|
||||
std::string
|
||||
extract_media_type(const std::string &content_type,
|
||||
std::map<std::string, std::string> *params = nullptr) {
|
||||
// Extract type/subtype from Content-Type value (RFC 2045)
|
||||
// e.g. "application/json; charset=utf-8" -> "application/json"
|
||||
auto media_type = content_type;
|
||||
auto semicolon_pos = media_type.find(';');
|
||||
if (semicolon_pos != std::string::npos) {
|
||||
auto param_str = media_type.substr(semicolon_pos + 1);
|
||||
media_type = media_type.substr(0, semicolon_pos);
|
||||
|
||||
if (params) {
|
||||
// Parse parameters: key=value pairs separated by ';'
|
||||
split(param_str.data(), param_str.data() + param_str.size(), ';',
|
||||
[&](const char *b, const char *e) {
|
||||
std::string key;
|
||||
std::string val;
|
||||
split(b, e, '=', [&](const char *b2, const char *e2) {
|
||||
if (key.empty()) {
|
||||
key.assign(b2, e2);
|
||||
} else {
|
||||
val.assign(b2, e2);
|
||||
}
|
||||
});
|
||||
if (!key.empty()) {
|
||||
params->emplace(trim_copy(key), trim_double_quotes_copy(val));
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Trim whitespace from media type
|
||||
return trim_copy(media_type);
|
||||
}
|
||||
|
||||
bool can_compress_content_type(const std::string &content_type) {
|
||||
using udl::operator""_t;
|
||||
|
||||
auto tag = str2tag(content_type);
|
||||
auto mime_type = extract_media_type(content_type);
|
||||
auto tag = str2tag(mime_type);
|
||||
|
||||
switch (tag) {
|
||||
case "image/svg+xml"_t:
|
||||
@@ -2586,7 +2622,7 @@ bool can_compress_content_type(const std::string &content_type) {
|
||||
|
||||
case "text/event-stream"_t: return false;
|
||||
|
||||
default: return !content_type.rfind("text/", 0);
|
||||
default: return !mime_type.rfind("text/", 0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3141,7 +3177,8 @@ bool is_chunked_transfer_encoding(const Headers &headers) {
|
||||
template <typename T, typename U>
|
||||
bool prepare_content_receiver(T &x, int &status,
|
||||
ContentReceiverWithProgress receiver,
|
||||
bool decompress, U callback) {
|
||||
bool decompress, size_t payload_max_length,
|
||||
bool &exceed_payload_max_length, U callback) {
|
||||
if (decompress) {
|
||||
std::string encoding = x.get_header_value("Content-Encoding");
|
||||
std::unique_ptr<decompressor> decompressor;
|
||||
@@ -3157,12 +3194,22 @@ bool prepare_content_receiver(T &x, int &status,
|
||||
|
||||
if (decompressor) {
|
||||
if (decompressor->is_valid()) {
|
||||
size_t decompressed_size = 0;
|
||||
ContentReceiverWithProgress out = [&](const char *buf, size_t n,
|
||||
size_t off, size_t len) {
|
||||
return decompressor->decompress(buf, n,
|
||||
[&](const char *buf2, size_t n2) {
|
||||
return receiver(buf2, n2, off, len);
|
||||
});
|
||||
return decompressor->decompress(
|
||||
buf, n, [&](const char *buf2, size_t n2) {
|
||||
// Guard against zip-bomb: check
|
||||
// decompressed size against limit.
|
||||
if (payload_max_length > 0 &&
|
||||
(decompressed_size >= payload_max_length ||
|
||||
n2 > payload_max_length - decompressed_size)) {
|
||||
exceed_payload_max_length = true;
|
||||
return false;
|
||||
}
|
||||
decompressed_size += n2;
|
||||
return receiver(buf2, n2, off, len);
|
||||
});
|
||||
};
|
||||
return callback(std::move(out));
|
||||
} else {
|
||||
@@ -3183,11 +3230,14 @@ template <typename T>
|
||||
bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status,
|
||||
DownloadProgress progress,
|
||||
ContentReceiverWithProgress receiver, bool decompress) {
|
||||
bool exceed_payload_max_length = false;
|
||||
return prepare_content_receiver(
|
||||
x, status, std::move(receiver), decompress,
|
||||
[&](const ContentReceiverWithProgress &out) {
|
||||
x, status, std::move(receiver), decompress, payload_max_length,
|
||||
exceed_payload_max_length, [&](const ContentReceiverWithProgress &out) {
|
||||
auto ret = true;
|
||||
auto exceed_payload_max_length = false;
|
||||
// Note: exceed_payload_max_length may also be set by the decompressor
|
||||
// wrapper in prepare_content_receiver when the decompressed payload
|
||||
// size exceeds the limit.
|
||||
|
||||
if (is_chunked_transfer_encoding(x.headers)) {
|
||||
auto result = read_content_chunked(strm, x, payload_max_length, out);
|
||||
@@ -3603,12 +3653,11 @@ std::string normalize_query_string(const std::string &query) {
|
||||
|
||||
bool parse_multipart_boundary(const std::string &content_type,
|
||||
std::string &boundary) {
|
||||
auto boundary_keyword = "boundary=";
|
||||
auto pos = content_type.find(boundary_keyword);
|
||||
if (pos == std::string::npos) { return false; }
|
||||
auto end = content_type.find(';', pos);
|
||||
auto beg = pos + strlen(boundary_keyword);
|
||||
boundary = trim_double_quotes_copy(content_type.substr(beg, end - beg));
|
||||
std::map<std::string, std::string> params;
|
||||
extract_media_type(content_type, ¶ms);
|
||||
auto it = params.find("boundary");
|
||||
if (it == params.end()) { return false; }
|
||||
boundary = it->second;
|
||||
return !boundary.empty();
|
||||
}
|
||||
|
||||
@@ -3776,11 +3825,7 @@ bool parse_accept_header(const std::string &s,
|
||||
}
|
||||
|
||||
// Remove additional parameters from media type
|
||||
auto param_pos = accept_entry.media_type.find(';');
|
||||
if (param_pos != std::string::npos) {
|
||||
accept_entry.media_type =
|
||||
trim_copy(accept_entry.media_type.substr(0, param_pos));
|
||||
}
|
||||
accept_entry.media_type = extract_media_type(accept_entry.media_type);
|
||||
|
||||
// Basic validation of media type format
|
||||
if (accept_entry.media_type.empty()) {
|
||||
@@ -5610,7 +5655,7 @@ size_t Request::get_param_value_count(const std::string &key) const {
|
||||
|
||||
bool Request::is_multipart_form_data() const {
|
||||
const auto &content_type = get_header_value("Content-Type");
|
||||
return !content_type.rfind("multipart/form-data", 0);
|
||||
return detail::extract_media_type(content_type) == "multipart/form-data";
|
||||
}
|
||||
|
||||
// Multipart FormData implementation
|
||||
@@ -7092,7 +7137,8 @@ bool Server::read_content(Stream &strm, Request &req, Response &res) {
|
||||
return true;
|
||||
})) {
|
||||
const auto &content_type = req.get_header_value("Content-Type");
|
||||
if (!content_type.find("application/x-www-form-urlencoded")) {
|
||||
if (detail::extract_media_type(content_type) ==
|
||||
"application/x-www-form-urlencoded") {
|
||||
if (req.body.size() > CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH) {
|
||||
res.status = StatusCode::PayloadTooLarge_413; // NOTE: should be 414?
|
||||
output_error_log(Error::ExceedMaxPayloadSize, &req);
|
||||
@@ -7479,45 +7525,63 @@ bool Server::routing(Request &req, Response &res, Stream &strm) {
|
||||
if (detail::expect_content(req)) {
|
||||
// Content reader handler
|
||||
{
|
||||
// Track whether the ContentReader was aborted due to the decompressed
|
||||
// payload exceeding `payload_max_length_`.
|
||||
// The user handler runs after the lambda returns, so we must restore the
|
||||
// 413 status if the handler overwrites it.
|
||||
bool content_reader_payload_too_large = false;
|
||||
|
||||
ContentReader reader(
|
||||
[&](ContentReceiver receiver) {
|
||||
auto result = read_content_with_content_receiver(
|
||||
strm, req, res, std::move(receiver), nullptr, nullptr);
|
||||
if (!result) { output_error_log(Error::Read, &req); }
|
||||
if (!result) {
|
||||
output_error_log(Error::Read, &req);
|
||||
if (res.status == StatusCode::PayloadTooLarge_413) {
|
||||
content_reader_payload_too_large = true;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
},
|
||||
[&](FormDataHeader header, ContentReceiver receiver) {
|
||||
auto result = read_content_with_content_receiver(
|
||||
strm, req, res, nullptr, std::move(header),
|
||||
std::move(receiver));
|
||||
if (!result) { output_error_log(Error::Read, &req); }
|
||||
if (!result) {
|
||||
output_error_log(Error::Read, &req);
|
||||
if (res.status == StatusCode::PayloadTooLarge_413) {
|
||||
content_reader_payload_too_large = true;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
});
|
||||
|
||||
bool dispatched = false;
|
||||
if (req.method == "POST") {
|
||||
if (dispatch_request_for_content_reader(
|
||||
req, res, std::move(reader),
|
||||
post_handlers_for_content_reader_)) {
|
||||
return true;
|
||||
}
|
||||
dispatched = dispatch_request_for_content_reader(
|
||||
req, res, std::move(reader), post_handlers_for_content_reader_);
|
||||
} else if (req.method == "PUT") {
|
||||
if (dispatch_request_for_content_reader(
|
||||
req, res, std::move(reader),
|
||||
put_handlers_for_content_reader_)) {
|
||||
return true;
|
||||
}
|
||||
dispatched = dispatch_request_for_content_reader(
|
||||
req, res, std::move(reader), put_handlers_for_content_reader_);
|
||||
} else if (req.method == "PATCH") {
|
||||
if (dispatch_request_for_content_reader(
|
||||
req, res, std::move(reader),
|
||||
patch_handlers_for_content_reader_)) {
|
||||
return true;
|
||||
}
|
||||
dispatched = dispatch_request_for_content_reader(
|
||||
req, res, std::move(reader), patch_handlers_for_content_reader_);
|
||||
} else if (req.method == "DELETE") {
|
||||
if (dispatch_request_for_content_reader(
|
||||
req, res, std::move(reader),
|
||||
delete_handlers_for_content_reader_)) {
|
||||
return true;
|
||||
dispatched = dispatch_request_for_content_reader(
|
||||
req, res, std::move(reader), delete_handlers_for_content_reader_);
|
||||
}
|
||||
|
||||
if (dispatched) {
|
||||
if (content_reader_payload_too_large) {
|
||||
// Enforce the limit: override any status the handler may have set
|
||||
// and return false so the error path sends a plain 413 response.
|
||||
res.status = StatusCode::PayloadTooLarge_413;
|
||||
res.body.clear();
|
||||
res.content_length_ = 0;
|
||||
res.content_provider_ = nullptr;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7930,16 +7994,6 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
|
||||
routed = true;
|
||||
} else {
|
||||
res.status = StatusCode::InternalServerError_500;
|
||||
std::string val;
|
||||
auto s = e.what();
|
||||
for (size_t i = 0; s[i]; i++) {
|
||||
switch (s[i]) {
|
||||
case '\r': val += "\\r"; break;
|
||||
case '\n': val += "\\n"; break;
|
||||
default: val += s[i]; break;
|
||||
}
|
||||
}
|
||||
res.set_header("EXCEPTION_WHAT", val);
|
||||
}
|
||||
} catch (...) {
|
||||
if (exception_handler_) {
|
||||
@@ -7948,7 +8002,6 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
|
||||
routed = true;
|
||||
} else {
|
||||
res.status = StatusCode::InternalServerError_500;
|
||||
res.set_header("EXCEPTION_WHAT", "UNKNOWN");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -11629,8 +11682,7 @@ void SSLClient::set_session_verifier(
|
||||
session_verifier_ = std::move(verifier);
|
||||
}
|
||||
|
||||
#if defined(_WIN32) && \
|
||||
!defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE)
|
||||
#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE
|
||||
void SSLClient::enable_windows_certificate_verification(bool enabled) {
|
||||
enable_windows_cert_verification_ = enabled;
|
||||
}
|
||||
@@ -11788,8 +11840,7 @@ bool SSLClient::initialize_ssl(Socket &socket, Error &error) {
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(_WIN32) && \
|
||||
!defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE)
|
||||
#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE
|
||||
// Additional Windows Schannel verification.
|
||||
// This provides real-time certificate validation with Windows Update
|
||||
// integration, working with both OpenSSL and MbedTLS backends.
|
||||
@@ -11835,8 +11886,7 @@ void Client::enable_server_hostname_verification(bool enabled) {
|
||||
cli_->enable_server_hostname_verification(enabled);
|
||||
}
|
||||
|
||||
#if defined(_WIN32) && \
|
||||
!defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE)
|
||||
#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE
|
||||
void Client::enable_windows_certificate_verification(bool enabled) {
|
||||
if (is_ssl_) {
|
||||
static_cast<SSLClient &>(*cli_).enable_windows_certificate_verification(
|
||||
@@ -11959,7 +12009,7 @@ bool enumerate_windows_system_certs(Callback cb) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__APPLE__) && defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
|
||||
#ifdef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
|
||||
// Enumerate macOS Keychain certificates and call callback with DER data
|
||||
template <typename Callback>
|
||||
bool enumerate_macos_keychain_certs(Callback cb) {
|
||||
|
||||
Vendored
+31
-16
@@ -8,8 +8,8 @@
|
||||
#ifndef CPPHTTPLIB_HTTPLIB_H
|
||||
#define CPPHTTPLIB_HTTPLIB_H
|
||||
|
||||
#define CPPHTTPLIB_VERSION "0.34.0"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x002200"
|
||||
#define CPPHTTPLIB_VERSION "0.35.0"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x002300"
|
||||
|
||||
/*
|
||||
* Platform compatibility check
|
||||
@@ -357,14 +357,32 @@ using socket_t = int;
|
||||
#include <any>
|
||||
#endif
|
||||
|
||||
// On macOS with a TLS backend, enable Keychain root certificates by default
|
||||
// unless the user explicitly opts out.
|
||||
#if defined(__APPLE__) && \
|
||||
!defined(CPPHTTPLIB_DISABLE_MACOSX_AUTOMATIC_ROOT_CERTIFICATES) && \
|
||||
(defined(CPPHTTPLIB_OPENSSL_SUPPORT) || \
|
||||
defined(CPPHTTPLIB_MBEDTLS_SUPPORT) || \
|
||||
defined(CPPHTTPLIB_WOLFSSL_SUPPORT))
|
||||
#ifndef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
|
||||
#define CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// On Windows, enable Schannel certificate verification by default
|
||||
// unless the user explicitly opts out.
|
||||
#if defined(_WIN32) && \
|
||||
!defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE)
|
||||
#define CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE
|
||||
#endif
|
||||
|
||||
#if defined(CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO) || \
|
||||
defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
|
||||
#if TARGET_OS_MAC
|
||||
#include <CFNetwork/CFHost.h>
|
||||
#include <CoreFoundation/CoreFoundation.h>
|
||||
#endif
|
||||
#endif // CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO or
|
||||
// CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
|
||||
#endif
|
||||
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
#ifdef _WIN32
|
||||
@@ -382,11 +400,11 @@ using socket_t = int;
|
||||
#endif
|
||||
#endif // _WIN32
|
||||
|
||||
#if defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
|
||||
#ifdef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
|
||||
#if TARGET_OS_MAC
|
||||
#include <Security/Security.h>
|
||||
#endif
|
||||
#endif // CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO
|
||||
#endif
|
||||
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/evp.h>
|
||||
@@ -430,11 +448,11 @@ using socket_t = int;
|
||||
#pragma comment(lib, "crypt32.lib")
|
||||
#endif
|
||||
#endif // _WIN32
|
||||
#if defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
|
||||
#ifdef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
|
||||
#if TARGET_OS_MAC
|
||||
#include <Security/Security.h>
|
||||
#endif
|
||||
#endif // CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
|
||||
#endif
|
||||
|
||||
// Mbed TLS 3.x API compatibility
|
||||
#if MBEDTLS_VERSION_MAJOR >= 3
|
||||
@@ -473,11 +491,11 @@ using socket_t = int;
|
||||
#pragma comment(lib, "crypt32.lib")
|
||||
#endif
|
||||
#endif // _WIN32
|
||||
#if defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
|
||||
#ifdef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
|
||||
#if TARGET_OS_MAC
|
||||
#include <Security/Security.h>
|
||||
#endif
|
||||
#endif // CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
|
||||
#endif
|
||||
#endif // CPPHTTPLIB_WOLFSSL_SUPPORT
|
||||
|
||||
// Define CPPHTTPLIB_SSL_ENABLED if any SSL backend is available
|
||||
@@ -2557,8 +2575,7 @@ public:
|
||||
|
||||
tls::ctx_t tls_context() const;
|
||||
|
||||
#if defined(_WIN32) && \
|
||||
!defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE)
|
||||
#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE
|
||||
void enable_windows_certificate_verification(bool enabled);
|
||||
#endif
|
||||
|
||||
@@ -2679,8 +2696,7 @@ public:
|
||||
|
||||
tls::ctx_t tls_context() const { return ctx_; }
|
||||
|
||||
#if defined(_WIN32) && \
|
||||
!defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE)
|
||||
#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE
|
||||
void enable_windows_certificate_verification(bool enabled);
|
||||
#endif
|
||||
|
||||
@@ -2712,8 +2728,7 @@ private:
|
||||
|
||||
std::function<SSLVerifierResponse(tls::session_t)> session_verifier_;
|
||||
|
||||
#if defined(_WIN32) && \
|
||||
!defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE)
|
||||
#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE
|
||||
bool enable_windows_cert_verification_ = true;
|
||||
#endif
|
||||
|
||||
|
||||
Vendored
+347
-250
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user