Files
llama.cpp/tools/cli/cli-context.cpp
T
2026-06-23 16:47:30 +02:00

560 lines
18 KiB
C++

#include "cli-context.h"
#include "cli-view.h"
#include "arg.h"
#include "base64.hpp"
#include "log.h"
#include "console.h"
#include <algorithm>
#include <filesystem>
#include <fstream>
#include <map>
#include <set>
std::atomic<bool> g_cli_interrupted = false;
static bool should_stop() {
return g_cli_interrupted.load();
}
static constexpr size_t FILE_GLOB_MAX_RESULTS = 100;
const char * LLAMA_ASCII_LOGO = R"(
)";
// number of values an arg consumes on the command line
static int arg_num_values(const common_arg & opt) {
if (opt.value_hint_2 != nullptr) {
return 2;
}
if (opt.value_hint != nullptr) {
return 1;
}
return 0;
}
static std::string format_error_message(const json & err) {
if (err.contains("error") && err.at("error").is_object()) {
const auto & e = err.at("error");
if (e.contains("message") && e.at("message").is_string()) {
return e.at("message").get<std::string>();
}
}
return err.dump();
}
static std::string media_type_from_ext(const std::string & fname) {
std::string ext = std::filesystem::path(fname).extension().string();
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
if (ext == ".wav" || ext == ".mp3") {
return "audio";
}
if (ext == ".mp4" || ext == ".avi" || ext == ".mkv" || ext == ".mov" || ext == ".webm") {
return "video";
}
return "image";
}
bool cli_context::init() {
view::init(params);
std::optional<view::spinner> spinner;
bool use_external_server = !params.server_base.empty();
if (use_external_server) {
std::string base = params.server_base;
while (!base.empty() && base.back() == '/') {
base.pop_back();
}
client.server_base = base;
spinner.emplace("Connecting to server at " + base);
} else {
if (params.model.path.empty() && params.model.url.empty() &&
params.model.hf_repo.empty() && params.model.docker_repo.empty()) {
view::show_error(
"no model specified",
"use -m <file.gguf> or -hf <user/repo> to run a local model,\n"
"or --server-base <url> to connect to a running llama-server"
);
return false;
}
spinner.emplace("\n\nLoading model...");
server.emplace();
if (!server->start(params)) {
view::show_error("server start failed");
return false;
}
if (!server->wait_ready(should_stop)) {
if (!should_stop()) {
view::show_error("the server exited before becoming ready");
}
return false;
}
client.server_base = server->address();
}
// for --server-base this is the main availability check; for a spawned
// server it is a cheap sanity check on top of the ready signal
auto is_aborted = [this]() {
return should_stop() || (server && !server->alive());
};
bool healthy = false;
try {
healthy = client.wait_health(is_aborted);
} catch (const std::exception & e) {
client.last_error = e.what();
}
if (!healthy) {
if (!should_stop()) {
view::show_error(client.last_error);
}
return false;
}
if (use_external_server) {
spinner.reset();
if (!list_and_ask_models()) {
return false;
}
// restore the spinner for the next step
spinner.emplace("Waiting for server...");
}
fetch_server_props();
return true;
}
void cli_context::fetch_server_props() {
try {
json props = client.get_props();
model_name = props.value("model_alias", "");
if (model_name.empty()) {
const std::string path = props.value("model_path", "");
if (!path.empty()) {
model_name = std::filesystem::path(path).filename().string();
}
}
build_info = props.value("build_info", "");
if (props.contains("modalities") && props.at("modalities").is_object()) {
const auto & modalities = props.at("modalities");
has_vision = modalities.value("vision", false);
has_audio = modalities.value("audio", false);
has_video = modalities.value("video", false);
}
} catch (const std::exception & e) {
// /props can be disabled on remote servers; not fatal
LOG_DBG("failed to fetch /props: %s\n", e.what());
}
}
bool cli_context::list_and_ask_models() {
auto models = client.list_models();
// only one model: use it without asking
if (models.size() == 1) {
model_name = models[0];
client.model = model_name;
return true;
}
std::string message = "\nAvailable models:";
if (!models.empty()) {
for (size_t i = 0; i < models.size(); ++i) {
message += "\n " + std::to_string(i + 1) + ". " + models[i];
}
}
message += "\n";
view::show_message(message);
std::string selection;
while (selection.empty()) {
if (should_stop()) {
return false;
}
view::user_turn user_turn;
selection = user_turn.read_input(false, "Select model by number: ");
if (selection.empty()) {
continue;
}
try {
size_t idx = std::stoul(selection);
if (idx > 0 && idx <= models.size()) {
model_name = models[idx - 1];
client.model = model_name;
view::show_message("Selected model: " + model_name);
break;
}
} catch (...) {
// ignore
}
view::show_error("Invalid selection. Please enter a valid number.");
selection.clear();
continue;
}
return true;
}
void cli_context::add_system_prompt() {
if (!params.system_prompt.empty()) {
messages.push_back({
{"role", "system"},
{"content", params.system_prompt}
});
}
}
void cli_context::push_user_message(const std::string & text) {
json content;
if (pending_media.empty()) {
content = text;
} else {
// multimodal message: media parts first, then the text
content = pending_media;
content.push_back({
{"type", "text"},
{"text", text}
});
pending_media = json::array();
}
messages.push_back({
{"role", "user"},
{"content", content}
});
}
bool cli_context::stage_media_file(const std::string & fname, const std::string & type) {
std::ifstream file(fname, std::ios::binary);
if (!file) {
return false;
}
std::string data((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
std::string encoded = base64::encode(data);
if (type == "audio") {
std::string ext = std::filesystem::path(fname).extension().string();
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
pending_media.push_back({
{"type", "input_audio"},
{"input_audio", {
{"data", encoded},
{"format", ext == ".mp3" ? "mp3" : "wav"}
}}
});
} else if (type == "video") {
pending_media.push_back({
{"type", "input_video"},
{"input_video", {
{"data", encoded}
}}
});
} else {
// the server detects the actual image type from the data
pending_media.push_back({
{"type", "image_url"},
{"image_url", {
{"url", "data:image/unknown;base64," + encoded}
}}
});
}
return true;
}
bool cli_context::generate_completion(std::string & assistant_content, cli_timings & timings) {
json body = {
{"messages", messages},
{"stream", true},
// in order to get timings even when we cancel mid-way
{"timings_per_token", true},
};
bool stream_error = false;
view::assistant_turn a;
json err = client.create_chat_completion(body, should_stop, [&](const json & chunk) {
if (chunk.contains("error")) {
stream_error = true;
view::show_error(format_error_message(chunk));
return;
}
if (chunk.contains("timings")) {
const auto & t = chunk.at("timings");
timings.prompt_per_second = t.value("prompt_per_second", 0.0);
timings.predicted_per_second = t.value("predicted_per_second", 0.0);
}
if (!chunk.contains("choices") || !chunk.at("choices").is_array() || chunk.at("choices").empty()) {
return;
}
const auto & choice = chunk.at("choices").at(0);
if (!choice.contains("delta")) {
return;
}
const auto & delta = choice.at("delta");
if (delta.contains("reasoning_content") && delta.at("reasoning_content").is_string()) {
const std::string text = delta.at("reasoning_content").get<std::string>();
if (!text.empty()) {
a.push(view::ASSISTANT_DISPLAY_MODE_REASONING, text);
}
}
if (delta.contains("content") && delta.at("content").is_string()) {
const std::string text = delta.at("content").get<std::string>();
if (!text.empty()) {
assistant_content += text;
a.push(view::ASSISTANT_DISPLAY_MODE_CONTENT, text);
}
}
});
g_cli_interrupted.store(false);
if (!err.is_null()) {
view::show_error(format_error_message(err));
return false;
}
return !stream_error;
}
int cli_context::run() {
add_system_prompt();
std::string modalities = "text";
if (has_vision) {
modalities += ", vision";
}
if (has_audio) {
modalities += ", audio";
}
if (has_video) {
modalities += ", video";
}
std::string banner;
banner += "\n";
banner += LLAMA_ASCII_LOGO;
banner += "\n";
banner += "build : " + build_info + "\n";
banner += "model : " + model_name + "\n";
banner += "modalities : " + modalities + "\n";
if (!params.system_prompt.empty()) {
banner += "using custom system prompt\n";
}
banner += "\n";
banner += "available commands:\n";
banner += " /exit or Ctrl+C stop or exit\n";
banner += " /regen regenerate the last response\n";
banner += " /clear clear the chat history\n";
banner += " /read <file> add a text file\n";
banner += " /glob <pattern> add text files using globbing pattern\n";
if (has_vision) {
banner += " /image <file> add an image file\n";
}
if (has_audio) {
banner += " /audio <file> add an audio file\n";
}
if (has_video) {
banner += " /video <file> add a video file\n";
}
banner += "\n";
view::show_message(banner);
// interactive loop
std::string cur_msg;
auto add_text_file = [&](const std::string & fname) -> bool {
std::ifstream file(fname, std::ios::binary);
if (!file) {
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
return false;
}
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
cur_msg += "--- File: ";
cur_msg += fname;
cur_msg += " ---\n";
cur_msg += content;
view::show_message(string_format("Loaded text from '%s'", fname.c_str()));
return true;
};
while (true) {
std::string buffer;
{
view::user_turn user_turn;
if (params.prompt.empty()) {
buffer = user_turn.read_input(params.multiline_input);
} else {
// process input prompt from args
for (auto & fname : params.image) {
if (!stage_media_file(fname, media_type_from_ext(fname))) {
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
break;
}
view::show_message(string_format("Loaded media from '%s'", fname.c_str()));
}
buffer = params.prompt;
user_turn.echo(buffer);
params.prompt.clear(); // only use it once
}
}
if (should_stop()) {
g_cli_interrupted.store(false);
break;
}
// remove trailing newline
if (!buffer.empty() && buffer.back() == '\n') {
buffer.pop_back();
}
// skip empty messages
if (buffer.empty()) {
continue;
}
bool add_user_msg = true;
// process commands
if (string_starts_with(buffer, "/exit")) {
break;
} else if (string_starts_with(buffer, "/regen")) {
if (messages.size() >= 2) {
size_t last_idx = messages.size() - 1;
messages.erase(last_idx);
add_user_msg = false;
} else {
view::show_error("No message to regenerate.");
continue;
}
} else if (string_starts_with(buffer, "/clear")) {
messages.clear();
add_system_prompt();
pending_media = json::array();
view::show_message("Chat history cleared.");
continue;
} else if (
(string_starts_with(buffer, "/image ") && has_vision) ||
(string_starts_with(buffer, "/audio ") && has_audio) ||
(string_starts_with(buffer, "/video ") && has_video)) {
std::string type = buffer.substr(1, 5);
// just in case (bad copy-paste for example), we strip all trailing/leading spaces
std::string fname = string_strip(buffer.substr(7));
if (!stage_media_file(fname, type)) {
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
continue;
}
view::show_message(string_format("Loaded media from '%s'", fname.c_str()));
continue;
} else if (string_starts_with(buffer, "/read ")) {
std::string fname = string_strip(buffer.substr(6));
add_text_file(fname);
continue;
} else if (string_starts_with(buffer, "/glob ")) {
std::error_code ec;
size_t count = 0;
auto curdir = std::filesystem::current_path();
std::string pattern = string_strip(buffer.substr(6));
std::filesystem::path rel_path;
auto startglob = pattern.find_first_of("![*?");
if (startglob != std::string::npos && startglob != 0) {
auto endpath = pattern.substr(0, startglob).find_last_of('/');
if (endpath != std::string::npos) {
std::string rel_pattern = pattern.substr(0, endpath);
#if !defined(_WIN32)
if (string_starts_with(rel_pattern, '~')) {
const char * home = std::getenv("HOME");
if (home && home[0]) {
rel_pattern = home + rel_pattern.substr(1);
}
}
#endif
rel_path = rel_pattern;
pattern.erase(0, endpath + 1);
curdir /= rel_path;
}
}
for (const auto & entry : std::filesystem::recursive_directory_iterator(curdir,
std::filesystem::directory_options::skip_permission_denied, ec)) {
if (!entry.is_regular_file()) {
continue;
}
std::string rel = std::filesystem::relative(entry.path(), curdir, ec).string();
if (ec) {
ec.clear();
continue;
}
std::replace(rel.begin(), rel.end(), '\\', '/');
if (!glob_match(pattern, rel)) {
continue;
}
if (!add_text_file((rel_path / rel).string())) {
continue;
}
if (++count >= FILE_GLOB_MAX_RESULTS) {
view::show_error(string_format("Maximum number of globbed files allowed (%zu) reached.", FILE_GLOB_MAX_RESULTS));
break;
}
}
continue;
} else {
// not a command
cur_msg += buffer;
}
// generate response
if (add_user_msg) {
push_user_message(cur_msg);
cur_msg.clear();
}
cli_timings timings;
std::string assistant_content;
generate_completion(assistant_content, timings);
messages.push_back({
{"role", "assistant"},
{"content", assistant_content}
});
if (params.show_timings) {
view::show_info(string_format(
"\n[ Prompt: %.1f t/s | Generation: %.1f t/s ]",
timings.prompt_per_second,
timings.predicted_per_second
));
}
if (params.single_turn) {
break;
}
}
view::show_message("\n\nExiting...");
return 0;
}
void cli_context::shutdown() {
if (server) {
server->stop();
server.reset();
}
}