Compare commits

...

5 Commits

Author SHA1 Message Date
DAN™ bcebd7dbf6 llama : add support for GritLM (#5959)
* add gritlm example

* gritlm results match

* tabs to spaces

* comment out debug printing

* rebase to new embed

* gritlm embeddings are back babeee

* add to gitignore

* allow to toggle embedding mode

* Clean-up GritLM sample code.

* Fix types.

* Flush stdout and output ending newline if streaming.

* mostly style fixes; correct KQ_mask comment

* add causal_attn flag to llama_cparams

* gritml : minor

* llama : minor

---------

Co-authored-by: Douglas Hanley <thesecretaryofwar@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-03-10 17:56:30 +02:00
Clint Herron 2960eae847 grammar : verify parsed state (#5950) 2024-03-10 17:17:43 +02:00
Georgi Gerganov c78541479c nix: update flake.lock (#5969)
Flake lock file updates:

• Updated input 'nixpkgs':
    'github:NixOS/nixpkgs/1536926ef5621b09bba54035ae2bb6d806d72ac8' (2024-02-29)
  → 'github:NixOS/nixpkgs/9df3e30ce24fd28c7b3e2de0d986769db5d6225d' (2024-03-06)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-03-10 16:43:08 +02:00
Pierrick Hymbert 621e86b331 server: benchmark: chat/completions scenario and other llm servers comparison (#5941)
* server: bench: Init a bench scenario with K6
See #5827

* server: bench: EOL EOF

* server: bench: PR feedback and improved k6 script configuration

* server: bench: remove llamacpp_completions_tokens_seconds as it include prompt processing time and it's misleading

server: bench: add max_tokens from SERVER_BENCH_MAX_TOKENS

server: bench: increase truncated rate to 80% before failing

* server: bench: fix doc

* server: bench: change gauge custom metrics to trend

* server: bench: change gauge custom metrics to trend
server: bench: add trend custom metrics for total tokens per second average

* server: bench: doc add an option to debug http request

* server: bench: filter dataset too short and too long sequences

* server: bench: allow to filter out conversation in the dataset based on env variable

* server: bench: fix assistant message sent instead of user message

* server: bench: fix assistant message sent instead of user message

* server : add defrag thold parameter

* server: bench: select prompts based on the current iteration id not randomly to make the bench more reproducible

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-03-09 23:41:49 +01:00
Georgi Gerganov 77d1ac7e00 server : print chat template info 2024-03-09 22:04:00 +02:00
12 changed files with 522 additions and 9 deletions
+1
View File
@@ -45,6 +45,7 @@ models-mnt
/embedding
/gguf
/gguf-llama-simple
/gritlm
/imatrix
/infill
/libllama.so
+5 -1
View File
@@ -2,7 +2,7 @@
BUILD_TARGETS = \
main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey tests/test-c.o
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey gritlm tests/test-c.o
# Binaries only useful for tests
TEST_TARGETS = \
@@ -724,6 +724,10 @@ embedding: examples/embedding/embedding.cpp ggml.o llama.o $(C
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
gritlm: examples/gritlm/gritlm.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
+16
View File
@@ -278,6 +278,22 @@ namespace grammar_parser {
while (*pos) {
pos = parse_rule(state, pos);
}
// Validate the state to ensure that all rules are defined
for (const auto & rule : state.rules) {
for (const auto & elem : rule) {
if (elem.type == LLAMA_GRETYPE_RULE_REF) {
// Ensure that the rule at that location exists
if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) {
// Get the name of the rule that is missing
for (const auto & kv : state.symbol_ids) {
if (kv.second == elem.value) {
throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
}
}
}
}
}
}
return state;
} catch (const std::exception & err) {
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
+1
View File
@@ -20,6 +20,7 @@ else()
add_subdirectory(convert-llama2c-to-ggml)
add_subdirectory(embedding)
add_subdirectory(finetune)
add_subdirectory(gritlm)
add_subdirectory(infill)
add_subdirectory(llama-bench)
add_subdirectory(llava)
+5
View File
@@ -0,0 +1,5 @@
set(TARGET gritlm)
add_executable(${TARGET} gritlm.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
+229
View File
@@ -0,0 +1,229 @@
#include "common.h"
#include "llama.h"
#include <string>
#include <vector>
// #define GRIT_DEBUG
static float dot_product(const std::vector<float> & v1, const std::vector<float> & v2) {
float dot = 0.0f;
for (uint64_t i = 0; i < v1.size(); ++i) {
dot += v1[i] * v2[i];
}
return dot;
}
static float norm(const std::vector<float> & v) {
return std::sqrt(dot_product(v, v));
}
static float cosine_similarity(const std::vector<float> & v1, const std::vector<float> & v2) {
return dot_product(v1, v2) / (norm(v1) * norm(v2));
}
static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
std::vector<std::vector<float>> result;
const llama_model * mdl = llama_get_model(ctx);
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
for (uint64_t i = 0; i < sentences.size(); i++) {
llama_batch_clear(batch);
const std::string input_string = instruction + sentences[i];
std::vector<llama_token> inputs = llama_tokenize(mdl, input_string, true, false);
const int32_t n_toks = inputs.size();
// GritLM seems to have EOS = ""
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
// inputs.push_back(llama_token_eos(mdl));
// we want to ignore instruction tokens for mean pooling
const int32_t n_inst = llama_tokenize(mdl, instruction, true, false).size();
#ifdef GRIT_DEBUG
// debug tokens - should be matching as referenced in the GritLM sample
std::for_each(inputs.begin(), inputs.end(), [&ctx](llama_token t) {
std::printf("[%u:%s]", t, llama_token_to_piece(ctx, t).c_str());
});
std::printf("\n");
#endif
// add input to batch (this increments n_tokens)
for (int32_t j = 0; j < n_toks; j++) {
llama_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
}
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
llama_set_causal_attn(ctx, false);
// run model
llama_decode(ctx, batch);
// get embedding dimensions
uint64_t n_embd = llama_n_embd(mdl);
// allocate embedding output
std::vector<float> emb_unorm(n_embd, 0.0f);
// sum up all token embeddings
for (int32_t k = n_inst; k < n_toks; k++) {
float * emb = llama_get_embeddings_ith(ctx, k);
for (uint64_t j = 0; j < n_embd; j++) {
emb_unorm[j] += emb[j];
}
}
// divide by number of tokens (mean pooling)
{
const uint64_t n_sent = n_toks - n_inst;
for (uint64_t j = 0; j < n_embd; j++) {
emb_unorm[j] /= n_sent;
}
}
std::vector<float> emb_norm(emb_unorm.size());
llama_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
result.push_back(emb_norm);
#ifdef GRIT_DEBUG
// print out emb_norm
std::printf("embedding %ld: ", i);
for (uint64_t j = 0; j < n_embd; j++) {
std::printf("%.5f ", emb_norm[j]);
}
std::printf("\n\n");
#endif
}
llama_batch_free(batch);
return result;
}
static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) {
std::string result;
const llama_model * mdl = llama_get_model(ctx);
llama_token eos_token = llama_token_eos(mdl);
llama_kv_cache_clear(ctx);
llama_set_causal_attn(ctx, true);
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
int32_t i_current_token = 0;
while (true) {
llama_batch_clear(bat);
auto n_inputs = (int32_t)inputs.size();
for (int32_t i = 0; i < n_inputs; i++) {
llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
}
inputs.clear();
llama_decode(ctx, bat);
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
auto n_candidates = (int32_t)candidates.size();
for (int32_t token = 0; token < n_candidates; token++) {
candidates[token] = llama_token_data{ token, logits[token], 0.0f };
}
auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false };
llama_token token = llama_sample_token_greedy(ctx, &candidates_p);
if (token == eos_token) {
break;
}
std::string piece = llama_token_to_piece(ctx, token);
if (stream) {
std::printf("%s", piece.c_str());
std::fflush(stdout);
}
inputs.push_back(token);
result += piece;
}
if (stream) {
std::printf("\n");
}
llama_batch_free(bat);
return result;
}
static std::string gritlm_instruction(const std::string & instruction) {
return !instruction.empty() ? "<|user|>\n" + instruction + "\n<|embed|>\n" : "<|embed|>\n";
}
int main(int argc, char * argv[]) {
gpt_params params;
if (!gpt_params_parse(argc, argv, params)) {
return 1;
}
llama_model_params mparams = llama_model_params_from_gpt_params(params);
llama_context_params cparams = llama_context_params_from_gpt_params(params);
llama_backend_init();
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
// create new context - set to embedding mode
cparams.embeddings = true;
llama_context * ctx = llama_new_context_with_model(mdl, cparams);
// ### Embedding/Representation ###
// samples taken from: https://github.com/ContextualAI/gritlm#basic
{
const std::string instruction = "Given a scientific paper title, retrieve the paper's abstract";
const std::vector<std::string> queries = {
"Bitcoin: A Peer-to-Peer Electronic Cash System",
"Generative Representational Instruction Tuning",
};
const std::vector<std::string> documents = {
"A purely peer-to-peer version of electronic cash would allow online payments to be sent directly from one party to another without going through a financial institution. Digital signatures provide part of the solution, but the main benefits are lost if a trusted third party is still required to prevent double-spending. We propose a solution to the double-spending problem using a peer-to-peer network. The network timestamps transactions by hashing them into an ongoing chain of hash-based proof-of-work, forming a record that cannot be changed without redoing the proof-of-work. The longest chain not only serves as proof of the sequence of events witnessed, but proof that it came from the largest pool of CPU power. As long as a majority of CPU power is controlled by nodes that are not cooperating to attack the network, they'll generate the longest chain and outpace attackers. The network itself requires minimal structure. Messages are broadcast on a best effort basis, and nodes can leave and rejoin the network at will, accepting the longest proof-of-work chain as proof of what happened while they were gone.",
"All text-based language problems can be reduced to either generation or embedding. Current models only perform well at one or the other. We introduce generative representational instruction tuning (GRIT) whereby a large language model is trained to handle both generative and embedding tasks by distinguishing between them through instructions. Compared to other open models, our resulting GritLM 7B sets a new state of the art on the Massive Text Embedding Benchmark (MTEB) and outperforms all models up to its size on a range of generative tasks. By scaling up further, GritLM 8X7B outperforms all open generative language models that we tried while still being among the best embedding models. Notably, we find that GRIT matches training on only generative or embedding data, thus we can unify both at no performance loss. Among other benefits, the unification via GRIT speeds up Retrieval-Augmented Generation (RAG) by > 60% for long documents, by no longer requiring separate retrieval and generation models. Models, code, etc. are freely available at https://github.com/ContextualAI/gritlm.",
};
// No need to add instruction for retrieval documents
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
const float cosine_sim_q0_d0 = cosine_similarity(q_rep[0], d_rep[0]);
const float cosine_sim_q0_d1 = cosine_similarity(q_rep[0], d_rep[1]);
const float cosine_sim_q1_d0 = cosine_similarity(q_rep[1], d_rep[0]);
const float cosine_sim_q1_d1 = cosine_similarity(q_rep[1], d_rep[1]);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[0].c_str(), cosine_sim_q0_d0);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[1].c_str(), cosine_sim_q0_d1);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[0].c_str(), cosine_sim_q1_d0);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1);
}
// ### Generation ###
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
{
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
std::string response = generate(ctx, prompt, true);
}
llama_free(ctx);
llama_free_model(mdl);
llama_backend_free();
return 0;
}
+88
View File
@@ -0,0 +1,88 @@
### Server benchmark tools
Benchmark is using [k6](https://k6.io/).
##### Install k6
Follow instruction from: https://k6.io/docs/get-started/installation/
Example for ubuntu:
```shell
snap install k6
```
#### Download a dataset
This dataset was originally proposed in [vLLM benchmarks](https://github.com/vllm-project/vllm/blob/main/benchmarks/README.md).
```shell
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
```
#### Download a model
Example for PHI-2
```shell
../../../scripts/hf.sh --repo ggml-org/models --file phi-2/ggml-model-q4_0.gguf
```
#### Start the server
The server must answer OAI Chat completion requests on `http://localhost:8080/v1` or according to the environment variable `SERVER_BENCH_URL`.
Example:
```shell
server --host localhost --port 8080 \
--model ggml-model-q4_0.gguf \
--cont-batching \
--metrics \
--parallel 8 \
--batch-size 512 \
--ctx-size 4096 \
--log-format text \
-ngl 33
```
#### Run the benchmark
For 500 chat completions request with 8 concurrent users during maximum 10 minutes, run:
```shell
k6 run script.js --duration 10m --iterations 500 --vus 8
```
The benchmark values can be overridden with:
- `SERVER_BENCH_URL` server url prefix for chat completions, default `http://localhost:8080/v1`
- `SERVER_BENCH_N_PROMPTS` total prompts to randomly select in the benchmark, default `480`
- `SERVER_BENCH_MODEL_ALIAS` model alias to pass in the completion request, default `my-model`
- `SERVER_BENCH_MAX_TOKENS` max tokens to predict, default: `512`
- `SERVER_BENCH_DATASET` path to the benchmark dataset file
- `SERVER_BENCH_MAX_PROMPT_TOKENS` maximum prompt tokens to filter out in the dataset: default `1024`
- `SERVER_BENCH_MAX_CONTEXT` maximum context size of the completions request to filter out in the dataset: prompt + predicted tokens, default `2048`
Note: the local tokenizer is just a string space split, real number of tokens will differ.
Or with [k6 options](https://k6.io/docs/using-k6/k6-options/reference/):
```shell
SERVER_BENCH_N_PROMPTS=500 k6 run script.js --duration 10m --iterations 500 --vus 8
```
To [debug http request](https://k6.io/docs/using-k6/http-debugging/) use `--http-debug="full"`.
#### Metrics
Following metrics are available computed from the OAI chat completions response `usage`:
- `llamacpp_tokens_second` Trend of `usage.total_tokens / request duration`
- `llamacpp_prompt_tokens` Trend of `usage.prompt_tokens`
- `llamacpp_prompt_tokens_total_counter` Counter of `usage.prompt_tokens`
- `llamacpp_completion_tokens` Trend of `usage.completion_tokens`
- `llamacpp_completion_tokens_total_counter` Counter of `usage.completion_tokens`
- `llamacpp_completions_truncated_rate` Rate of completions truncated, i.e. if `finish_reason === 'length'`
- `llamacpp_completions_stop_rate` Rate of completions stopped by the model, i.e. if `finish_reason === 'stop'`
The script will fail if too many completions are truncated, see `llamacpp_completions_truncated_rate`.
K6 metrics might be compared against [server metrics](../README.md), with:
```shell
curl http://localhost:8080/metrics
```
+120
View File
@@ -0,0 +1,120 @@
import http from 'k6/http'
import {check, sleep} from 'k6'
import {SharedArray} from 'k6/data'
import {Counter, Rate, Trend} from 'k6/metrics'
import exec from 'k6/execution';
// Server chat completions prefix
const server_url = __ENV.SERVER_BENCH_URL ? __ENV.SERVER_BENCH_URL : 'http://localhost:8080/v1'
// Number of total prompts in the dataset - default 10m / 10 seconds/request * number of users
const n_prompt = __ENV.SERVER_BENCH_N_PROMPTS ? parseInt(__ENV.SERVER_BENCH_N_PROMPTS) : 600 / 10 * 8
// Model name to request
const model = __ENV.SERVER_BENCH_MODEL_ALIAS ? __ENV.SERVER_BENCH_MODEL_ALIAS : 'my-model'
// Dataset path
const dataset_path = __ENV.SERVER_BENCH_DATASET ? __ENV.SERVER_BENCH_DATASET : './ShareGPT_V3_unfiltered_cleaned_split.json'
// Max tokens to predict
const max_tokens = __ENV.SERVER_BENCH_MAX_TOKENS ? parseInt(__ENV.SERVER_BENCH_MAX_TOKENS) : 512
// Max prompt tokens
const n_prompt_tokens = __ENV.SERVER_BENCH_MAX_PROMPT_TOKENS ? parseInt(__ENV.SERVER_BENCH_MAX_PROMPT_TOKENS) : 1024
// Max slot context
const n_ctx_slot = __ENV.SERVER_BENCH_MAX_CONTEXT ? parseInt(__ENV.SERVER_BENCH_MAX_CONTEXT) : 2048
export function setup() {
console.info(`Benchmark config: server_url=${server_url} n_prompt=${n_prompt} model=${model} dataset_path=${dataset_path} max_tokens=${max_tokens}`)
}
const data = new SharedArray('conversations', function () {
const tokenizer = (message) => message.split(/[\s,'".?]/)
return JSON.parse(open(dataset_path))
// Filter out the conversations with less than 2 turns.
.filter(data => data["conversations"].length >= 2)
.filter(data => data["conversations"][0]["from"] === "human")
.map(data => {
return {
prompt: data["conversations"][0]["value"],
n_prompt_tokens: tokenizer(data["conversations"][0]["value"]).length,
n_completion_tokens: tokenizer(data["conversations"][1]["value"]).length,
}
})
// Filter out too short sequences
.filter(conv => conv.n_prompt_tokens >= 4 && conv.n_completion_tokens >= 4)
// Filter out too long sequences.
.filter(conv => conv.n_prompt_tokens <= n_prompt_tokens && conv.n_prompt_tokens + conv.n_completion_tokens <= n_ctx_slot)
// Keep only first n prompts
.slice(0, n_prompt)
})
const llamacpp_prompt_tokens = new Trend('llamacpp_prompt_tokens')
const llamacpp_completion_tokens = new Trend('llamacpp_completion_tokens')
const llamacpp_tokens_second = new Trend('llamacpp_tokens_second')
const llamacpp_prompt_tokens_total_counter = new Counter('llamacpp_prompt_tokens_total_counter')
const llamacpp_completion_tokens_total_counter = new Counter('llamacpp_completion_tokens_total_counter')
const llamacpp_completions_truncated_rate = new Rate('llamacpp_completions_truncated_rate')
const llamacpp_completions_stop_rate = new Rate('llamacpp_completions_stop_rate')
export const options = {
thresholds: {
llamacpp_completions_truncated_rate: [
// more than 80% of truncated input will abort the test
{threshold: 'rate < 0.8', abortOnFail: true, delayAbortEval: '1m'},
],
},
duration: '10m',
vus: 8,
}
export default function () {
const conversation = data[exec.scenario.iterationInInstance % data.length]
const payload = {
"messages": [
{
"role": "system",
"content": "You are ChatGPT, an AI assistant.",
},
{
"role": "user",
"content": conversation.prompt,
}
],
"model": model,
"stream": false,
"max_tokens": max_tokens
}
const body = JSON.stringify(payload)
let res = http.post(`${server_url}/chat/completions`, body, {
headers: {'Content-Type': 'application/json'},
timeout: '300s'
})
check(res, {'success completion': (r) => r.status === 200})
if (res.status === 200) {
const completions = res.json()
llamacpp_prompt_tokens.add(completions.usage.prompt_tokens)
llamacpp_prompt_tokens_total_counter.add(completions.usage.prompt_tokens)
llamacpp_completion_tokens.add(completions.usage.completion_tokens)
llamacpp_completion_tokens_total_counter.add(completions.usage.completion_tokens)
llamacpp_completions_truncated_rate.add(completions.choices[0].finish_reason === 'length')
llamacpp_completions_stop_rate.add(completions.choices[0].finish_reason === 'stop')
llamacpp_tokens_second.add(completions.usage.total_tokens / res.timings.duration * 1.e3)
} else {
console.error(`response: ${res.body} request=${payload}`)
}
sleep(0.3)
}
+28 -2
View File
@@ -2133,6 +2133,8 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
printf(" --pooling {none,mean,cls} pooling type for embeddings, use model default if unspecified\n");
printf(" -dt N, --defrag-thold N\n");
printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
@@ -2197,7 +2199,8 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`\n");
printf(" --chat-template JINJA_TEMPLATE\n");
printf(" set custom jinja chat template (default: template taken from model's metadata)\n");
printf(" Note: only commonly used templates are accepted, since we don't have jinja parser\n");
printf(" only commonly used templates are accepted:\n");
printf(" https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template\n");
printf("\n");
}
@@ -2354,6 +2357,12 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else { invalid_param = true; break; }
} else if (arg == "--defrag-thold" || arg == "-dt") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.defrag_thold = std::stof(argv[i]);
} else if (arg == "--threads" || arg == "-t") {
if (++i >= argc)
{
@@ -2798,13 +2807,30 @@ int main(int argc, char ** argv) {
const auto model_meta = ctx_server.model_meta();
if (sparams.chat_template.empty()) { // custom chat template is not supplied
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
if (sparams.chat_template.empty()) {
if (!ctx_server.validate_model_chat_template()) {
LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
sparams.chat_template = "chatml";
}
}
// print sample chat example to make it clear which template is used
{
json chat;
chat.push_back({{"role", "system"}, {"content", "You are a helpful assistant"}});
chat.push_back({{"role", "user"}, {"content", "Hello"}});
chat.push_back({{"role", "assistant"}, {"content", "Hi there"}});
chat.push_back({{"role", "user"}, {"content", "How are you?"}});
const std::string chat_example = format_chat(ctx_server.model, sparams.chat_template, chat);
LOG_INFO("chat template", {
{"chat_example", chat_example},
{"built_in", sparams.chat_template.empty()},
});
}
//
// Middlewares
//
Generated
+3 -3
View File
@@ -20,11 +20,11 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1709237383,
"narHash": "sha256-cy6ArO4k5qTx+l5o+0mL9f5fa86tYUX3ozE1S+Txlds=",
"lastModified": 1709703039,
"narHash": "sha256-6hqgQ8OK6gsMu1VtcGKBxKQInRLHtzulDo9Z5jxHEFY=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "1536926ef5621b09bba54035ae2bb6d806d72ac8",
"rev": "9df3e30ce24fd28c7b3e2de0d986769db5d6225d",
"type": "github"
},
"original": {
+22 -3
View File
@@ -1744,6 +1744,7 @@ struct llama_cparams {
float defrag_thold;
bool embeddings;
bool causal_attn;
bool offload_kqv;
enum llama_pooling_type pooling_type;
@@ -3939,6 +3940,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
LLAMA_LOG_INFO("%s: causal attm = %d\n", __func__, hparams.causal_attn);
LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type);
LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type);
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type);
@@ -8532,7 +8534,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
}
if (hparams.causal_attn) {
GGML_ASSERT(
(hparams.causal_attn || !cparams.causal_attn) &&
"non-causal attention with generative models is not supported"
);
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
if (cparams.causal_attn) {
const int64_t n_kv = kv_self.n;
const int64_t n_tokens = batch.n_tokens;
@@ -8560,8 +8568,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
} else {
// non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used)
// when using kv cache, the mask needs to match the kv cache size
const int64_t n_tokens = batch.n_tokens;
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
@@ -8580,7 +8589,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = f;
data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
}
for (int i = n_tokens; i < n_stride; ++i) {
data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
}
}
}
@@ -12733,6 +12746,8 @@ struct llama_context * llama_new_context_with_model(
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
}
cparams.causal_attn = hparams.causal_attn;
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
@@ -13767,6 +13782,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
ctx->abort_callback_data = abort_callback_data;
}
void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
ctx->cparams.causal_attn = causal_attn;
}
struct llama_batch llama_batch_get_one(
llama_token * tokens,
int32_t n_tokens,
+4
View File
@@ -643,6 +643,10 @@ extern "C" {
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
// Set whether to use causal attention or not
// If set to true, the model will only attend to the past tokens
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
// Set abort callback
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);