Compare commits

...

7 Commits

Author SHA1 Message Date
Xuan Son Nguyen b093e46873 case: router with only one model 2026-06-23 16:47:30 +02:00
Xuan Son Nguyen 1401fc3ca7 cli support router mode
Co-authored-by: Piotr Wilkin <ilintar@gmail.com>
2026-06-23 16:43:58 +02:00
Xuan Son Nguyen 85c58bbcd0 remote server ok 2026-06-23 16:19:28 +02:00
Xuan Son Nguyen 19296c1735 working 2026-06-23 16:09:09 +02:00
Xuan Son Nguyen 90c111bf98 Merge branch 'master' into xsn/cli_http_based 2026-06-23 13:29:22 +02:00
Xuan Son Nguyen f7421eabe8 wip 2026-06-23 13:28:14 +02:00
Xuan Son Nguyen 59797670dc cli: move to HTTP-based implementation 2026-06-23 13:14:28 +02:00
13 changed files with 1313 additions and 712 deletions
+9 -3
View File
@@ -603,9 +603,8 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
// model is required (except for server)
// TODO @ngxson : maybe show a list of available models in CLI in this case
if (params.model.path.empty()
&& !params.usage
&& !params.completion) {
bool can_skip_model = params.usage || params.completion || !params.server_base.empty();
if (!can_skip_model && params.model.path.empty()) {
throw std::invalid_argument("error: --model is required\n");
}
}
@@ -1119,6 +1118,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.completion = true;
}
));
add_opt(common_arg(
{"--server-base"}, "URL",
string_format("connect to this server instead of starting a new one, example: 'http://localhost:8080' (default: none)"),
[](common_params & params, const std::string & value) {
params.server_base = value;
}
).set_examples({LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--verbose-prompt"},
string_format("print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false"),
+3
View File
@@ -631,6 +631,9 @@ struct common_params {
std::map<std::string, std::string> default_template_kwargs;
// CLI params
std::string server_base; // if set, connect to this server instead of starting a new one
// UI configs
bool ui = true;
bool ui_mcp_proxy = false;
+70
View File
@@ -2,6 +2,16 @@
#include <cpp-httplib/httplib.h>
#ifdef _WIN32
#include <winsock2.h>
#include <windows.h>
#else
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <unistd.h>
#endif
struct common_http_url {
std::string scheme;
std::string user;
@@ -97,3 +107,63 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
static std::string common_http_show_masked_url(const common_http_url & parts) {
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path;
}
static int common_http_get_free_port() {
#ifdef _WIN32
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
return -1;
}
typedef SOCKET native_socket_t;
#define INVALID_SOCKET_VAL INVALID_SOCKET
#define CLOSE_SOCKET(s) closesocket(s)
#else
typedef int native_socket_t;
#define INVALID_SOCKET_VAL -1
#define CLOSE_SOCKET(s) close(s)
#endif
native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock == INVALID_SOCKET_VAL) {
#ifdef _WIN32
WSACleanup();
#endif
return -1;
}
struct sockaddr_in serv_addr;
std::memset(&serv_addr, 0, sizeof(serv_addr));
serv_addr.sin_family = AF_INET;
serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
serv_addr.sin_port = htons(0);
if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
CLOSE_SOCKET(sock);
#ifdef _WIN32
WSACleanup();
#endif
return -1;
}
#ifdef _WIN32
int namelen = sizeof(serv_addr);
#else
socklen_t namelen = sizeof(serv_addr);
#endif
if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
CLOSE_SOCKET(sock);
#ifdef _WIN32
WSACleanup();
#endif
return -1;
}
int port = ntohs(serv_addr.sin_port);
CLOSE_SOCKET(sock);
#ifdef _WIN32
WSACleanup();
#endif
return port;
}
+4 -2
View File
@@ -2,11 +2,13 @@
set(TARGET llama-cli-impl)
add_library(${TARGET} cli.cpp)
add_library(${TARGET} cli.cpp
cli-client.cpp
cli-context.cpp)
set_target_properties(${TARGET} PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS ON)
target_include_directories(${TARGET} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} ../server)
target_link_libraries(${TARGET} PUBLIC server-context llama-common ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PUBLIC llama-server-impl llama-common ${CMAKE_THREAD_LIBS_INIT})
if(LLAMA_TOOLS_INSTALL)
install(TARGETS ${TARGET} LIBRARY)
+164
View File
@@ -0,0 +1,164 @@
#include "cli-client.h"
#include "http.h"
#include <algorithm>
#include <chrono>
#include <thread>
// generation can stall for a long time during prompt processing, so the
// read timeout must be generous
static constexpr time_t CLI_HTTP_READ_TIMEOUT_SEC = 3600;
// upper bound for the accumulated response body kept for error reporting
static constexpr size_t CLI_HTTP_MAX_ERROR_BODY = 1024 * 1024;
// returns the path with the base url's path prefix prepended (if any)
static std::string join_path(const common_http_url & parts, const std::string & path) {
if (parts.path.empty() || parts.path == "/") {
return path;
}
std::string prefix = parts.path;
if (prefix.back() == '/') {
prefix.pop_back();
}
return prefix + path;
}
json cli_client::get(const std::string & path) {
auto [cli, parts] = common_http_client(server_base);
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
auto path_with_model = path + (model.empty() ? "" : ("?model=" + model));
auto res = cli.Get(join_path(parts, path_with_model));
if (!res) {
throw std::runtime_error("failed to connect to " + server_base + ": " + httplib::to_string(res.error()));
}
if (res->status < 200 || res->status >= 300) {
throw std::runtime_error("GET " + path + " failed with status " + std::to_string(res->status) + ": " + res->body);
}
json result = json::parse(res->body, nullptr, false);
if (result.is_discarded()) {
throw std::runtime_error("GET " + path + " returned invalid JSON");
}
return result;
}
json cli_client::post(const std::string & path, const json & body) {
auto [cli, parts] = common_http_client(server_base);
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
auto body_with_model = body;
if (!model.empty()) {
body_with_model["model"] = model;
}
auto res = cli.Post(join_path(parts, path), body_with_model.dump(), "application/json");
if (!res) {
throw std::runtime_error("failed to connect to " + server_base + ": " + httplib::to_string(res.error()));
}
if (res->status < 200 || res->status >= 300) {
throw std::runtime_error("POST " + path + " failed with status " + std::to_string(res->status) + ": " + res->body);
}
json result = json::parse(res->body, nullptr, false);
if (result.is_discarded()) {
throw std::runtime_error("POST " + path + " returned invalid JSON");
}
return result;
}
json cli_client::post_sse(const std::string & path,
const json & body,
const std::function<bool()> & should_stop,
const std::function<void(const json &)> & on_data) {
auto [cli, parts] = common_http_client(server_base);
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
std::string pending; // buffer for incomplete SSE lines
std::string raw_body; // accumulated body, used only for error reporting
auto receiver = [&](const char * data, size_t len) -> bool {
if (should_stop()) {
return false; // aborts the request
}
if (raw_body.size() < CLI_HTTP_MAX_ERROR_BODY) {
raw_body.append(data, std::min(len, CLI_HTTP_MAX_ERROR_BODY - raw_body.size()));
}
pending.append(data, len);
size_t pos;
while ((pos = pending.find('\n')) != std::string::npos) {
std::string line = pending.substr(0, pos);
pending.erase(0, pos + 1);
if (!line.empty() && line.back() == '\r') {
line.pop_back();
}
if (line.rfind("data: ", 0) != 0) {
continue;
}
std::string payload = line.substr(6);
if (payload == "[DONE]") {
continue;
}
json event = json::parse(payload, nullptr, false);
if (!event.is_discarded()) {
on_data(event);
}
}
return true;
};
httplib::Headers headers = {{"Accept", "text/event-stream"}};
auto body_with_model = body;
if (!model.empty()) {
body_with_model["model"] = model;
}
auto res = cli.Post(join_path(parts, path), headers, body_with_model.dump(), "application/json", receiver);
if (!res) {
if (res.error() == httplib::Error::Canceled && should_stop()) {
return json(); // cancelled by the user
}
return json {{"error", {{"message", "failed to connect to " + server_base + ": " + httplib::to_string(res.error())}}}};
}
if (res->status < 200 || res->status >= 300) {
json error_body = json::parse(raw_body, nullptr, false);
if (!error_body.is_discarded() && error_body.contains("error")) {
return error_body;
}
return json {{"error", {{"message", "request failed with status " + std::to_string(res->status)}}}};
}
return json();
}
bool cli_client::wait_health(const std::function<bool()> & is_aborted) {
int connect_attempts = 0;
while (!is_aborted()) {
auto [cli, parts] = common_http_client(server_base);
cli.set_connection_timeout(1, 0);
auto res = cli.Get(join_path(parts, "/health"));
if (res) {
if (res->status == 200) {
return true;
}
// any other status means the server is up but not ready yet
// (e.g. 503 while the model is still loading)
} else if (++connect_attempts >= 10) {
last_error = "failed to connect to " + server_base + ": " + httplib::to_string(res.error());
return false;
}
std::this_thread::sleep_for(std::chrono::milliseconds(300));
}
last_error = "aborted while waiting for the server to become ready";
return false;
}
std::vector<std::string> cli_client::list_models() {
json resp = get("/v1/models");
if (!resp.contains("data") || !resp.at("data").is_array()) {
throw std::runtime_error("invalid response from /v1/models");
}
std::vector<std::string> models;
for (const auto & m : resp.at("data")) {
if (m.contains("id") && m.at("id").is_string()) {
models.push_back(m.at("id").get<std::string>());
}
}
return models;
}
+56
View File
@@ -0,0 +1,56 @@
#pragma once
#include "ggml.h"
#define JSON_ASSERT GGML_ASSERT
#include <nlohmann/json.hpp>
#include <functional>
#include <string>
using json = nlohmann::ordered_json;
// openai-like client for CLI
struct cli_client {
std::string server_base; // base url, for example "http://127.0.0.1:8080"
std::string last_error; // set when wait_health() fails
std::string model; // optional, set when the server has multiple models (router mode)
// simple GET request, returns the response json
// throws std::runtime_error on transport error or non-2xx status
json get(const std::string & path);
// simple POST request, returns the response json
// throws std::runtime_error on transport error or non-2xx status
json post(const std::string & path, const json & body);
// POST request with an SSE streaming response; on_data is invoked once
// per "data:" event; the function returns after the stream is finished:
// a null json on graceful exit (incl. cancellation via should_stop),
// the error response json otherwise
json post_sse(const std::string & path,
const json & body,
const std::function<bool()> & should_stop,
const std::function<void(const json &)> & on_data);
// poll /health until the server is ready to accept requests
// returns false if is_aborted returned true or the server is unreachable
bool wait_health(const std::function<bool()> & is_aborted);
//
// higher-level wrappers
//
json create_chat_completion(const json & request,
const std::function<bool()> & should_stop,
const std::function<void(const json &)> & on_data) {
return post_sse("/v1/chat/completions", request, should_stop, on_data);
}
json get_props() {
return get("/props");
}
std::vector<std::string> list_models();
};
+559
View File
@@ -0,0 +1,559 @@
#include "cli-context.h"
#include "cli-view.h"
#include "arg.h"
#include "base64.hpp"
#include "log.h"
#include "console.h"
#include <algorithm>
#include <filesystem>
#include <fstream>
#include <map>
#include <set>
std::atomic<bool> g_cli_interrupted = false;
static bool should_stop() {
return g_cli_interrupted.load();
}
static constexpr size_t FILE_GLOB_MAX_RESULTS = 100;
const char * LLAMA_ASCII_LOGO = R"(
)";
// number of values an arg consumes on the command line
static int arg_num_values(const common_arg & opt) {
if (opt.value_hint_2 != nullptr) {
return 2;
}
if (opt.value_hint != nullptr) {
return 1;
}
return 0;
}
static std::string format_error_message(const json & err) {
if (err.contains("error") && err.at("error").is_object()) {
const auto & e = err.at("error");
if (e.contains("message") && e.at("message").is_string()) {
return e.at("message").get<std::string>();
}
}
return err.dump();
}
static std::string media_type_from_ext(const std::string & fname) {
std::string ext = std::filesystem::path(fname).extension().string();
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
if (ext == ".wav" || ext == ".mp3") {
return "audio";
}
if (ext == ".mp4" || ext == ".avi" || ext == ".mkv" || ext == ".mov" || ext == ".webm") {
return "video";
}
return "image";
}
bool cli_context::init() {
view::init(params);
std::optional<view::spinner> spinner;
bool use_external_server = !params.server_base.empty();
if (use_external_server) {
std::string base = params.server_base;
while (!base.empty() && base.back() == '/') {
base.pop_back();
}
client.server_base = base;
spinner.emplace("Connecting to server at " + base);
} else {
if (params.model.path.empty() && params.model.url.empty() &&
params.model.hf_repo.empty() && params.model.docker_repo.empty()) {
view::show_error(
"no model specified",
"use -m <file.gguf> or -hf <user/repo> to run a local model,\n"
"or --server-base <url> to connect to a running llama-server"
);
return false;
}
spinner.emplace("\n\nLoading model...");
server.emplace();
if (!server->start(params)) {
view::show_error("server start failed");
return false;
}
if (!server->wait_ready(should_stop)) {
if (!should_stop()) {
view::show_error("the server exited before becoming ready");
}
return false;
}
client.server_base = server->address();
}
// for --server-base this is the main availability check; for a spawned
// server it is a cheap sanity check on top of the ready signal
auto is_aborted = [this]() {
return should_stop() || (server && !server->alive());
};
bool healthy = false;
try {
healthy = client.wait_health(is_aborted);
} catch (const std::exception & e) {
client.last_error = e.what();
}
if (!healthy) {
if (!should_stop()) {
view::show_error(client.last_error);
}
return false;
}
if (use_external_server) {
spinner.reset();
if (!list_and_ask_models()) {
return false;
}
// restore the spinner for the next step
spinner.emplace("Waiting for server...");
}
fetch_server_props();
return true;
}
void cli_context::fetch_server_props() {
try {
json props = client.get_props();
model_name = props.value("model_alias", "");
if (model_name.empty()) {
const std::string path = props.value("model_path", "");
if (!path.empty()) {
model_name = std::filesystem::path(path).filename().string();
}
}
build_info = props.value("build_info", "");
if (props.contains("modalities") && props.at("modalities").is_object()) {
const auto & modalities = props.at("modalities");
has_vision = modalities.value("vision", false);
has_audio = modalities.value("audio", false);
has_video = modalities.value("video", false);
}
} catch (const std::exception & e) {
// /props can be disabled on remote servers; not fatal
LOG_DBG("failed to fetch /props: %s\n", e.what());
}
}
bool cli_context::list_and_ask_models() {
auto models = client.list_models();
// only one model: use it without asking
if (models.size() == 1) {
model_name = models[0];
client.model = model_name;
return true;
}
std::string message = "\nAvailable models:";
if (!models.empty()) {
for (size_t i = 0; i < models.size(); ++i) {
message += "\n " + std::to_string(i + 1) + ". " + models[i];
}
}
message += "\n";
view::show_message(message);
std::string selection;
while (selection.empty()) {
if (should_stop()) {
return false;
}
view::user_turn user_turn;
selection = user_turn.read_input(false, "Select model by number: ");
if (selection.empty()) {
continue;
}
try {
size_t idx = std::stoul(selection);
if (idx > 0 && idx <= models.size()) {
model_name = models[idx - 1];
client.model = model_name;
view::show_message("Selected model: " + model_name);
break;
}
} catch (...) {
// ignore
}
view::show_error("Invalid selection. Please enter a valid number.");
selection.clear();
continue;
}
return true;
}
void cli_context::add_system_prompt() {
if (!params.system_prompt.empty()) {
messages.push_back({
{"role", "system"},
{"content", params.system_prompt}
});
}
}
void cli_context::push_user_message(const std::string & text) {
json content;
if (pending_media.empty()) {
content = text;
} else {
// multimodal message: media parts first, then the text
content = pending_media;
content.push_back({
{"type", "text"},
{"text", text}
});
pending_media = json::array();
}
messages.push_back({
{"role", "user"},
{"content", content}
});
}
bool cli_context::stage_media_file(const std::string & fname, const std::string & type) {
std::ifstream file(fname, std::ios::binary);
if (!file) {
return false;
}
std::string data((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
std::string encoded = base64::encode(data);
if (type == "audio") {
std::string ext = std::filesystem::path(fname).extension().string();
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
pending_media.push_back({
{"type", "input_audio"},
{"input_audio", {
{"data", encoded},
{"format", ext == ".mp3" ? "mp3" : "wav"}
}}
});
} else if (type == "video") {
pending_media.push_back({
{"type", "input_video"},
{"input_video", {
{"data", encoded}
}}
});
} else {
// the server detects the actual image type from the data
pending_media.push_back({
{"type", "image_url"},
{"image_url", {
{"url", "data:image/unknown;base64," + encoded}
}}
});
}
return true;
}
bool cli_context::generate_completion(std::string & assistant_content, cli_timings & timings) {
json body = {
{"messages", messages},
{"stream", true},
// in order to get timings even when we cancel mid-way
{"timings_per_token", true},
};
bool stream_error = false;
view::assistant_turn a;
json err = client.create_chat_completion(body, should_stop, [&](const json & chunk) {
if (chunk.contains("error")) {
stream_error = true;
view::show_error(format_error_message(chunk));
return;
}
if (chunk.contains("timings")) {
const auto & t = chunk.at("timings");
timings.prompt_per_second = t.value("prompt_per_second", 0.0);
timings.predicted_per_second = t.value("predicted_per_second", 0.0);
}
if (!chunk.contains("choices") || !chunk.at("choices").is_array() || chunk.at("choices").empty()) {
return;
}
const auto & choice = chunk.at("choices").at(0);
if (!choice.contains("delta")) {
return;
}
const auto & delta = choice.at("delta");
if (delta.contains("reasoning_content") && delta.at("reasoning_content").is_string()) {
const std::string text = delta.at("reasoning_content").get<std::string>();
if (!text.empty()) {
a.push(view::ASSISTANT_DISPLAY_MODE_REASONING, text);
}
}
if (delta.contains("content") && delta.at("content").is_string()) {
const std::string text = delta.at("content").get<std::string>();
if (!text.empty()) {
assistant_content += text;
a.push(view::ASSISTANT_DISPLAY_MODE_CONTENT, text);
}
}
});
g_cli_interrupted.store(false);
if (!err.is_null()) {
view::show_error(format_error_message(err));
return false;
}
return !stream_error;
}
int cli_context::run() {
add_system_prompt();
std::string modalities = "text";
if (has_vision) {
modalities += ", vision";
}
if (has_audio) {
modalities += ", audio";
}
if (has_video) {
modalities += ", video";
}
std::string banner;
banner += "\n";
banner += LLAMA_ASCII_LOGO;
banner += "\n";
banner += "build : " + build_info + "\n";
banner += "model : " + model_name + "\n";
banner += "modalities : " + modalities + "\n";
if (!params.system_prompt.empty()) {
banner += "using custom system prompt\n";
}
banner += "\n";
banner += "available commands:\n";
banner += " /exit or Ctrl+C stop or exit\n";
banner += " /regen regenerate the last response\n";
banner += " /clear clear the chat history\n";
banner += " /read <file> add a text file\n";
banner += " /glob <pattern> add text files using globbing pattern\n";
if (has_vision) {
banner += " /image <file> add an image file\n";
}
if (has_audio) {
banner += " /audio <file> add an audio file\n";
}
if (has_video) {
banner += " /video <file> add a video file\n";
}
banner += "\n";
view::show_message(banner);
// interactive loop
std::string cur_msg;
auto add_text_file = [&](const std::string & fname) -> bool {
std::ifstream file(fname, std::ios::binary);
if (!file) {
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
return false;
}
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
cur_msg += "--- File: ";
cur_msg += fname;
cur_msg += " ---\n";
cur_msg += content;
view::show_message(string_format("Loaded text from '%s'", fname.c_str()));
return true;
};
while (true) {
std::string buffer;
{
view::user_turn user_turn;
if (params.prompt.empty()) {
buffer = user_turn.read_input(params.multiline_input);
} else {
// process input prompt from args
for (auto & fname : params.image) {
if (!stage_media_file(fname, media_type_from_ext(fname))) {
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
break;
}
view::show_message(string_format("Loaded media from '%s'", fname.c_str()));
}
buffer = params.prompt;
user_turn.echo(buffer);
params.prompt.clear(); // only use it once
}
}
if (should_stop()) {
g_cli_interrupted.store(false);
break;
}
// remove trailing newline
if (!buffer.empty() && buffer.back() == '\n') {
buffer.pop_back();
}
// skip empty messages
if (buffer.empty()) {
continue;
}
bool add_user_msg = true;
// process commands
if (string_starts_with(buffer, "/exit")) {
break;
} else if (string_starts_with(buffer, "/regen")) {
if (messages.size() >= 2) {
size_t last_idx = messages.size() - 1;
messages.erase(last_idx);
add_user_msg = false;
} else {
view::show_error("No message to regenerate.");
continue;
}
} else if (string_starts_with(buffer, "/clear")) {
messages.clear();
add_system_prompt();
pending_media = json::array();
view::show_message("Chat history cleared.");
continue;
} else if (
(string_starts_with(buffer, "/image ") && has_vision) ||
(string_starts_with(buffer, "/audio ") && has_audio) ||
(string_starts_with(buffer, "/video ") && has_video)) {
std::string type = buffer.substr(1, 5);
// just in case (bad copy-paste for example), we strip all trailing/leading spaces
std::string fname = string_strip(buffer.substr(7));
if (!stage_media_file(fname, type)) {
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
continue;
}
view::show_message(string_format("Loaded media from '%s'", fname.c_str()));
continue;
} else if (string_starts_with(buffer, "/read ")) {
std::string fname = string_strip(buffer.substr(6));
add_text_file(fname);
continue;
} else if (string_starts_with(buffer, "/glob ")) {
std::error_code ec;
size_t count = 0;
auto curdir = std::filesystem::current_path();
std::string pattern = string_strip(buffer.substr(6));
std::filesystem::path rel_path;
auto startglob = pattern.find_first_of("![*?");
if (startglob != std::string::npos && startglob != 0) {
auto endpath = pattern.substr(0, startglob).find_last_of('/');
if (endpath != std::string::npos) {
std::string rel_pattern = pattern.substr(0, endpath);
#if !defined(_WIN32)
if (string_starts_with(rel_pattern, '~')) {
const char * home = std::getenv("HOME");
if (home && home[0]) {
rel_pattern = home + rel_pattern.substr(1);
}
}
#endif
rel_path = rel_pattern;
pattern.erase(0, endpath + 1);
curdir /= rel_path;
}
}
for (const auto & entry : std::filesystem::recursive_directory_iterator(curdir,
std::filesystem::directory_options::skip_permission_denied, ec)) {
if (!entry.is_regular_file()) {
continue;
}
std::string rel = std::filesystem::relative(entry.path(), curdir, ec).string();
if (ec) {
ec.clear();
continue;
}
std::replace(rel.begin(), rel.end(), '\\', '/');
if (!glob_match(pattern, rel)) {
continue;
}
if (!add_text_file((rel_path / rel).string())) {
continue;
}
if (++count >= FILE_GLOB_MAX_RESULTS) {
view::show_error(string_format("Maximum number of globbed files allowed (%zu) reached.", FILE_GLOB_MAX_RESULTS));
break;
}
}
continue;
} else {
// not a command
cur_msg += buffer;
}
// generate response
if (add_user_msg) {
push_user_message(cur_msg);
cur_msg.clear();
}
cli_timings timings;
std::string assistant_content;
generate_completion(assistant_content, timings);
messages.push_back({
{"role", "assistant"},
{"content", assistant_content}
});
if (params.show_timings) {
view::show_info(string_format(
"\n[ Prompt: %.1f t/s | Generation: %.1f t/s ]",
timings.prompt_per_second,
timings.predicted_per_second
));
}
if (params.single_turn) {
break;
}
}
view::show_message("\n\nExiting...");
return 0;
}
void cli_context::shutdown() {
if (server) {
server->stop();
server.reset();
}
}
+62
View File
@@ -0,0 +1,62 @@
#pragma once
#include "common.h"
#include "cli-client.h"
#include "cli-server.h"
#include <atomic>
#include <optional>
#include <string>
struct cli_timings {
double prompt_per_second = 0.0;
double predicted_per_second = 0.0;
};
// set by the SIGINT handler; cleared once the interrupt has been handled
extern std::atomic<bool> g_cli_interrupted;
struct cli_context {
common_params params;
cli_client client; // always initialized
std::optional<cli_server> server; // only set when no --server-base is given
json messages = json::array();
json pending_media = json::array(); // staged multimodal content parts
// properties of the connected server
// will be populated by fetch_server_props()
std::string model_name;
std::string build_info;
bool has_vision = false;
bool has_audio = false;
bool has_video = false;
cli_context(const common_params & params) : params(params) {}
// connect to --server-base or spawn a local llama-server child;
// argc/argv are needed to forward the server-relevant args to the child
bool init();
// run the interactive chat loop, returns the process exit code
int run();
// stop the local server child (if any)
void shutdown();
private:
bool generate_completion(std::string & assistant_content, cli_timings & timings);
void fetch_server_props();
void add_system_prompt();
void push_user_message(const std::string & text);
// check if server have multiple models (router mode)
// if yes, list them then ask; do nothing otherwise
bool list_and_ask_models();
// read a file and stage it as a multimodal content part; type is one of
// "image", "audio", "video"; returns false if the file cannot be read
bool stage_media_file(const std::string & fname, const std::string & type);
};
+85
View File
@@ -0,0 +1,85 @@
#pragma once
#include <thread>
#include "http.h"
// spawn llama-server in a thread and interact with it via a random port
// note: in the future, we may have a server running as daemon and the CLI can connect to it automatically
// llama_server will be available as a dynamic library symbol
int llama_server(common_params & params, int argc, char ** argv);
void llama_server_terminate();
struct cli_server {
std::thread th;
int port = -1;
std::atomic<bool> is_alive = false;
std::atomic<bool> is_stopping = false;
~cli_server() {
stop();
}
void stop() {
if (alive() && !is_stopping.exchange(true)) {
llama_server_terminate();
th.join();
}
}
bool start(common_params & params) {
port = common_http_get_free_port();
if (port <= 0) {
fprintf(stderr, "failed to get a free port\n");
exit(1);
}
is_alive.store(true, std::memory_order_release);
th = std::thread([&]() {
common_params server_params = params; // copy
server_params.port = port;
// argc / argv are only used in router mode, we can skip them for now
int res = llama_server(server_params, 0, nullptr);
if (res != 0) {
fprintf(stderr, "llama_server exited with code %d\n", res);
}
is_alive.store(false, std::memory_order_release);
});
return true;
}
std::string address() const {
return "http://127.0.0.1:" + std::to_string(port);
}
bool wait_ready(std::function<bool()> should_stop) {
if (!alive()) {
return false;
}
while (!should_stop()) {
auto [cli, parts] = common_http_client(address());
cli.set_connection_timeout(1, 0);
auto res = cli.Get("/health");
if (res) {
if (res->status == 200) {
return true;
}
// any other status means the server is up but not ready yet
// (e.g. 503 while the model is still loading)
}
if (!alive()) {
// in case server die permanently
return false;
}
std::this_thread::sleep_for(std::chrono::milliseconds(200));
}
return true;
}
bool alive() const {
return is_alive.load(std::memory_order_acquire);
}
};
+250
View File
@@ -0,0 +1,250 @@
#pragma once
#include "common.h"
#include "console.h"
#include <array>
#include <algorithm>
#include <filesystem>
#include <string_view>
// TODO?: Make this reusable, enums, docs
static const std::array<std::string_view, 8> cmds = {
"/audio ",
"/clear",
"/exit",
"/glob ",
"/image ",
"/read ",
"/regen",
"/video ",
};
static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std::string_view line, size_t cursor_byte_pos) {
std::vector<std::pair<std::string, size_t>> matches;
std::string cmd;
if (line.length() > 1 && line.front() == '/' && !std::any_of(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
return string_starts_with(line, prefix);
})) {
auto it = cmds.begin();
while ((it = std::find_if(it, cmds.end(), [line](std::string_view cmd_line) {
return string_starts_with(cmd_line, line);
})) != cmds.end()) {
matches.emplace_back(*it, it->length());
++it;
}
} else {
auto it = std::find_if(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
return prefix.back() == ' ' && string_starts_with(line, prefix);
});
if (it != cmds.end()) {
cmd = *it;
}
}
if (!cmd.empty() && cmd != "/glob " && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) {
const std::string path_prefix = std::string(line.substr(cmd.length(), cursor_byte_pos - cmd.length()));
const std::string path_postfix = std::string(line.substr(cursor_byte_pos));
auto cur_dir = std::filesystem::current_path();
std::string cur_dir_str = cur_dir.string();
std::string expanded_prefix = path_prefix;
#if !defined(_WIN32)
if (string_starts_with(path_prefix, '~')) {
const char * home = std::getenv("HOME");
if (home && home[0]) {
expanded_prefix = home + path_prefix.substr(1);
}
}
if (string_starts_with(expanded_prefix, '/')) {
#else
if (std::isalpha(expanded_prefix[0]) && expanded_prefix.find(':') == 1) {
#endif
cur_dir = std::filesystem::path(expanded_prefix).parent_path();
cur_dir_str.clear();
} else if (!path_prefix.empty()) {
cur_dir /= std::filesystem::path(path_prefix).parent_path();
}
std::error_code ec;
for (const auto & entry : std::filesystem::directory_iterator(cur_dir, ec)) {
if (ec) {
break;
}
if (!entry.exists(ec)) {
ec.clear();
continue;
}
const std::string path_full = entry.path().string();
std::string path_entry = !cur_dir_str.empty() && string_starts_with(path_full, cur_dir_str) ? path_full.substr(cur_dir_str.length() + 1) : path_full;
if (entry.is_directory(ec)) {
path_entry.push_back(std::filesystem::path::preferred_separator);
}
if (expanded_prefix.empty() || string_starts_with(path_entry, expanded_prefix)) {
const std::string updated_line = cmd + path_entry;
matches.emplace_back(updated_line + path_postfix, updated_line.length());
}
if (ec) {
ec.clear();
}
}
if (matches.empty()) {
const std::string updated_line = cmd + path_prefix;
matches.emplace_back(updated_line + path_postfix, updated_line.length());
}
// Add the longest common prefix
if (!expanded_prefix.empty() && matches.size() > 1) {
const std::string_view match0(matches[0].first);
const std::string_view match1(matches[1].first);
auto it = std::mismatch(match0.begin(), match0.end(), match1.begin(), match1.end());
size_t len = it.first - match0.begin();
for (size_t i = 2; i < matches.size(); ++i) {
const std::string_view matchi(matches[i].first);
auto cmp = std::mismatch(match0.begin(), match0.end(), matchi.begin(), matchi.end());
len = std::min(len, static_cast<size_t>(cmp.first - match0.begin()));
}
const std::string updated_line = std::string(match0.substr(0, len));
matches.emplace_back(updated_line + path_postfix, updated_line.length());
}
std::sort(matches.begin(), matches.end(), [](const auto & a, const auto & b) {
return a.first.compare(0, a.second, b.first, 0, b.second) < 0;
});
}
return matches;
}
// note: make this view implementation generic, so that we can move to TUI in the future if we want to
namespace view {
static void init(const common_params & params) {
// TODO: avoid using atexit() here by making `console` a singleton
console::init(params.simple_io, params.use_color);
atexit([]() { console::cleanup(); });
console::set_completion_callback(auto_completion_callback);
}
struct spinner {
spinner(const std::string & message) {
if (!message.empty()) {
console::log("%s ", message.c_str());
}
console::spinner::start();
}
~spinner() {
console::spinner::stop();
}
};
struct user_turn {
user_turn() {
console::set_display(DISPLAY_TYPE_USER_INPUT);
}
~user_turn() {
console::set_display(DISPLAY_TYPE_RESET);
}
void echo(const std::string & buffer) {
if (buffer.size() > 500) {
console::log("\n> %s ... (truncated)\n", buffer.substr(0, 500).c_str());
} else {
console::log("\n> %s\n", buffer.c_str());
}
}
std::string read_input(bool multiline_input, const char * prompt = nullptr) {
if (prompt) {
console::log("%s", prompt);
} else {
console::log("\n> ");
}
std::string buffer;
std::string line;
bool another_line = true;
do {
another_line = console::readline(line, multiline_input);
buffer += line;
} while (another_line);
return buffer;
}
};
enum assistant_display_mode {
ASSISTANT_DISPLAY_MODE_REASONING,
ASSISTANT_DISPLAY_MODE_CONTENT,
};
struct assistant_turn {
assistant_display_mode mode = ASSISTANT_DISPLAY_MODE_CONTENT;
bool trailing_newline = true;
bool is_inside_reasoning = false;
assistant_turn() {
console::set_display(DISPLAY_TYPE_RESET);
}
~assistant_turn() {
console::set_display(DISPLAY_TYPE_RESET);
add_newline_if_needed();
}
void push(assistant_display_mode m, const std::string & buffer) {
if (m != mode) {
add_newline_if_needed();
switch (m) {
case ASSISTANT_DISPLAY_MODE_CONTENT:
{
if (is_inside_reasoning) {
console::log("[End thinking]\n\n");
is_inside_reasoning = false;
}
console::set_display(DISPLAY_TYPE_RESET);
} break;
case ASSISTANT_DISPLAY_MODE_REASONING:
{
console::set_display(DISPLAY_TYPE_REASONING);
is_inside_reasoning = true;
console::log("\n[Start thinking]\n\n");
} break;
}
}
mode = m;
if (buffer.empty()) {
return;
}
trailing_newline = buffer.back() == '\n';
console::log("%s", buffer.c_str());
console::flush();
}
void add_newline_if_needed() {
if (!trailing_newline) {
console::log("\n");
console::flush();
}
}
};
static void show_error(const std::string & title, const std::string & message = "") {
console::spinner::stop();
console::error("Error: %s\n", title.c_str());
if (!message.empty()) {
console::log("%s\n", message.c_str());
}
}
static void show_message(const std::string & message) {
console::log("%s\n", message.c_str());
}
static void show_info(const std::string & message) {
console::set_display(DISPLAY_TYPE_INFO);
console::log("%s\n", message.c_str());
console::set_display(DISPLAY_TYPE_RESET);
}
}
+13 -622
View File
@@ -1,20 +1,10 @@
#include "chat.h"
#include "common.h"
#include "arg.h"
#include "console.h"
#include "fit.h"
// #include "log.h"
#include "common.h"
#include "log.h"
#include "server-common.h"
#include "server-context.h"
#include "server-task.h"
#include "cli-context.h"
#include "cli-view.h"
#include <array>
#include <atomic>
#include <algorithm>
#include <filesystem>
#include <fstream>
#include <thread>
#include <signal.h>
#if defined(_WIN32)
@@ -25,342 +15,19 @@
#include <windows.h>
#endif
const char * LLAMA_ASCII_LOGO = R"(
)";
static std::atomic<bool> g_is_interrupted = false;
static bool should_stop() {
return g_is_interrupted.load();
}
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
static void signal_handler(int) {
if (g_is_interrupted.load()) {
if (g_cli_interrupted.load()) {
// second Ctrl+C - exit immediately
// make sure to clear colors before exiting (not using LOG or console.cpp here to avoid deadlock)
fprintf(stdout, "\033[0m\n");
fflush(stdout);
std::exit(130);
}
g_is_interrupted.store(true);
g_cli_interrupted.store(true);
}
#endif
struct cli_context {
server_context ctx_server;
json messages = json::array();
std::vector<raw_buffer> input_files;
task_params defaults;
bool verbose_prompt;
// thread for showing "loading" animation
std::atomic<bool> loading_show;
cli_context(const common_params & params) {
defaults.sampling = params.sampling;
defaults.speculative = params.speculative;
defaults.n_keep = params.n_keep;
defaults.n_predict = params.n_predict;
defaults.antiprompt = params.antiprompt;
defaults.stream = true; // make sure we always use streaming mode
defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way
// defaults.return_progress = true; // TODO: show progress
verbose_prompt = params.verbose_prompt;
}
std::string generate_completion(result_timings & out_timings) {
server_response_reader rd = ctx_server.get_response_reader();
auto chat_params = format_chat();
{
// TODO: reduce some copies here in the future
server_task task = server_task(SERVER_TASK_TYPE_COMPLETION);
task.id = rd.get_new_id();
task.index = 0;
task.params = defaults; // copy
task.cli_prompt = chat_params.prompt; // copy
task.cli_files = input_files; // copy
task.cli = true;
// chat template settings
task.params.chat_parser_params = common_chat_parser_params(chat_params);
task.params.chat_parser_params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
if (!chat_params.parser.empty()) {
task.params.chat_parser_params.parser.load(chat_params.parser);
}
// Copy the preserved tokens into the sampling params
const llama_vocab * vocab = llama_model_get_vocab(
llama_get_model(ctx_server.get_llama_context()));
for (const auto & token : chat_params.preserved_tokens) {
auto ids = common_tokenize(vocab, token, false, true);
if (ids.size() == 1) {
task.params.sampling.preserved_tokens.insert(ids[0]);
}
}
// reasoning budget sampler
if (!chat_params.thinking_end_tag.empty()) {
task.params.sampling.reasoning_budget_tokens = defaults.sampling.reasoning_budget_tokens;
task.params.sampling.generation_prompt = chat_params.generation_prompt;
if (!chat_params.thinking_start_tag.empty()) {
task.params.sampling.reasoning_budget_start =
common_tokenize(vocab, chat_params.thinking_start_tag, false, true);
}
task.params.sampling.reasoning_budget_end =
common_tokenize(vocab, chat_params.thinking_end_tag, false, true);
task.params.sampling.reasoning_budget_forced =
common_tokenize(vocab, defaults.sampling.reasoning_budget_message + chat_params.thinking_end_tag, false, true);
}
rd.post_task({std::move(task)});
}
if (verbose_prompt) {
console::set_display(DISPLAY_TYPE_PROMPT);
console::log("%s\n\n", chat_params.prompt.c_str());
console::set_display(DISPLAY_TYPE_RESET);
}
// wait for first result
console::spinner::start();
server_task_result_ptr result = rd.next(should_stop);
while (true) {
auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
if (res_partial && res_partial->is_begin) {
// this is the "send 200 status to client" signal in streaming mode
// skip, do not stop the spinner
result = rd.next(should_stop);
} else {
console::spinner::stop();
break;
}
}
std::string curr_content;
bool is_thinking = false;
while (result) {
if (should_stop()) {
break;
}
if (result->is_error()) {
json err_data = result->to_json();
if (err_data.contains("message")) {
console::error("Error: %s\n", err_data["message"].get<std::string>().c_str());
} else {
console::error("Error: %s\n", err_data.dump().c_str());
}
return curr_content;
}
auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
if (res_partial) {
out_timings = std::move(res_partial->timings);
for (const auto & diff : res_partial->oaicompat_msg_diffs) {
if (!diff.content_delta.empty()) {
if (is_thinking) {
console::log("\n[End thinking]\n\n");
console::set_display(DISPLAY_TYPE_RESET);
is_thinking = false;
}
curr_content += diff.content_delta;
console::log("%s", diff.content_delta.c_str());
console::flush();
}
if (!diff.reasoning_content_delta.empty()) {
console::set_display(DISPLAY_TYPE_REASONING);
if (!is_thinking) {
console::log("[Start thinking]\n");
}
is_thinking = true;
console::log("%s", diff.reasoning_content_delta.c_str());
console::flush();
}
}
}
auto res_final = dynamic_cast<server_task_result_cmpl_final *>(result.get());
if (res_final) {
out_timings = std::move(res_final->timings);
break;
}
result = rd.next(should_stop);
}
g_is_interrupted.store(false);
// server_response_reader automatically cancels pending tasks upon destruction
return curr_content;
}
// TODO: support remote files in the future (http, https, etc)
std::string load_input_file(const std::string & fname, bool is_media) {
std::ifstream file = fs_open_ifstream(fname, std::ios::binary);
if (!file) {
return "";
}
if (is_media) {
raw_buffer buf;
buf.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
input_files.push_back(std::move(buf));
return get_media_marker();
} else {
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
return content;
}
}
common_chat_params format_chat() {
auto meta = ctx_server.get_meta();
auto & chat_params = meta.chat_params;
auto caps = common_chat_templates_get_caps(chat_params.tmpls.get());
common_chat_templates_inputs inputs;
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
inputs.tools = {}; // TODO
inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
inputs.json_schema = ""; // TODO
inputs.grammar = ""; // TODO
inputs.use_jinja = chat_params.use_jinja;
inputs.parallel_tool_calls = caps["supports_parallel_tool_calls"];
inputs.add_generation_prompt = true;
inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
inputs.force_pure_content = chat_params.force_pure_content;
inputs.enable_thinking = chat_params.enable_thinking ? common_chat_templates_support_enable_thinking(chat_params.tmpls.get()) : false;
// Apply chat template to the list of messages
return common_chat_templates_apply(chat_params.tmpls.get(), inputs);
}
};
// TODO?: Make this reusable, enums, docs
static const std::array<std::string_view, 8> cmds = {
"/audio ",
"/clear",
"/exit",
"/glob ",
"/image ",
"/read ",
"/regen",
"/video ",
};
static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std::string_view line, size_t cursor_byte_pos) {
std::vector<std::pair<std::string, size_t>> matches;
std::string cmd;
if (line.length() > 1 && line.front() == '/' && !std::any_of(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
return string_starts_with(line, prefix);
})) {
auto it = cmds.begin();
while ((it = std::find_if(it, cmds.end(), [line](std::string_view cmd_line) {
return string_starts_with(cmd_line, line);
})) != cmds.end()) {
matches.emplace_back(*it, it->length());
++it;
}
} else {
auto it = std::find_if(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
return prefix.back() == ' ' && string_starts_with(line, prefix);
});
if (it != cmds.end()) {
cmd = *it;
}
}
if (!cmd.empty() && cmd != "/glob " && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) {
const std::string path_prefix = std::string(line.substr(cmd.length(), cursor_byte_pos - cmd.length()));
const std::string path_postfix = std::string(line.substr(cursor_byte_pos));
auto cur_dir = std::filesystem::current_path();
std::string cur_dir_str = cur_dir.string();
std::string expanded_prefix = path_prefix;
#if !defined(_WIN32)
if (string_starts_with(path_prefix, '~')) {
const char * home = std::getenv("HOME");
if (home && home[0]) {
expanded_prefix = home + path_prefix.substr(1);
}
}
if (string_starts_with(expanded_prefix, '/')) {
#else
if (std::isalpha(expanded_prefix[0]) && expanded_prefix.find(':') == 1) {
#endif
cur_dir = std::filesystem::path(expanded_prefix).parent_path();
cur_dir_str.clear();
} else if (!path_prefix.empty()) {
cur_dir /= std::filesystem::path(path_prefix).parent_path();
}
std::error_code ec;
for (const auto & entry : std::filesystem::directory_iterator(cur_dir, ec)) {
if (ec) {
break;
}
if (!entry.exists(ec)) {
ec.clear();
continue;
}
const std::string path_full = entry.path().string();
std::string path_entry = !cur_dir_str.empty() && string_starts_with(path_full, cur_dir_str) ? path_full.substr(cur_dir_str.length() + 1) : path_full;
if (entry.is_directory(ec)) {
path_entry.push_back(std::filesystem::path::preferred_separator);
}
if (expanded_prefix.empty() || string_starts_with(path_entry, expanded_prefix)) {
const std::string updated_line = cmd + path_entry;
matches.emplace_back(updated_line + path_postfix, updated_line.length());
}
if (ec) {
ec.clear();
}
}
if (matches.empty()) {
const std::string updated_line = cmd + path_prefix;
matches.emplace_back(updated_line + path_postfix, updated_line.length());
}
// Add the longest common prefix
if (!expanded_prefix.empty() && matches.size() > 1) {
const std::string_view match0(matches[0].first);
const std::string_view match1(matches[1].first);
auto it = std::mismatch(match0.begin(), match0.end(), match1.begin(), match1.end());
size_t len = it.first - match0.begin();
for (size_t i = 2; i < matches.size(); ++i) {
const std::string_view matchi(matches[i].first);
auto cmp = std::mismatch(match0.begin(), match0.end(), matchi.begin(), matchi.end());
len = std::min(len, static_cast<size_t>(cmp.first - match0.begin()));
}
const std::string updated_line = std::string(match0.substr(0, len));
matches.emplace_back(updated_line + path_postfix, updated_line.length());
}
std::sort(matches.begin(), matches.end(), [](const auto & a, const auto & b) {
return a.first.compare(0, a.second, b.first, 0, b.second) < 0;
});
}
return matches;
}
static constexpr size_t FILE_GLOB_MAX_RESULTS = 100;
// satisfies -Wmissing-declarations
int llama_cli(int argc, char ** argv);
@@ -375,25 +42,6 @@ int llama_cli(int argc, char ** argv) {
return 1;
}
// TODO: maybe support it later?
if (params.conversation_mode == COMMON_CONVERSATION_MODE_DISABLED) {
console::error("--no-conversation is not supported by llama-cli\n");
console::error("please use llama-completion instead\n");
}
// struct that contains llama context and inference
cli_context ctx_cli(params);
llama_backend_init();
llama_numa_init(params.numa);
// TODO: avoid using atexit() here by making `console` a singleton
console::init(params.simple_io, params.use_color);
atexit([]() { console::cleanup(); });
console::set_display(DISPLAY_TYPE_RESET);
console::set_completion_callback(auto_completion_callback);
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = signal_handler;
@@ -408,273 +56,16 @@ int llama_cli(int argc, char ** argv) {
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
console::log("\nLoading model... "); // followed by loading animation
console::spinner::start();
if (!ctx_cli.ctx_server.load_model(params)) {
console::spinner::stop();
console::error("\nFailed to load the model\n");
cli_context ctx_cli(params);
if (!ctx_cli.init()) {
ctx_cli.shutdown();
return 1;
}
ctx_cli.defaults.sampling = params.sampling;
int ret = ctx_cli.run();
console::spinner::stop();
console::log("\n");
ctx_cli.shutdown();
std::thread inference_thread([&ctx_cli]() {
ctx_cli.ctx_server.start_loop();
});
auto inf = ctx_cli.ctx_server.get_meta();
std::string modalities = "text";
if (inf.has_inp_image) {
modalities += ", vision";
}
if (inf.has_inp_audio) {
modalities += ", audio";
}
auto add_system_prompt = [&]() {
if (!params.system_prompt.empty()) {
ctx_cli.messages.push_back({
{"role", "system"},
{"content", params.system_prompt}
});
}
};
add_system_prompt();
console::log("\n");
console::log("%s\n", LLAMA_ASCII_LOGO);
console::log("build : %s\n", inf.build_info.c_str());
console::log("model : %s\n", inf.model_name.c_str());
console::log("modalities : %s\n", modalities.c_str());
if (!params.system_prompt.empty()) {
console::log("using custom system prompt\n");
}
console::log("\n");
console::log("available commands:\n");
console::log(" /exit or Ctrl+C stop or exit\n");
console::log(" /regen regenerate the last response\n");
console::log(" /clear clear the chat history\n");
console::log(" /read <file> add a text file\n");
console::log(" /glob <pattern> add text files using globbing pattern\n");
if (inf.has_inp_image) {
console::log(" /image <file> add an image file\n");
}
if (inf.has_inp_audio) {
console::log(" /audio <file> add an audio file\n");
}
if (inf.has_inp_video) {
console::log(" /video <file> add a video file\n");
}
console::log("\n");
// interactive loop
std::string cur_msg;
auto add_text_file = [&](const std::string & fname) -> bool {
std::string marker = ctx_cli.load_input_file(fname, false);
if (marker.empty()) {
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
return false;
}
if (inf.fim_sep_token != LLAMA_TOKEN_NULL) {
cur_msg += common_token_to_piece(ctx_cli.ctx_server.get_llama_context(), inf.fim_sep_token, true);
cur_msg += fname;
cur_msg.push_back('\n');
} else {
cur_msg += "--- File: ";
cur_msg += fname;
cur_msg += " ---\n";
}
cur_msg += marker;
console::log("Loaded text from '%s'\n", fname.c_str());
return true;
};
while (true) {
std::string buffer;
console::set_display(DISPLAY_TYPE_USER_INPUT);
if (params.prompt.empty()) {
console::log("\n> ");
std::string line;
bool another_line = true;
do {
another_line = console::readline(line, params.multiline_input);
buffer += line;
} while (another_line);
} else {
// process input prompt from args
for (auto & fname : params.image) {
std::string marker = ctx_cli.load_input_file(fname, true);
if (marker.empty()) {
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
break;
}
console::log("Loaded media from '%s'\n", fname.c_str());
cur_msg += marker;
}
buffer = params.prompt;
if (buffer.size() > 500) {
console::log("\n> %s ... (truncated)\n", buffer.substr(0, 500).c_str());
} else {
console::log("\n> %s\n", buffer.c_str());
}
params.prompt.clear(); // only use it once
}
console::set_display(DISPLAY_TYPE_RESET);
console::log("\n");
if (should_stop()) {
g_is_interrupted.store(false);
break;
}
// remove trailing newline
if (!buffer.empty() &&buffer.back() == '\n') {
buffer.pop_back();
}
// skip empty messages
if (buffer.empty()) {
continue;
}
bool add_user_msg = true;
// process commands
if (string_starts_with(buffer, "/exit")) {
break;
} else if (string_starts_with(buffer, "/regen")) {
if (ctx_cli.messages.size() >= 2) {
size_t last_idx = ctx_cli.messages.size() - 1;
ctx_cli.messages.erase(last_idx);
add_user_msg = false;
} else {
console::error("No message to regenerate.\n");
continue;
}
} else if (string_starts_with(buffer, "/clear")) {
ctx_cli.messages.clear();
add_system_prompt();
ctx_cli.input_files.clear();
console::log("Chat history cleared.\n");
continue;
} else if (
(string_starts_with(buffer, "/image ") && inf.has_inp_image) ||
(string_starts_with(buffer, "/audio ") && inf.has_inp_audio) ||
(string_starts_with(buffer, "/video ") && inf.has_inp_video)) {
// just in case (bad copy-paste for example), we strip all trailing/leading spaces
std::string fname = string_strip(buffer.substr(7));
std::string marker = ctx_cli.load_input_file(fname, true);
if (marker.empty()) {
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
continue;
}
cur_msg += marker;
console::log("Loaded media from '%s'\n", fname.c_str());
continue;
} else if (string_starts_with(buffer, "/read ")) {
std::string fname = string_strip(buffer.substr(6));
add_text_file(fname);
continue;
} else if (string_starts_with(buffer, "/glob ")) {
std::error_code ec;
size_t count = 0;
auto curdir = std::filesystem::current_path();
std::string pattern = string_strip(buffer.substr(6));
std::filesystem::path rel_path;
auto startglob = pattern.find_first_of("![*?");
if (startglob != std::string::npos && startglob != 0) {
auto endpath = pattern.substr(0, startglob).find_last_of('/');
if (endpath != std::string::npos) {
std::string rel_pattern = pattern.substr(0, endpath);
#if !defined(_WIN32)
if (string_starts_with(rel_pattern, '~')) {
const char * home = std::getenv("HOME");
if (home && home[0]) {
rel_pattern = home + rel_pattern.substr(1);
}
}
#endif
rel_path = rel_pattern;
pattern.erase(0, endpath + 1);
curdir /= rel_path;
}
}
for (const auto & entry : std::filesystem::recursive_directory_iterator(curdir,
std::filesystem::directory_options::skip_permission_denied, ec)) {
if (!entry.is_regular_file()) {
continue;
}
std::string rel = std::filesystem::relative(entry.path(), curdir, ec).string();
if (ec) {
ec.clear();
continue;
}
std::replace(rel.begin(), rel.end(), '\\', '/');
if (!glob_match(pattern, rel)) {
continue;
}
if (!add_text_file((rel_path / rel).string())) {
continue;
}
if (++count >= FILE_GLOB_MAX_RESULTS) {
console::error("Maximum number of globbed files allowed (%zu) reached.\n", FILE_GLOB_MAX_RESULTS);
break;
}
}
continue;
} else {
// not a command
cur_msg += buffer;
}
// generate response
if (add_user_msg) {
ctx_cli.messages.push_back({
{"role", "user"},
{"content", cur_msg}
});
cur_msg.clear();
}
result_timings timings;
std::string assistant_content = ctx_cli.generate_completion(timings);
ctx_cli.messages.push_back({
{"role", "assistant"},
{"content", assistant_content}
});
console::log("\n");
if (params.show_timings) {
console::set_display(DISPLAY_TYPE_INFO);
console::log("\n");
console::log("[ Prompt: %.1f t/s | Generation: %.1f t/s ]\n", timings.prompt_per_second, timings.predicted_per_second);
console::set_display(DISPLAY_TYPE_RESET);
}
if (params.single_turn) {
break;
}
}
console::set_display(DISPLAY_TYPE_RESET);
console::log("\nExiting...\n");
ctx_cli.ctx_server.terminate();
inference_thread.join();
// bump the log level to display timings
common_log_set_verbosity_thold(LOG_LEVEL_INFO);
common_memory_breakdown_print(ctx_cli.ctx_server.get_llama_context());
return 0;
return ret;
}
+3 -69
View File
@@ -5,6 +5,7 @@
#include "build-info.h"
#include "preset.h"
#include "download.h"
#include "http.h"
#include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
#include <sheredom/subprocess.h>
@@ -25,14 +26,7 @@
#include <sstream>
#include <cstring>
#ifdef _WIN32
#include <winsock2.h>
#include <windows.h>
#else
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <unistd.h>
#ifndef _WIN32
extern char **environ;
#endif
@@ -704,66 +698,6 @@ std::optional<server_model_meta> server_models::get_meta(const std::string & nam
return std::nullopt;
}
static int get_free_port() {
#ifdef _WIN32
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
return -1;
}
typedef SOCKET native_socket_t;
#define INVALID_SOCKET_VAL INVALID_SOCKET
#define CLOSE_SOCKET(s) closesocket(s)
#else
typedef int native_socket_t;
#define INVALID_SOCKET_VAL -1
#define CLOSE_SOCKET(s) close(s)
#endif
native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock == INVALID_SOCKET_VAL) {
#ifdef _WIN32
WSACleanup();
#endif
return -1;
}
struct sockaddr_in serv_addr;
std::memset(&serv_addr, 0, sizeof(serv_addr));
serv_addr.sin_family = AF_INET;
serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
serv_addr.sin_port = htons(0);
if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
CLOSE_SOCKET(sock);
#ifdef _WIN32
WSACleanup();
#endif
return -1;
}
#ifdef _WIN32
int namelen = sizeof(serv_addr);
#else
socklen_t namelen = sizeof(serv_addr);
#endif
if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
CLOSE_SOCKET(sock);
#ifdef _WIN32
WSACleanup();
#endif
return -1;
}
int port = ntohs(serv_addr.sin_port);
CLOSE_SOCKET(sock);
#ifdef _WIN32
WSACleanup();
#endif
return port;
}
// helper to convert vector<string> to char **
// pointers are only valid as long as the original vector is valid
static std::vector<char *> to_char_ptr_array(const std::vector<std::string> & vec) {
@@ -867,7 +801,7 @@ void server_models::load(const std::string & name, const load_options & opts) {
// prepare new instance info
instance_t inst;
inst.meta = meta;
inst.meta.port = get_free_port();
inst.meta.port = common_http_get_free_port();
inst.meta.status = SERVER_MODEL_STATUS_LOADING;
inst.meta.loaded_info = json{};
inst.meta.last_used = ggml_time_ms();
+35 -16
View File
@@ -35,6 +35,19 @@ static inline void signal_handler(int signal) {
shutdown_handler(signal);
}
// satisfies -Wmissing-declarations (used by llama command)
int llama_server(int argc, char ** argv);
// to be used via CLI (argc / argv are used by router mode only)
int llama_server(common_params & params, int argc, char ** argv);
void llama_server_terminate();
void llama_server_terminate() {
if (shutdown_handler) {
shutdown_handler(0);
}
}
// wrapper function that handles exceptions and logs errors
// this is to make sure handler_t never throws exceptions; instead, it returns an error response
static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) {
@@ -71,9 +84,6 @@ static server_http_context::handler_t ex_wrapper(server_http_context::handler_t
};
}
// satisfies -Wmissing-declarations
int llama_server(int argc, char ** argv);
int llama_server(int argc, char ** argv) {
std::setlocale(LC_NUMERIC, "C");
@@ -89,8 +99,14 @@ int llama_server(int argc, char ** argv) {
llama_backend_init();
llama_numa_init(params.numa);
return llama_server(params, argc, argv);
}
int llama_server(common_params & params, int argc, char ** argv) {
bool is_run_by_cli = (argv == nullptr);
// note: router mode also accepts -hf remote-preset, so we need to check that first
if (!params.model.hf_repo.empty()) {
if (!is_run_by_cli && !params.model.hf_repo.empty()) {
try {
common_params_handle_models_params handle_params;
handle_params.preset_only = true;
@@ -272,8 +288,9 @@ int llama_server(int argc, char ** argv) {
if (child.is_child() && child.get_mode() == SERVER_CHILD_MODE_DOWNLOAD) {
return child.run_download(params);
} else if (!is_router_server) {
} else if (!is_router_server && !is_run_by_cli) {
// single-model mode (NOT spawned by router)
// if this is invoked by CLI, model downloading should already handled
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {});
}
@@ -356,20 +373,22 @@ int llama_server(int argc, char ** argv) {
};
}
// TODO: refactor in common/console
// register signal handler is not running by CLI
if (!is_run_by_cli) {
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = signal_handler;
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
sigaction(SIGTERM, &sigint_action, NULL);
struct sigaction sigint_action;
sigint_action.sa_handler = signal_handler;
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
sigaction(SIGTERM, &sigint_action, NULL);
#elif defined (_WIN32)
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
};
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
};
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
}
if (is_router_server) {
SRV_INF("router server is listening on %s\n", ctx_http.listening_address.c_str());