mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 16:17:40 +02:00
560 lines
18 KiB
C++
560 lines
18 KiB
C++
#include "cli-context.h"
|
|
#include "cli-view.h"
|
|
|
|
#include "arg.h"
|
|
#include "base64.hpp"
|
|
#include "log.h"
|
|
#include "console.h"
|
|
|
|
#include <algorithm>
|
|
#include <filesystem>
|
|
#include <fstream>
|
|
#include <map>
|
|
#include <set>
|
|
|
|
std::atomic<bool> g_cli_interrupted = false;
|
|
|
|
static bool should_stop() {
|
|
return g_cli_interrupted.load();
|
|
}
|
|
|
|
static constexpr size_t FILE_GLOB_MAX_RESULTS = 100;
|
|
|
|
const char * LLAMA_ASCII_LOGO = R"(
|
|
▄▄ ▄▄
|
|
██ ██
|
|
██ ██ ▀▀█▄ ███▄███▄ ▀▀█▄ ▄████ ████▄ ████▄
|
|
██ ██ ▄█▀██ ██ ██ ██ ▄█▀██ ██ ██ ██ ██ ██
|
|
██ ██ ▀█▄██ ██ ██ ██ ▀█▄██ ██ ▀████ ████▀ ████▀
|
|
██ ██
|
|
▀▀ ▀▀
|
|
)";
|
|
|
|
// number of values an arg consumes on the command line
|
|
static int arg_num_values(const common_arg & opt) {
|
|
if (opt.value_hint_2 != nullptr) {
|
|
return 2;
|
|
}
|
|
if (opt.value_hint != nullptr) {
|
|
return 1;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
static std::string format_error_message(const json & err) {
|
|
if (err.contains("error") && err.at("error").is_object()) {
|
|
const auto & e = err.at("error");
|
|
if (e.contains("message") && e.at("message").is_string()) {
|
|
return e.at("message").get<std::string>();
|
|
}
|
|
}
|
|
return err.dump();
|
|
}
|
|
|
|
static std::string media_type_from_ext(const std::string & fname) {
|
|
std::string ext = std::filesystem::path(fname).extension().string();
|
|
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
|
|
if (ext == ".wav" || ext == ".mp3") {
|
|
return "audio";
|
|
}
|
|
if (ext == ".mp4" || ext == ".avi" || ext == ".mkv" || ext == ".mov" || ext == ".webm") {
|
|
return "video";
|
|
}
|
|
return "image";
|
|
}
|
|
|
|
bool cli_context::init() {
|
|
view::init(params);
|
|
|
|
std::optional<view::spinner> spinner;
|
|
|
|
bool use_external_server = !params.server_base.empty();
|
|
if (use_external_server) {
|
|
std::string base = params.server_base;
|
|
while (!base.empty() && base.back() == '/') {
|
|
base.pop_back();
|
|
}
|
|
client.server_base = base;
|
|
|
|
spinner.emplace("Connecting to server at " + base);
|
|
} else {
|
|
if (params.model.path.empty() && params.model.url.empty() &&
|
|
params.model.hf_repo.empty() && params.model.docker_repo.empty()) {
|
|
view::show_error(
|
|
"no model specified",
|
|
"use -m <file.gguf> or -hf <user/repo> to run a local model,\n"
|
|
"or --server-base <url> to connect to a running llama-server"
|
|
);
|
|
return false;
|
|
}
|
|
|
|
spinner.emplace("\n\nLoading model...");
|
|
|
|
server.emplace();
|
|
if (!server->start(params)) {
|
|
view::show_error("server start failed");
|
|
return false;
|
|
}
|
|
if (!server->wait_ready(should_stop)) {
|
|
if (!should_stop()) {
|
|
view::show_error("the server exited before becoming ready");
|
|
}
|
|
return false;
|
|
}
|
|
client.server_base = server->address();
|
|
}
|
|
|
|
// for --server-base this is the main availability check; for a spawned
|
|
// server it is a cheap sanity check on top of the ready signal
|
|
auto is_aborted = [this]() {
|
|
return should_stop() || (server && !server->alive());
|
|
};
|
|
bool healthy = false;
|
|
try {
|
|
healthy = client.wait_health(is_aborted);
|
|
} catch (const std::exception & e) {
|
|
client.last_error = e.what();
|
|
}
|
|
if (!healthy) {
|
|
if (!should_stop()) {
|
|
view::show_error(client.last_error);
|
|
}
|
|
return false;
|
|
}
|
|
|
|
if (use_external_server) {
|
|
spinner.reset();
|
|
if (!list_and_ask_models()) {
|
|
return false;
|
|
}
|
|
// restore the spinner for the next step
|
|
spinner.emplace("Waiting for server...");
|
|
}
|
|
|
|
fetch_server_props();
|
|
|
|
return true;
|
|
}
|
|
|
|
void cli_context::fetch_server_props() {
|
|
try {
|
|
json props = client.get_props();
|
|
model_name = props.value("model_alias", "");
|
|
if (model_name.empty()) {
|
|
const std::string path = props.value("model_path", "");
|
|
if (!path.empty()) {
|
|
model_name = std::filesystem::path(path).filename().string();
|
|
}
|
|
}
|
|
build_info = props.value("build_info", "");
|
|
if (props.contains("modalities") && props.at("modalities").is_object()) {
|
|
const auto & modalities = props.at("modalities");
|
|
has_vision = modalities.value("vision", false);
|
|
has_audio = modalities.value("audio", false);
|
|
has_video = modalities.value("video", false);
|
|
}
|
|
} catch (const std::exception & e) {
|
|
// /props can be disabled on remote servers; not fatal
|
|
LOG_DBG("failed to fetch /props: %s\n", e.what());
|
|
}
|
|
}
|
|
|
|
bool cli_context::list_and_ask_models() {
|
|
auto models = client.list_models();
|
|
|
|
// only one model: use it without asking
|
|
if (models.size() == 1) {
|
|
model_name = models[0];
|
|
client.model = model_name;
|
|
return true;
|
|
}
|
|
|
|
std::string message = "\nAvailable models:";
|
|
if (!models.empty()) {
|
|
for (size_t i = 0; i < models.size(); ++i) {
|
|
message += "\n " + std::to_string(i + 1) + ". " + models[i];
|
|
}
|
|
}
|
|
message += "\n";
|
|
view::show_message(message);
|
|
std::string selection;
|
|
while (selection.empty()) {
|
|
if (should_stop()) {
|
|
return false;
|
|
}
|
|
view::user_turn user_turn;
|
|
selection = user_turn.read_input(false, "Select model by number: ");
|
|
if (selection.empty()) {
|
|
continue;
|
|
}
|
|
try {
|
|
size_t idx = std::stoul(selection);
|
|
if (idx > 0 && idx <= models.size()) {
|
|
model_name = models[idx - 1];
|
|
client.model = model_name;
|
|
view::show_message("Selected model: " + model_name);
|
|
break;
|
|
}
|
|
} catch (...) {
|
|
// ignore
|
|
}
|
|
view::show_error("Invalid selection. Please enter a valid number.");
|
|
selection.clear();
|
|
continue;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void cli_context::add_system_prompt() {
|
|
if (!params.system_prompt.empty()) {
|
|
messages.push_back({
|
|
{"role", "system"},
|
|
{"content", params.system_prompt}
|
|
});
|
|
}
|
|
}
|
|
|
|
void cli_context::push_user_message(const std::string & text) {
|
|
json content;
|
|
if (pending_media.empty()) {
|
|
content = text;
|
|
} else {
|
|
// multimodal message: media parts first, then the text
|
|
content = pending_media;
|
|
content.push_back({
|
|
{"type", "text"},
|
|
{"text", text}
|
|
});
|
|
pending_media = json::array();
|
|
}
|
|
messages.push_back({
|
|
{"role", "user"},
|
|
{"content", content}
|
|
});
|
|
}
|
|
|
|
bool cli_context::stage_media_file(const std::string & fname, const std::string & type) {
|
|
std::ifstream file(fname, std::ios::binary);
|
|
if (!file) {
|
|
return false;
|
|
}
|
|
std::string data((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
|
std::string encoded = base64::encode(data);
|
|
|
|
if (type == "audio") {
|
|
std::string ext = std::filesystem::path(fname).extension().string();
|
|
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
|
|
pending_media.push_back({
|
|
{"type", "input_audio"},
|
|
{"input_audio", {
|
|
{"data", encoded},
|
|
{"format", ext == ".mp3" ? "mp3" : "wav"}
|
|
}}
|
|
});
|
|
} else if (type == "video") {
|
|
pending_media.push_back({
|
|
{"type", "input_video"},
|
|
{"input_video", {
|
|
{"data", encoded}
|
|
}}
|
|
});
|
|
} else {
|
|
// the server detects the actual image type from the data
|
|
pending_media.push_back({
|
|
{"type", "image_url"},
|
|
{"image_url", {
|
|
{"url", "data:image/unknown;base64," + encoded}
|
|
}}
|
|
});
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool cli_context::generate_completion(std::string & assistant_content, cli_timings & timings) {
|
|
json body = {
|
|
{"messages", messages},
|
|
{"stream", true},
|
|
// in order to get timings even when we cancel mid-way
|
|
{"timings_per_token", true},
|
|
};
|
|
|
|
bool stream_error = false;
|
|
|
|
view::assistant_turn a;
|
|
|
|
json err = client.create_chat_completion(body, should_stop, [&](const json & chunk) {
|
|
if (chunk.contains("error")) {
|
|
stream_error = true;
|
|
view::show_error(format_error_message(chunk));
|
|
return;
|
|
}
|
|
if (chunk.contains("timings")) {
|
|
const auto & t = chunk.at("timings");
|
|
timings.prompt_per_second = t.value("prompt_per_second", 0.0);
|
|
timings.predicted_per_second = t.value("predicted_per_second", 0.0);
|
|
}
|
|
if (!chunk.contains("choices") || !chunk.at("choices").is_array() || chunk.at("choices").empty()) {
|
|
return;
|
|
}
|
|
const auto & choice = chunk.at("choices").at(0);
|
|
if (!choice.contains("delta")) {
|
|
return;
|
|
}
|
|
const auto & delta = choice.at("delta");
|
|
if (delta.contains("reasoning_content") && delta.at("reasoning_content").is_string()) {
|
|
const std::string text = delta.at("reasoning_content").get<std::string>();
|
|
if (!text.empty()) {
|
|
a.push(view::ASSISTANT_DISPLAY_MODE_REASONING, text);
|
|
}
|
|
}
|
|
if (delta.contains("content") && delta.at("content").is_string()) {
|
|
const std::string text = delta.at("content").get<std::string>();
|
|
if (!text.empty()) {
|
|
assistant_content += text;
|
|
a.push(view::ASSISTANT_DISPLAY_MODE_CONTENT, text);
|
|
}
|
|
}
|
|
});
|
|
|
|
g_cli_interrupted.store(false);
|
|
|
|
if (!err.is_null()) {
|
|
view::show_error(format_error_message(err));
|
|
return false;
|
|
}
|
|
return !stream_error;
|
|
}
|
|
|
|
int cli_context::run() {
|
|
add_system_prompt();
|
|
|
|
std::string modalities = "text";
|
|
if (has_vision) {
|
|
modalities += ", vision";
|
|
}
|
|
if (has_audio) {
|
|
modalities += ", audio";
|
|
}
|
|
if (has_video) {
|
|
modalities += ", video";
|
|
}
|
|
|
|
std::string banner;
|
|
banner += "\n";
|
|
banner += LLAMA_ASCII_LOGO;
|
|
banner += "\n";
|
|
banner += "build : " + build_info + "\n";
|
|
banner += "model : " + model_name + "\n";
|
|
banner += "modalities : " + modalities + "\n";
|
|
if (!params.system_prompt.empty()) {
|
|
banner += "using custom system prompt\n";
|
|
}
|
|
banner += "\n";
|
|
banner += "available commands:\n";
|
|
banner += " /exit or Ctrl+C stop or exit\n";
|
|
banner += " /regen regenerate the last response\n";
|
|
banner += " /clear clear the chat history\n";
|
|
banner += " /read <file> add a text file\n";
|
|
banner += " /glob <pattern> add text files using globbing pattern\n";
|
|
if (has_vision) {
|
|
banner += " /image <file> add an image file\n";
|
|
}
|
|
if (has_audio) {
|
|
banner += " /audio <file> add an audio file\n";
|
|
}
|
|
if (has_video) {
|
|
banner += " /video <file> add a video file\n";
|
|
}
|
|
banner += "\n";
|
|
|
|
view::show_message(banner);
|
|
|
|
// interactive loop
|
|
std::string cur_msg;
|
|
|
|
auto add_text_file = [&](const std::string & fname) -> bool {
|
|
std::ifstream file(fname, std::ios::binary);
|
|
if (!file) {
|
|
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
|
|
return false;
|
|
}
|
|
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
|
cur_msg += "--- File: ";
|
|
cur_msg += fname;
|
|
cur_msg += " ---\n";
|
|
cur_msg += content;
|
|
view::show_message(string_format("Loaded text from '%s'", fname.c_str()));
|
|
return true;
|
|
};
|
|
|
|
while (true) {
|
|
std::string buffer;
|
|
{
|
|
view::user_turn user_turn;
|
|
|
|
if (params.prompt.empty()) {
|
|
buffer = user_turn.read_input(params.multiline_input);
|
|
} else {
|
|
// process input prompt from args
|
|
for (auto & fname : params.image) {
|
|
if (!stage_media_file(fname, media_type_from_ext(fname))) {
|
|
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
|
|
break;
|
|
}
|
|
view::show_message(string_format("Loaded media from '%s'", fname.c_str()));
|
|
}
|
|
buffer = params.prompt;
|
|
user_turn.echo(buffer);
|
|
params.prompt.clear(); // only use it once
|
|
}
|
|
}
|
|
|
|
if (should_stop()) {
|
|
g_cli_interrupted.store(false);
|
|
break;
|
|
}
|
|
|
|
// remove trailing newline
|
|
if (!buffer.empty() && buffer.back() == '\n') {
|
|
buffer.pop_back();
|
|
}
|
|
|
|
// skip empty messages
|
|
if (buffer.empty()) {
|
|
continue;
|
|
}
|
|
|
|
bool add_user_msg = true;
|
|
|
|
// process commands
|
|
if (string_starts_with(buffer, "/exit")) {
|
|
break;
|
|
} else if (string_starts_with(buffer, "/regen")) {
|
|
if (messages.size() >= 2) {
|
|
size_t last_idx = messages.size() - 1;
|
|
messages.erase(last_idx);
|
|
add_user_msg = false;
|
|
} else {
|
|
view::show_error("No message to regenerate.");
|
|
continue;
|
|
}
|
|
} else if (string_starts_with(buffer, "/clear")) {
|
|
messages.clear();
|
|
add_system_prompt();
|
|
|
|
pending_media = json::array();
|
|
view::show_message("Chat history cleared.");
|
|
continue;
|
|
} else if (
|
|
(string_starts_with(buffer, "/image ") && has_vision) ||
|
|
(string_starts_with(buffer, "/audio ") && has_audio) ||
|
|
(string_starts_with(buffer, "/video ") && has_video)) {
|
|
std::string type = buffer.substr(1, 5);
|
|
// just in case (bad copy-paste for example), we strip all trailing/leading spaces
|
|
std::string fname = string_strip(buffer.substr(7));
|
|
if (!stage_media_file(fname, type)) {
|
|
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
|
|
continue;
|
|
}
|
|
view::show_message(string_format("Loaded media from '%s'", fname.c_str()));
|
|
continue;
|
|
} else if (string_starts_with(buffer, "/read ")) {
|
|
std::string fname = string_strip(buffer.substr(6));
|
|
add_text_file(fname);
|
|
continue;
|
|
} else if (string_starts_with(buffer, "/glob ")) {
|
|
std::error_code ec;
|
|
size_t count = 0;
|
|
auto curdir = std::filesystem::current_path();
|
|
std::string pattern = string_strip(buffer.substr(6));
|
|
std::filesystem::path rel_path;
|
|
|
|
auto startglob = pattern.find_first_of("![*?");
|
|
if (startglob != std::string::npos && startglob != 0) {
|
|
auto endpath = pattern.substr(0, startglob).find_last_of('/');
|
|
if (endpath != std::string::npos) {
|
|
std::string rel_pattern = pattern.substr(0, endpath);
|
|
#if !defined(_WIN32)
|
|
if (string_starts_with(rel_pattern, '~')) {
|
|
const char * home = std::getenv("HOME");
|
|
if (home && home[0]) {
|
|
rel_pattern = home + rel_pattern.substr(1);
|
|
}
|
|
}
|
|
#endif
|
|
rel_path = rel_pattern;
|
|
pattern.erase(0, endpath + 1);
|
|
curdir /= rel_path;
|
|
}
|
|
}
|
|
|
|
for (const auto & entry : std::filesystem::recursive_directory_iterator(curdir,
|
|
std::filesystem::directory_options::skip_permission_denied, ec)) {
|
|
if (!entry.is_regular_file()) {
|
|
continue;
|
|
}
|
|
|
|
std::string rel = std::filesystem::relative(entry.path(), curdir, ec).string();
|
|
if (ec) {
|
|
ec.clear();
|
|
continue;
|
|
}
|
|
std::replace(rel.begin(), rel.end(), '\\', '/');
|
|
|
|
if (!glob_match(pattern, rel)) {
|
|
continue;
|
|
}
|
|
|
|
if (!add_text_file((rel_path / rel).string())) {
|
|
continue;
|
|
}
|
|
|
|
if (++count >= FILE_GLOB_MAX_RESULTS) {
|
|
view::show_error(string_format("Maximum number of globbed files allowed (%zu) reached.", FILE_GLOB_MAX_RESULTS));
|
|
break;
|
|
}
|
|
}
|
|
continue;
|
|
} else {
|
|
// not a command
|
|
cur_msg += buffer;
|
|
}
|
|
|
|
// generate response
|
|
if (add_user_msg) {
|
|
push_user_message(cur_msg);
|
|
cur_msg.clear();
|
|
}
|
|
cli_timings timings;
|
|
std::string assistant_content;
|
|
generate_completion(assistant_content, timings);
|
|
messages.push_back({
|
|
{"role", "assistant"},
|
|
{"content", assistant_content}
|
|
});
|
|
|
|
if (params.show_timings) {
|
|
view::show_info(string_format(
|
|
"\n[ Prompt: %.1f t/s | Generation: %.1f t/s ]",
|
|
timings.prompt_per_second,
|
|
timings.predicted_per_second
|
|
));
|
|
}
|
|
|
|
if (params.single_turn) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
view::show_message("\n\nExiting...");
|
|
|
|
return 0;
|
|
}
|
|
|
|
void cli_context::shutdown() {
|
|
if (server) {
|
|
server->stop();
|
|
server.reset();
|
|
}
|
|
}
|