Compare commits

...

12 Commits

Author SHA1 Message Date
Jared Van Bortel 15499eb942 mpt : do not duplicate token_embd.weight on disk (#5670) 2024-02-22 17:05:23 -05:00
Georgi Gerganov 96633eeca1 gemma : use more bits for the token_embd.weight tensor (#5650)
* gemma : use Q8_0 for the token_embd.weight tensor

* llama : quantize token_embd.weight using output type
2024-02-22 23:23:46 +02:00
Georgi Gerganov 847eedbdb2 py : add Gemma conversion from HF models (#5647)
* py : add gemma conversion from HF models

* Update convert-hf-to-gguf.py

Co-authored-by: Aarni Koskela <akx@iki.fi>

* Update convert-hf-to-gguf.py

Co-authored-by: Aarni Koskela <akx@iki.fi>

* Update convert-hf-to-gguf.py

Co-authored-by: Jared Van Bortel <jared@nomic.ai>

---------

Co-authored-by: Aarni Koskela <akx@iki.fi>
Co-authored-by: Jared Van Bortel <jared@nomic.ai>
2024-02-22 23:22:48 +02:00
Georgi Gerganov 7e4f339c40 ggml : always define ggml_fp16_t as uint16_t (#5666)
* ggml : always define ggml_fp16_t as uint16_t

ggml-ci

* ggml : cont

ggml-ci

* ggml : cont

* ggml : cont

ggml-ci

* ggml : cont

ggml-ci

* cuda : no longer ggml headers last

ggml-ci

* ggml : fix q6_K FP16 -> FP32 conversion

ggml-ci

* ggml : more FP16 -> FP32 conversion fixes

ggml-ci
2024-02-22 23:21:39 +02:00
Georgi Gerganov 334f76fa38 sync : ggml 2024-02-22 23:21:05 +02:00
Georgi Gerganov efd56b1c21 ggml : 32-bit arm compat (whisper/1891)
* ggml : 32-bit arm compat

* ggml : add ggml_vqtbl1q_s8 impl

* ggml : cont
2024-02-22 23:20:50 +02:00
Someone 201294ae17 nix: init singularity and docker images (#5056)
Exposes a few attributes demonstrating how to build [singularity](https://docs.sylabs.io/guides/latest/user-guide/)/[apptainer](https://apptainer.org/) and Docker images re-using llama.cpp's Nix expression.

Built locally on `x86_64-linux` with `nix build github:someoneserge/llama.cpp/feat/nix/images#llamaPackages.{docker,docker-min,sif,llama-cpp}` and it's fast and effective.
2024-02-22 11:44:10 -08:00
Georgi Gerganov 5a9e2f60ba py : minor fixes (#5668) 2024-02-22 20:13:25 +02:00
Xuan Son Nguyen 373ee3fbba Add Gemma chat template (#5665)
* add gemma chat template

* gemma: only apply system_prompt on non-model message
2024-02-22 19:10:21 +01:00
Someone 4cb4d8b22d workflows: nix: hardcode cachix ids, build unconditionally (#5663)
GitHub does not expose environment and repository variables to PRs coming from forks implies that we've been disabling the Nix CI actions for most PRs. 

The `if:` also didn't make much sense, because we can always pull from cachix, and there's no point (albeit no risk either) in pushing cache for the untrusted code.
2024-02-22 08:32:09 -08:00
Georgi Gerganov 3a03541ced minor : fix trailing whitespace (#5638) 2024-02-22 13:54:03 +02:00
Georgi Gerganov 56d03d92be readme : update hot topics 2024-02-22 10:35:54 +02:00
15 changed files with 249 additions and 62 deletions
+37
View File
@@ -0,0 +1,37 @@
{
lib,
dockerTools,
buildEnv,
llama-cpp,
interactive ? true,
coreutils,
}:
# A tar that can be fed into `docker load`:
#
# $ nix build .#llamaPackages.docker
# $ docker load < result
# For details and variations cf.
# - https://nixos.org/manual/nixpkgs/unstable/#ssec-pkgs-dockerTools-buildLayeredImage
# - https://discourse.nixos.org/t/a-faster-dockertools-buildimage-prototype/16922
# - https://nixery.dev/
# Approximate (compressed) sizes, at the time of writing, are:
#
# .#llamaPackages.docker: 125M;
# .#llamaPackagesCuda.docker: 537M;
# .#legacyPackages.aarch64-linux.llamaPackagesXavier.docker: 415M.
dockerTools.buildLayeredImage {
name = llama-cpp.pname;
tag = "latest";
contents =
[ llama-cpp ]
++ lib.optionals interactive [
coreutils
dockerTools.binSh
dockerTools.caCertificates
];
}
+3
View File
@@ -12,5 +12,8 @@ lib.makeScope newScope (
self: {
inherit llamaVersion;
llama-cpp = self.callPackage ./package.nix { };
docker = self.callPackage ./docker.nix { };
docker-min = self.callPackage ./docker.nix { interactive = false; };
sif = self.callPackage ./sif.nix { };
}
)
+27
View File
@@ -0,0 +1,27 @@
{
lib,
singularity-tools,
llama-cpp,
bashInteractive,
interactive ? false,
}:
let
optionalInt = cond: x: if cond then x else 0;
in
singularity-tools.buildImage rec {
inherit (llama-cpp) name;
contents = [ llama-cpp ] ++ lib.optionals interactive [ bashInteractive ];
# These are excessive (but safe) for most variants. Building singularity
# images requires superuser privileges, so we build them inside a VM in a
# writable image of pre-determined size.
#
# ROCm is currently affected by https://github.com/NixOS/nixpkgs/issues/276846
#
# Expected image sizes:
# - cpu/blas: 150M,
# - cuda, all gencodes: 560M,
diskSize = 4096 + optionalInt llama-cpp.useRocm 16384;
memSize = diskSize;
}
+3 -4
View File
@@ -19,7 +19,6 @@ on:
jobs:
nix-build-aarch64:
if: ${{ vars.CACHIX_NAME != '' }}
runs-on: ubuntu-latest
steps:
- name: Checkout repository
@@ -37,8 +36,8 @@ jobs:
extra-conf: |
extra-platforms = aarch64-linux
extra-system-features = nixos-test kvm
extra-substituters = https://${{ vars.CACHIX_NAME }}.cachix.org https://cuda-maintainers.cachix.org
extra-trusted-public-keys = ${{ vars.CACHIX_PUBLIC_KEY }} cuda-maintainers.cachix.org-1:0dq3bujKpuEPMCX6U4WylrUDZ9JyUG0VpVZa7CNfq5E=
extra-substituters = https://llama-cpp.cachix.org https://cuda-maintainers.cachix.org
extra-trusted-public-keys = llama-cpp.cachix.org-1:H75X+w83wUKTIPSO1KWy9ADUrzThyGs8P5tmAbkWhQc= cuda-maintainers.cachix.org-1:0dq3bujKpuEPMCX6U4WylrUDZ9JyUG0VpVZa7CNfq5E=
- uses: DeterminateSystems/magic-nix-cache-action@v2
with:
upstream-cache: https://${{ matrix.cachixName }}.cachix.org
@@ -46,7 +45,7 @@ jobs:
uses: cachix/cachix-action@v13
with:
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
name: ${{ vars.CACHIX_NAME }}
name: llama-cpp
- name: Show all output paths
run: >
nix run github:nix-community/nix-eval-jobs
+5 -6
View File
@@ -23,8 +23,8 @@ jobs:
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
extra-conf: |
extra-substituters = https://${{ vars.CACHIX_NAME }}.cachix.org https://cuda-maintainers.cachix.org
extra-trusted-public-keys = ${{ vars.CACHIX_PUBLIC_KEY }} cuda-maintainers.cachix.org-1:0dq3bujKpuEPMCX6U4WylrUDZ9JyUG0VpVZa7CNfq5E=
extra-substituters = https://llama-cpp.cachix.org https://cuda-maintainers.cachix.org
extra-trusted-public-keys = llama-cpp.cachix.org-1:H75X+w83wUKTIPSO1KWy9ADUrzThyGs8P5tmAbkWhQc= cuda-maintainers.cachix.org-1:0dq3bujKpuEPMCX6U4WylrUDZ9JyUG0VpVZa7CNfq5E=
- uses: DeterminateSystems/magic-nix-cache-action@v2
with:
upstream-cache: https://${{ matrix.cachixName }}.cachix.org
@@ -37,7 +37,6 @@ jobs:
--flake
".#packages.$(nix eval --raw --impure --expr builtins.currentSystem)"
nix-build:
if: ${{ vars.CACHIX_NAME != '' }}
strategy:
fail-fast: false
matrix:
@@ -51,8 +50,8 @@ jobs:
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
extra-conf: |
extra-substituters = https://${{ vars.CACHIX_NAME }}.cachix.org https://cuda-maintainers.cachix.org
extra-trusted-public-keys = ${{ vars.CACHIX_PUBLIC_KEY }} cuda-maintainers.cachix.org-1:0dq3bujKpuEPMCX6U4WylrUDZ9JyUG0VpVZa7CNfq5E=
extra-substituters = https://llama-cpp.cachix.org https://cuda-maintainers.cachix.org
extra-trusted-public-keys = llama-cpp.cachix.org-1:H75X+w83wUKTIPSO1KWy9ADUrzThyGs8P5tmAbkWhQc= cuda-maintainers.cachix.org-1:0dq3bujKpuEPMCX6U4WylrUDZ9JyUG0VpVZa7CNfq5E=
- uses: DeterminateSystems/magic-nix-cache-action@v2
with:
upstream-cache: https://${{ matrix.cachixName }}.cachix.org
@@ -60,7 +59,7 @@ jobs:
uses: cachix/cachix-action@v13
with:
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
name: ${{ vars.CACHIX_NAME }}
name: llama-cpp
- name: Build
run: >
nix run github:Mic92/nix-fast-build
+1
View File
@@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
### Hot topics
- Support for chat templates: [Wiki (contributions welcome)](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template)
- Support for Gemma models: https://github.com/ggerganov/llama.cpp/pull/5631
- Non-linear quantization IQ4_NL: https://github.com/ggerganov/llama.cpp/pull/5590
- Looking for contributions to improve and maintain the `server` example: https://github.com/ggerganov/llama.cpp/issues/4216
+62 -6
View File
@@ -218,6 +218,8 @@ class Model:
return BertModel
if model_architecture == "NomicBertModel":
return NomicBertModel
if model_architecture == "GemmaForCausalLM":
return GemmaModel
return Model
def _is_model_safetensors(self) -> bool:
@@ -277,6 +279,8 @@ class Model:
return gguf.MODEL_ARCH.BERT
if arch == "NomicBertModel":
return gguf.MODEL_ARCH.NOMIC_BERT
if arch == "GemmaForCausalLM":
return gguf.MODEL_ARCH.GEMMA
raise NotImplementedError(f'Architecture "{arch}" not supported!')
@@ -618,11 +622,6 @@ class MPTModel(Model):
self.gguf_writer.add_tensor(new_name, data)
# note: MPT output is tied to (same as) wte in original model;
# for easier implementation in llama.cpp it's duplicated in GGUF, though :/
if new_name == "token_embd.weight":
self.gguf_writer.add_tensor("output.weight", data)
class OrionModel(Model):
def set_vocab(self):
@@ -655,6 +654,8 @@ class OrionModel(Model):
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_head_count(head_count)
self.gguf_writer.add_head_count_kv(head_count_kv)
# note: config provides rms norm but it is actually layer norm
# ref: https://huggingface.co/OrionStarAI/Orion-14B-Chat/blob/276a17221ce42beb45f66fac657a41540e71f4f5/modeling_orion.py#L570-L571
self.gguf_writer.add_layer_norm_eps(self.hparams["rms_norm_eps"])
def write_tensors(self):
@@ -1031,7 +1032,6 @@ class PersimmonModel(Model):
self.gguf_writer.add_head_count_kv(head_count_kv)
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
def set_vocab(self):
self._set_vocab_sentencepiece()
@@ -1785,6 +1785,62 @@ class NomicBertModel(BertModel):
yield name, data
class GemmaModel(Model):
def set_vocab(self):
self._set_vocab_sentencepiece()
def set_gguf_parameters(self):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]
self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_key_length(hparams["head_dim"])
self.gguf_writer.add_value_length(hparams["head_dim"])
def write_tensors(self):
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
for name, data_torch in self.get_tensors():
# ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
if name.endswith("norm.weight"):
data_torch = data_torch + 1
old_dtype = data_torch.dtype
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)
data = data_torch.squeeze().numpy()
# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()
n_dims = len(data.shape)
data_dtype = data.dtype
data = data.astype(np.float32)
# if f16 desired, convert any float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)
###### CONVERSION LOGIC ######
+4 -5
View File
@@ -1,3 +1,7 @@
#include "ggml-cuda.h"
#include "ggml.h"
#include "ggml-backend-impl.h"
#include <algorithm>
#include <assert.h>
#include <atomic>
@@ -121,11 +125,6 @@
#endif // defined(GGML_USE_HIPBLAS)
// ggml-cuda need half type so keep ggml headers include at last
#include "ggml-cuda.h"
#include "ggml.h"
#include "ggml-backend-impl.h"
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
#define CC_PASCAL 600
+20 -7
View File
@@ -53,11 +53,23 @@ extern "C" {
//
#include <arm_neon.h>
#define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x))
#define GGML_COMPUTE_FP32_TO_FP16(x) (x)
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
#define GGML_FP16_TO_FP32(x) ((float) (x))
#define GGML_FP32_TO_FP16(x) (x)
#define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
__fp16 tmp;
memcpy(&tmp, &h, sizeof(ggml_fp16_t));
return (float)tmp;
}
static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
ggml_fp16_t res;
__fp16 tmp = f;
memcpy(&res, &tmp, sizeof(ggml_fp16_t));
return res;
}
#else
@@ -214,8 +226,7 @@ extern float ggml_table_f32_f16[1 << 16];
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
// This is also true for POWER9.
#if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16)
#if !defined(GGML_FP16_TO_FP32)
inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
uint16_t s;
memcpy(&s, &f, sizeof(uint16_t));
@@ -223,8 +234,10 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
}
#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
#endif
#if !defined(GGML_FP32_TO_FP16)
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
#endif
#define GGML_HASHTABLE_FULL ((size_t)-1)
+45 -20
View File
@@ -438,6 +438,30 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
return res;
}
// NOTE: not tested
inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
int8x16_t res;
res[ 0] = a[b[ 0]];
res[ 1] = a[b[ 1]];
res[ 2] = a[b[ 2]];
res[ 3] = a[b[ 3]];
res[ 4] = a[b[ 4]];
res[ 5] = a[b[ 5]];
res[ 6] = a[b[ 6]];
res[ 7] = a[b[ 7]];
res[ 8] = a[b[ 8]];
res[ 9] = a[b[ 9]];
res[10] = a[b[10]];
res[11] = a[b[11]];
res[12] = a[b[12]];
res[13] = a[b[13]];
res[14] = a[b[14]];
res[15] = a[b[15]];
return res;
}
#else
#define ggml_int16x8x2_t int16x8x2_t
@@ -451,6 +475,7 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
#define ggml_vld1q_u8_x4 vld1q_u8_x4
#define ggml_vld1q_s8_x2 vld1q_s8_x2
#define ggml_vld1q_s8_x4 vld1q_s8_x4
#define ggml_vqtbl1q_s8 vqtbl1q_s8
#endif
@@ -5629,8 +5654,8 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * (float)x[i].d;
const float dmin = -y[i].d * (float)x[i].dmin;
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
const uint8_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
@@ -5779,8 +5804,8 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * (float)x[i].d;
const float dmin = -y[i].d * (float)x[i].dmin;
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
const uint8_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
@@ -6433,7 +6458,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
const float d = y[i].d * (float)x[i].d;
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1));
q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2));
@@ -6635,7 +6660,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
const float d = y[i].d * (float)x[i].d;
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
@@ -7138,9 +7163,9 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
aux16[1] = (a[0] >> 4) & 0x0f0f;
const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]);
sum_mins += y[i].d * (float)x[i].d[1] * summi;
sum_mins += y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * summi;
const float d = y[i].d * (float)x[i].d[0];
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]);
const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4);
@@ -7798,7 +7823,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * (float)x[i].d;
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
const int8_t * sc = x[i].scales;
const uint8_t * restrict q5 = x[i].qs;
@@ -7940,7 +7965,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * (float)x[i].d;
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
const int8_t * sc = x[i].scales;
const uint8_t * restrict q5 = x[i].qs;
@@ -8508,7 +8533,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
for (int i = 0; i < nb; ++i) {
const float d_all = (float)x[i].d;
const float d_all = GGML_FP16_TO_FP32(x[i].d);
const uint8_t * restrict q6 = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
@@ -8679,7 +8704,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
for (int i = 0; i < nb; ++i) {
const float d_all = (float)x[i].d;
const float d_all = GGML_FP16_TO_FP32(x[i].d);
const uint8_t * restrict q6 = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
@@ -9333,7 +9358,7 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
uint16_t gindex[8];
uint16x8x2_t vindex;
int8x16x4_t q1b;
int8x16x4_t q8b;
ggml_int8x16x4_t q8b;
uint16x8x4_t scales;
int32x4x2_t sumi;
int32x4x2_t dotq;
@@ -9498,7 +9523,6 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
float sumf = 0;
for (int ib = 0; ib < nb; ib += 2) {
q4bits.val[0] = vld1q_u8(x[ib+0].qs);
q4bits.val[1] = vld1q_u8(x[ib+1].qs);
q8b.val[0] = vld1q_s8(y[ib+0].qs);
@@ -9506,16 +9530,17 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
q8b.val[2] = vld1q_s8(y[ib+1].qs);
q8b.val[3] = vld1q_s8(y[ib+1].qs + 16);
q4b.val[0] = vqtbl1q_s8(values, vandq_u8(q4bits.val[0], m4b));
q4b.val[1] = vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
q4b.val[2] = vqtbl1q_s8(values, vandq_u8(q4bits.val[1], m4b));
q4b.val[3] = vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
sumf += (float)x[ib+0].d * (float)y[ib+0].d * vaddvq_s32(prod_1) + (float)x[ib+1].d * (float)y[ib+1].d * vaddvq_s32(prod_2);
sumf +=
GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib+0].d) * vaddvq_s32(prod_1) +
GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib+1].d) * vaddvq_s32(prod_2);
}
*s = sumf;
+3 -3
View File
@@ -323,7 +323,7 @@ float ggml_table_f32_f16[1 << 16];
// note: do not use these inside ggml.c
// these are meant to be used via the ggml.h API
float ggml_fp16_to_fp32(ggml_fp16_t x) {
return (float) GGML_FP16_TO_FP32(x);
return GGML_FP16_TO_FP32(x);
}
ggml_fp16_t ggml_fp32_to_fp16(float x) {
@@ -798,7 +798,7 @@ inline static float vaddvq_f32(float32x4_t v) {
#define GGML_F16x8 float16x8_t
#define GGML_F16x8_ZERO vdupq_n_f16(0.0f)
#define GGML_F16x8_SET1(x) vdupq_n_f16(x)
#define GGML_F16x8_LOAD vld1q_f16
#define GGML_F16x8_LOAD(x) vld1q_f16((const __fp16 *)(x))
#define GGML_F16x8_STORE vst1q_f16
#define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
#define GGML_F16x8_ADD vaddq_f16
@@ -841,7 +841,7 @@ inline static float vaddvq_f32(float32x4_t v) {
#define GGML_F32Cx4 float32x4_t
#define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f)
#define GGML_F32Cx4_SET1(x) vdupq_n_f32(x)
#define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16(x))
#define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const __fp16 *)(x)))
#define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y))
#define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
#define GGML_F32Cx4_ADD vaddq_f32
-6
View File
@@ -315,13 +315,7 @@
extern "C" {
#endif
#if defined(__ARM_NEON) && defined(__CUDACC__)
typedef half ggml_fp16_t;
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
typedef __fp16 ggml_fp16_t;
#else
typedef uint16_t ggml_fp16_t;
#endif
// convert FP16 <-> FP32
GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x);
+34 -4
View File
@@ -509,7 +509,6 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
@@ -4056,7 +4055,10 @@ static bool llm_load_tensors(
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, false);
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
// same as tok_embd, duplicated to allow offloading
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
ml.n_created--; // artificial tensor
ml.size_data += ggml_nbytes(model.output);
}
for (int i = 0; i < n_layer; ++i) {
@@ -6192,7 +6194,7 @@ struct llm_build_context {
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
cb(cur, "wqkv", il);
if (model.layers[il].bqkv){
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
cb(cur, "bqkv", il);
@@ -7450,6 +7452,7 @@ struct llm_build_context {
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
cb(inpL, "inp_embd", -1);
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
cb(inpL, "inp_scaled", -1);
@@ -7491,6 +7494,7 @@ struct llm_build_context {
n_embd_head_k, 2, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
cb(Qcur, "Qcur_scaled", il);
@@ -7505,6 +7509,7 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
cb(cur, "kqv_out", il);
}
struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
cb(sa_out, "sa_out", il);
@@ -10495,7 +10500,10 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
return std::make_pair(i_layer, n_layer);
};
if (name == tn(LLM_TENSOR_OUTPUT, "weight")) {
// for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings
// with the quantization of the output tensor
if (name == tn(LLM_TENSOR_OUTPUT, "weight") ||
(LLM_TENSOR_NAMES.at(arch).find(LLM_TENSOR_OUTPUT) == LLM_TENSOR_NAMES.at(arch).end() && name == "token_embd.weight")) {
int nx = tensor->ne[0];
if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) {
new_type = GGML_TYPE_Q8_0;
@@ -12782,6 +12790,28 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<s>assistant\n";
}
} else if (tmpl.find("<start_of_turn>") != std::string::npos) {
// google/gemma-7b-it
std::string system_prompt = "";
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
// there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
system_prompt = trim(message->content);
continue;
}
// in gemma, "assistant" is "model"
role = role == "assistant" ? "model" : message->role;
ss << "<start_of_turn>" << role << "\n";
if (!system_prompt.empty() && role != "model") {
ss << system_prompt << "\n\n";
system_prompt = "";
}
ss << trim(message->content) << "<end_of_turn>\n";
}
if (add_ass) {
ss << "<start_of_turn>model\n";
}
} else {
// template not supported
return -1;
+1 -1
View File
@@ -1 +1 @@
30805514e1bf389a59d30a54a0525cbdc30d5bd1
8cdf783f288a98eddf521b0ab1b4d405be9e18ba
+4
View File
@@ -29,6 +29,8 @@ int main(void) {
"{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\\\n' + system_message + '\\\\n<</SYS>>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\\\n' + content.strip() + '\\\\n<</SYS>>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
// mlabonne/AlphaMonarch-7B
"{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
// google/gemma-7b-it
"{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}",
};
std::vector<std::string> expected_output = {
// teknium/OpenHermes-2.5-Mistral-7B
@@ -41,6 +43,8 @@ int main(void) {
"[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
// mlabonne/AlphaMonarch-7B
"system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
// google/gemma-7b-it
"<start_of_turn>user\nYou are a helpful assistant\n\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
};
std::vector<char> formatted_chat(1024);
int32_t res;