mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-24 06:37:41 +02:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a432e6f863 | |||
| 5d67f69f59 | |||
| beef5cf077 | |||
| b093e46873 | |||
| 1401fc3ca7 | |||
| 85c58bbcd0 | |||
| 19296c1735 | |||
| 90c111bf98 | |||
| f7421eabe8 | |||
| 59797670dc |
+9
-3
@@ -603,9 +603,8 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
|
||||
// model is required (except for server)
|
||||
// TODO @ngxson : maybe show a list of available models in CLI in this case
|
||||
if (params.model.path.empty()
|
||||
&& !params.usage
|
||||
&& !params.completion) {
|
||||
bool can_skip_model = params.usage || params.completion || !params.server_base.empty();
|
||||
if (!can_skip_model && params.model.path.empty()) {
|
||||
throw std::invalid_argument("error: --model is required\n");
|
||||
}
|
||||
}
|
||||
@@ -1119,6 +1118,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.completion = true;
|
||||
}
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"--server-base"}, "URL",
|
||||
string_format("connect to this server instead of starting a new one, example: 'http://localhost:8080' (default: none)"),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.server_base = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_CLI}));
|
||||
add_opt(common_arg(
|
||||
{"--verbose-prompt"},
|
||||
string_format("print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false"),
|
||||
|
||||
@@ -631,6 +631,9 @@ struct common_params {
|
||||
|
||||
std::map<std::string, std::string> default_template_kwargs;
|
||||
|
||||
// CLI params
|
||||
std::string server_base; // if set, connect to this server instead of starting a new one
|
||||
|
||||
// UI configs
|
||||
bool ui = true;
|
||||
bool ui_mcp_proxy = false;
|
||||
|
||||
@@ -2,6 +2,16 @@
|
||||
|
||||
#include <cpp-httplib/httplib.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <winsock2.h>
|
||||
#include <windows.h>
|
||||
#else
|
||||
#include <sys/socket.h>
|
||||
#include <netinet/in.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
struct common_http_url {
|
||||
std::string scheme;
|
||||
std::string user;
|
||||
@@ -97,3 +107,63 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
|
||||
static std::string common_http_show_masked_url(const common_http_url & parts) {
|
||||
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path;
|
||||
}
|
||||
|
||||
static int common_http_get_free_port() {
|
||||
#ifdef _WIN32
|
||||
WSADATA wsaData;
|
||||
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
|
||||
return -1;
|
||||
}
|
||||
typedef SOCKET native_socket_t;
|
||||
#define INVALID_SOCKET_VAL INVALID_SOCKET
|
||||
#define CLOSE_SOCKET(s) closesocket(s)
|
||||
#else
|
||||
typedef int native_socket_t;
|
||||
#define INVALID_SOCKET_VAL -1
|
||||
#define CLOSE_SOCKET(s) close(s)
|
||||
#endif
|
||||
|
||||
native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock == INVALID_SOCKET_VAL) {
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
struct sockaddr_in serv_addr;
|
||||
std::memset(&serv_addr, 0, sizeof(serv_addr));
|
||||
serv_addr.sin_family = AF_INET;
|
||||
serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
|
||||
serv_addr.sin_port = htons(0);
|
||||
|
||||
if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
|
||||
CLOSE_SOCKET(sock);
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
int namelen = sizeof(serv_addr);
|
||||
#else
|
||||
socklen_t namelen = sizeof(serv_addr);
|
||||
#endif
|
||||
if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
|
||||
CLOSE_SOCKET(sock);
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
int port = ntohs(serv_addr.sin_port);
|
||||
|
||||
CLOSE_SOCKET(sock);
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
|
||||
return port;
|
||||
}
|
||||
|
||||
@@ -2,11 +2,13 @@
|
||||
|
||||
set(TARGET llama-cli-impl)
|
||||
|
||||
add_library(${TARGET} cli.cpp)
|
||||
add_library(${TARGET} cli.cpp
|
||||
cli-client.cpp
|
||||
cli-context.cpp)
|
||||
set_target_properties(${TARGET} PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS ON)
|
||||
|
||||
target_include_directories(${TARGET} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} ../server)
|
||||
target_link_libraries(${TARGET} PUBLIC server-context llama-common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PUBLIC llama-server-impl llama-common ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} LIBRARY)
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
#include "cli-client.h"
|
||||
|
||||
#include "http.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
|
||||
// generation can stall for a long time during prompt processing, so the
|
||||
// read timeout must be generous
|
||||
static constexpr time_t CLI_HTTP_READ_TIMEOUT_SEC = 3600;
|
||||
|
||||
// upper bound for the accumulated response body kept for error reporting
|
||||
static constexpr size_t CLI_HTTP_MAX_ERROR_BODY = 1024 * 1024;
|
||||
|
||||
// returns the path with the base url's path prefix prepended (if any)
|
||||
static std::string join_path(const common_http_url & parts, const std::string & path) {
|
||||
if (parts.path.empty() || parts.path == "/") {
|
||||
return path;
|
||||
}
|
||||
std::string prefix = parts.path;
|
||||
if (prefix.back() == '/') {
|
||||
prefix.pop_back();
|
||||
}
|
||||
return prefix + path;
|
||||
}
|
||||
|
||||
json cli_client::get(const std::string & path) {
|
||||
auto [cli, parts] = common_http_client(server_base);
|
||||
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
|
||||
auto path_with_model = path + (model.empty() ? "" : ("?model=" + model));
|
||||
auto res = cli.Get(join_path(parts, path_with_model));
|
||||
if (!res) {
|
||||
throw std::runtime_error("failed to connect to " + server_base + ": " + httplib::to_string(res.error()));
|
||||
}
|
||||
if (res->status < 200 || res->status >= 300) {
|
||||
throw std::runtime_error("GET " + path + " failed with status " + std::to_string(res->status) + ": " + res->body);
|
||||
}
|
||||
json result = json::parse(res->body, nullptr, false);
|
||||
if (result.is_discarded()) {
|
||||
throw std::runtime_error("GET " + path + " returned invalid JSON");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
json cli_client::post(const std::string & path, const json & body) {
|
||||
auto [cli, parts] = common_http_client(server_base);
|
||||
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
|
||||
auto body_with_model = body;
|
||||
if (!model.empty()) {
|
||||
body_with_model["model"] = model;
|
||||
}
|
||||
auto res = cli.Post(join_path(parts, path), body_with_model.dump(), "application/json");
|
||||
if (!res) {
|
||||
throw std::runtime_error("failed to connect to " + server_base + ": " + httplib::to_string(res.error()));
|
||||
}
|
||||
if (res->status < 200 || res->status >= 300) {
|
||||
throw std::runtime_error("POST " + path + " failed with status " + std::to_string(res->status) + ": " + res->body);
|
||||
}
|
||||
json result = json::parse(res->body, nullptr, false);
|
||||
if (result.is_discarded()) {
|
||||
throw std::runtime_error("POST " + path + " returned invalid JSON");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
json cli_client::post_sse(const std::string & path,
|
||||
const json & body,
|
||||
const std::function<bool()> & should_stop,
|
||||
const std::function<void(const json &)> & on_data) {
|
||||
auto [cli, parts] = common_http_client(server_base);
|
||||
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
|
||||
|
||||
std::string pending; // buffer for incomplete SSE lines
|
||||
std::string raw_body; // accumulated body, used only for error reporting
|
||||
|
||||
auto receiver = [&](const char * data, size_t len) -> bool {
|
||||
if (should_stop()) {
|
||||
return false; // aborts the request
|
||||
}
|
||||
if (raw_body.size() < CLI_HTTP_MAX_ERROR_BODY) {
|
||||
raw_body.append(data, std::min(len, CLI_HTTP_MAX_ERROR_BODY - raw_body.size()));
|
||||
}
|
||||
pending.append(data, len);
|
||||
size_t pos;
|
||||
while ((pos = pending.find('\n')) != std::string::npos) {
|
||||
std::string line = pending.substr(0, pos);
|
||||
pending.erase(0, pos + 1);
|
||||
if (!line.empty() && line.back() == '\r') {
|
||||
line.pop_back();
|
||||
}
|
||||
if (line.rfind("data: ", 0) != 0) {
|
||||
continue;
|
||||
}
|
||||
std::string payload = line.substr(6);
|
||||
if (payload == "[DONE]") {
|
||||
continue;
|
||||
}
|
||||
json event = json::parse(payload, nullptr, false);
|
||||
if (!event.is_discarded()) {
|
||||
on_data(event);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
httplib::Headers headers = {{"Accept", "text/event-stream"}};
|
||||
auto body_with_model = body;
|
||||
if (!model.empty()) {
|
||||
body_with_model["model"] = model;
|
||||
}
|
||||
auto res = cli.Post(join_path(parts, path), headers, body_with_model.dump(), "application/json", receiver);
|
||||
|
||||
if (!res) {
|
||||
if (res.error() == httplib::Error::Canceled && should_stop()) {
|
||||
return json(); // cancelled by the user
|
||||
}
|
||||
return json {{"error", {{"message", "failed to connect to " + server_base + ": " + httplib::to_string(res.error())}}}};
|
||||
}
|
||||
if (res->status < 200 || res->status >= 300) {
|
||||
json error_body = json::parse(raw_body, nullptr, false);
|
||||
if (!error_body.is_discarded() && error_body.contains("error")) {
|
||||
return error_body;
|
||||
}
|
||||
return json {{"error", {{"message", "request failed with status " + std::to_string(res->status)}}}};
|
||||
}
|
||||
return json();
|
||||
}
|
||||
|
||||
bool cli_client::wait_health(const std::function<bool()> & is_aborted) {
|
||||
int connect_attempts = 0;
|
||||
while (!is_aborted()) {
|
||||
auto [cli, parts] = common_http_client(server_base);
|
||||
cli.set_connection_timeout(1, 0);
|
||||
auto res = cli.Get(join_path(parts, "/health"));
|
||||
if (res) {
|
||||
if (res->status == 200) {
|
||||
return true;
|
||||
}
|
||||
// any other status means the server is up but not ready yet
|
||||
// (e.g. 503 while the model is still loading)
|
||||
} else if (++connect_attempts >= 10) {
|
||||
last_error = "failed to connect to " + server_base + ": " + httplib::to_string(res.error());
|
||||
return false;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(300));
|
||||
}
|
||||
last_error = "aborted while waiting for the server to become ready";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::string> cli_client::list_models() {
|
||||
json resp = get("/v1/models");
|
||||
if (!resp.contains("data") || !resp.at("data").is_array()) {
|
||||
throw std::runtime_error("invalid response from /v1/models");
|
||||
}
|
||||
std::vector<std::string> models;
|
||||
for (const auto & m : resp.at("data")) {
|
||||
if (m.contains("id") && m.at("id").is_string()) {
|
||||
models.push_back(m.at("id").get<std::string>());
|
||||
}
|
||||
}
|
||||
return models;
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
// openai-like client for CLI
|
||||
struct cli_client {
|
||||
std::string server_base; // base url, for example "http://127.0.0.1:8080"
|
||||
std::string last_error; // set when wait_health() fails
|
||||
|
||||
std::string model; // optional, set when the server has multiple models (router mode)
|
||||
|
||||
// simple GET request, returns the response json
|
||||
// throws std::runtime_error on transport error or non-2xx status
|
||||
json get(const std::string & path);
|
||||
|
||||
// simple POST request, returns the response json
|
||||
// throws std::runtime_error on transport error or non-2xx status
|
||||
json post(const std::string & path, const json & body);
|
||||
|
||||
// POST request with an SSE streaming response; on_data is invoked once
|
||||
// per "data:" event; the function returns after the stream is finished:
|
||||
// a null json on graceful exit (incl. cancellation via should_stop),
|
||||
// the error response json otherwise
|
||||
json post_sse(const std::string & path,
|
||||
const json & body,
|
||||
const std::function<bool()> & should_stop,
|
||||
const std::function<void(const json &)> & on_data);
|
||||
|
||||
// poll /health until the server is ready to accept requests
|
||||
// returns false if is_aborted returned true or the server is unreachable
|
||||
bool wait_health(const std::function<bool()> & is_aborted);
|
||||
|
||||
//
|
||||
// higher-level wrappers
|
||||
//
|
||||
|
||||
json create_chat_completion(const json & request,
|
||||
const std::function<bool()> & should_stop,
|
||||
const std::function<void(const json &)> & on_data) {
|
||||
return post_sse("/v1/chat/completions", request, should_stop, on_data);
|
||||
}
|
||||
|
||||
json get_props() {
|
||||
return get("/props");
|
||||
}
|
||||
|
||||
std::vector<std::string> list_models();
|
||||
};
|
||||
@@ -0,0 +1,559 @@
|
||||
#include "cli-context.h"
|
||||
#include "cli-view.h"
|
||||
|
||||
#include "arg.h"
|
||||
#include "base64.hpp"
|
||||
#include "log.h"
|
||||
#include "console.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <set>
|
||||
|
||||
std::atomic<bool> g_cli_interrupted = false;
|
||||
|
||||
static bool should_stop() {
|
||||
return g_cli_interrupted.load();
|
||||
}
|
||||
|
||||
static constexpr size_t FILE_GLOB_MAX_RESULTS = 100;
|
||||
|
||||
const char * LLAMA_ASCII_LOGO = R"(
|
||||
▄▄ ▄▄
|
||||
██ ██
|
||||
██ ██ ▀▀█▄ ███▄███▄ ▀▀█▄ ▄████ ████▄ ████▄
|
||||
██ ██ ▄█▀██ ██ ██ ██ ▄█▀██ ██ ██ ██ ██ ██
|
||||
██ ██ ▀█▄██ ██ ██ ██ ▀█▄██ ██ ▀████ ████▀ ████▀
|
||||
██ ██
|
||||
▀▀ ▀▀
|
||||
)";
|
||||
|
||||
// number of values an arg consumes on the command line
|
||||
static int arg_num_values(const common_arg & opt) {
|
||||
if (opt.value_hint_2 != nullptr) {
|
||||
return 2;
|
||||
}
|
||||
if (opt.value_hint != nullptr) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static std::string format_error_message(const json & err) {
|
||||
if (err.contains("error") && err.at("error").is_object()) {
|
||||
const auto & e = err.at("error");
|
||||
if (e.contains("message") && e.at("message").is_string()) {
|
||||
return e.at("message").get<std::string>();
|
||||
}
|
||||
}
|
||||
return err.dump();
|
||||
}
|
||||
|
||||
static std::string media_type_from_ext(const std::string & fname) {
|
||||
std::string ext = std::filesystem::path(fname).extension().string();
|
||||
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
|
||||
if (ext == ".wav" || ext == ".mp3") {
|
||||
return "audio";
|
||||
}
|
||||
if (ext == ".mp4" || ext == ".avi" || ext == ".mkv" || ext == ".mov" || ext == ".webm") {
|
||||
return "video";
|
||||
}
|
||||
return "image";
|
||||
}
|
||||
|
||||
bool cli_context::init() {
|
||||
view::init(params);
|
||||
|
||||
std::optional<view::spinner> spinner;
|
||||
|
||||
bool use_external_server = !params.server_base.empty();
|
||||
if (use_external_server) {
|
||||
std::string base = params.server_base;
|
||||
while (!base.empty() && base.back() == '/') {
|
||||
base.pop_back();
|
||||
}
|
||||
client.server_base = base;
|
||||
|
||||
spinner.emplace("Connecting to server at " + base);
|
||||
} else {
|
||||
if (params.model.path.empty() && params.model.url.empty() &&
|
||||
params.model.hf_repo.empty() && params.model.docker_repo.empty()) {
|
||||
view::show_error(
|
||||
"no model specified",
|
||||
"use -m <file.gguf> or -hf <user/repo> to run a local model,\n"
|
||||
"or --server-base <url> to connect to a running llama-server"
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
spinner.emplace("\n\nLoading model...");
|
||||
|
||||
server.emplace();
|
||||
if (!server->start(params)) {
|
||||
view::show_error("server start failed");
|
||||
return false;
|
||||
}
|
||||
if (!server->wait_ready(should_stop)) {
|
||||
if (!should_stop()) {
|
||||
view::show_error("the server exited before becoming ready");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
client.server_base = server->address();
|
||||
}
|
||||
|
||||
// for --server-base this is the main availability check; for a spawned
|
||||
// server it is a cheap sanity check on top of the ready signal
|
||||
auto is_aborted = [this]() {
|
||||
return should_stop() || (server && !server->alive());
|
||||
};
|
||||
bool healthy = false;
|
||||
try {
|
||||
healthy = client.wait_health(is_aborted);
|
||||
} catch (const std::exception & e) {
|
||||
client.last_error = e.what();
|
||||
}
|
||||
if (!healthy) {
|
||||
if (!should_stop()) {
|
||||
view::show_error(client.last_error);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (use_external_server) {
|
||||
spinner.reset();
|
||||
if (!list_and_ask_models()) {
|
||||
return false;
|
||||
}
|
||||
// restore the spinner for the next step
|
||||
spinner.emplace("Waiting for server...");
|
||||
}
|
||||
|
||||
fetch_server_props();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void cli_context::fetch_server_props() {
|
||||
try {
|
||||
json props = client.get_props();
|
||||
model_name = props.value("model_alias", "");
|
||||
if (model_name.empty()) {
|
||||
const std::string path = props.value("model_path", "");
|
||||
if (!path.empty()) {
|
||||
model_name = std::filesystem::path(path).filename().string();
|
||||
}
|
||||
}
|
||||
build_info = props.value("build_info", "");
|
||||
if (props.contains("modalities") && props.at("modalities").is_object()) {
|
||||
const auto & modalities = props.at("modalities");
|
||||
has_vision = modalities.value("vision", false);
|
||||
has_audio = modalities.value("audio", false);
|
||||
has_video = modalities.value("video", false);
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
// /props can be disabled on remote servers; not fatal
|
||||
LOG_DBG("failed to fetch /props: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
bool cli_context::list_and_ask_models() {
|
||||
auto models = client.list_models();
|
||||
|
||||
// only one model: use it without asking
|
||||
if (models.size() == 1) {
|
||||
model_name = models[0];
|
||||
client.model = model_name;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string message = "\nAvailable models:";
|
||||
if (!models.empty()) {
|
||||
for (size_t i = 0; i < models.size(); ++i) {
|
||||
message += "\n " + std::to_string(i + 1) + ". " + models[i];
|
||||
}
|
||||
}
|
||||
message += "\n";
|
||||
view::show_message(message);
|
||||
std::string selection;
|
||||
while (selection.empty()) {
|
||||
if (should_stop()) {
|
||||
return false;
|
||||
}
|
||||
view::user_turn user_turn;
|
||||
selection = user_turn.read_input(false, "Select model by number: ");
|
||||
if (selection.empty()) {
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
size_t idx = std::stoul(selection);
|
||||
if (idx > 0 && idx <= models.size()) {
|
||||
model_name = models[idx - 1];
|
||||
client.model = model_name;
|
||||
view::show_message("Selected model: " + model_name);
|
||||
break;
|
||||
}
|
||||
} catch (...) {
|
||||
// ignore
|
||||
}
|
||||
view::show_error("Invalid selection. Please enter a valid number.");
|
||||
selection.clear();
|
||||
continue;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void cli_context::add_system_prompt() {
|
||||
if (!params.system_prompt.empty()) {
|
||||
messages.push_back({
|
||||
{"role", "system"},
|
||||
{"content", params.system_prompt}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void cli_context::push_user_message(const std::string & text) {
|
||||
json content;
|
||||
if (pending_media.empty()) {
|
||||
content = text;
|
||||
} else {
|
||||
// multimodal message: media parts first, then the text
|
||||
content = pending_media;
|
||||
content.push_back({
|
||||
{"type", "text"},
|
||||
{"text", text}
|
||||
});
|
||||
pending_media = json::array();
|
||||
}
|
||||
messages.push_back({
|
||||
{"role", "user"},
|
||||
{"content", content}
|
||||
});
|
||||
}
|
||||
|
||||
bool cli_context::stage_media_file(const std::string & fname, const std::string & type) {
|
||||
std::ifstream file(fname, std::ios::binary);
|
||||
if (!file) {
|
||||
return false;
|
||||
}
|
||||
std::string data((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
std::string encoded = base64::encode(data);
|
||||
|
||||
if (type == "audio") {
|
||||
std::string ext = std::filesystem::path(fname).extension().string();
|
||||
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
|
||||
pending_media.push_back({
|
||||
{"type", "input_audio"},
|
||||
{"input_audio", {
|
||||
{"data", encoded},
|
||||
{"format", ext == ".mp3" ? "mp3" : "wav"}
|
||||
}}
|
||||
});
|
||||
} else if (type == "video") {
|
||||
pending_media.push_back({
|
||||
{"type", "input_video"},
|
||||
{"input_video", {
|
||||
{"data", encoded}
|
||||
}}
|
||||
});
|
||||
} else {
|
||||
// the server detects the actual image type from the data
|
||||
pending_media.push_back({
|
||||
{"type", "image_url"},
|
||||
{"image_url", {
|
||||
{"url", "data:image/unknown;base64," + encoded}
|
||||
}}
|
||||
});
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool cli_context::generate_completion(std::string & assistant_content, cli_timings & timings) {
|
||||
json body = {
|
||||
{"messages", messages},
|
||||
{"stream", true},
|
||||
// in order to get timings even when we cancel mid-way
|
||||
{"timings_per_token", true},
|
||||
};
|
||||
|
||||
bool stream_error = false;
|
||||
|
||||
view::assistant_turn a;
|
||||
|
||||
json err = client.create_chat_completion(body, should_stop, [&](const json & chunk) {
|
||||
if (chunk.contains("error")) {
|
||||
stream_error = true;
|
||||
view::show_error(format_error_message(chunk));
|
||||
return;
|
||||
}
|
||||
if (chunk.contains("timings")) {
|
||||
const auto & t = chunk.at("timings");
|
||||
timings.prompt_per_second = t.value("prompt_per_second", 0.0);
|
||||
timings.predicted_per_second = t.value("predicted_per_second", 0.0);
|
||||
}
|
||||
if (!chunk.contains("choices") || !chunk.at("choices").is_array() || chunk.at("choices").empty()) {
|
||||
return;
|
||||
}
|
||||
const auto & choice = chunk.at("choices").at(0);
|
||||
if (!choice.contains("delta")) {
|
||||
return;
|
||||
}
|
||||
const auto & delta = choice.at("delta");
|
||||
if (delta.contains("reasoning_content") && delta.at("reasoning_content").is_string()) {
|
||||
const std::string text = delta.at("reasoning_content").get<std::string>();
|
||||
if (!text.empty()) {
|
||||
a.push(view::ASSISTANT_DISPLAY_MODE_REASONING, text);
|
||||
}
|
||||
}
|
||||
if (delta.contains("content") && delta.at("content").is_string()) {
|
||||
const std::string text = delta.at("content").get<std::string>();
|
||||
if (!text.empty()) {
|
||||
assistant_content += text;
|
||||
a.push(view::ASSISTANT_DISPLAY_MODE_CONTENT, text);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
g_cli_interrupted.store(false);
|
||||
|
||||
if (!err.is_null()) {
|
||||
view::show_error(format_error_message(err));
|
||||
return false;
|
||||
}
|
||||
return !stream_error;
|
||||
}
|
||||
|
||||
int cli_context::run() {
|
||||
add_system_prompt();
|
||||
|
||||
std::string modalities = "text";
|
||||
if (has_vision) {
|
||||
modalities += ", vision";
|
||||
}
|
||||
if (has_audio) {
|
||||
modalities += ", audio";
|
||||
}
|
||||
if (has_video) {
|
||||
modalities += ", video";
|
||||
}
|
||||
|
||||
std::string banner;
|
||||
banner += "\n";
|
||||
banner += LLAMA_ASCII_LOGO;
|
||||
banner += "\n";
|
||||
banner += "build : " + build_info + "\n";
|
||||
banner += "model : " + model_name + "\n";
|
||||
banner += "modalities : " + modalities + "\n";
|
||||
if (!params.system_prompt.empty()) {
|
||||
banner += "using custom system prompt\n";
|
||||
}
|
||||
banner += "\n";
|
||||
banner += "available commands:\n";
|
||||
banner += " /exit or Ctrl+C stop or exit\n";
|
||||
banner += " /regen regenerate the last response\n";
|
||||
banner += " /clear clear the chat history\n";
|
||||
banner += " /read <file> add a text file\n";
|
||||
banner += " /glob <pattern> add text files using globbing pattern\n";
|
||||
if (has_vision) {
|
||||
banner += " /image <file> add an image file\n";
|
||||
}
|
||||
if (has_audio) {
|
||||
banner += " /audio <file> add an audio file\n";
|
||||
}
|
||||
if (has_video) {
|
||||
banner += " /video <file> add a video file\n";
|
||||
}
|
||||
banner += "\n";
|
||||
|
||||
view::show_message(banner);
|
||||
|
||||
// interactive loop
|
||||
std::string cur_msg;
|
||||
|
||||
auto add_text_file = [&](const std::string & fname) -> bool {
|
||||
std::ifstream file(fname, std::ios::binary);
|
||||
if (!file) {
|
||||
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
|
||||
return false;
|
||||
}
|
||||
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
cur_msg += "--- File: ";
|
||||
cur_msg += fname;
|
||||
cur_msg += " ---\n";
|
||||
cur_msg += content;
|
||||
view::show_message(string_format("Loaded text from '%s'", fname.c_str()));
|
||||
return true;
|
||||
};
|
||||
|
||||
while (true) {
|
||||
std::string buffer;
|
||||
{
|
||||
view::user_turn user_turn;
|
||||
|
||||
if (params.prompt.empty()) {
|
||||
buffer = user_turn.read_input(params.multiline_input);
|
||||
} else {
|
||||
// process input prompt from args
|
||||
for (auto & fname : params.image) {
|
||||
if (!stage_media_file(fname, media_type_from_ext(fname))) {
|
||||
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
|
||||
break;
|
||||
}
|
||||
view::show_message(string_format("Loaded media from '%s'", fname.c_str()));
|
||||
}
|
||||
buffer = params.prompt;
|
||||
user_turn.echo(buffer);
|
||||
params.prompt.clear(); // only use it once
|
||||
}
|
||||
}
|
||||
|
||||
if (should_stop()) {
|
||||
g_cli_interrupted.store(false);
|
||||
break;
|
||||
}
|
||||
|
||||
// remove trailing newline
|
||||
if (!buffer.empty() && buffer.back() == '\n') {
|
||||
buffer.pop_back();
|
||||
}
|
||||
|
||||
// skip empty messages
|
||||
if (buffer.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool add_user_msg = true;
|
||||
|
||||
// process commands
|
||||
if (string_starts_with(buffer, "/exit")) {
|
||||
break;
|
||||
} else if (string_starts_with(buffer, "/regen")) {
|
||||
if (messages.size() >= 2) {
|
||||
size_t last_idx = messages.size() - 1;
|
||||
messages.erase(last_idx);
|
||||
add_user_msg = false;
|
||||
} else {
|
||||
view::show_error("No message to regenerate.");
|
||||
continue;
|
||||
}
|
||||
} else if (string_starts_with(buffer, "/clear")) {
|
||||
messages.clear();
|
||||
add_system_prompt();
|
||||
|
||||
pending_media = json::array();
|
||||
view::show_message("Chat history cleared.");
|
||||
continue;
|
||||
} else if (
|
||||
(string_starts_with(buffer, "/image ") && has_vision) ||
|
||||
(string_starts_with(buffer, "/audio ") && has_audio) ||
|
||||
(string_starts_with(buffer, "/video ") && has_video)) {
|
||||
std::string type = buffer.substr(1, 5);
|
||||
// just in case (bad copy-paste for example), we strip all trailing/leading spaces
|
||||
std::string fname = string_strip(buffer.substr(7));
|
||||
if (!stage_media_file(fname, type)) {
|
||||
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
|
||||
continue;
|
||||
}
|
||||
view::show_message(string_format("Loaded media from '%s'", fname.c_str()));
|
||||
continue;
|
||||
} else if (string_starts_with(buffer, "/read ")) {
|
||||
std::string fname = string_strip(buffer.substr(6));
|
||||
add_text_file(fname);
|
||||
continue;
|
||||
} else if (string_starts_with(buffer, "/glob ")) {
|
||||
std::error_code ec;
|
||||
size_t count = 0;
|
||||
auto curdir = std::filesystem::current_path();
|
||||
std::string pattern = string_strip(buffer.substr(6));
|
||||
std::filesystem::path rel_path;
|
||||
|
||||
auto startglob = pattern.find_first_of("![*?");
|
||||
if (startglob != std::string::npos && startglob != 0) {
|
||||
auto endpath = pattern.substr(0, startglob).find_last_of('/');
|
||||
if (endpath != std::string::npos) {
|
||||
std::string rel_pattern = pattern.substr(0, endpath);
|
||||
#if !defined(_WIN32)
|
||||
if (string_starts_with(rel_pattern, '~')) {
|
||||
const char * home = std::getenv("HOME");
|
||||
if (home && home[0]) {
|
||||
rel_pattern = home + rel_pattern.substr(1);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
rel_path = rel_pattern;
|
||||
pattern.erase(0, endpath + 1);
|
||||
curdir /= rel_path;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & entry : std::filesystem::recursive_directory_iterator(curdir,
|
||||
std::filesystem::directory_options::skip_permission_denied, ec)) {
|
||||
if (!entry.is_regular_file()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string rel = std::filesystem::relative(entry.path(), curdir, ec).string();
|
||||
if (ec) {
|
||||
ec.clear();
|
||||
continue;
|
||||
}
|
||||
std::replace(rel.begin(), rel.end(), '\\', '/');
|
||||
|
||||
if (!glob_match(pattern, rel)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!add_text_file((rel_path / rel).string())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (++count >= FILE_GLOB_MAX_RESULTS) {
|
||||
view::show_error(string_format("Maximum number of globbed files allowed (%zu) reached.", FILE_GLOB_MAX_RESULTS));
|
||||
break;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
} else {
|
||||
// not a command
|
||||
cur_msg += buffer;
|
||||
}
|
||||
|
||||
// generate response
|
||||
if (add_user_msg) {
|
||||
push_user_message(cur_msg);
|
||||
cur_msg.clear();
|
||||
}
|
||||
cli_timings timings;
|
||||
std::string assistant_content;
|
||||
generate_completion(assistant_content, timings);
|
||||
messages.push_back({
|
||||
{"role", "assistant"},
|
||||
{"content", assistant_content}
|
||||
});
|
||||
|
||||
if (params.show_timings) {
|
||||
view::show_info(string_format(
|
||||
"\n[ Prompt: %.1f t/s | Generation: %.1f t/s ]",
|
||||
timings.prompt_per_second,
|
||||
timings.predicted_per_second
|
||||
));
|
||||
}
|
||||
|
||||
if (params.single_turn) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
view::show_message("\n\nExiting...");
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void cli_context::shutdown() {
|
||||
if (server) {
|
||||
server->stop();
|
||||
server.reset();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#include "cli-client.h"
|
||||
#include "cli-server.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
struct cli_timings {
|
||||
double prompt_per_second = 0.0;
|
||||
double predicted_per_second = 0.0;
|
||||
};
|
||||
|
||||
// set by the SIGINT handler; cleared once the interrupt has been handled
|
||||
extern std::atomic<bool> g_cli_interrupted;
|
||||
|
||||
struct cli_context {
|
||||
common_params params;
|
||||
|
||||
cli_client client; // always initialized
|
||||
std::optional<cli_server> server; // only set when no --server-base is given
|
||||
|
||||
json messages = json::array();
|
||||
json pending_media = json::array(); // staged multimodal content parts
|
||||
|
||||
// properties of the connected server
|
||||
// will be populated by fetch_server_props()
|
||||
std::string model_name;
|
||||
std::string build_info;
|
||||
bool has_vision = false;
|
||||
bool has_audio = false;
|
||||
bool has_video = false;
|
||||
|
||||
cli_context(const common_params & params) : params(params) {}
|
||||
~cli_context() {
|
||||
shutdown();
|
||||
}
|
||||
|
||||
// connect to --server-base or spawn a local llama-server child;
|
||||
// argc/argv are needed to forward the server-relevant args to the child
|
||||
bool init();
|
||||
|
||||
// run the interactive chat loop, returns the process exit code
|
||||
int run();
|
||||
|
||||
// stop the local server child (if any)
|
||||
void shutdown();
|
||||
|
||||
private:
|
||||
bool generate_completion(std::string & assistant_content, cli_timings & timings);
|
||||
void fetch_server_props();
|
||||
void add_system_prompt();
|
||||
void push_user_message(const std::string & text);
|
||||
|
||||
// check if server have multiple models (router mode)
|
||||
// if yes, list them then ask; do nothing otherwise
|
||||
bool list_and_ask_models();
|
||||
|
||||
// read a file and stage it as a multimodal content part; type is one of
|
||||
// "image", "audio", "video"; returns false if the file cannot be read
|
||||
bool stage_media_file(const std::string & fname, const std::string & type);
|
||||
};
|
||||
@@ -0,0 +1,83 @@
|
||||
#pragma once
|
||||
|
||||
#include <thread>
|
||||
|
||||
#include "http.h"
|
||||
|
||||
// llama_server will be available as a dynamic library symbol
|
||||
int llama_server(common_params & params, int argc, char ** argv);
|
||||
void llama_server_terminate();
|
||||
|
||||
struct cli_server {
|
||||
std::thread th;
|
||||
int port = -1;
|
||||
std::atomic<bool> is_alive = false;
|
||||
std::atomic<bool> is_stopping = false;
|
||||
|
||||
~cli_server() {
|
||||
stop();
|
||||
}
|
||||
|
||||
void stop() {
|
||||
if (alive() && !is_stopping.exchange(true)) {
|
||||
llama_server_terminate();
|
||||
th.join();
|
||||
}
|
||||
}
|
||||
|
||||
// spawn llama-server in a thread and interact with it via a random port
|
||||
bool start(common_params & params) {
|
||||
port = common_http_get_free_port();
|
||||
if (port <= 0) {
|
||||
fprintf(stderr, "failed to get a free port\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
is_alive.store(true, std::memory_order_release);
|
||||
|
||||
th = std::thread([&]() {
|
||||
common_params server_params = params; // copy
|
||||
server_params.port = port;
|
||||
// argc / argv are only used in router mode, we can skip them for now
|
||||
int res = llama_server(server_params, 0, nullptr);
|
||||
if (res != 0) {
|
||||
fprintf(stderr, "llama_server exited with code %d\n", res);
|
||||
}
|
||||
is_alive.store(false, std::memory_order_release);
|
||||
});
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string address() const {
|
||||
return "http://127.0.0.1:" + std::to_string(port);
|
||||
}
|
||||
|
||||
bool wait_ready(std::function<bool()> should_stop) {
|
||||
if (!alive()) {
|
||||
return false;
|
||||
}
|
||||
while (!should_stop()) {
|
||||
auto [cli, parts] = common_http_client(address());
|
||||
cli.set_connection_timeout(1, 0);
|
||||
auto res = cli.Get("/health");
|
||||
if (res) {
|
||||
if (res->status == 200) {
|
||||
return true;
|
||||
}
|
||||
// any other status means the server is up but not ready yet
|
||||
// (e.g. 503 while the model is still loading)
|
||||
}
|
||||
if (!alive()) {
|
||||
// in case server die permanently
|
||||
return false;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(200));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool alive() const {
|
||||
return is_alive.load(std::memory_order_acquire);
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,250 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "console.h"
|
||||
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
#include <filesystem>
|
||||
#include <string_view>
|
||||
|
||||
// TODO?: Make this reusable, enums, docs
|
||||
static const std::array<std::string_view, 8> cmds = {
|
||||
"/audio ",
|
||||
"/clear",
|
||||
"/exit",
|
||||
"/glob ",
|
||||
"/image ",
|
||||
"/read ",
|
||||
"/regen",
|
||||
"/video ",
|
||||
};
|
||||
|
||||
static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std::string_view line, size_t cursor_byte_pos) {
|
||||
std::vector<std::pair<std::string, size_t>> matches;
|
||||
std::string cmd;
|
||||
|
||||
if (line.length() > 1 && line.front() == '/' && !std::any_of(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
|
||||
return string_starts_with(line, prefix);
|
||||
})) {
|
||||
auto it = cmds.begin();
|
||||
|
||||
while ((it = std::find_if(it, cmds.end(), [line](std::string_view cmd_line) {
|
||||
return string_starts_with(cmd_line, line);
|
||||
})) != cmds.end()) {
|
||||
matches.emplace_back(*it, it->length());
|
||||
++it;
|
||||
}
|
||||
} else {
|
||||
auto it = std::find_if(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
|
||||
return prefix.back() == ' ' && string_starts_with(line, prefix);
|
||||
});
|
||||
|
||||
if (it != cmds.end()) {
|
||||
cmd = *it;
|
||||
}
|
||||
}
|
||||
|
||||
if (!cmd.empty() && cmd != "/glob " && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) {
|
||||
const std::string path_prefix = std::string(line.substr(cmd.length(), cursor_byte_pos - cmd.length()));
|
||||
const std::string path_postfix = std::string(line.substr(cursor_byte_pos));
|
||||
auto cur_dir = std::filesystem::current_path();
|
||||
std::string cur_dir_str = cur_dir.string();
|
||||
std::string expanded_prefix = path_prefix;
|
||||
|
||||
#if !defined(_WIN32)
|
||||
if (string_starts_with(path_prefix, '~')) {
|
||||
const char * home = std::getenv("HOME");
|
||||
if (home && home[0]) {
|
||||
expanded_prefix = home + path_prefix.substr(1);
|
||||
}
|
||||
}
|
||||
if (string_starts_with(expanded_prefix, '/')) {
|
||||
#else
|
||||
if (std::isalpha(expanded_prefix[0]) && expanded_prefix.find(':') == 1) {
|
||||
#endif
|
||||
cur_dir = std::filesystem::path(expanded_prefix).parent_path();
|
||||
cur_dir_str.clear();
|
||||
} else if (!path_prefix.empty()) {
|
||||
cur_dir /= std::filesystem::path(path_prefix).parent_path();
|
||||
}
|
||||
|
||||
std::error_code ec;
|
||||
for (const auto & entry : std::filesystem::directory_iterator(cur_dir, ec)) {
|
||||
if (ec) {
|
||||
break;
|
||||
}
|
||||
if (!entry.exists(ec)) {
|
||||
ec.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::string path_full = entry.path().string();
|
||||
std::string path_entry = !cur_dir_str.empty() && string_starts_with(path_full, cur_dir_str) ? path_full.substr(cur_dir_str.length() + 1) : path_full;
|
||||
|
||||
if (entry.is_directory(ec)) {
|
||||
path_entry.push_back(std::filesystem::path::preferred_separator);
|
||||
}
|
||||
|
||||
if (expanded_prefix.empty() || string_starts_with(path_entry, expanded_prefix)) {
|
||||
const std::string updated_line = cmd + path_entry;
|
||||
matches.emplace_back(updated_line + path_postfix, updated_line.length());
|
||||
}
|
||||
|
||||
if (ec) {
|
||||
ec.clear();
|
||||
}
|
||||
}
|
||||
|
||||
if (matches.empty()) {
|
||||
const std::string updated_line = cmd + path_prefix;
|
||||
matches.emplace_back(updated_line + path_postfix, updated_line.length());
|
||||
}
|
||||
|
||||
// Add the longest common prefix
|
||||
if (!expanded_prefix.empty() && matches.size() > 1) {
|
||||
const std::string_view match0(matches[0].first);
|
||||
const std::string_view match1(matches[1].first);
|
||||
auto it = std::mismatch(match0.begin(), match0.end(), match1.begin(), match1.end());
|
||||
size_t len = it.first - match0.begin();
|
||||
|
||||
for (size_t i = 2; i < matches.size(); ++i) {
|
||||
const std::string_view matchi(matches[i].first);
|
||||
auto cmp = std::mismatch(match0.begin(), match0.end(), matchi.begin(), matchi.end());
|
||||
len = std::min(len, static_cast<size_t>(cmp.first - match0.begin()));
|
||||
}
|
||||
|
||||
const std::string updated_line = std::string(match0.substr(0, len));
|
||||
matches.emplace_back(updated_line + path_postfix, updated_line.length());
|
||||
}
|
||||
|
||||
std::sort(matches.begin(), matches.end(), [](const auto & a, const auto & b) {
|
||||
return a.first.compare(0, a.second, b.first, 0, b.second) < 0;
|
||||
});
|
||||
}
|
||||
|
||||
return matches;
|
||||
}
|
||||
|
||||
// note: make this view implementation generic, so that we can move to TUI in the future if we want to
|
||||
namespace view {
|
||||
static void init(const common_params & params) {
|
||||
// TODO: avoid using atexit() here by making `console` a singleton
|
||||
console::init(params.simple_io, params.use_color);
|
||||
atexit([]() { console::cleanup(); });
|
||||
|
||||
console::set_completion_callback(auto_completion_callback);
|
||||
}
|
||||
|
||||
struct spinner {
|
||||
spinner(const std::string & message) {
|
||||
if (!message.empty()) {
|
||||
console::log("%s ", message.c_str());
|
||||
}
|
||||
console::spinner::start();
|
||||
}
|
||||
~spinner() {
|
||||
console::spinner::stop();
|
||||
}
|
||||
};
|
||||
|
||||
struct user_turn {
|
||||
user_turn() {
|
||||
console::set_display(DISPLAY_TYPE_USER_INPUT);
|
||||
}
|
||||
~user_turn() {
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
}
|
||||
void echo(const std::string & buffer) {
|
||||
if (buffer.size() > 500) {
|
||||
console::log("\n> %s ... (truncated)\n", buffer.substr(0, 500).c_str());
|
||||
} else {
|
||||
console::log("\n> %s\n", buffer.c_str());
|
||||
}
|
||||
}
|
||||
std::string read_input(bool multiline_input, const char * prompt = nullptr) {
|
||||
if (prompt) {
|
||||
console::log("%s", prompt);
|
||||
} else {
|
||||
console::log("\n> ");
|
||||
}
|
||||
std::string buffer;
|
||||
std::string line;
|
||||
bool another_line = true;
|
||||
do {
|
||||
another_line = console::readline(line, multiline_input);
|
||||
buffer += line;
|
||||
} while (another_line);
|
||||
return buffer;
|
||||
}
|
||||
};
|
||||
|
||||
enum assistant_display_mode {
|
||||
ASSISTANT_DISPLAY_MODE_REASONING,
|
||||
ASSISTANT_DISPLAY_MODE_CONTENT,
|
||||
};
|
||||
struct assistant_turn {
|
||||
assistant_display_mode mode = ASSISTANT_DISPLAY_MODE_CONTENT;
|
||||
bool trailing_newline = true;
|
||||
bool is_inside_reasoning = false;
|
||||
assistant_turn() {
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
}
|
||||
~assistant_turn() {
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
add_newline_if_needed();
|
||||
}
|
||||
void push(assistant_display_mode m, const std::string & buffer) {
|
||||
if (m != mode) {
|
||||
add_newline_if_needed();
|
||||
switch (m) {
|
||||
case ASSISTANT_DISPLAY_MODE_CONTENT:
|
||||
{
|
||||
if (is_inside_reasoning) {
|
||||
console::log("[End thinking]\n\n");
|
||||
is_inside_reasoning = false;
|
||||
}
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
} break;
|
||||
case ASSISTANT_DISPLAY_MODE_REASONING:
|
||||
{
|
||||
console::set_display(DISPLAY_TYPE_REASONING);
|
||||
is_inside_reasoning = true;
|
||||
console::log("\n[Start thinking]\n\n");
|
||||
} break;
|
||||
}
|
||||
}
|
||||
mode = m;
|
||||
if (buffer.empty()) {
|
||||
return;
|
||||
}
|
||||
trailing_newline = buffer.back() == '\n';
|
||||
console::log("%s", buffer.c_str());
|
||||
console::flush();
|
||||
}
|
||||
void add_newline_if_needed() {
|
||||
if (!trailing_newline) {
|
||||
console::log("\n");
|
||||
console::flush();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static void show_error(const std::string & title, const std::string & message = "") {
|
||||
console::spinner::stop();
|
||||
console::error("Error: %s\n", title.c_str());
|
||||
if (!message.empty()) {
|
||||
console::log("%s\n", message.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
static void show_message(const std::string & message) {
|
||||
console::log("%s\n", message.c_str());
|
||||
}
|
||||
|
||||
static void show_info(const std::string & message) {
|
||||
console::set_display(DISPLAY_TYPE_INFO);
|
||||
console::log("%s\n", message.c_str());
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
}
|
||||
}
|
||||
+10
-624
@@ -1,20 +1,10 @@
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "arg.h"
|
||||
#include "console.h"
|
||||
#include "fit.h"
|
||||
// #include "log.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
|
||||
#include "server-common.h"
|
||||
#include "server-context.h"
|
||||
#include "server-task.h"
|
||||
#include "cli-context.h"
|
||||
#include "cli-view.h"
|
||||
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <algorithm>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <thread>
|
||||
#include <signal.h>
|
||||
|
||||
#if defined(_WIN32)
|
||||
@@ -25,342 +15,19 @@
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
const char * LLAMA_ASCII_LOGO = R"(
|
||||
▄▄ ▄▄
|
||||
██ ██
|
||||
██ ██ ▀▀█▄ ███▄███▄ ▀▀█▄ ▄████ ████▄ ████▄
|
||||
██ ██ ▄█▀██ ██ ██ ██ ▄█▀██ ██ ██ ██ ██ ██
|
||||
██ ██ ▀█▄██ ██ ██ ██ ▀█▄██ ██ ▀████ ████▀ ████▀
|
||||
██ ██
|
||||
▀▀ ▀▀
|
||||
)";
|
||||
|
||||
static std::atomic<bool> g_is_interrupted = false;
|
||||
static bool should_stop() {
|
||||
return g_is_interrupted.load();
|
||||
}
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||
static void signal_handler(int) {
|
||||
if (g_is_interrupted.load()) {
|
||||
if (g_cli_interrupted.load()) {
|
||||
// second Ctrl+C - exit immediately
|
||||
// make sure to clear colors before exiting (not using LOG or console.cpp here to avoid deadlock)
|
||||
fprintf(stdout, "\033[0m\n");
|
||||
fflush(stdout);
|
||||
std::exit(130);
|
||||
}
|
||||
g_is_interrupted.store(true);
|
||||
g_cli_interrupted.store(true);
|
||||
}
|
||||
#endif
|
||||
|
||||
struct cli_context {
|
||||
server_context ctx_server;
|
||||
json messages = json::array();
|
||||
std::vector<raw_buffer> input_files;
|
||||
task_params defaults;
|
||||
bool verbose_prompt;
|
||||
|
||||
// thread for showing "loading" animation
|
||||
std::atomic<bool> loading_show;
|
||||
|
||||
cli_context(const common_params & params) {
|
||||
defaults.sampling = params.sampling;
|
||||
defaults.speculative = params.speculative;
|
||||
defaults.n_keep = params.n_keep;
|
||||
defaults.n_predict = params.n_predict;
|
||||
defaults.antiprompt = params.antiprompt;
|
||||
|
||||
defaults.stream = true; // make sure we always use streaming mode
|
||||
defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way
|
||||
// defaults.return_progress = true; // TODO: show progress
|
||||
|
||||
verbose_prompt = params.verbose_prompt;
|
||||
}
|
||||
|
||||
std::string generate_completion(result_timings & out_timings) {
|
||||
server_response_reader rd = ctx_server.get_response_reader();
|
||||
auto chat_params = format_chat();
|
||||
{
|
||||
// TODO: reduce some copies here in the future
|
||||
server_task task = server_task(SERVER_TASK_TYPE_COMPLETION);
|
||||
task.id = rd.get_new_id();
|
||||
task.index = 0;
|
||||
task.params = defaults; // copy
|
||||
task.cli_prompt = chat_params.prompt; // copy
|
||||
task.cli_files = input_files; // copy
|
||||
task.cli = true;
|
||||
|
||||
// chat template settings
|
||||
task.params.chat_parser_params = common_chat_parser_params(chat_params);
|
||||
task.params.chat_parser_params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
if (!chat_params.parser.empty()) {
|
||||
task.params.chat_parser_params.parser.load(chat_params.parser);
|
||||
}
|
||||
|
||||
// Copy the preserved tokens into the sampling params
|
||||
const llama_vocab * vocab = llama_model_get_vocab(
|
||||
llama_get_model(ctx_server.get_llama_context()));
|
||||
for (const auto & token : chat_params.preserved_tokens) {
|
||||
auto ids = common_tokenize(vocab, token, false, true);
|
||||
if (ids.size() == 1) {
|
||||
task.params.sampling.preserved_tokens.insert(ids[0]);
|
||||
}
|
||||
}
|
||||
|
||||
// reasoning budget sampler
|
||||
if (!chat_params.thinking_end_tag.empty()) {
|
||||
task.params.sampling.reasoning_budget_tokens = defaults.sampling.reasoning_budget_tokens;
|
||||
task.params.sampling.generation_prompt = chat_params.generation_prompt;
|
||||
|
||||
if (!chat_params.thinking_start_tag.empty()) {
|
||||
task.params.sampling.reasoning_budget_start =
|
||||
common_tokenize(vocab, chat_params.thinking_start_tag, false, true);
|
||||
}
|
||||
task.params.sampling.reasoning_budget_end =
|
||||
common_tokenize(vocab, chat_params.thinking_end_tag, false, true);
|
||||
task.params.sampling.reasoning_budget_forced =
|
||||
common_tokenize(vocab, defaults.sampling.reasoning_budget_message + chat_params.thinking_end_tag, false, true);
|
||||
}
|
||||
|
||||
rd.post_task({std::move(task)});
|
||||
}
|
||||
|
||||
if (verbose_prompt) {
|
||||
console::set_display(DISPLAY_TYPE_PROMPT);
|
||||
console::log("%s\n\n", chat_params.prompt.c_str());
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
}
|
||||
|
||||
// wait for first result
|
||||
console::spinner::start();
|
||||
server_task_result_ptr result = rd.next(should_stop);
|
||||
|
||||
while (true) {
|
||||
auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
|
||||
if (res_partial && res_partial->is_begin) {
|
||||
// this is the "send 200 status to client" signal in streaming mode
|
||||
// skip, do not stop the spinner
|
||||
result = rd.next(should_stop);
|
||||
} else {
|
||||
console::spinner::stop();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::string curr_content;
|
||||
bool is_thinking = false;
|
||||
|
||||
while (result) {
|
||||
if (should_stop()) {
|
||||
break;
|
||||
}
|
||||
if (result->is_error()) {
|
||||
json err_data = result->to_json();
|
||||
if (err_data.contains("message")) {
|
||||
console::error("Error: %s\n", err_data["message"].get<std::string>().c_str());
|
||||
} else {
|
||||
console::error("Error: %s\n", err_data.dump().c_str());
|
||||
}
|
||||
return curr_content;
|
||||
}
|
||||
auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
|
||||
if (res_partial) {
|
||||
out_timings = std::move(res_partial->timings);
|
||||
for (const auto & diff : res_partial->oaicompat_msg_diffs) {
|
||||
if (!diff.content_delta.empty()) {
|
||||
if (is_thinking) {
|
||||
console::log("\n[End thinking]\n\n");
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
is_thinking = false;
|
||||
}
|
||||
curr_content += diff.content_delta;
|
||||
console::log("%s", diff.content_delta.c_str());
|
||||
console::flush();
|
||||
}
|
||||
if (!diff.reasoning_content_delta.empty()) {
|
||||
console::set_display(DISPLAY_TYPE_REASONING);
|
||||
if (!is_thinking) {
|
||||
console::log("[Start thinking]\n");
|
||||
}
|
||||
is_thinking = true;
|
||||
console::log("%s", diff.reasoning_content_delta.c_str());
|
||||
console::flush();
|
||||
}
|
||||
}
|
||||
}
|
||||
auto res_final = dynamic_cast<server_task_result_cmpl_final *>(result.get());
|
||||
if (res_final) {
|
||||
out_timings = std::move(res_final->timings);
|
||||
break;
|
||||
}
|
||||
result = rd.next(should_stop);
|
||||
}
|
||||
g_is_interrupted.store(false);
|
||||
// server_response_reader automatically cancels pending tasks upon destruction
|
||||
return curr_content;
|
||||
}
|
||||
|
||||
// TODO: support remote files in the future (http, https, etc)
|
||||
std::string load_input_file(const std::string & fname, bool is_media) {
|
||||
std::ifstream file = fs_open_ifstream(fname, std::ios::binary);
|
||||
if (!file) {
|
||||
return "";
|
||||
}
|
||||
if (is_media) {
|
||||
raw_buffer buf;
|
||||
buf.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
input_files.push_back(std::move(buf));
|
||||
return get_media_marker();
|
||||
} else {
|
||||
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
return content;
|
||||
}
|
||||
}
|
||||
|
||||
common_chat_params format_chat() {
|
||||
auto meta = ctx_server.get_meta();
|
||||
auto & chat_params = meta.chat_params;
|
||||
|
||||
auto caps = common_chat_templates_get_caps(chat_params.tmpls.get());
|
||||
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
|
||||
inputs.tools = {}; // TODO
|
||||
inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
inputs.json_schema = ""; // TODO
|
||||
inputs.grammar = ""; // TODO
|
||||
inputs.use_jinja = chat_params.use_jinja;
|
||||
inputs.parallel_tool_calls = caps["supports_parallel_tool_calls"];
|
||||
inputs.add_generation_prompt = true;
|
||||
inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
inputs.force_pure_content = chat_params.force_pure_content;
|
||||
inputs.enable_thinking = chat_params.enable_thinking ? common_chat_templates_support_enable_thinking(chat_params.tmpls.get()) : false;
|
||||
|
||||
// Apply chat template to the list of messages
|
||||
return common_chat_templates_apply(chat_params.tmpls.get(), inputs);
|
||||
}
|
||||
};
|
||||
|
||||
// TODO?: Make this reusable, enums, docs
|
||||
static const std::array<std::string_view, 8> cmds = {
|
||||
"/audio ",
|
||||
"/clear",
|
||||
"/exit",
|
||||
"/glob ",
|
||||
"/image ",
|
||||
"/read ",
|
||||
"/regen",
|
||||
"/video ",
|
||||
};
|
||||
|
||||
static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std::string_view line, size_t cursor_byte_pos) {
|
||||
std::vector<std::pair<std::string, size_t>> matches;
|
||||
std::string cmd;
|
||||
|
||||
if (line.length() > 1 && line.front() == '/' && !std::any_of(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
|
||||
return string_starts_with(line, prefix);
|
||||
})) {
|
||||
auto it = cmds.begin();
|
||||
|
||||
while ((it = std::find_if(it, cmds.end(), [line](std::string_view cmd_line) {
|
||||
return string_starts_with(cmd_line, line);
|
||||
})) != cmds.end()) {
|
||||
matches.emplace_back(*it, it->length());
|
||||
++it;
|
||||
}
|
||||
} else {
|
||||
auto it = std::find_if(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
|
||||
return prefix.back() == ' ' && string_starts_with(line, prefix);
|
||||
});
|
||||
|
||||
if (it != cmds.end()) {
|
||||
cmd = *it;
|
||||
}
|
||||
}
|
||||
|
||||
if (!cmd.empty() && cmd != "/glob " && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) {
|
||||
const std::string path_prefix = std::string(line.substr(cmd.length(), cursor_byte_pos - cmd.length()));
|
||||
const std::string path_postfix = std::string(line.substr(cursor_byte_pos));
|
||||
auto cur_dir = std::filesystem::current_path();
|
||||
std::string cur_dir_str = cur_dir.string();
|
||||
std::string expanded_prefix = path_prefix;
|
||||
|
||||
#if !defined(_WIN32)
|
||||
if (string_starts_with(path_prefix, '~')) {
|
||||
const char * home = std::getenv("HOME");
|
||||
if (home && home[0]) {
|
||||
expanded_prefix = home + path_prefix.substr(1);
|
||||
}
|
||||
}
|
||||
if (string_starts_with(expanded_prefix, '/')) {
|
||||
#else
|
||||
if (std::isalpha(expanded_prefix[0]) && expanded_prefix.find(':') == 1) {
|
||||
#endif
|
||||
cur_dir = std::filesystem::path(expanded_prefix).parent_path();
|
||||
cur_dir_str.clear();
|
||||
} else if (!path_prefix.empty()) {
|
||||
cur_dir /= std::filesystem::path(path_prefix).parent_path();
|
||||
}
|
||||
|
||||
std::error_code ec;
|
||||
for (const auto & entry : std::filesystem::directory_iterator(cur_dir, ec)) {
|
||||
if (ec) {
|
||||
break;
|
||||
}
|
||||
if (!entry.exists(ec)) {
|
||||
ec.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::string path_full = entry.path().string();
|
||||
std::string path_entry = !cur_dir_str.empty() && string_starts_with(path_full, cur_dir_str) ? path_full.substr(cur_dir_str.length() + 1) : path_full;
|
||||
|
||||
if (entry.is_directory(ec)) {
|
||||
path_entry.push_back(std::filesystem::path::preferred_separator);
|
||||
}
|
||||
|
||||
if (expanded_prefix.empty() || string_starts_with(path_entry, expanded_prefix)) {
|
||||
const std::string updated_line = cmd + path_entry;
|
||||
matches.emplace_back(updated_line + path_postfix, updated_line.length());
|
||||
}
|
||||
|
||||
if (ec) {
|
||||
ec.clear();
|
||||
}
|
||||
}
|
||||
|
||||
if (matches.empty()) {
|
||||
const std::string updated_line = cmd + path_prefix;
|
||||
matches.emplace_back(updated_line + path_postfix, updated_line.length());
|
||||
}
|
||||
|
||||
// Add the longest common prefix
|
||||
if (!expanded_prefix.empty() && matches.size() > 1) {
|
||||
const std::string_view match0(matches[0].first);
|
||||
const std::string_view match1(matches[1].first);
|
||||
auto it = std::mismatch(match0.begin(), match0.end(), match1.begin(), match1.end());
|
||||
size_t len = it.first - match0.begin();
|
||||
|
||||
for (size_t i = 2; i < matches.size(); ++i) {
|
||||
const std::string_view matchi(matches[i].first);
|
||||
auto cmp = std::mismatch(match0.begin(), match0.end(), matchi.begin(), matchi.end());
|
||||
len = std::min(len, static_cast<size_t>(cmp.first - match0.begin()));
|
||||
}
|
||||
|
||||
const std::string updated_line = std::string(match0.substr(0, len));
|
||||
matches.emplace_back(updated_line + path_postfix, updated_line.length());
|
||||
}
|
||||
|
||||
std::sort(matches.begin(), matches.end(), [](const auto & a, const auto & b) {
|
||||
return a.first.compare(0, a.second, b.first, 0, b.second) < 0;
|
||||
});
|
||||
}
|
||||
|
||||
return matches;
|
||||
}
|
||||
|
||||
static constexpr size_t FILE_GLOB_MAX_RESULTS = 100;
|
||||
|
||||
// satisfies -Wmissing-declarations
|
||||
int llama_cli(int argc, char ** argv);
|
||||
|
||||
@@ -375,25 +42,6 @@ int llama_cli(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// TODO: maybe support it later?
|
||||
if (params.conversation_mode == COMMON_CONVERSATION_MODE_DISABLED) {
|
||||
console::error("--no-conversation is not supported by llama-cli\n");
|
||||
console::error("please use llama-completion instead\n");
|
||||
}
|
||||
|
||||
// struct that contains llama context and inference
|
||||
cli_context ctx_cli(params);
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
// TODO: avoid using atexit() here by making `console` a singleton
|
||||
console::init(params.simple_io, params.use_color);
|
||||
atexit([]() { console::cleanup(); });
|
||||
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
console::set_completion_callback(auto_completion_callback);
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
sigint_action.sa_handler = signal_handler;
|
||||
@@ -408,273 +56,11 @@ int llama_cli(int argc, char ** argv) {
|
||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
#endif
|
||||
|
||||
console::log("\nLoading model... "); // followed by loading animation
|
||||
console::spinner::start();
|
||||
if (!ctx_cli.ctx_server.load_model(params)) {
|
||||
console::spinner::stop();
|
||||
console::error("\nFailed to load the model\n");
|
||||
cli_context ctx_cli(params);
|
||||
|
||||
if (!ctx_cli.init()) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ctx_cli.defaults.sampling = params.sampling;
|
||||
|
||||
console::spinner::stop();
|
||||
console::log("\n");
|
||||
|
||||
std::thread inference_thread([&ctx_cli]() {
|
||||
ctx_cli.ctx_server.start_loop();
|
||||
});
|
||||
|
||||
auto inf = ctx_cli.ctx_server.get_meta();
|
||||
std::string modalities = "text";
|
||||
if (inf.has_inp_image) {
|
||||
modalities += ", vision";
|
||||
}
|
||||
if (inf.has_inp_audio) {
|
||||
modalities += ", audio";
|
||||
}
|
||||
|
||||
auto add_system_prompt = [&]() {
|
||||
if (!params.system_prompt.empty()) {
|
||||
ctx_cli.messages.push_back({
|
||||
{"role", "system"},
|
||||
{"content", params.system_prompt}
|
||||
});
|
||||
}
|
||||
};
|
||||
add_system_prompt();
|
||||
|
||||
console::log("\n");
|
||||
console::log("%s\n", LLAMA_ASCII_LOGO);
|
||||
console::log("build : %s\n", inf.build_info.c_str());
|
||||
console::log("model : %s\n", inf.model_name.c_str());
|
||||
console::log("modalities : %s\n", modalities.c_str());
|
||||
if (!params.system_prompt.empty()) {
|
||||
console::log("using custom system prompt\n");
|
||||
}
|
||||
console::log("\n");
|
||||
console::log("available commands:\n");
|
||||
console::log(" /exit or Ctrl+C stop or exit\n");
|
||||
console::log(" /regen regenerate the last response\n");
|
||||
console::log(" /clear clear the chat history\n");
|
||||
console::log(" /read <file> add a text file\n");
|
||||
console::log(" /glob <pattern> add text files using globbing pattern\n");
|
||||
if (inf.has_inp_image) {
|
||||
console::log(" /image <file> add an image file\n");
|
||||
}
|
||||
if (inf.has_inp_audio) {
|
||||
console::log(" /audio <file> add an audio file\n");
|
||||
}
|
||||
if (inf.has_inp_video) {
|
||||
console::log(" /video <file> add a video file\n");
|
||||
}
|
||||
console::log("\n");
|
||||
|
||||
// interactive loop
|
||||
std::string cur_msg;
|
||||
|
||||
auto add_text_file = [&](const std::string & fname) -> bool {
|
||||
std::string marker = ctx_cli.load_input_file(fname, false);
|
||||
if (marker.empty()) {
|
||||
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
|
||||
return false;
|
||||
}
|
||||
if (inf.fim_sep_token != LLAMA_TOKEN_NULL) {
|
||||
cur_msg += common_token_to_piece(ctx_cli.ctx_server.get_llama_context(), inf.fim_sep_token, true);
|
||||
cur_msg += fname;
|
||||
cur_msg.push_back('\n');
|
||||
} else {
|
||||
cur_msg += "--- File: ";
|
||||
cur_msg += fname;
|
||||
cur_msg += " ---\n";
|
||||
}
|
||||
cur_msg += marker;
|
||||
console::log("Loaded text from '%s'\n", fname.c_str());
|
||||
return true;
|
||||
};
|
||||
|
||||
while (true) {
|
||||
std::string buffer;
|
||||
console::set_display(DISPLAY_TYPE_USER_INPUT);
|
||||
if (params.prompt.empty()) {
|
||||
console::log("\n> ");
|
||||
std::string line;
|
||||
bool another_line = true;
|
||||
do {
|
||||
another_line = console::readline(line, params.multiline_input);
|
||||
buffer += line;
|
||||
} while (another_line);
|
||||
} else {
|
||||
// process input prompt from args
|
||||
for (auto & fname : params.image) {
|
||||
std::string marker = ctx_cli.load_input_file(fname, true);
|
||||
if (marker.empty()) {
|
||||
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
|
||||
break;
|
||||
}
|
||||
console::log("Loaded media from '%s'\n", fname.c_str());
|
||||
cur_msg += marker;
|
||||
}
|
||||
buffer = params.prompt;
|
||||
if (buffer.size() > 500) {
|
||||
console::log("\n> %s ... (truncated)\n", buffer.substr(0, 500).c_str());
|
||||
} else {
|
||||
console::log("\n> %s\n", buffer.c_str());
|
||||
}
|
||||
params.prompt.clear(); // only use it once
|
||||
}
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
console::log("\n");
|
||||
|
||||
if (should_stop()) {
|
||||
g_is_interrupted.store(false);
|
||||
break;
|
||||
}
|
||||
|
||||
// remove trailing newline
|
||||
if (!buffer.empty() &&buffer.back() == '\n') {
|
||||
buffer.pop_back();
|
||||
}
|
||||
|
||||
// skip empty messages
|
||||
if (buffer.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool add_user_msg = true;
|
||||
|
||||
// process commands
|
||||
if (string_starts_with(buffer, "/exit")) {
|
||||
break;
|
||||
} else if (string_starts_with(buffer, "/regen")) {
|
||||
if (ctx_cli.messages.size() >= 2) {
|
||||
size_t last_idx = ctx_cli.messages.size() - 1;
|
||||
ctx_cli.messages.erase(last_idx);
|
||||
add_user_msg = false;
|
||||
} else {
|
||||
console::error("No message to regenerate.\n");
|
||||
continue;
|
||||
}
|
||||
} else if (string_starts_with(buffer, "/clear")) {
|
||||
ctx_cli.messages.clear();
|
||||
add_system_prompt();
|
||||
|
||||
ctx_cli.input_files.clear();
|
||||
console::log("Chat history cleared.\n");
|
||||
continue;
|
||||
} else if (
|
||||
(string_starts_with(buffer, "/image ") && inf.has_inp_image) ||
|
||||
(string_starts_with(buffer, "/audio ") && inf.has_inp_audio) ||
|
||||
(string_starts_with(buffer, "/video ") && inf.has_inp_video)) {
|
||||
// just in case (bad copy-paste for example), we strip all trailing/leading spaces
|
||||
std::string fname = string_strip(buffer.substr(7));
|
||||
std::string marker = ctx_cli.load_input_file(fname, true);
|
||||
if (marker.empty()) {
|
||||
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
|
||||
continue;
|
||||
}
|
||||
cur_msg += marker;
|
||||
console::log("Loaded media from '%s'\n", fname.c_str());
|
||||
continue;
|
||||
} else if (string_starts_with(buffer, "/read ")) {
|
||||
std::string fname = string_strip(buffer.substr(6));
|
||||
add_text_file(fname);
|
||||
continue;
|
||||
} else if (string_starts_with(buffer, "/glob ")) {
|
||||
std::error_code ec;
|
||||
size_t count = 0;
|
||||
auto curdir = std::filesystem::current_path();
|
||||
std::string pattern = string_strip(buffer.substr(6));
|
||||
std::filesystem::path rel_path;
|
||||
|
||||
auto startglob = pattern.find_first_of("![*?");
|
||||
if (startglob != std::string::npos && startglob != 0) {
|
||||
auto endpath = pattern.substr(0, startglob).find_last_of('/');
|
||||
if (endpath != std::string::npos) {
|
||||
std::string rel_pattern = pattern.substr(0, endpath);
|
||||
#if !defined(_WIN32)
|
||||
if (string_starts_with(rel_pattern, '~')) {
|
||||
const char * home = std::getenv("HOME");
|
||||
if (home && home[0]) {
|
||||
rel_pattern = home + rel_pattern.substr(1);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
rel_path = rel_pattern;
|
||||
pattern.erase(0, endpath + 1);
|
||||
curdir /= rel_path;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & entry : std::filesystem::recursive_directory_iterator(curdir,
|
||||
std::filesystem::directory_options::skip_permission_denied, ec)) {
|
||||
if (!entry.is_regular_file()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string rel = std::filesystem::relative(entry.path(), curdir, ec).string();
|
||||
if (ec) {
|
||||
ec.clear();
|
||||
continue;
|
||||
}
|
||||
std::replace(rel.begin(), rel.end(), '\\', '/');
|
||||
|
||||
if (!glob_match(pattern, rel)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!add_text_file((rel_path / rel).string())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (++count >= FILE_GLOB_MAX_RESULTS) {
|
||||
console::error("Maximum number of globbed files allowed (%zu) reached.\n", FILE_GLOB_MAX_RESULTS);
|
||||
break;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
} else {
|
||||
// not a command
|
||||
cur_msg += buffer;
|
||||
}
|
||||
|
||||
// generate response
|
||||
if (add_user_msg) {
|
||||
ctx_cli.messages.push_back({
|
||||
{"role", "user"},
|
||||
{"content", cur_msg}
|
||||
});
|
||||
cur_msg.clear();
|
||||
}
|
||||
result_timings timings;
|
||||
std::string assistant_content = ctx_cli.generate_completion(timings);
|
||||
ctx_cli.messages.push_back({
|
||||
{"role", "assistant"},
|
||||
{"content", assistant_content}
|
||||
});
|
||||
console::log("\n");
|
||||
|
||||
if (params.show_timings) {
|
||||
console::set_display(DISPLAY_TYPE_INFO);
|
||||
console::log("\n");
|
||||
console::log("[ Prompt: %.1f t/s | Generation: %.1f t/s ]\n", timings.prompt_per_second, timings.predicted_per_second);
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
}
|
||||
|
||||
if (params.single_turn) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
|
||||
console::log("\nExiting...\n");
|
||||
ctx_cli.ctx_server.terminate();
|
||||
inference_thread.join();
|
||||
|
||||
// bump the log level to display timings
|
||||
common_log_set_verbosity_thold(LOG_LEVEL_INFO);
|
||||
common_memory_breakdown_print(ctx_cli.ctx_server.get_llama_context());
|
||||
|
||||
return 0;
|
||||
return ctx_cli.run();
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "build-info.h"
|
||||
#include "preset.h"
|
||||
#include "download.h"
|
||||
#include "http.h"
|
||||
|
||||
#include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
|
||||
#include <sheredom/subprocess.h>
|
||||
@@ -25,14 +26,7 @@
|
||||
#include <sstream>
|
||||
#include <cstring>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <winsock2.h>
|
||||
#include <windows.h>
|
||||
#else
|
||||
#include <sys/socket.h>
|
||||
#include <netinet/in.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <unistd.h>
|
||||
#ifndef _WIN32
|
||||
extern char **environ;
|
||||
#endif
|
||||
|
||||
@@ -704,66 +698,6 @@ std::optional<server_model_meta> server_models::get_meta(const std::string & nam
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static int get_free_port() {
|
||||
#ifdef _WIN32
|
||||
WSADATA wsaData;
|
||||
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
|
||||
return -1;
|
||||
}
|
||||
typedef SOCKET native_socket_t;
|
||||
#define INVALID_SOCKET_VAL INVALID_SOCKET
|
||||
#define CLOSE_SOCKET(s) closesocket(s)
|
||||
#else
|
||||
typedef int native_socket_t;
|
||||
#define INVALID_SOCKET_VAL -1
|
||||
#define CLOSE_SOCKET(s) close(s)
|
||||
#endif
|
||||
|
||||
native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock == INVALID_SOCKET_VAL) {
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
struct sockaddr_in serv_addr;
|
||||
std::memset(&serv_addr, 0, sizeof(serv_addr));
|
||||
serv_addr.sin_family = AF_INET;
|
||||
serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
|
||||
serv_addr.sin_port = htons(0);
|
||||
|
||||
if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
|
||||
CLOSE_SOCKET(sock);
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
int namelen = sizeof(serv_addr);
|
||||
#else
|
||||
socklen_t namelen = sizeof(serv_addr);
|
||||
#endif
|
||||
if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
|
||||
CLOSE_SOCKET(sock);
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
int port = ntohs(serv_addr.sin_port);
|
||||
|
||||
CLOSE_SOCKET(sock);
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
|
||||
return port;
|
||||
}
|
||||
|
||||
// helper to convert vector<string> to char **
|
||||
// pointers are only valid as long as the original vector is valid
|
||||
static std::vector<char *> to_char_ptr_array(const std::vector<std::string> & vec) {
|
||||
@@ -867,7 +801,7 @@ void server_models::load(const std::string & name, const load_options & opts) {
|
||||
// prepare new instance info
|
||||
instance_t inst;
|
||||
inst.meta = meta;
|
||||
inst.meta.port = get_free_port();
|
||||
inst.meta.port = common_http_get_free_port();
|
||||
inst.meta.status = SERVER_MODEL_STATUS_LOADING;
|
||||
inst.meta.loaded_info = json{};
|
||||
inst.meta.last_used = ggml_time_ms();
|
||||
|
||||
+35
-16
@@ -35,6 +35,19 @@ static inline void signal_handler(int signal) {
|
||||
shutdown_handler(signal);
|
||||
}
|
||||
|
||||
// satisfies -Wmissing-declarations (used by llama command)
|
||||
int llama_server(int argc, char ** argv);
|
||||
|
||||
// to be used via CLI (argc / argv are used by router mode only)
|
||||
int llama_server(common_params & params, int argc, char ** argv);
|
||||
void llama_server_terminate();
|
||||
void llama_server_terminate() {
|
||||
if (shutdown_handler) {
|
||||
shutdown_handler(0);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// wrapper function that handles exceptions and logs errors
|
||||
// this is to make sure handler_t never throws exceptions; instead, it returns an error response
|
||||
static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) {
|
||||
@@ -71,9 +84,6 @@ static server_http_context::handler_t ex_wrapper(server_http_context::handler_t
|
||||
};
|
||||
}
|
||||
|
||||
// satisfies -Wmissing-declarations
|
||||
int llama_server(int argc, char ** argv);
|
||||
|
||||
int llama_server(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
@@ -89,8 +99,14 @@ int llama_server(int argc, char ** argv) {
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
return llama_server(params, argc, argv);
|
||||
}
|
||||
|
||||
int llama_server(common_params & params, int argc, char ** argv) {
|
||||
bool is_run_by_cli = (argv == nullptr);
|
||||
|
||||
// note: router mode also accepts -hf remote-preset, so we need to check that first
|
||||
if (!params.model.hf_repo.empty()) {
|
||||
if (!is_run_by_cli && !params.model.hf_repo.empty()) {
|
||||
try {
|
||||
common_params_handle_models_params handle_params;
|
||||
handle_params.preset_only = true;
|
||||
@@ -272,8 +288,9 @@ int llama_server(int argc, char ** argv) {
|
||||
|
||||
if (child.is_child() && child.get_mode() == SERVER_CHILD_MODE_DOWNLOAD) {
|
||||
return child.run_download(params);
|
||||
} else if (!is_router_server) {
|
||||
} else if (!is_router_server && !is_run_by_cli) {
|
||||
// single-model mode (NOT spawned by router)
|
||||
// if this is invoked by CLI, model downloading should be already handled
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {});
|
||||
}
|
||||
|
||||
@@ -356,20 +373,22 @@ int llama_server(int argc, char ** argv) {
|
||||
};
|
||||
}
|
||||
|
||||
// TODO: refactor in common/console
|
||||
// register signal handler if not running by CLI
|
||||
if (!is_run_by_cli) {
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
sigint_action.sa_handler = signal_handler;
|
||||
sigemptyset (&sigint_action.sa_mask);
|
||||
sigint_action.sa_flags = 0;
|
||||
sigaction(SIGINT, &sigint_action, NULL);
|
||||
sigaction(SIGTERM, &sigint_action, NULL);
|
||||
struct sigaction sigint_action;
|
||||
sigint_action.sa_handler = signal_handler;
|
||||
sigemptyset (&sigint_action.sa_mask);
|
||||
sigint_action.sa_flags = 0;
|
||||
sigaction(SIGINT, &sigint_action, NULL);
|
||||
sigaction(SIGTERM, &sigint_action, NULL);
|
||||
#elif defined (_WIN32)
|
||||
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
|
||||
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
|
||||
};
|
||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
|
||||
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
|
||||
};
|
||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (is_router_server) {
|
||||
SRV_INF("router server is listening on %s\n", ctx_http.listening_address.c_str());
|
||||
|
||||
Reference in New Issue
Block a user