mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 16:17:40 +02:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e959d32b1c | |||
| 307bfa253d | |||
| 71e90e8813 | |||
| bc091a4dc5 | |||
| a4837577aa | |||
| e59ea539b8 | |||
| c94085df28 | |||
| e8a62631b3 |
+1
-1
@@ -830,7 +830,7 @@ std::string fs_get_cache_directory() {
|
||||
if (getenv("LLAMA_CACHE")) {
|
||||
cache_directory = std::getenv("LLAMA_CACHE");
|
||||
} else {
|
||||
#if defined(__linux__) || defined(__FreeBSD__)
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX)
|
||||
if (std::getenv("XDG_CACHE_HOME")) {
|
||||
cache_directory = std::getenv("XDG_CACHE_HOME");
|
||||
} else {
|
||||
|
||||
+20
-13
@@ -323,8 +323,8 @@ struct clip_ctx {
|
||||
std::vector<ggml_backend_t> backend_ptrs;
|
||||
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
||||
|
||||
ggml_backend_ptr backend;
|
||||
ggml_backend_ptr backend_cpu;
|
||||
ggml_backend_t backend;
|
||||
ggml_backend_t backend_cpu;
|
||||
ggml_backend_buffer_ptr buf;
|
||||
|
||||
ggml_backend_sched_ptr sched;
|
||||
@@ -332,27 +332,34 @@ struct clip_ctx {
|
||||
clip_image_size load_image_size;
|
||||
|
||||
clip_ctx(clip_context_params & ctx_params) {
|
||||
backend_cpu = ggml_backend_ptr(ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr));
|
||||
backend = ggml_backend_ptr(ctx_params.use_gpu
|
||||
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
||||
backend = ctx_params.use_gpu
|
||||
? ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr)
|
||||
: nullptr);
|
||||
: nullptr;
|
||||
|
||||
if (backend) {
|
||||
LOG_INF("%s: CLIP using %s backend\n", __func__, ggml_backend_name(backend.get()));
|
||||
backend_ptrs.push_back(backend.get());
|
||||
backend_buft.push_back(ggml_backend_get_default_buffer_type(backend.get()));
|
||||
LOG_INF("%s: CLIP using %s backend\n", __func__, ggml_backend_name(backend));
|
||||
backend_ptrs.push_back(backend);
|
||||
backend_buft.push_back(ggml_backend_get_default_buffer_type(backend));
|
||||
} else {
|
||||
backend = std::move(backend_cpu);
|
||||
backend = backend_cpu;
|
||||
LOG_INF("%s: CLIP using CPU backend\n", __func__);
|
||||
}
|
||||
|
||||
backend_ptrs.push_back(backend_cpu.get());
|
||||
backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu.get()));
|
||||
backend_ptrs.push_back(backend_cpu);
|
||||
backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu));
|
||||
|
||||
sched.reset(
|
||||
ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), 8192, false)
|
||||
);
|
||||
}
|
||||
|
||||
~clip_ctx() {
|
||||
ggml_backend_free(backend);
|
||||
if (backend != backend_cpu) {
|
||||
ggml_backend_free(backend_cpu);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
|
||||
@@ -1428,7 +1435,7 @@ struct clip_model_loader {
|
||||
}
|
||||
|
||||
// alloc memory and offload data
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend.get());
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend);
|
||||
ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft));
|
||||
ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
||||
for (auto & t : tensors_to_load) {
|
||||
@@ -2610,7 +2617,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_cpu_set_n_threads(ctx->backend_cpu.get(), n_threads);
|
||||
ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads);
|
||||
|
||||
auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf);
|
||||
if (status != GGML_STATUS_SUCCESS) {
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <fstream>
|
||||
#include <cmath>
|
||||
#include <cctype>
|
||||
#include <algorithm>
|
||||
|
||||
struct quant_option {
|
||||
std::string name;
|
||||
@@ -16,7 +17,7 @@ struct quant_option {
|
||||
std::string desc;
|
||||
};
|
||||
|
||||
static const std::vector<struct quant_option> QUANT_OPTIONS = {
|
||||
static const std::vector<quant_option> QUANT_OPTIONS = {
|
||||
{ "Q4_0", LLAMA_FTYPE_MOSTLY_Q4_0, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
|
||||
{ "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1, " 4.78G, +0.4511 ppl @ Llama-3-8B", },
|
||||
{ "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 5.21G, +0.1316 ppl @ Llama-3-8B", },
|
||||
@@ -105,7 +106,8 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
|
||||
//
|
||||
[[noreturn]]
|
||||
static void usage(const char * executable) {
|
||||
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable);
|
||||
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type]\n", executable);
|
||||
printf(" [--token-embedding-type] [--tensor-type] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n");
|
||||
printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
|
||||
printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
|
||||
printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n");
|
||||
@@ -114,6 +116,8 @@ static void usage(const char * executable) {
|
||||
printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n");
|
||||
printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor\n");
|
||||
printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n");
|
||||
printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n");
|
||||
printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n");
|
||||
printf(" --keep-split: will generate quantized model in the same shards as input\n");
|
||||
printf(" --override-kv KEY=TYPE:VALUE\n");
|
||||
printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n");
|
||||
@@ -244,6 +248,107 @@ static ggml_type parse_ggml_type(const char * arg) {
|
||||
return GGML_TYPE_COUNT;
|
||||
}
|
||||
|
||||
// Allowed tensors for arbitrary quantization with --tensor-type option
|
||||
static const std::vector<std::string> ALLOWED_TENSOR_TYPE = {
|
||||
"attn_k",
|
||||
"attn_kv_a_mqa",
|
||||
"attn_kv_b",
|
||||
"attn_o",
|
||||
"attn_output",
|
||||
"attn_q",
|
||||
"attn_q_a",
|
||||
"attn_q_b",
|
||||
"attn_qkv",
|
||||
"attn_v",
|
||||
"channel_mix_key",
|
||||
"channel_mix_receptance",
|
||||
"channel_mix_value",
|
||||
"cls",
|
||||
"cls.output",
|
||||
"cross_attn_k",
|
||||
"cross_attn_o",
|
||||
"cross_attn_q",
|
||||
"cross_attn_v",
|
||||
"ffn_act",
|
||||
"ffn_down",
|
||||
"ffn_down_exps",
|
||||
"ffn_down_shexp",
|
||||
"ffn_gate",
|
||||
"ffn_gate_exps",
|
||||
"ffn_gate_shexp",
|
||||
"ffn_up",
|
||||
"ffn_up_exps",
|
||||
"ffn_up_shexp",
|
||||
"ssm_in",
|
||||
"ssm_out",
|
||||
"time_mix_gate",
|
||||
"time_mix_key",
|
||||
"time_mix_output",
|
||||
"time_mix_receptance",
|
||||
"time_mix_value",
|
||||
};
|
||||
|
||||
// changes to this struct must be replicated in llama-quant.cpp
|
||||
struct tensor_quantization {
|
||||
std::string name;
|
||||
ggml_type quant = GGML_TYPE_COUNT;
|
||||
};
|
||||
|
||||
static bool parse_tensor_type(const char * data, std::vector<tensor_quantization> & tensor_type) {
|
||||
const char * sep = strchr(data, '=');
|
||||
if (sep == nullptr) {
|
||||
printf("\n%s: malformed tensor type '%s'\n\n", __func__, data);
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t tn_len = sep - data;
|
||||
if (tn_len == 0) {
|
||||
printf("\n%s: missing tensor name\n\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (const size_t qt_len = strlen(sep); qt_len == 1) {
|
||||
printf("\n%s: missing quantization type\n\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string tn(data, tn_len);
|
||||
std::transform(tn.begin(), tn.end(), tn.begin(), tolower);
|
||||
sep++;
|
||||
const std::string qt(sep);
|
||||
|
||||
bool found = false;
|
||||
for (const auto & allowed : ALLOWED_TENSOR_TYPE) {
|
||||
std::string tensor;
|
||||
tensor = tn.rfind('.') != std::string::npos ? tn.substr(tn.rfind('.') + 1) : tn;
|
||||
// handle special case of cls.output
|
||||
std::string cls_output = "cls.output";
|
||||
if (tn.find(cls_output) != std::string::npos) {
|
||||
tensor = "cls.output";
|
||||
}
|
||||
// check if an allowed tensor exists and it's at the end of the kv string
|
||||
if (tensor == allowed) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
printf("\n%s: invalid tensor name '%s'\n\n", __func__, tn.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (parse_ggml_type(qt.c_str()) == GGML_TYPE_COUNT) {
|
||||
printf("\n%s: invalid quantization type '%s'\n\n", __func__, qt.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
tensor_quantization tqz;
|
||||
tqz.name = tn;
|
||||
tqz.quant = parse_ggml_type(qt.c_str());
|
||||
tensor_type.emplace_back(std::move(tqz));
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
if (argc < 3) {
|
||||
usage(argv[0]);
|
||||
@@ -255,6 +360,7 @@ int main(int argc, char ** argv) {
|
||||
std::string imatrix_file;
|
||||
std::vector<std::string> included_weights, excluded_weights;
|
||||
std::vector<llama_model_kv_override> kv_overrides;
|
||||
std::vector<tensor_quantization> tensor_types;
|
||||
|
||||
for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
|
||||
if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) {
|
||||
@@ -277,6 +383,10 @@ int main(int argc, char ** argv) {
|
||||
} else {
|
||||
usage(argv[0]);
|
||||
}
|
||||
} else if (strcmp(argv[arg_idx], "--tensor-type") == 0) {
|
||||
if (arg_idx == argc-1 || !parse_tensor_type(argv[++arg_idx], tensor_types)) {
|
||||
usage(argv[0]);
|
||||
}
|
||||
} else if (strcmp(argv[arg_idx], "--override-kv") == 0) {
|
||||
if (arg_idx == argc-1 || !string_parse_kv_override(argv[++arg_idx], kv_overrides)) {
|
||||
usage(argv[0]);
|
||||
@@ -361,6 +471,9 @@ int main(int argc, char ** argv) {
|
||||
kv_overrides.back().key[0] = 0;
|
||||
params.kv_overrides = &kv_overrides;
|
||||
}
|
||||
if (!tensor_types.empty()) {
|
||||
params.tensor_types = &tensor_types;
|
||||
}
|
||||
|
||||
llama_backend_init();
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ static std::string fs_get_cache_directory() {
|
||||
if (getenv("LLAMA_CACHE")) {
|
||||
cache_directory = std::getenv("LLAMA_CACHE");
|
||||
} else {
|
||||
#ifdef __linux__
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX)
|
||||
if (std::getenv("XDG_CACHE_HOME")) {
|
||||
cache_directory = std::getenv("XDG_CACHE_HOME");
|
||||
} else {
|
||||
@@ -136,7 +136,9 @@ static std::string fs_get_cache_directory() {
|
||||
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
|
||||
#elif defined(_WIN32)
|
||||
cache_directory = std::getenv("LOCALAPPDATA");
|
||||
#endif // __linux__
|
||||
#else
|
||||
# error Unknown architecture
|
||||
#endif
|
||||
cache_directory = ensure_trailing_slash(cache_directory);
|
||||
cache_directory += "llama.cpp";
|
||||
}
|
||||
|
||||
@@ -3907,6 +3907,21 @@ int main(int argc, char ** argv) {
|
||||
res_ok(res, {{ "success", true }});
|
||||
};
|
||||
|
||||
const auto handle_api_show = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
||||
json data = {
|
||||
{
|
||||
"template", common_chat_templates_source(ctx_server.chat_templates.get()),
|
||||
},
|
||||
{
|
||||
"model_info", {
|
||||
{ "llama.context_length", ctx_server.slots.back().n_ctx, },
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
res_ok(res, data);
|
||||
};
|
||||
|
||||
// handle completion-like requests (completion, chat, infill)
|
||||
// we can optionally provide a custom format for partial results and final results
|
||||
const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
|
||||
@@ -4471,6 +4486,7 @@ int main(int argc, char ** argv) {
|
||||
svr->Get ("/metrics", handle_metrics);
|
||||
svr->Get ("/props", handle_props);
|
||||
svr->Post("/props", handle_props_change);
|
||||
svr->Post("/api/show", handle_api_show);
|
||||
svr->Get ("/models", handle_models); // public endpoint (no API key check)
|
||||
svr->Get ("/v1/models", handle_models); // public endpoint (no API key check)
|
||||
svr->Post("/completion", handle_completions); // legacy
|
||||
|
||||
@@ -183,67 +183,63 @@ static inline __m256 __avx_rearranged_f32cx8_load(ggml_fp16_t *x, __m128i arrang
|
||||
|
||||
#if defined(__AVX2__) || defined(__AVX512F__)
|
||||
#if defined(__AVX512F__)
|
||||
// add int16_t pairwise and return as 512 bit int vector
|
||||
static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) {
|
||||
// add int16_t pairwise and return as 512 bit int vector, then add the accumulator
|
||||
static inline __m512i sum_i16_pairs_acc_int32x16(const __m512i acc, const __m512i x) {
|
||||
const __m512i ones = _mm512_set1_epi16(1);
|
||||
return _mm512_madd_epi16(ones, x);
|
||||
return _mm512_add_epi32(acc, _mm512_madd_epi16(ones, x));
|
||||
}
|
||||
|
||||
static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) {
|
||||
static inline __m512i mul_sum_us8_pairs_acc_int32x16(const __m512i acc, const __m512i ax, const __m512i sy) {
|
||||
#if defined(__AVX512VNNI__)
|
||||
const __m512i zero = _mm512_setzero_si512();
|
||||
return _mm512_dpbusd_epi32(zero, ax, sy);
|
||||
return _mm512_dpbusd_epi32(acc, ax, sy);
|
||||
#else
|
||||
// Perform multiplication and create 16-bit values
|
||||
const __m512i dot = _mm512_maddubs_epi16(ax, sy);
|
||||
return sum_i16_pairs_int_32x16(dot);
|
||||
return sum_i16_pairs_acc_int32x16(acc, dot);
|
||||
#endif
|
||||
}
|
||||
|
||||
// multiply int8_t, add results pairwise twice and return as 512 bit int vector
|
||||
static inline __m512i mul_sum_i8_pairs_int32x16(const __m512i x, const __m512i y) {
|
||||
// multiply int8_t, add results pairwise twice and return as 512 bit int vector,then add the accumulator
|
||||
static inline __m512i mul_sum_i8_pairs_acc_int32x16(const __m512i acc, const __m512i x, const __m512i y) {
|
||||
const __m512i zero = _mm512_setzero_si512();
|
||||
// Get absolute values of x vectors
|
||||
const __m512i ax = _mm512_abs_epi8(x);
|
||||
// Sign the values of the y vectors
|
||||
__mmask64 blt0 = _mm512_movepi8_mask(x);
|
||||
const __m512i sy = _mm512_mask_sub_epi8(y, blt0, zero, y);
|
||||
return mul_sum_us8_pairs_int32x16(ax, sy);
|
||||
return mul_sum_us8_pairs_acc_int32x16(acc, ax, sy);
|
||||
}
|
||||
#endif
|
||||
|
||||
// add int16_t pairwise and return as 256 bit int vector
|
||||
static inline __m256i sum_i16_pairs_int32x8(const __m256i x) {
|
||||
// add int16_t pairwise and return as 256 bit int vector, then add the accumulator
|
||||
static inline __m256i sum_i16_pairs_acc_int32x8(const __m256i acc, const __m256i x) {
|
||||
const __m256i ones = _mm256_set1_epi16(1);
|
||||
return _mm256_madd_epi16(ones, x);
|
||||
return _mm256_add_epi32(acc, _mm256_madd_epi16(ones, x));
|
||||
}
|
||||
|
||||
static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) {
|
||||
static inline __m256i mul_sum_us8_pairs_acc_int32x8(const __m256i acc, const __m256i ax, const __m256i sy) {
|
||||
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
|
||||
const __m256i zero = _mm256_setzero_si256();
|
||||
return _mm256_dpbusd_epi32(zero, ax, sy);
|
||||
return _mm256_dpbusd_epi32(acc, ax, sy);
|
||||
#elif defined(__AVXVNNI__)
|
||||
const __m256i zero = _mm256_setzero_si256();
|
||||
return _mm256_dpbusd_avx_epi32(zero, ax, sy);
|
||||
return _mm256_dpbusd_avx_epi32(acc, ax, sy);
|
||||
#else
|
||||
// Perform multiplication and create 16-bit values
|
||||
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
||||
return sum_i16_pairs_int32x8(dot);
|
||||
return sum_i16_pairs_acc_int32x8(acc, dot);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Integer variant of the function defined in ggml-quants.c
|
||||
// multiply int8_t, add results pairwise twice and return as 256 bit int vector
|
||||
static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y) {
|
||||
#if __AVXVNNIINT8__
|
||||
const __m256i zero = _mm256_setzero_si256();
|
||||
return _mm256_dpbssd_epi32(zero, x, y);
|
||||
// multiply int8_t, add results pairwise twice and return as 256 bit int vector, then add the accumulator
|
||||
static inline __m256i mul_sum_i8_pairs_acc_int32x8(const __m256i acc, const __m256i x, const __m256i y) {
|
||||
#if defined(__AVXVNNIINT8__)
|
||||
return _mm256_dpbssd_epi32(acc, x, y);
|
||||
#else
|
||||
// Get absolute values of x vectors
|
||||
const __m256i ax = _mm256_sign_epi8(x, x);
|
||||
// Sign the values of the y vectors
|
||||
const __m256i sy = _mm256_sign_epi8(y, x);
|
||||
return mul_sum_us8_pairs_int32x8(ax, sy);
|
||||
return mul_sum_us8_pairs_acc_int32x8(acc, ax, sy);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
@@ -1175,17 +1171,17 @@ static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
|
||||
// ...........................................................................
|
||||
// B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)
|
||||
|
||||
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0)));
|
||||
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85)));
|
||||
iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0));
|
||||
iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85));
|
||||
|
||||
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170)));
|
||||
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255)));
|
||||
iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170));
|
||||
iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255));
|
||||
|
||||
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0)));
|
||||
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85)));
|
||||
iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0));
|
||||
iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85));
|
||||
|
||||
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170)));
|
||||
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255)));
|
||||
iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170));
|
||||
iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255));
|
||||
|
||||
// Accumulated values multipled with appropriate scales
|
||||
acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
|
||||
@@ -3239,22 +3235,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
|
||||
|
||||
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
||||
// Resembles MMLAs into 2x2 matrices in ARM Version
|
||||
__m512i iacc_mat_00_sp1 =
|
||||
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1));
|
||||
__m512i iacc_mat_01_sp1 =
|
||||
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1));
|
||||
__m512i iacc_mat_10_sp1 =
|
||||
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1));
|
||||
__m512i iacc_mat_11_sp1 =
|
||||
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1));
|
||||
__m512i iacc_mat_00_sp2 =
|
||||
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2));
|
||||
__m512i iacc_mat_01_sp2 =
|
||||
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2));
|
||||
__m512i iacc_mat_10_sp2 =
|
||||
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2));
|
||||
__m512i iacc_mat_11_sp2 =
|
||||
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2));
|
||||
const __m512i zero = _mm512_setzero_epi32();
|
||||
__m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1);
|
||||
__m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1);
|
||||
__m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1);
|
||||
__m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1);
|
||||
__m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2);
|
||||
__m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2);
|
||||
__m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2);
|
||||
__m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2);
|
||||
|
||||
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
||||
__m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
|
||||
@@ -3430,22 +3419,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
|
||||
|
||||
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
||||
// Resembles MMLAs into 2x2 matrices in ARM Version
|
||||
__m512i iacc_mat_00_sp1 =
|
||||
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1));
|
||||
__m512i iacc_mat_01_sp1 =
|
||||
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1));
|
||||
__m512i iacc_mat_10_sp1 =
|
||||
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1));
|
||||
__m512i iacc_mat_11_sp1 =
|
||||
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1));
|
||||
__m512i iacc_mat_00_sp2 =
|
||||
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2));
|
||||
__m512i iacc_mat_01_sp2 =
|
||||
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2));
|
||||
__m512i iacc_mat_10_sp2 =
|
||||
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2));
|
||||
__m512i iacc_mat_11_sp2 =
|
||||
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2));
|
||||
const __m512i zero = _mm512_setzero_epi32();
|
||||
__m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1);
|
||||
__m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1);
|
||||
__m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1);
|
||||
__m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1);
|
||||
__m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2);
|
||||
__m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2);
|
||||
__m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2);
|
||||
__m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2);
|
||||
|
||||
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
||||
__m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
|
||||
@@ -3605,22 +3587,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
|
||||
|
||||
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
||||
// Resembles MMLAs into 2x2 matrices in ARM Version
|
||||
__m256i iacc_mat_00_sp1 =
|
||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
|
||||
__m256i iacc_mat_01_sp1 =
|
||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
|
||||
__m256i iacc_mat_10_sp1 =
|
||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
|
||||
__m256i iacc_mat_11_sp1 =
|
||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
|
||||
__m256i iacc_mat_00_sp2 =
|
||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
|
||||
__m256i iacc_mat_01_sp2 =
|
||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
|
||||
__m256i iacc_mat_10_sp2 =
|
||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
|
||||
__m256i iacc_mat_11_sp2 =
|
||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
|
||||
const __m256i zero = _mm256_setzero_si256();
|
||||
__m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1);
|
||||
__m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1);
|
||||
__m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1);
|
||||
__m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1);
|
||||
__m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2);
|
||||
__m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2);
|
||||
__m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2);
|
||||
__m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2);
|
||||
|
||||
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
||||
__m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
|
||||
@@ -3769,22 +3744,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
|
||||
|
||||
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
||||
// Resembles MMLAs into 2x2 matrices in ARM Version
|
||||
__m256i iacc_mat_00_sp1 =
|
||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
|
||||
__m256i iacc_mat_01_sp1 =
|
||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
|
||||
__m256i iacc_mat_10_sp1 =
|
||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
|
||||
__m256i iacc_mat_11_sp1 =
|
||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
|
||||
__m256i iacc_mat_00_sp2 =
|
||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
|
||||
__m256i iacc_mat_01_sp2 =
|
||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
|
||||
__m256i iacc_mat_10_sp2 =
|
||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
|
||||
__m256i iacc_mat_11_sp2 =
|
||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
|
||||
const __m256i zero = _mm256_setzero_si256();
|
||||
__m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1);
|
||||
__m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1);
|
||||
__m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1);
|
||||
__m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1);
|
||||
__m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2);
|
||||
__m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2);
|
||||
__m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2);
|
||||
__m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2);
|
||||
|
||||
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
||||
__m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
|
||||
|
||||
@@ -2488,10 +2488,10 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
#endif
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_MUL_MAT_ID) {
|
||||
if (node->op == GGML_OP_MUL_MAT_ID || node->op == GGML_OP_CONT || node->op == GGML_OP_DUP) {
|
||||
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -201,6 +201,11 @@ void main() {
|
||||
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
||||
uint32_t k_stride = p.nb11;
|
||||
uint32_t v_stride = p.nb21;
|
||||
// When using grouped query attention, all rows use the same mask (stride 0).
|
||||
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
||||
// that prevents the compiler from folding the "&" through the select
|
||||
// and breaking the alignment detection.
|
||||
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
||||
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
||||
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
||||
{
|
||||
@@ -209,6 +214,7 @@ void main() {
|
||||
k_stride &= ~7;
|
||||
v_stride &= ~7;
|
||||
#endif
|
||||
m_stride &= ~7;
|
||||
}
|
||||
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
|
||||
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
||||
@@ -261,10 +267,7 @@ void main() {
|
||||
if (p.mask != 0) {
|
||||
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
||||
// When using grouped query attention, all rows use the same mask.
|
||||
if (p.gqa_ratio > 1) {
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1);
|
||||
}
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
||||
|
||||
|
||||
+12
-11
@@ -367,17 +367,18 @@ extern "C" {
|
||||
|
||||
// model quantization parameters
|
||||
typedef struct llama_model_quantize_params {
|
||||
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
|
||||
enum llama_ftype ftype; // quantize to this llama_ftype
|
||||
enum ggml_type output_tensor_type; // output tensor type
|
||||
enum ggml_type token_embedding_type; // token embeddings tensor type
|
||||
bool allow_requantize; // allow quantizing non-f32/f16 tensors
|
||||
bool quantize_output_tensor; // quantize output.weight
|
||||
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
|
||||
bool pure; // quantize all tensors to the default type
|
||||
bool keep_split; // quantize to the same number of shards
|
||||
void * imatrix; // pointer to importance matrix data
|
||||
void * kv_overrides; // pointer to vector containing overrides
|
||||
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
|
||||
enum llama_ftype ftype; // quantize to this llama_ftype
|
||||
enum ggml_type output_tensor_type; // output tensor type
|
||||
enum ggml_type token_embedding_type; // token embeddings tensor type
|
||||
bool allow_requantize; // allow quantizing non-f32/f16 tensors
|
||||
bool quantize_output_tensor; // quantize output.weight
|
||||
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
|
||||
bool pure; // quantize all tensors to the default type
|
||||
bool keep_split; // quantize to the same number of shards
|
||||
void * imatrix; // pointer to importance matrix data
|
||||
void * kv_overrides; // pointer to vector containing overrides
|
||||
void * tensor_types; // pointer to vector containing tensor types
|
||||
} llama_model_quantize_params;
|
||||
|
||||
typedef struct llama_logit_bias {
|
||||
|
||||
+28
-7
@@ -10,6 +10,7 @@
|
||||
#include <cinttypes>
|
||||
#include <fstream>
|
||||
#include <mutex>
|
||||
#include <regex>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
|
||||
@@ -47,8 +48,14 @@ struct quantize_state_impl {
|
||||
{}
|
||||
};
|
||||
|
||||
// changes to this struct must be replicated in quantize.cpp
|
||||
struct tensor_quantization {
|
||||
std::string name;
|
||||
ggml_type quant = GGML_TYPE_COUNT;
|
||||
};
|
||||
|
||||
static void llama_tensor_dequantize_impl(
|
||||
struct ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
|
||||
ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
|
||||
const size_t nelements, const int nthread
|
||||
) {
|
||||
if (output.size() < nelements) {
|
||||
@@ -536,7 +543,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
model.load_hparams(ml);
|
||||
model.load_stats (ml);
|
||||
|
||||
struct quantize_state_impl qs(model, params);
|
||||
quantize_state_impl qs(model, params);
|
||||
|
||||
if (params->only_copy) {
|
||||
ftype = ml.ftype;
|
||||
@@ -661,7 +668,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
// populate the original tensors so we get an initial meta data
|
||||
for (const auto * it : tensors) {
|
||||
uint16_t i_split = params->keep_split ? it->idx : 0;
|
||||
struct ggml_tensor * tensor = it->tensor;
|
||||
ggml_tensor * tensor = it->tensor;
|
||||
if (!ctx_outs[i_split]) {
|
||||
ctx_outs[i_split].reset(gguf_init_empty());
|
||||
}
|
||||
@@ -710,7 +717,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
new_ofstream(0);
|
||||
for (const auto * it : tensors) {
|
||||
const auto & weight = *it;
|
||||
struct ggml_tensor * tensor = weight.tensor;
|
||||
ggml_tensor * tensor = weight.tensor;
|
||||
if (weight.idx != cur_split && params->keep_split) {
|
||||
close_ofstream();
|
||||
new_ofstream(weight.idx);
|
||||
@@ -776,7 +783,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
// do not quantize relative position bias (T5)
|
||||
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
|
||||
|
||||
enum ggml_type new_type;
|
||||
ggml_type new_type;
|
||||
void * new_data;
|
||||
size_t new_size;
|
||||
|
||||
@@ -786,6 +793,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
// get more optimal quantization type based on the tensor shape, layer, etc.
|
||||
if (!params->pure && ggml_is_quantized(default_type)) {
|
||||
new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
|
||||
// unless the user specifies a type
|
||||
if (params->tensor_types) {
|
||||
const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
|
||||
for (const auto & [tname, qtype] : tensor_types) {
|
||||
if (std::regex pattern(tname); std::regex_search(tensor->name, pattern)) {
|
||||
if (qtype != new_type) {
|
||||
LLAMA_LOG_DEBUG("(overriding %s -> %s), ", ggml_type_name(new_type), ggml_type_name(qtype));
|
||||
}
|
||||
new_type = qtype;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
|
||||
new_type = params->token_embedding_type;
|
||||
@@ -910,8 +930,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
// interface implementation
|
||||
//
|
||||
|
||||
struct llama_model_quantize_params llama_model_quantize_default_params() {
|
||||
struct llama_model_quantize_params result = {
|
||||
llama_model_quantize_params llama_model_quantize_default_params() {
|
||||
llama_model_quantize_params result = {
|
||||
/*.nthread =*/ 0,
|
||||
/*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
|
||||
/*.output_tensor_type =*/ GGML_TYPE_COUNT,
|
||||
@@ -923,6 +943,7 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
|
||||
/*.keep_split =*/ false,
|
||||
/*.imatrix =*/ nullptr,
|
||||
/*.kv_overrides =*/ nullptr,
|
||||
/*.tensor_type =*/ nullptr,
|
||||
};
|
||||
|
||||
return result;
|
||||
|
||||
Reference in New Issue
Block a user