mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-07-01 10:07:44 +02:00
Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4a2751258a | |||
| cc5cafecf4 | |||
| aef22e7afc | |||
| 9ceb268ee1 | |||
| a4854f0349 | |||
| f2d988db55 | |||
| 91fd50be1b | |||
| 53eb9435da | |||
| d3435efc8a | |||
| 439c3b5021 | |||
| 59dda88aae | |||
| f5f8812f7c | |||
| 8ece3836b4 | |||
| 046d5fd44e | |||
| 480160d472 | |||
| d7c27d4964 | |||
| a9d7bcb7fc |
+116
-51
@@ -6,6 +6,7 @@
|
||||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
#include "download.h"
|
||||
#include "preset.h"
|
||||
|
||||
// fix problem with std::min and std::max
|
||||
#if defined(_WIN32)
|
||||
@@ -268,6 +269,46 @@ static void parse_tensor_buffer_overrides(const std::string & value, std::vector
|
||||
}
|
||||
}
|
||||
|
||||
static std::string clean_file_name(const std::string & fname) {
|
||||
std::string clean_fname = fname;
|
||||
string_replace_all(clean_fname, "\\", "_");
|
||||
string_replace_all(clean_fname, "/", "_");
|
||||
return clean_fname;
|
||||
}
|
||||
|
||||
static bool common_params_handle_remote_preset(common_params & params, llama_example ex) {
|
||||
GGML_ASSERT(!params.model.hf_repo.empty());
|
||||
|
||||
const bool offline = params.offline;
|
||||
std::string model_endpoint = get_model_endpoint();
|
||||
auto preset_url = model_endpoint + params.model.hf_repo + "/resolve/main/preset.ini";
|
||||
|
||||
// prepare local path for caching
|
||||
auto preset_fname = clean_file_name(params.model.hf_repo + "_preset.ini");
|
||||
auto preset_path = fs_get_cache_file(preset_fname);
|
||||
const int status = common_download_file_single(preset_url, preset_path, params.hf_token, offline);
|
||||
const bool has_preset = status >= 200 && status < 400;
|
||||
|
||||
// remote preset is optional, so we don't error out if not found
|
||||
if (has_preset) {
|
||||
LOG_INF("applying remote preset from %s\n", preset_url.c_str());
|
||||
common_preset_context ctx(ex, /* only_remote_allowed */ true);
|
||||
common_preset global; // unused for now
|
||||
auto remote_presets = ctx.load_from_ini(preset_path, global);
|
||||
if (remote_presets.find(COMMON_PRESET_DEFAULT_NAME) != remote_presets.end()) {
|
||||
common_preset & preset = remote_presets.at(COMMON_PRESET_DEFAULT_NAME);
|
||||
LOG_INF("\n%s", preset.to_ini().c_str()); // to_ini already added trailing newline
|
||||
preset.apply_to_params(params);
|
||||
} else {
|
||||
throw std::runtime_error("Remote preset.ini does not contain [" + std::string(COMMON_PRESET_DEFAULT_NAME) + "] section");
|
||||
}
|
||||
} else {
|
||||
LOG_INF("%s", "no remote preset found, skipping\n");
|
||||
}
|
||||
|
||||
return has_preset;
|
||||
}
|
||||
|
||||
struct handle_model_result {
|
||||
bool found_mmproj = false;
|
||||
common_params_model mmproj;
|
||||
@@ -309,9 +350,7 @@ static handle_model_result common_params_handle_model(
|
||||
// make sure model path is present (for caching purposes)
|
||||
if (model.path.empty()) {
|
||||
// this is to avoid different repo having same file name, or same file name in different subdirs
|
||||
std::string filename = model.hf_repo + "_" + model.hf_file;
|
||||
// to make sure we don't have any slashes in the filename
|
||||
string_replace_all(filename, "/", "_");
|
||||
std::string filename = clean_file_name(model.hf_repo + "_" + model.hf_file);
|
||||
model.path = fs_get_cache_file(filename);
|
||||
}
|
||||
|
||||
@@ -425,61 +464,87 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
}
|
||||
};
|
||||
|
||||
std::set<std::string> seen_args;
|
||||
auto parse_cli_args = [&]() {
|
||||
std::set<std::string> seen_args;
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
const std::string arg_prefix = "--";
|
||||
for (int i = 1; i < argc; i++) {
|
||||
const std::string arg_prefix = "--";
|
||||
|
||||
std::string arg = argv[i];
|
||||
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
|
||||
std::replace(arg.begin(), arg.end(), '_', '-');
|
||||
}
|
||||
if (arg_to_options.find(arg) == arg_to_options.end()) {
|
||||
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
|
||||
}
|
||||
if (!seen_args.insert(arg).second) {
|
||||
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
|
||||
}
|
||||
auto & tmp = arg_to_options[arg];
|
||||
auto opt = *tmp.first;
|
||||
bool is_positive = tmp.second;
|
||||
if (opt.has_value_from_env()) {
|
||||
fprintf(stderr, "warn: %s environment variable is set, but will be overwritten by command line argument %s\n", opt.env, arg.c_str());
|
||||
}
|
||||
try {
|
||||
if (opt.handler_void) {
|
||||
opt.handler_void(params);
|
||||
continue;
|
||||
std::string arg = argv[i];
|
||||
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
|
||||
std::replace(arg.begin(), arg.end(), '_', '-');
|
||||
}
|
||||
if (opt.handler_bool) {
|
||||
opt.handler_bool(params, is_positive);
|
||||
continue;
|
||||
if (arg_to_options.find(arg) == arg_to_options.end()) {
|
||||
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
|
||||
}
|
||||
if (!seen_args.insert(arg).second) {
|
||||
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
|
||||
}
|
||||
auto & tmp = arg_to_options[arg];
|
||||
auto opt = *tmp.first;
|
||||
bool is_positive = tmp.second;
|
||||
if (opt.has_value_from_env()) {
|
||||
fprintf(stderr, "warn: %s environment variable is set, but will be overwritten by command line argument %s\n", opt.env, arg.c_str());
|
||||
}
|
||||
try {
|
||||
if (opt.handler_void) {
|
||||
opt.handler_void(params);
|
||||
continue;
|
||||
}
|
||||
if (opt.handler_bool) {
|
||||
opt.handler_bool(params, is_positive);
|
||||
continue;
|
||||
}
|
||||
|
||||
// arg with single value
|
||||
check_arg(i);
|
||||
std::string val = argv[++i];
|
||||
if (opt.handler_int) {
|
||||
opt.handler_int(params, std::stoi(val));
|
||||
continue;
|
||||
}
|
||||
if (opt.handler_string) {
|
||||
opt.handler_string(params, val);
|
||||
continue;
|
||||
}
|
||||
// arg with single value
|
||||
check_arg(i);
|
||||
std::string val = argv[++i];
|
||||
if (opt.handler_int) {
|
||||
opt.handler_int(params, std::stoi(val));
|
||||
continue;
|
||||
}
|
||||
if (opt.handler_string) {
|
||||
opt.handler_string(params, val);
|
||||
continue;
|
||||
}
|
||||
|
||||
// arg with 2 values
|
||||
check_arg(i);
|
||||
std::string val2 = argv[++i];
|
||||
if (opt.handler_str_str) {
|
||||
opt.handler_str_str(params, val, val2);
|
||||
continue;
|
||||
// arg with 2 values
|
||||
check_arg(i);
|
||||
std::string val2 = argv[++i];
|
||||
if (opt.handler_str_str) {
|
||||
opt.handler_str_str(params, val, val2);
|
||||
continue;
|
||||
}
|
||||
} catch (std::exception & e) {
|
||||
throw std::invalid_argument(string_format(
|
||||
"error while handling argument \"%s\": %s\n\n"
|
||||
"usage:\n%s\n\nto show complete usage, run with -h",
|
||||
arg.c_str(), e.what(), opt.to_string().c_str()));
|
||||
}
|
||||
} catch (std::exception & e) {
|
||||
throw std::invalid_argument(string_format(
|
||||
"error while handling argument \"%s\": %s\n\n"
|
||||
"usage:\n%s\n\nto show complete usage, run with -h",
|
||||
arg.c_str(), e.what(), opt.to_string().c_str()));
|
||||
}
|
||||
};
|
||||
|
||||
// parse the first time to get -hf option (used for remote preset)
|
||||
parse_cli_args();
|
||||
|
||||
// maybe handle remote preset
|
||||
if (!params.model.hf_repo.empty()) {
|
||||
std::string cli_hf_repo = params.model.hf_repo;
|
||||
bool has_preset = common_params_handle_remote_preset(params, ctx_arg.ex);
|
||||
|
||||
// special case: if hf_repo explicitly set by preset, we need to preserve it (ignore CLI value)
|
||||
// this is useful when we have one HF repo pointing to other HF repos (one model - multiple GGUFs)
|
||||
std::string preset_hf_repo = params.model.hf_repo;
|
||||
bool preset_has_hf_repo = preset_hf_repo != cli_hf_repo;
|
||||
|
||||
if (has_preset) {
|
||||
// re-parse CLI args to override preset values
|
||||
parse_cli_args();
|
||||
}
|
||||
|
||||
// preserve hf_repo from preset if needed
|
||||
if (preset_has_hf_repo) {
|
||||
params.model.hf_repo = preset_hf_repo;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+54
-29
@@ -157,6 +157,10 @@ static std::string read_etag(const std::string & path) {
|
||||
return none;
|
||||
}
|
||||
|
||||
static bool is_http_status_ok(int status) {
|
||||
return status >= 200 && status < 400;
|
||||
}
|
||||
|
||||
#ifdef LLAMA_USE_CURL
|
||||
|
||||
//
|
||||
@@ -306,12 +310,14 @@ static bool common_download_head(CURL * curl,
|
||||
}
|
||||
|
||||
// download one single file from remote URL to local path
|
||||
static bool common_download_file_single_online(const std::string & url,
|
||||
// returns status code or -1 on error
|
||||
static int common_download_file_single_online(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
const common_header_list & custom_headers) {
|
||||
static const int max_attempts = 3;
|
||||
static const int retry_delay_seconds = 2;
|
||||
|
||||
for (int i = 0; i < max_attempts; ++i) {
|
||||
std::string etag;
|
||||
|
||||
@@ -371,7 +377,7 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
|
||||
if (remove(path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -380,14 +386,14 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
if (std::filesystem::exists(path_temporary)) {
|
||||
if (remove(path_temporary.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
if (std::filesystem::exists(path)) {
|
||||
if (remove(path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -414,23 +420,27 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
|
||||
long http_code = 0;
|
||||
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
|
||||
if (http_code < 200 || http_code >= 400) {
|
||||
|
||||
int status = static_cast<int>(http_code);
|
||||
if (!is_http_status_ok(http_code)) {
|
||||
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
|
||||
return false;
|
||||
return status; // TODO: maybe only return on certain codes
|
||||
}
|
||||
|
||||
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return static_cast<int>(http_code);
|
||||
} else {
|
||||
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
|
||||
}
|
||||
|
||||
break;
|
||||
return 304; // Not Modified - fake cached response
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
return -1; // max attempts reached
|
||||
}
|
||||
|
||||
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params) {
|
||||
@@ -625,7 +635,8 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
}
|
||||
|
||||
// download one single file from remote URL to local path
|
||||
static bool common_download_file_single_online(const std::string & url,
|
||||
// returns status code or -1 on error
|
||||
static int common_download_file_single_online(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
const common_header_list & custom_headers) {
|
||||
@@ -659,8 +670,10 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1);
|
||||
if (file_exists) {
|
||||
LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str());
|
||||
return true;
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
return head->status; // cannot use cached file, return raw status code
|
||||
// TODO: maybe retry only on certain codes
|
||||
}
|
||||
|
||||
std::string etag;
|
||||
@@ -692,12 +705,12 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
if (file_exists) {
|
||||
if (!should_download_from_scratch) {
|
||||
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
|
||||
return true;
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
|
||||
if (remove(path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -709,7 +722,7 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
existing_size = std::filesystem::file_size(path_temporary);
|
||||
} else if (remove(path_temporary.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -730,15 +743,16 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
|
||||
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
if (!etag.empty()) {
|
||||
write_etag(path, etag);
|
||||
}
|
||||
break;
|
||||
|
||||
return head->status; // TODO: use actual GET status?
|
||||
}
|
||||
|
||||
return true;
|
||||
return -1; // max attempts reached
|
||||
}
|
||||
|
||||
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url,
|
||||
@@ -777,22 +791,22 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
|
||||
|
||||
#if defined(LLAMA_USE_CURL) || defined(LLAMA_USE_HTTPLIB)
|
||||
|
||||
static bool common_download_file_single(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
bool offline,
|
||||
const common_header_list & headers) {
|
||||
int common_download_file_single(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
bool offline,
|
||||
const common_header_list & headers) {
|
||||
if (!offline) {
|
||||
return common_download_file_single_online(url, path, bearer_token, headers);
|
||||
}
|
||||
|
||||
if (!std::filesystem::exists(path)) {
|
||||
LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
|
||||
LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
|
||||
return true;
|
||||
return 304; // Not Modified - fake cached response
|
||||
}
|
||||
|
||||
// download multiple files from remote URLs to local paths
|
||||
@@ -810,7 +824,8 @@ static bool common_download_file_multiple(const std::vector<std::pair<std::strin
|
||||
std::async(
|
||||
std::launch::async,
|
||||
[&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
|
||||
return common_download_file_single(it.first, it.second, bearer_token, offline, headers);
|
||||
const int http_status = common_download_file_single(it.first, it.second, bearer_token, offline, headers);
|
||||
return is_http_status_ok(http_status);
|
||||
},
|
||||
item
|
||||
)
|
||||
@@ -837,7 +852,8 @@ bool common_download_model(const common_params_model & model,
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!common_download_file_single(model.url, model.path, bearer_token, offline, headers)) {
|
||||
const int http_status = common_download_file_single(model.url, model.path, bearer_token, offline, headers);
|
||||
if (!is_http_status_ok(http_status)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -975,7 +991,7 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag,
|
||||
} else if (res_code == 401) {
|
||||
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
|
||||
} else {
|
||||
throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
|
||||
throw std::runtime_error(string_format("error from HF API (%s), response code: %ld, data: %s", url.c_str(), res_code, res_str.c_str()));
|
||||
}
|
||||
|
||||
// check response
|
||||
@@ -1094,7 +1110,8 @@ std::string common_docker_resolve_model(const std::string & docker) {
|
||||
std::string local_path = fs_get_cache_file(model_filename);
|
||||
|
||||
const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
|
||||
if (!common_download_file_single(blob_url, local_path, token, false, {})) {
|
||||
const int http_status = common_download_file_single(blob_url, local_path, token, false, {});
|
||||
if (!is_http_status_ok(http_status)) {
|
||||
throw std::runtime_error("Failed to download Docker Model");
|
||||
}
|
||||
|
||||
@@ -1120,6 +1137,14 @@ std::string common_docker_resolve_model(const std::string &) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
int common_download_file_single(const std::string &,
|
||||
const std::string &,
|
||||
const std::string &,
|
||||
bool,
|
||||
const common_header_list &) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
#endif // LLAMA_USE_CURL || LLAMA_USE_HTTPLIB
|
||||
|
||||
std::vector<common_cached_model_info> common_list_cached_models() {
|
||||
|
||||
@@ -65,6 +65,14 @@ bool common_download_model(
|
||||
// returns list of cached models
|
||||
std::vector<common_cached_model_info> common_list_cached_models();
|
||||
|
||||
// download single file from url to local path
|
||||
// returns status code or -1 on error
|
||||
int common_download_file_single(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
bool offline,
|
||||
const common_header_list & headers = {});
|
||||
|
||||
// resolve and download model from Docker registry
|
||||
// return local path to downloaded model file
|
||||
std::string common_docker_resolve_model(const std::string & docker);
|
||||
|
||||
+76
-1
@@ -16,6 +16,46 @@ static std::string rm_leading_dashes(const std::string & str) {
|
||||
return str.substr(pos);
|
||||
}
|
||||
|
||||
// only allow a subset of args for remote presets for security reasons
|
||||
// do not add more args unless absolutely necessary
|
||||
// args that output to files are strictly prohibited
|
||||
static std::set<std::string> get_remote_preset_whitelist(const std::map<std::string, common_arg> & key_to_opt) {
|
||||
static const std::set<std::string> allowed_options = {
|
||||
"model-url",
|
||||
"hf-repo",
|
||||
"hf-repo-draft",
|
||||
"hf-repo-v", // vocoder
|
||||
"hf-file-v", // vocoder
|
||||
"mmproj-url",
|
||||
"pooling",
|
||||
"jinja",
|
||||
"batch-size",
|
||||
"ubatch-size",
|
||||
"cache-reuse",
|
||||
// note: sampling params are automatically allowed by default
|
||||
// negated args will be added automatically
|
||||
};
|
||||
|
||||
std::set<std::string> allowed_keys;
|
||||
|
||||
for (const auto & it : key_to_opt) {
|
||||
const std::string & key = it.first;
|
||||
const common_arg & opt = it.second;
|
||||
if (allowed_options.find(key) != allowed_options.end() || opt.is_sparam) {
|
||||
allowed_keys.insert(key);
|
||||
// also add variant keys (args without leading dashes and env vars)
|
||||
for (const auto & arg : opt.get_args()) {
|
||||
allowed_keys.insert(rm_leading_dashes(arg));
|
||||
}
|
||||
for (const auto & env : opt.get_env()) {
|
||||
allowed_keys.insert(env);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return allowed_keys;
|
||||
}
|
||||
|
||||
std::vector<std::string> common_preset::to_args(const std::string & bin_path) const {
|
||||
std::vector<std::string> args;
|
||||
|
||||
@@ -121,6 +161,29 @@ void common_preset::merge(const common_preset & other) {
|
||||
}
|
||||
}
|
||||
|
||||
void common_preset::apply_to_params(common_params & params) const {
|
||||
for (const auto & [opt, val] : options) {
|
||||
// apply each option to params
|
||||
if (opt.handler_string) {
|
||||
opt.handler_string(params, val);
|
||||
} else if (opt.handler_int) {
|
||||
opt.handler_int(params, std::stoi(val));
|
||||
} else if (opt.handler_bool) {
|
||||
opt.handler_bool(params, common_arg_utils::is_truthy(val));
|
||||
} else if (opt.handler_str_str) {
|
||||
// not supported yet
|
||||
throw std::runtime_error(string_format(
|
||||
"%s: option with two values is not supported yet",
|
||||
__func__
|
||||
));
|
||||
} else if (opt.handler_void) {
|
||||
opt.handler_void(params);
|
||||
} else {
|
||||
GGML_ABORT("unknown handler type");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static std::map<std::string, std::map<std::string, std::string>> parse_ini_from_file(const std::string & path) {
|
||||
std::map<std::string, std::map<std::string, std::string>> parsed;
|
||||
|
||||
@@ -230,10 +293,16 @@ static std::string parse_bool_arg(const common_arg & arg, const std::string & ke
|
||||
return value;
|
||||
}
|
||||
|
||||
common_preset_context::common_preset_context(llama_example ex)
|
||||
common_preset_context::common_preset_context(llama_example ex, bool only_remote_allowed)
|
||||
: ctx_params(common_params_parser_init(default_params, ex)) {
|
||||
common_params_add_preset_options(ctx_params.options);
|
||||
key_to_opt = get_map_key_opt(ctx_params);
|
||||
|
||||
// setup allowed keys if only_remote_allowed is true
|
||||
if (only_remote_allowed) {
|
||||
filter_allowed_keys = true;
|
||||
allowed_keys = get_remote_preset_whitelist(key_to_opt);
|
||||
}
|
||||
}
|
||||
|
||||
common_presets common_preset_context::load_from_ini(const std::string & path, common_preset & global) const {
|
||||
@@ -250,6 +319,12 @@ common_presets common_preset_context::load_from_ini(const std::string & path, co
|
||||
LOG_DBG("loading preset: %s\n", preset.name.c_str());
|
||||
for (const auto & [key, value] : section.second) {
|
||||
LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str());
|
||||
if (filter_allowed_keys && allowed_keys.find(key) == allowed_keys.end()) {
|
||||
throw std::runtime_error(string_format(
|
||||
"option '%s' is not allowed in remote presets",
|
||||
key.c_str()
|
||||
));
|
||||
}
|
||||
if (key_to_opt.find(key) != key_to_opt.end()) {
|
||||
const auto & opt = key_to_opt.at(key);
|
||||
if (is_bool_arg(opt)) {
|
||||
|
||||
+10
-1
@@ -6,6 +6,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <set>
|
||||
|
||||
//
|
||||
// INI preset parser and writer
|
||||
@@ -40,6 +41,9 @@ struct common_preset {
|
||||
|
||||
// merge another preset into this one, overwriting existing options
|
||||
void merge(const common_preset & other);
|
||||
|
||||
// apply preset options to common_params
|
||||
void apply_to_params(common_params & params) const;
|
||||
};
|
||||
|
||||
// interface for multiple presets in one file
|
||||
@@ -50,7 +54,12 @@ struct common_preset_context {
|
||||
common_params default_params; // unused for now
|
||||
common_params_context ctx_params;
|
||||
std::map<std::string, common_arg> key_to_opt;
|
||||
common_preset_context(llama_example ex);
|
||||
|
||||
bool filter_allowed_keys = false;
|
||||
std::set<std::string> allowed_keys;
|
||||
|
||||
// if only_remote_allowed is true, only accept whitelisted keys
|
||||
common_preset_context(llama_example ex, bool only_remote_allowed = false);
|
||||
|
||||
// load presets from INI file
|
||||
common_presets load_from_ini(const std::string & path, common_preset & global) const;
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
# llama.cpp INI Presets
|
||||
|
||||
## Introduction
|
||||
|
||||
The INI preset feature, introduced in [PR#17859](https://github.com/ggml-org/llama.cpp/pull/17859), allows users to create reusable and shareable parameter configurations for llama.cpp.
|
||||
|
||||
### Using Presets with the Server
|
||||
|
||||
When running multiple models on the server (router mode), INI preset files can be used to configure model-specific parameters. Please refer to the [server documentation](../tools/server/README.md) for more details.
|
||||
|
||||
### Using a Remote Preset
|
||||
|
||||
> [!NOTE]
|
||||
>
|
||||
> This feature is currently only supported via the `-hf` option.
|
||||
|
||||
For GGUF models hosted on Hugging Face, you can include a `preset.ini` file in the root directory of the repository to define specific configurations for that model.
|
||||
|
||||
Example:
|
||||
|
||||
```ini
|
||||
hf-repo-draft = username/my-draft-model-GGUF
|
||||
temp = 0.5
|
||||
top-k = 20
|
||||
top-p = 0.95
|
||||
```
|
||||
|
||||
For security reasons, only certain options are allowed. Please refer to [preset.cpp](../common/preset.cpp) for the complete list of permitted options.
|
||||
|
||||
Example usage:
|
||||
|
||||
Assuming your repository `username/my-model-with-preset` contains a `preset.ini` with the configuration above:
|
||||
|
||||
```sh
|
||||
llama-cli -hf username/my-model-with-preset
|
||||
|
||||
# This is equivalent to:
|
||||
llama-cli -hf username/my-model-with-preset \
|
||||
--hf-repo-draft username/my-draft-model-GGUF \
|
||||
--temp 0.5 \
|
||||
--top-k 20 \
|
||||
--top-p 0.95
|
||||
```
|
||||
|
||||
You can also override preset arguments by specifying them on the command line:
|
||||
|
||||
```sh
|
||||
# Force temp = 0.1, overriding the preset value
|
||||
llama-cli -hf username/my-model-with-preset --temp 0.1
|
||||
```
|
||||
|
||||
If you want to define multiple preset configurations for one or more GGUF models, you can create a blank HF repo for each preset. Each HF repo should contain a `preset.ini` file that references the actual model(s):
|
||||
|
||||
```ini
|
||||
hf-repo = user/my-model-main
|
||||
hf-repo-draft = user/my-model-draft
|
||||
temp = 0.8
|
||||
ctx-size = 1024
|
||||
; (and other configurations)
|
||||
```
|
||||
@@ -234,6 +234,11 @@
|
||||
|
||||
#if UINTPTR_MAX == 0xFFFFFFFF
|
||||
#define GGML_MEM_ALIGN 4
|
||||
#elif defined(__EMSCRIPTEN__)
|
||||
// emscripten uses max_align_t == 8, so we need GGML_MEM_ALIGN == 8 for 64-bit wasm.
|
||||
// (for 32-bit wasm, the first conditional is true and GGML_MEM_ALIGN stays 4.)
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18628
|
||||
#define GGML_MEM_ALIGN 8
|
||||
#else
|
||||
#define GGML_MEM_ALIGN 16
|
||||
#endif
|
||||
|
||||
@@ -144,7 +144,7 @@ extern "C" {
|
||||
// device description: short informative description of the device, could be the model name
|
||||
const char * (*get_description)(ggml_backend_dev_t dev);
|
||||
|
||||
// device memory in bytes
|
||||
// device memory in bytes: 0 bytes to indicate no memory to report
|
||||
void (*get_memory)(ggml_backend_dev_t dev, size_t * free, size_t * total);
|
||||
|
||||
// device type
|
||||
|
||||
@@ -4287,8 +4287,8 @@ static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_
|
||||
}
|
||||
|
||||
static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
*free = 1;
|
||||
*total = 1;
|
||||
*free = 0;
|
||||
*total = 0;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
+3
-1
@@ -1292,7 +1292,9 @@ extern "C" {
|
||||
// available samplers:
|
||||
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
|
||||
|
||||
/// seed == LLAMA_DEFAULT_SEED to use a random seed.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_dist(uint32_t seed);
|
||||
|
||||
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
/// Setting k <= 0 makes this a noop
|
||||
|
||||
+20
-8
@@ -4,12 +4,13 @@
|
||||
#
|
||||
# - creates a new remote using the fork's clone URL
|
||||
# - creates a local branch tracking the remote branch
|
||||
# - creates a new worktree in a parent folder, suffixed with "-pr-${PR}"
|
||||
# - creates a new worktree in a parent folder, suffixed with "-pr-$PR"
|
||||
#
|
||||
# sample usage:
|
||||
# ./scripts/pr2wt.sh 12345
|
||||
# ./scripts/pr2wt.sh 12345 opencode
|
||||
# ./scripts/pr2wt.sh 12345 "cmake -B build && cmake --build build"
|
||||
# ./scripts/pr2wt.sh 12345 "bash -l"
|
||||
|
||||
function usage() {
|
||||
echo "usage: $0 <pr_number> [cmd]"
|
||||
@@ -39,7 +40,7 @@ org_repo=${org_repo%.git}
|
||||
|
||||
echo "org/repo: $org_repo"
|
||||
|
||||
meta=$(curl -sSf -H "Accept: application/vnd.github+json" "https://api.github.com/repos/${org_repo}/pulls/${PR}")
|
||||
meta=$(curl -sSf -H "Accept: application/vnd.github+json" "https://api.github.com/repos/$org_repo/pulls/$PR")
|
||||
|
||||
url_remote=$(echo "$meta" | jq -r '.head.repo.clone_url')
|
||||
head_ref=$(echo "$meta" | jq -r '.head.ref')
|
||||
@@ -47,21 +48,32 @@ head_ref=$(echo "$meta" | jq -r '.head.ref')
|
||||
echo "url: $url_remote"
|
||||
echo "head_ref: $head_ref"
|
||||
|
||||
git remote rm pr/${PR} 2> /dev/null
|
||||
git remote add pr/${PR} $url_remote
|
||||
git fetch pr/${PR} $head_ref
|
||||
url_remote_cur=$(git config --get "remote.pr/$PR.url" 2>/dev/null || true)
|
||||
|
||||
if [[ "$url_remote_cur" != "$url_remote" ]]; then
|
||||
git remote rm pr/$PR 2> /dev/null
|
||||
git remote add pr/$PR "$url_remote"
|
||||
fi
|
||||
|
||||
git fetch "pr/$PR" "$head_ref"
|
||||
|
||||
dir=$(basename $(pwd))
|
||||
|
||||
git branch -D pr/$PR 2> /dev/null
|
||||
git worktree add -b pr/$PR ../$dir-pr-$PR pr/$PR/${head_ref} 2> /dev/null
|
||||
git worktree add -b pr/$PR ../$dir-pr-$PR pr/$PR/$head_ref 2> /dev/null
|
||||
|
||||
wt_path=$(cd ../$dir-pr-$PR && pwd)
|
||||
|
||||
echo "git worktree created in $wt_path"
|
||||
|
||||
# if a command was provided, execute it
|
||||
cd $wt_path
|
||||
git branch --set-upstream-to=pr/$PR/$head_ref
|
||||
git pull --ff-only || {
|
||||
echo "error: failed to pull pr/$PR"
|
||||
exit 1
|
||||
}
|
||||
|
||||
if [[ $# -eq 2 ]]; then
|
||||
cd ../$dir-pr-$PR
|
||||
echo "executing: $2"
|
||||
eval "$2"
|
||||
fi
|
||||
|
||||
+12
-4
@@ -2452,6 +2452,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
pimpl->gpu_buft_list.emplace(dev, std::move(buft_list));
|
||||
}
|
||||
|
||||
ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (cpu_dev == nullptr) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
|
||||
// calculate the split points
|
||||
bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_devices(), [](float x) { return x == 0.0f; });
|
||||
std::vector<float> splits(n_devices());
|
||||
@@ -2462,6 +2467,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
size_t total;
|
||||
size_t free;
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
|
||||
// devices can return 0 bytes for free and total memory if they do not
|
||||
// have any to report. in this case, we will use the host memory as a fallback
|
||||
// fixes: https://github.com/ggml-org/llama.cpp/issues/18577
|
||||
if (free == 0 && total == 0) {
|
||||
ggml_backend_dev_memory(cpu_dev, &free, &total);
|
||||
}
|
||||
splits[i] = free;
|
||||
}
|
||||
} else {
|
||||
@@ -2478,10 +2490,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
splits[i] /= split_sum;
|
||||
}
|
||||
|
||||
ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (cpu_dev == nullptr) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
const int i_gpu_start = std::max(int(hparams.n_layer) + 1 - n_gpu_layers, 0);
|
||||
const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, int(n_layer) + 1);
|
||||
auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev {
|
||||
|
||||
@@ -2142,7 +2142,7 @@ struct llama_sampler_xtc {
|
||||
const uint32_t seed;
|
||||
uint32_t seed_cur;
|
||||
|
||||
std::mt19937 rng;
|
||||
std::mt19937 rng;
|
||||
};
|
||||
|
||||
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
|
||||
|
||||
+13
-1
@@ -111,8 +111,20 @@ static std::vector<llama_device_memory_data> llama_get_device_memory_data(
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < ret.size(); i++) {
|
||||
size_t free, total;
|
||||
size_t free;
|
||||
size_t total;
|
||||
ggml_backend_dev_memory(model->devices[i], &free, &total);
|
||||
|
||||
// devices can return 0 bytes for free and total memory if they do not
|
||||
// have any to report. in this case, we will use the host memory as a fallback
|
||||
// fixes: https://github.com/ggml-org/llama.cpp/issues/18577
|
||||
if (free == 0 && total == 0) {
|
||||
ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (cpu_dev == nullptr) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
ggml_backend_dev_memory(cpu_dev, &free, &total);
|
||||
}
|
||||
ret[i].free = free;
|
||||
ret[i].total = total;
|
||||
}
|
||||
|
||||
+410
-336
@@ -4,7 +4,6 @@
|
||||
#include "server-task.h"
|
||||
#include "server-queue.h"
|
||||
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include "log.h"
|
||||
@@ -16,7 +15,6 @@
|
||||
#include <cstddef>
|
||||
#include <cinttypes>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
#include <filesystem>
|
||||
|
||||
// fix problem with std::min and std::max
|
||||
@@ -81,6 +79,8 @@ struct server_slot {
|
||||
|
||||
common_speculative * spec = nullptr;
|
||||
|
||||
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
|
||||
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
|
||||
std::unique_ptr<const server_task> task;
|
||||
std::unique_ptr<const server_task> task_prev; // used for debugging
|
||||
|
||||
@@ -155,7 +155,7 @@ struct server_slot {
|
||||
|
||||
common_sampler_ptr smpl;
|
||||
|
||||
llama_token sampled; // in speculative mode, this is the last accepted token
|
||||
llama_token sampled; // in speculative mode, this is the last accepted token
|
||||
llama_tokens drafted;
|
||||
|
||||
// stats
|
||||
@@ -203,12 +203,46 @@ struct server_slot {
|
||||
alora_invocation_start = -1;
|
||||
}
|
||||
|
||||
// remove cached prompt + tokens
|
||||
void clear(bool allow_processing) {
|
||||
if (!allow_processing) {
|
||||
GGML_ASSERT(!is_processing());
|
||||
}
|
||||
|
||||
SLT_INF(*this, "clearing slot with %zu tokens\n", prompt.tokens.size());
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
|
||||
prompt.tokens.clear();
|
||||
}
|
||||
|
||||
void init_sampler() const {
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
common_sampler_reset(smpl.get());
|
||||
|
||||
int n_text = 0;
|
||||
|
||||
for (int i = 0; i < (int) prompt.tokens.size(); i++) {
|
||||
const llama_token id = prompt.tokens[i];
|
||||
|
||||
if (id != LLAMA_TOKEN_NULL) {
|
||||
common_sampler_accept(smpl.get(), id, false);
|
||||
n_text++;
|
||||
}
|
||||
}
|
||||
|
||||
SLT_INF(*this, "init sampler, took %0.2f ms, tokens: text = %d, total = %d\n",
|
||||
(ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size());
|
||||
}
|
||||
|
||||
// TODO: move to server_task
|
||||
bool need_embd() const {
|
||||
GGML_ASSERT(task);
|
||||
|
||||
return server_task_type_need_embd(task->type);
|
||||
}
|
||||
|
||||
// TODO: move to server_task
|
||||
bool need_logits() const {
|
||||
GGML_ASSERT(task);
|
||||
|
||||
@@ -260,10 +294,13 @@ struct server_slot {
|
||||
SLT_WRN(*this, "%s", "slot is not processing\n");
|
||||
return;
|
||||
}
|
||||
|
||||
generated_token_probs.push_back(token);
|
||||
}
|
||||
|
||||
int get_n_draft_max() const {
|
||||
GGML_ASSERT(task);
|
||||
|
||||
if (!can_speculate()) {
|
||||
return 0;
|
||||
}
|
||||
@@ -289,12 +326,14 @@ struct server_slot {
|
||||
}
|
||||
|
||||
// note: a slot can also be either a parent or a child
|
||||
// TODO: move to server_task
|
||||
bool is_parent() const {
|
||||
return is_processing() && task->n_children > 0;
|
||||
return task->n_children > 0;
|
||||
}
|
||||
|
||||
// TODO: move to server_task
|
||||
bool is_child() const {
|
||||
return is_processing() && task->id_parent >= 0;
|
||||
return task->id_parent >= 0;
|
||||
}
|
||||
|
||||
void release() {
|
||||
@@ -303,10 +342,16 @@ struct server_slot {
|
||||
|
||||
SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated);
|
||||
|
||||
t_last_used = ggml_time_us();
|
||||
t_last_used = ggml_time_us();
|
||||
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
|
||||
|
||||
state = SLOT_STATE_IDLE;
|
||||
|
||||
// do not keep context of the child slots - the parent's context is enough
|
||||
if (is_child()) {
|
||||
clear(false);
|
||||
}
|
||||
|
||||
task_prev = std::move(task);
|
||||
task.reset();
|
||||
|
||||
@@ -427,14 +472,22 @@ struct server_slot {
|
||||
}
|
||||
|
||||
void copy_state_to(server_slot & other) const {
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), other.id, 0, -1);
|
||||
llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, 0, -1);
|
||||
GGML_ASSERT(state == SLOT_STATE_DONE_PROMPT);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), other.id, -1, -1);
|
||||
llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, -1, -1);
|
||||
|
||||
other.n_decoded = n_decoded;
|
||||
other.n_remaining = n_remaining;
|
||||
other.i_batch = i_batch;
|
||||
|
||||
other.t_start_process_prompt = t_start_process_prompt;
|
||||
other.t_prompt_processing = t_prompt_processing;
|
||||
other.n_prompt_tokens_cache = n_prompt_tokens_cache;
|
||||
other.n_prompt_tokens_processed = n_prompt_tokens_processed;
|
||||
|
||||
other.prompt = prompt.clone();
|
||||
other.init_sampler();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -747,6 +800,7 @@ private:
|
||||
}
|
||||
|
||||
slots.clear();
|
||||
|
||||
for (int i = 0; i < params_base.n_parallel; i++) {
|
||||
server_slot slot;
|
||||
|
||||
@@ -995,7 +1049,7 @@ private:
|
||||
ret->prompt_save(*prompt_cache);
|
||||
|
||||
if (!ret->prompt_load(*prompt_cache, task.tokens)) {
|
||||
clear_slot(*ret);
|
||||
ret->clear(false);
|
||||
}
|
||||
|
||||
prompt_cache->update();
|
||||
@@ -1007,17 +1061,6 @@ private:
|
||||
return ret;
|
||||
}
|
||||
|
||||
void clear_slot(server_slot & slot, bool allow_processing = false) const {
|
||||
if (!allow_processing) {
|
||||
GGML_ASSERT(!slot.is_processing());
|
||||
}
|
||||
|
||||
SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size());
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||
slot.prompt.tokens.clear();
|
||||
}
|
||||
|
||||
// return true if at least one slot has been cleared
|
||||
// TODO: improve logic
|
||||
// - smarter decision which slot to clear (LRU or longest prompt?)
|
||||
@@ -1038,7 +1081,7 @@ private:
|
||||
if (slot.prompt.n_tokens() > 0) {
|
||||
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
|
||||
|
||||
clear_slot(slot);
|
||||
slot.clear(false);
|
||||
|
||||
res = true;
|
||||
|
||||
@@ -1184,7 +1227,7 @@ private:
|
||||
? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
|
||||
: SLOT_STATE_STARTED;
|
||||
|
||||
SLT_INF(slot, "%s", "processing task\n");
|
||||
SLT_INF(slot, "processing task, is_child = %d\n", slot.is_child());
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -1821,7 +1864,7 @@ private:
|
||||
// Erase token cache
|
||||
const size_t n_erased = slot->prompt.tokens.size();
|
||||
|
||||
clear_slot(*slot);
|
||||
slot->clear(false);
|
||||
|
||||
auto res = std::make_unique<server_task_result_slot_erase>();
|
||||
res->id = task.id;
|
||||
@@ -2055,293 +2098,317 @@ private:
|
||||
continue;
|
||||
}
|
||||
|
||||
// this slot still has a prompt to be processed
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
||||
const auto & input_tokens = slot.task->tokens;
|
||||
// check if this is a child slot
|
||||
if (slot.state == SLOT_STATE_WAIT_OTHER) {
|
||||
SLT_DBG(slot, "%s", "waiting for parent slot to complete\n");
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: maybe move branch to outside of this loop in the future
|
||||
if (slot.state == SLOT_STATE_STARTED) {
|
||||
slot.t_start_process_prompt = ggml_time_us();
|
||||
slot.t_start_generation = 0;
|
||||
|
||||
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
||||
|
||||
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
|
||||
slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());
|
||||
|
||||
// print prompt tokens (for debugging)
|
||||
/*if (1) {
|
||||
// first 16 tokens (avoid flooding logs)
|
||||
for (int i = 0; i < std::min<int>(16, input_tokens.size()); i++) {
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
|
||||
// SLOT_STATE_STARTED -> SLOT_STATE_PROCESSING_PROMPT
|
||||
// TODO: maybe move branch to outside of this loop in the future
|
||||
if (slot.state == SLOT_STATE_STARTED) {
|
||||
// wait for all children to be launched
|
||||
if (slot.is_parent()) {
|
||||
int n_launched = 0;
|
||||
for (auto & other : slots) {
|
||||
if (other.is_processing() && other.is_child() && other.task->id_parent == slot.task->id) {
|
||||
++n_launched;
|
||||
}
|
||||
} else {
|
||||
// all
|
||||
for (int i = 0; i < (int) input_tokens.size(); i++) {
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
|
||||
}
|
||||
}*/
|
||||
|
||||
// keep track how many tokens we can reuse from the previous state
|
||||
int n_past = 0;
|
||||
|
||||
// empty prompt passed -> release the slot and send empty response
|
||||
if (input_tokens.empty()) {
|
||||
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
|
||||
|
||||
slot.print_timings();
|
||||
send_final_response(slot);
|
||||
slot.release();
|
||||
}
|
||||
|
||||
if (n_launched < slot.task->n_children) {
|
||||
SLT_DBG(slot, "waiting for children to be launched, n_children = %d, n_launched = %d\n", slot.task->n_children, n_launched);
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: support memory-less logits computation
|
||||
if (slot.need_logits() && !llama_get_memory(ctx)) {
|
||||
send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!slot.can_split()) {
|
||||
if (slot.task->n_tokens() > n_ubatch) {
|
||||
send_error(slot,
|
||||
string_format(
|
||||
"input (%d tokens) is too large to process. increase the physical batch "
|
||||
"size (current batch size: %d)",
|
||||
slot.task->n_tokens(), n_ubatch),
|
||||
ERROR_TYPE_SERVER);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (slot.task->n_tokens() > slot.n_ctx) {
|
||||
send_error(
|
||||
slot,
|
||||
string_format(
|
||||
"input (%d tokens) is larger than the max context size (%d tokens). skipping",
|
||||
slot.task->n_tokens(), slot.n_ctx),
|
||||
ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (slot.task->n_tokens() >= slot.n_ctx) {
|
||||
send_error(slot,
|
||||
string_format("request (%d tokens) exceeds the available context size (%d "
|
||||
"tokens), try increasing it",
|
||||
slot.task->n_tokens(), slot.n_ctx),
|
||||
ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (slot.task->params.cache_prompt) {
|
||||
// reuse any previously computed tokens that are common with the new prompt
|
||||
n_past = slot.prompt.tokens.get_common_prefix(input_tokens);
|
||||
|
||||
// if there is an alora invoked, don't cache after the invocation start
|
||||
if (slot.alora_invocation_start > 0) {
|
||||
SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start);
|
||||
n_past = std::min(n_past, slot.alora_invocation_start - 1);
|
||||
}
|
||||
|
||||
const auto n_cache_reuse = slot.task->params.n_cache_reuse;
|
||||
|
||||
const bool can_cache_reuse =
|
||||
llama_memory_can_shift(llama_get_memory(ctx)) &&
|
||||
!slot.prompt.tokens.has_mtmd;
|
||||
|
||||
if (!can_cache_reuse && n_cache_reuse > 0) {
|
||||
SLT_WRN(slot, "cache reuse is not supported - ignoring n_cache_reuse = %d\n", n_cache_reuse);
|
||||
}
|
||||
|
||||
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
||||
if (can_cache_reuse && n_cache_reuse > 0) {
|
||||
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
||||
|
||||
size_t head_c = n_past; // cache
|
||||
size_t head_p = n_past; // current prompt
|
||||
|
||||
if (mctx) {
|
||||
// we should never reach this
|
||||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", n_cache_reuse, n_past);
|
||||
|
||||
while (head_c < slot.prompt.tokens.size() &&
|
||||
head_p < input_tokens.size()) {
|
||||
|
||||
size_t n_match = 0;
|
||||
while (head_c + n_match < slot.prompt.tokens.size() &&
|
||||
head_p + n_match < input_tokens.size() &&
|
||||
slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
|
||||
n_match++;
|
||||
}
|
||||
|
||||
if (n_match >= (size_t) n_cache_reuse) {
|
||||
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
|
||||
//for (size_t i = head_p; i < head_p + n_match; i++) {
|
||||
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
//}
|
||||
|
||||
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
|
||||
|
||||
llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c);
|
||||
llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift);
|
||||
|
||||
for (size_t i = 0; i < n_match; i++) {
|
||||
slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
|
||||
n_past++;
|
||||
}
|
||||
|
||||
head_c += n_match;
|
||||
head_p += n_match;
|
||||
} else {
|
||||
head_c += 1;
|
||||
}
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past);
|
||||
}
|
||||
} else {
|
||||
// if we don't cache the prompt, we have to remove all previous tokens
|
||||
n_past = 0;
|
||||
}
|
||||
|
||||
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
|
||||
const auto n_swa = std::max(1, llama_model_n_swa(model));
|
||||
|
||||
// the largest pos_min required for a checkpoint to be useful
|
||||
const auto pos_min_thold = std::max(0, n_past - n_swa);
|
||||
|
||||
// note: disallow with mtmd contexts for now
|
||||
// https://github.com/ggml-org/llama.cpp/issues/17043
|
||||
if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) {
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||
if (pos_min == -1) {
|
||||
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
|
||||
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
|
||||
}
|
||||
|
||||
// when the prompt prefix does not match, print the tokens around the mismatch
|
||||
// this is useful for debugging prompt caching
|
||||
if (slots_debug) {
|
||||
const int np0 = std::max<int>(n_past - 4, 0);
|
||||
const int np1 = std::min<int>(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size()));
|
||||
|
||||
std::stringstream ss0;
|
||||
std::stringstream ss1;
|
||||
|
||||
std::stringstream st0;
|
||||
std::stringstream st1;
|
||||
|
||||
ss0 << "old: ... ";
|
||||
ss1 << "new: ... ";
|
||||
|
||||
for (int i = np0; i < np1; i++) {
|
||||
if (i == n_past) {
|
||||
ss0 << " | ";
|
||||
ss1 << " | ";
|
||||
}
|
||||
|
||||
{
|
||||
const auto token = slot.prompt.tokens[i];
|
||||
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
|
||||
ss0 << piece;
|
||||
st0 << std::setw(8) << token;
|
||||
}
|
||||
|
||||
{
|
||||
const auto token = slot.task->tokens[i];
|
||||
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
|
||||
ss1 << piece;
|
||||
st1 << std::setw(8) << token;
|
||||
}
|
||||
}
|
||||
|
||||
SLT_WRN(slot, "%s\n", ss0.str().c_str());
|
||||
SLT_WRN(slot, "%s\n", ss1.str().c_str());
|
||||
|
||||
SLT_WRN(slot, "%s\n", st0.str().c_str());
|
||||
SLT_WRN(slot, "%s\n", st1.str().c_str());
|
||||
}
|
||||
|
||||
if (pos_min > pos_min_thold) {
|
||||
// TODO: support can be added in the future when corresponding vision models get released
|
||||
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
||||
|
||||
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
|
||||
|
||||
// search for a context checkpoint
|
||||
const auto it = std::find_if(
|
||||
slot.prompt.checkpoints.rbegin(),
|
||||
slot.prompt.checkpoints.rend(),
|
||||
[&](const auto & cur) {
|
||||
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
|
||||
return cur.pos_min < pos_min_thold;
|
||||
}
|
||||
);
|
||||
|
||||
bool do_reset = it == slot.prompt.checkpoints.rend();
|
||||
|
||||
if (!do_reset) {
|
||||
// restore the context checkpoint
|
||||
const size_t checkpoint_size = it->data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
if (n != checkpoint_size) {
|
||||
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
|
||||
do_reset = true;
|
||||
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
|
||||
} else {
|
||||
n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
|
||||
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
|
||||
if (do_reset) {
|
||||
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
|
||||
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||
n_past = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// erase any checkpoints with pos_min > pos_min_thold
|
||||
for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
|
||||
const auto & cur = *it;
|
||||
if (cur.pos_min > pos_min_thold) {
|
||||
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
|
||||
it = slot.prompt.checkpoints.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// [TAG_PROMPT_LOGITS]
|
||||
if (n_past == slot.task->n_tokens() && n_past > 0) {
|
||||
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens());
|
||||
n_past--;
|
||||
SLT_WRN(slot, "n_past was set to %d\n", n_past);
|
||||
}
|
||||
|
||||
slot.n_prompt_tokens_cache = n_past;
|
||||
slot.n_prompt_tokens_processed = 0;
|
||||
|
||||
slot.prompt.tokens.keep_first(n_past);
|
||||
|
||||
// send initial 0% progress update if needed
|
||||
// this is to signal the client that the request has started processing
|
||||
if (slot.task->params.stream && slot.task->params.return_progress) {
|
||||
send_partial_response(slot, {}, true);
|
||||
}
|
||||
}
|
||||
|
||||
const auto & input_tokens = slot.task->tokens;
|
||||
|
||||
slot.t_start_process_prompt = ggml_time_us();
|
||||
slot.t_start_generation = 0;
|
||||
|
||||
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
||||
|
||||
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
|
||||
slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());
|
||||
|
||||
// print prompt tokens (for debugging)
|
||||
/*if (1) {
|
||||
// first 16 tokens (avoid flooding logs)
|
||||
for (int i = 0; i < std::min<int>(16, input_tokens.size()); i++) {
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
|
||||
}
|
||||
} else {
|
||||
// all
|
||||
for (int i = 0; i < (int) input_tokens.size(); i++) {
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
|
||||
}
|
||||
}*/
|
||||
|
||||
// keep track how many tokens we can reuse from the previous state
|
||||
int n_past = 0;
|
||||
|
||||
// empty prompt passed -> release the slot and send empty response
|
||||
if (input_tokens.empty()) {
|
||||
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
|
||||
|
||||
slot.print_timings();
|
||||
send_final_response(slot);
|
||||
slot.release();
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: support memory-less logits computation
|
||||
if (slot.need_logits() && !llama_get_memory(ctx)) {
|
||||
send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!slot.can_split()) {
|
||||
if (slot.task->n_tokens() > n_ubatch) {
|
||||
send_error(slot,
|
||||
string_format(
|
||||
"input (%d tokens) is too large to process. increase the physical batch "
|
||||
"size (current batch size: %d)",
|
||||
slot.task->n_tokens(), n_ubatch),
|
||||
ERROR_TYPE_SERVER);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (slot.task->n_tokens() > slot.n_ctx) {
|
||||
send_error(
|
||||
slot,
|
||||
string_format(
|
||||
"input (%d tokens) is larger than the max context size (%d tokens). skipping",
|
||||
slot.task->n_tokens(), slot.n_ctx),
|
||||
ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (slot.task->n_tokens() >= slot.n_ctx) {
|
||||
send_error(slot,
|
||||
string_format("request (%d tokens) exceeds the available context size (%d "
|
||||
"tokens), try increasing it",
|
||||
slot.task->n_tokens(), slot.n_ctx),
|
||||
ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (slot.task->params.cache_prompt) {
|
||||
// reuse any previously computed tokens that are common with the new prompt
|
||||
n_past = slot.prompt.tokens.get_common_prefix(input_tokens);
|
||||
|
||||
// if there is an alora invoked, don't cache after the invocation start
|
||||
if (slot.alora_invocation_start > 0) {
|
||||
SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start);
|
||||
n_past = std::min(n_past, slot.alora_invocation_start - 1);
|
||||
}
|
||||
|
||||
const auto n_cache_reuse = slot.task->params.n_cache_reuse;
|
||||
|
||||
const bool can_cache_reuse =
|
||||
llama_memory_can_shift(llama_get_memory(ctx)) &&
|
||||
!slot.prompt.tokens.has_mtmd;
|
||||
|
||||
if (!can_cache_reuse && n_cache_reuse > 0) {
|
||||
SLT_WRN(slot, "cache reuse is not supported - ignoring n_cache_reuse = %d\n", n_cache_reuse);
|
||||
}
|
||||
|
||||
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
||||
if (can_cache_reuse && n_cache_reuse > 0) {
|
||||
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
||||
|
||||
size_t head_c = n_past; // cache
|
||||
size_t head_p = n_past; // current prompt
|
||||
|
||||
if (mctx) {
|
||||
// we should never reach this
|
||||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", n_cache_reuse, n_past);
|
||||
|
||||
while (head_c < slot.prompt.tokens.size() &&
|
||||
head_p < input_tokens.size()) {
|
||||
|
||||
size_t n_match = 0;
|
||||
while (head_c + n_match < slot.prompt.tokens.size() &&
|
||||
head_p + n_match < input_tokens.size() &&
|
||||
slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
|
||||
n_match++;
|
||||
}
|
||||
|
||||
if (n_match >= (size_t) n_cache_reuse) {
|
||||
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
|
||||
//for (size_t i = head_p; i < head_p + n_match; i++) {
|
||||
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
//}
|
||||
|
||||
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
|
||||
|
||||
llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c);
|
||||
llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift);
|
||||
|
||||
for (size_t i = 0; i < n_match; i++) {
|
||||
slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
|
||||
n_past++;
|
||||
}
|
||||
|
||||
head_c += n_match;
|
||||
head_p += n_match;
|
||||
} else {
|
||||
head_c += 1;
|
||||
}
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past);
|
||||
}
|
||||
} else {
|
||||
// if we don't cache the prompt, we have to remove all previous tokens
|
||||
n_past = 0;
|
||||
}
|
||||
|
||||
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
|
||||
const auto n_swa = std::max(1, llama_model_n_swa(model));
|
||||
|
||||
// the largest pos_min required for a checkpoint to be useful
|
||||
const auto pos_min_thold = std::max(0, n_past - n_swa);
|
||||
|
||||
// note: disallow with mtmd contexts for now
|
||||
// https://github.com/ggml-org/llama.cpp/issues/17043
|
||||
if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) {
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||
if (pos_min == -1) {
|
||||
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
|
||||
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
|
||||
}
|
||||
|
||||
// when the prompt prefix does not match, print the tokens around the mismatch
|
||||
// this is useful for debugging prompt caching
|
||||
if (slots_debug) {
|
||||
const int np0 = std::max<int>(n_past - 4, 0);
|
||||
const int np1 = std::min<int>(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size()));
|
||||
|
||||
std::stringstream ss0;
|
||||
std::stringstream ss1;
|
||||
|
||||
std::stringstream st0;
|
||||
std::stringstream st1;
|
||||
|
||||
ss0 << "old: ... ";
|
||||
ss1 << "new: ... ";
|
||||
|
||||
for (int i = np0; i < np1; i++) {
|
||||
if (i == n_past) {
|
||||
ss0 << " | ";
|
||||
ss1 << " | ";
|
||||
}
|
||||
|
||||
{
|
||||
const auto token = slot.prompt.tokens[i];
|
||||
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
|
||||
ss0 << piece;
|
||||
st0 << std::setw(8) << token;
|
||||
}
|
||||
|
||||
{
|
||||
const auto token = slot.task->tokens[i];
|
||||
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
|
||||
ss1 << piece;
|
||||
st1 << std::setw(8) << token;
|
||||
}
|
||||
}
|
||||
|
||||
SLT_WRN(slot, "%s\n", ss0.str().c_str());
|
||||
SLT_WRN(slot, "%s\n", ss1.str().c_str());
|
||||
|
||||
SLT_WRN(slot, "%s\n", st0.str().c_str());
|
||||
SLT_WRN(slot, "%s\n", st1.str().c_str());
|
||||
}
|
||||
|
||||
if (pos_min > pos_min_thold) {
|
||||
// TODO: support can be added in the future when corresponding vision models get released
|
||||
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
||||
|
||||
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
|
||||
|
||||
// search for a context checkpoint
|
||||
const auto it = std::find_if(
|
||||
slot.prompt.checkpoints.rbegin(),
|
||||
slot.prompt.checkpoints.rend(),
|
||||
[&](const auto & cur) {
|
||||
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
|
||||
return cur.pos_min < pos_min_thold;
|
||||
}
|
||||
);
|
||||
|
||||
bool do_reset = it == slot.prompt.checkpoints.rend();
|
||||
|
||||
if (!do_reset) {
|
||||
// restore the context checkpoint
|
||||
const size_t checkpoint_size = it->data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
if (n != checkpoint_size) {
|
||||
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
|
||||
do_reset = true;
|
||||
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
|
||||
} else {
|
||||
n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
|
||||
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
|
||||
if (do_reset) {
|
||||
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
|
||||
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||
n_past = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// erase any checkpoints with pos_min > pos_min_thold
|
||||
for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
|
||||
const auto & cur = *it;
|
||||
if (cur.pos_min > pos_min_thold) {
|
||||
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
|
||||
it = slot.prompt.checkpoints.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// [TAG_PROMPT_LOGITS]
|
||||
if (n_past == slot.task->n_tokens() && n_past > 0) {
|
||||
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens());
|
||||
n_past--;
|
||||
SLT_WRN(slot, "n_past was set to %d\n", n_past);
|
||||
}
|
||||
|
||||
slot.n_prompt_tokens_cache = n_past;
|
||||
slot.n_prompt_tokens_processed = 0;
|
||||
|
||||
slot.prompt.tokens.keep_first(n_past);
|
||||
|
||||
// send initial 0% progress update if needed
|
||||
// this is to signal the client that the request has started processing
|
||||
if (slot.task->params.stream && slot.task->params.return_progress) {
|
||||
send_partial_response(slot, {}, true);
|
||||
}
|
||||
}
|
||||
|
||||
// SLOT_STATE_PROCESSING_PROMPT -> SLOT_STATE_DONE_PROMPT
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT) {
|
||||
const auto & input_tokens = slot.task->tokens;
|
||||
|
||||
if (!slot.can_split()) {
|
||||
// cannot fit the prompt in the current batch - will try next iter
|
||||
if (batch.n_tokens + slot.task->n_tokens() > n_batch) {
|
||||
@@ -2357,7 +2424,7 @@ private:
|
||||
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
|
||||
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
|
||||
|
||||
clear_slot(slot, /*allow_processing=*/true);
|
||||
slot.clear(true);
|
||||
|
||||
// there is no common part left
|
||||
slot.n_prompt_tokens_cache = 0;
|
||||
@@ -2457,16 +2524,6 @@ private:
|
||||
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
|
||||
common_sampler_reset(slot.smpl.get());
|
||||
|
||||
// Process all prompt tokens through sampler system
|
||||
for (int i = 0; i < slot.task->n_tokens(); ++i) {
|
||||
llama_token id = input_tokens[i];
|
||||
if (id != LLAMA_TOKEN_NULL) {
|
||||
common_sampler_accept(slot.smpl.get(), id, false);
|
||||
}
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
@@ -2475,6 +2532,8 @@ private:
|
||||
|
||||
SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
|
||||
|
||||
slot.init_sampler();
|
||||
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
|
||||
|
||||
@@ -2521,11 +2580,6 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
SRV_WRN("%s", "no tokens to decode\n");
|
||||
return;
|
||||
}
|
||||
|
||||
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
|
||||
|
||||
if (slot_batched) {
|
||||
@@ -2542,6 +2596,10 @@ private:
|
||||
llama_set_embeddings(ctx, slot_batched->need_embd());
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
SRV_WRN("%s", "no tokens to decode\n");
|
||||
}
|
||||
|
||||
int32_t i_next = 0;
|
||||
|
||||
// process the created batch of tokens
|
||||
@@ -2593,7 +2651,7 @@ private:
|
||||
|
||||
// note: it's complicated to keep track of how much of the current batch has been
|
||||
// processed before the error occurred, so we simply clear the entire context
|
||||
clear_slot(slot);
|
||||
slot.clear(false);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2617,31 +2675,34 @@ private:
|
||||
// on successful decode, restore the original batch size
|
||||
n_batch = llama_n_batch(ctx);
|
||||
|
||||
// technically, measuring the time here excludes the sampling time for the last batch
|
||||
// but on the other hand, we don't want to do too many system calls to measure the time, so it's ok
|
||||
const int64_t t_current = ggml_time_us();
|
||||
|
||||
// handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too
|
||||
for (auto & slot : slots) {
|
||||
// may need to copy state to other slots
|
||||
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
|
||||
std::vector<server_slot *> child_slots;
|
||||
SLT_INF(slot, "parent task prompt done, n_children = %d\n", slot.task->n_children);
|
||||
|
||||
std::vector<server_slot *> children;
|
||||
for (auto & other : slots) {
|
||||
if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
|
||||
child_slots.push_back(&other);
|
||||
children.push_back(&other);
|
||||
}
|
||||
}
|
||||
|
||||
// we can only proceed if all child slots are having the correct tasks
|
||||
if (child_slots.size() == slot.task->n_children) {
|
||||
if (slot.task->n_children == (int) children.size()) {
|
||||
// copy state to the child slots
|
||||
for (auto & child : child_slots) {
|
||||
SLT_INF(slot, "copying state to child %d\n", child->id);
|
||||
for (auto & child : children) {
|
||||
SLT_INF(slot, " - copying state to child %d\n", child->id);
|
||||
|
||||
GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER);
|
||||
|
||||
slot.copy_state_to(*child);
|
||||
child->state = SLOT_STATE_DONE_PROMPT;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto & slot : slots) {
|
||||
// optionally send prompt processing progress
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
|
||||
if (slot.task->params.stream && slot.task->params.return_progress) {
|
||||
@@ -2687,6 +2748,9 @@ private:
|
||||
|
||||
common_sampler_accept(slot.smpl.get(), id, true);
|
||||
|
||||
// here we have synchronized the llama_context (due to the sampling above), so we can do time measurement
|
||||
const int64_t t_current = ggml_time_us();
|
||||
|
||||
slot.n_decoded += 1;
|
||||
|
||||
if (slot.n_decoded == 1) {
|
||||
@@ -2723,13 +2787,15 @@ private:
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t n_draft = slot.drafted.size();
|
||||
const size_t n_draft = slot.drafted.size();
|
||||
|
||||
// the accepted tokens from the speculation
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted);
|
||||
slot.i_batch_dft.clear();
|
||||
slot.drafted.clear();
|
||||
|
||||
const int64_t t_current = ggml_time_us();
|
||||
|
||||
slot.n_decoded += ids.size();
|
||||
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
||||
@@ -2924,17 +2990,25 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
task.params.oaicompat_model = meta->model_name;
|
||||
|
||||
// prepare child tasks
|
||||
if (task.params.n_cmpl > 1) {
|
||||
task.n_children = task.params.n_cmpl - 1;
|
||||
for (size_t j = 0; j < task.n_children; j++) {
|
||||
server_task child = task.create_child(
|
||||
task.id,
|
||||
rd.get_new_id());
|
||||
|
||||
for (int j = 0; j < task.n_children; j++) {
|
||||
server_task child = task.create_child(task.id, rd.get_new_id());
|
||||
|
||||
// use different sampling seed for each child
|
||||
// note: https://github.com/ggml-org/llama.cpp/pull/18700#discussion_r2675115723
|
||||
if (child.params.sampling.seed != LLAMA_DEFAULT_SEED) {
|
||||
child.params.sampling.seed += j + 1;
|
||||
}
|
||||
|
||||
tasks.push_back(std::move(child));
|
||||
}
|
||||
}
|
||||
|
||||
tasks.push_back(std::move(task));
|
||||
// note: the parent task always launches first
|
||||
tasks.insert(tasks.begin(), std::move(task));
|
||||
}
|
||||
|
||||
rd.post_tasks(std::move(tasks));
|
||||
|
||||
@@ -121,8 +121,8 @@ struct server_task {
|
||||
int id_slot = -1;
|
||||
|
||||
// used by parallel sampling (multiple completions from same prompt)
|
||||
size_t n_children = 0; // number of tasks reusing this prompt
|
||||
int id_parent = -1;
|
||||
int n_children = 0; // number of tasks reusing this prompt
|
||||
int id_parent = -1;
|
||||
|
||||
// used by SERVER_TASK_TYPE_INFERENCE
|
||||
task_params params;
|
||||
@@ -173,11 +173,13 @@ struct server_task {
|
||||
|
||||
server_task create_child(int id_parent, int id_child) const {
|
||||
server_task copy;
|
||||
|
||||
copy.id = id_child;
|
||||
copy.id_parent = id_parent;
|
||||
copy.params = params;
|
||||
copy.type = type;
|
||||
copy.tokens = tokens.clone();
|
||||
|
||||
return copy;
|
||||
}
|
||||
|
||||
|
||||
@@ -503,5 +503,4 @@ def test_chat_completions_multiple_choices():
|
||||
assert len(res.body["choices"]) == 2
|
||||
for choice in res.body["choices"]:
|
||||
assert "assistant" == choice["message"]["role"]
|
||||
assert match_regex("Suddenly", choice["message"]["content"])
|
||||
assert choice["finish_reason"] == "length"
|
||||
|
||||
Reference in New Issue
Block a user