Compare commits

...

3 Commits

Author SHA1 Message Date
Matt Thompson dec5ca5577 server : Add id to tool call responses api (#24882) 2026-06-22 23:03:12 +02:00
Mahdiou Diallo 9c0ac887f3 ui: Prioritize favorite models in model selection (#24766)
Updated model selection prioritization to include favorite models.
2026-06-22 21:00:21 +02:00
Xuan-Son Nguyen 721354fbdf server: (router) move model downloading to dedicated process (#24834)
* server: real-time model load progress tracking via /models/sse

* update docs

* server: move model download to child process

* rm unused

* fix most problems

* clean up

* nit fixes

* fix test case

* do not detact() thread

* shorter MODEL_DOWNLOAD_TIMEOUT in test

* throttle
2026-06-22 18:24:04 +02:00
11 changed files with 327 additions and 156 deletions
+10 -4
View File
@@ -396,7 +396,7 @@ static bool parse_bool_value(const std::string & value) {
// CLI argument parsing functions
//
bool common_params_handle_models(common_params & params, llama_example curr_ex) {
bool common_params_handle_models(common_params & params, llama_example curr_ex, common_download_callback * callback) {
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
params.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
@@ -408,6 +408,10 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex)
opts.download_mtp = spec_type_draft_mtp;
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
if (callback) {
opts.callback = callback;
}
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
// so we should not auto-discover mtp/mmproj siblings for them
common_download_opts sub_opts = opts;
@@ -584,8 +588,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
}
// export_graph_ops loads only metadata
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
const bool skip_model_download =
// server will call common_params_handle_models() later, so we skip it here
ctx_arg.ex == LLAMA_EXAMPLE_SERVER ||
// export_graph_ops loads only metadata
ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
if (!skip_model_download) {
// handle model and download
@@ -594,7 +601,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
// model is required (except for server)
// TODO @ngxson : maybe show a list of available models in CLI in this case
if (params.model.path.empty()
&& ctx_arg.ex != LLAMA_EXAMPLE_SERVER
&& !params.usage
&& !params.completion) {
throw std::invalid_argument("error: --model is required\n");
+5 -1
View File
@@ -1,6 +1,7 @@
#pragma once
#include "common.h"
#include "download.h"
#include <set>
#include <map>
@@ -133,7 +134,10 @@ void common_params_add_preset_options(std::vector<common_arg> & args);
// return true if the model is ready to use
// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc)
// if params.skip_download is true, no downloads will be attempted. return false if the model is invalid or missing (e.g. ETag check failed)
bool common_params_handle_models(common_params & params, llama_example curr_ex);
bool common_params_handle_models(
common_params & params,
llama_example curr_ex,
common_download_callback * callback = nullptr);
// initialize argument parser context - used by test-arg-parser and preset
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
+3 -3
View File
@@ -204,9 +204,9 @@ Instead of building everything from the ground up (like what most AI agents will
The flow for downloading a new model:
- POST request comes in --> `post_router_models` --> validation
- `server_models::download()` is called
- Sets up a new thread `inst.th` and runs the download inside
- If a stop request comes in, set `stop_download` to `true`
- A new `llama-server` subprocess will be spawned with special `SERVER_CHILD_MODE_DOWNLOAD`
- Child process runs the download and report status back to router via stdin/out
- If a stop request comes in, the router asks the child process to stop (same mechanism as running a model in child process)
- Otherwise, upon completion, we call `load_models()` to refresh the list of models
### Notable Related PRs
+6
View File
@@ -931,6 +931,8 @@ private:
bool sleeping = false;
int64_t t_last_load_progress_ms = 0;
void destroy() {
spec.reset();
ctx_dft.reset();
@@ -1244,6 +1246,10 @@ private:
}
if (has_mmproj) {
if (callback_state) {
callback_state(SERVER_STATE_LOADING, {{"stage", "mmproj_model"}});
}
if (!is_resume) {
mtmd_helper_log_set(common_log_default_callback, nullptr);
}
+3 -1
View File
@@ -53,7 +53,7 @@ struct server_context_meta {
};
enum server_state {
// SERVER_STATE_DOWNLOADING,
SERVER_STATE_DOWNLOADING,
SERVER_STATE_LOADING,
SERVER_STATE_READY,
SERVER_STATE_SLEEPING,
@@ -61,6 +61,7 @@ enum server_state {
static std::string server_state_to_str(server_state state) {
switch (state) {
case SERVER_STATE_DOWNLOADING: return "downloading";
case SERVER_STATE_LOADING: return "loading";
case SERVER_STATE_READY: return "ready";
case SERVER_STATE_SLEEPING: return "sleeping";
@@ -69,6 +70,7 @@ static std::string server_state_to_str(server_state state) {
}
static server_state server_state_from_str(const std::string & str) {
if (str == "downloading") return SERVER_STATE_DOWNLOADING;
if (str == "loading") return SERVER_STATE_LOADING;
if (str == "ready") return SERVER_STATE_READY;
if (str == "sleeping") return SERVER_STATE_SLEEPING;
+230 -130
View File
@@ -64,6 +64,17 @@ struct server_subproc {
return sproc.has_value() && subprocess_alive(&sproc.value());
}
void request_exit() {
if (sproc.has_value()) {
FILE * stdin_file = subprocess_stdin(&sproc.value());
if (stdin_file) {
fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
fflush(stdin_file);
}
}
stopped.store(true, std::memory_order_relaxed);
}
void terminate() {
if (!sproc.has_value()) {
return;
@@ -323,7 +334,7 @@ void server_models::notify_sse(const std::string & event, const std::string & mo
}
void server_models::load_models() {
// Phase 1: load presets from all sources pure I/O, no lock needed
// Phase 1: load presets from all sources - pure I/O, no lock needed
// 1. cached models
common_presets cached_models = ctx_preset.load_from_cache();
SRV_INF("Loaded %zu cached model presets\n", cached_models.size());
@@ -376,7 +387,7 @@ void server_models::load_models() {
return source_map.count(name) ? source_map.at(name) : SERVER_MODEL_SOURCE_PRESET;
};
// Helpers that read `mapping` must be called while holding the lock.
// Helpers that read `mapping` - must be called while holding the lock.
std::unordered_set<std::string> custom_names;
for (const auto & [name, preset] : custom_presets) custom_names.insert(name);
auto join_set = [](const std::set<std::string> & s) {
@@ -523,7 +534,7 @@ void server_models::load_models() {
}
}
// join outside the lock monitoring thread calls update_status (needs lock)
// join outside the lock - monitoring thread calls update_status (needs lock)
lk.unlock();
for (auto & th : threads_to_join) th.join();
lk.lock();
@@ -622,7 +633,7 @@ void server_models::load_models() {
apply_stop_timeout();
// clear reload flag before unlocking for autoload load() blocks on !is_reloading,
// clear reload flag before unlocking for autoload - load() blocks on !is_reloading,
// so clearing it here (while still locked) prevents a deadlock in the autoload calls below
is_reloading = false;
cv.notify_all();
@@ -815,17 +826,23 @@ void server_models::unload_lru() {
}
void server_models::load(const std::string & name) {
if (!has_model(name)) {
throw std::runtime_error("model name=" + name + " is not found");
load(name, load_options{});
}
void server_models::load(const std::string & name, const load_options & opts) {
if (!opts.custom_meta.has_value()) {
if (!has_model(name)) {
throw std::runtime_error("model name=" + name + " is not found");
}
unload_lru();
}
unload_lru();
std::unique_lock<std::mutex> lk(mutex);
// edge case: block until any in-progress reload has finished so we always load
// against the freshest preset and a consistent mapping state
cv.wait(lk, [this]() { return !is_reloading; });
auto meta = mapping[name].meta;
auto meta = opts.custom_meta.has_value() ? *opts.custom_meta : mapping[name].meta;
if (meta.status != SERVER_MODEL_STATUS_UNLOADED) {
SRV_INF("model %s is not ready\n", name.c_str());
return;
@@ -869,6 +886,12 @@ void server_models::load(const std::string & name) {
std::vector<std::string> child_env = base_env; // copy
child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port));
if (opts.mode == SERVER_CHILD_MODE_DOWNLOAD) {
inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING;
child_env.push_back("LLAMA_SERVER_CHILD_MODE=download");
child_env.push_back("LLAMA_ARG_HF_REPO=" + name);
}
SRV_INF("%s", "spawning server instance with args:\n");
for (const auto & arg : child_args) {
SRV_INF(" %s\n", arg.c_str());
@@ -886,13 +909,17 @@ void server_models::load(const std::string & name) {
if (result != 0) {
throw std::runtime_error("failed to spawn server instance");
}
inst.stdin_file = subprocess_stdin(&inst.subproc->get());
}
// start a thread to manage the child process
// captured variables are guaranteed to be destroyed only after the thread is joined
inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port, stop_timeout = inst.meta.stop_timeout]() {
inst.th = std::thread([
this, name,
child_proc = inst.subproc,
port = inst.meta.port,
stop_timeout = inst.meta.stop_timeout,
child_mode = opts.mode
]() {
FILE * stdin_file = subprocess_stdin(&child_proc->get());
FILE * stdout_file = subprocess_stdout(&child_proc->get()); // combined stdout/stderr
@@ -925,7 +952,7 @@ void server_models::load(const std::string & name) {
return is_stopping() || child_proc->stopped.load(std::memory_order_acquire);
});
}
// child crashed or finished on its own skip graceful shutdown sequence
// child crashed or finished on its own, skip graceful shutdown sequence
if (child_proc->stopped.load(std::memory_order_acquire)) {
return;
}
@@ -973,10 +1000,14 @@ void server_models::load(const std::string & name) {
subprocess_destroy(&child_proc->get());
// update status and exit code
this->update_status(name, {
SERVER_MODEL_STATUS_UNLOADED,
exit_code
});
if (child_mode == SERVER_CHILD_MODE_DOWNLOAD) {
// instance will be cleaned up on next load_models() call
} else {
this->update_status(name, {
SERVER_MODEL_STATUS_UNLOADED,
exit_code
});
}
SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code);
});
@@ -984,7 +1015,7 @@ void server_models::load(const std::string & name) {
{
auto & old_instance = mapping[name];
// old process should have exited already, but just in case, we clean it up here
if (old_instance.subproc->is_alive()) {
if (old_instance.subproc && old_instance.subproc->is_alive()) {
SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
old_instance.subproc->terminate(); // force kill
}
@@ -1001,92 +1032,13 @@ void server_models::load(const std::string & name) {
cv.notify_all();
}
// callback for model downloading functionality
struct server_models_download_res : public common_download_callback {
common_params_model model;
common_download_opts opts;
std::function<bool()> should_stop;
std::function<void(const common_download_progress & p)> on_progress;
bool is_ok = false;
bool run() {
try {
common_download_model(model, opts);
is_ok = true;
} catch (const std::exception & e) {
auto model_name = model.get_name();
SRV_ERR("download failed for model name=%s: %s\n", model_name.c_str(), e.what());
is_ok = false;
}
return is_ok;
}
void on_start(const common_download_progress & p) override {
on_progress(p);
}
void on_update(const common_download_progress & p) override {
on_progress(p);
}
void on_done(const common_download_progress &, bool ok) override {
is_ok = ok;
}
bool is_cancelled() const override {
return should_stop();
}
};
void server_models::download(common_params_model && model, common_download_opts && opts) {
std::string name = model.get_name();
GGML_ASSERT(name == model.hf_repo);
std::unique_lock<std::mutex> lk(mutex);
if (mapping.find(name) != mapping.end()) {
throw std::runtime_error("model name=" + name + " already exists");
}
instance_t inst;
inst.meta.name = name;
inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING;
inst.subproc = std::make_shared<server_subproc>();
auto dl = std::make_unique<server_models_download_res>();
dl->model = model; // copy
dl->opts = opts; // copy
dl->should_stop = [sp = inst.subproc]() {
return sp->stopped.load(std::memory_order_relaxed);
};
dl->on_progress = [this, name](const common_download_progress & p) {
update_download_progress(name, p, false);
};
inst.th = std::thread([this, dl = std::move(dl)]() {
dl->opts.callback = dl.get();
bool ok = dl->run();
auto model_name = dl->model.get_name();
SRV_INF("download finished for model name=%s with status=%s\n",
model_name.c_str(), ok ? "success" : "failure");
update_download_progress(model_name, {}, true, ok);
// need_reload is set inside update_download_progress under the mutex;
// the next load_models() call will clean up this instance
});
mapping[name] = std::move(inst);
notify_sse("status_update", name, {
{"status", server_model_status_to_string(SERVER_MODEL_STATUS_DOWNLOADING)},
});
cv.notify_all();
}
void server_models::unload(const std::string & name) {
std::unique_lock<std::mutex> lk(mutex);
auto it = mapping.find(name);
if (it != mapping.end()) {
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
SRV_INF("cancelling download for model name=%s\n", name.c_str());
it->second.subproc->stopped.store(true, std::memory_order_relaxed);
it->second.subproc->request_exit();
// for convenience, we wait the status change here
wait(lk, name, [](const server_model_meta & new_meta) {
return new_meta.status != SERVER_MODEL_STATUS_DOWNLOADING;
@@ -1198,37 +1150,65 @@ void server_models::update_download_progress(const std::string & name, const com
}
bool server_models::remove(const std::string & name) {
auto meta = get_meta(name);
// do everything under one lock acquisition; avoid get_meta() /
// unload() because they can trigger load_models() which erases
// transient DOWNLOADING / DOWNLOADED entries as a side-effect
std::unique_lock<std::mutex> lk(mutex);
if (!meta.has_value()) {
auto it = mapping.find(name);
if (it == mapping.end()) {
throw std::runtime_error("model name=" + name + " is not found");
}
if (meta->source != SERVER_MODEL_SOURCE_CACHE) {
if (it->second.meta.source != SERVER_MODEL_SOURCE_CACHE) {
throw std::runtime_error("model name=" + name + " is not removable (not from cache)");
}
unload(name); // cancel download or stop running instance
{
std::unique_lock<std::mutex> lk(mutex);
// a cancelled download lands on DOWNLOADED; a stopped instance lands on UNLOADED
wait(lk, name, [](const server_model_meta & new_meta) {
return new_meta.status == SERVER_MODEL_STATUS_UNLOADED
|| new_meta.status == SERVER_MODEL_STATUS_DOWNLOADED;
});
// join before erasing - after status reaches UNLOADED/DOWNLOADED the thread no
// longer acquires this mutex, so joining while holding it is safe
if (mapping[name].th.joinable()) {
mapping[name].th.join();
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
// cancel in-flight download
SRV_INF("cancelling download for model name=%s\n", name.c_str());
it->second.subproc->request_exit();
} else if (it->second.meta.is_running()) {
// stop running instance
SRV_INF("stopping model instance name=%s\n", name.c_str());
stopping_models.insert(name);
if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) {
it->second.subproc->terminate();
}
// remove the model from disk (hold lock to prevent concurrent load)
bool ok = common_download_remove(name);
if (ok) {
mapping.erase(name);
}
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "failed");
notify_sse("model_remove", name, {});
return ok;
cv_stop.notify_all();
}
// wait until the monitoring thread finishes
wait(lk, name, [](const server_model_meta & meta) {
return meta.status == SERVER_MODEL_STATUS_UNLOADED
|| meta.status == SERVER_MODEL_STATUS_DOWNLOADED;
});
// re-find after wait - load_models() may have erased the entry during the wait
it = mapping.find(name);
if (it == mapping.end()) {
// load_models() already joined the thread and erased the entry;
// we just need to clean up the cached files on disk
lk.unlock();
bool ok = common_download_remove(name);
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial");
notify_sse("model_remove", name, {});
return true;
}
// join before erasing - thread no longer acquires this mutex
if (it->second.th.joinable()) {
it->second.th.join();
}
// remove from disk (best-effort: cancelled downloads may have no cached files)
bool ok = common_download_remove(name);
mapping.erase(name);
if (!ok) {
SRV_WRN("removing model name=%s from disk returned false (no cached files?)\n", name.c_str());
}
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial");
notify_sse("model_remove", name, {});
return true;
}
void server_models::wait(const std::string & name, std::function<bool(const server_model_meta &)> predicate) {
@@ -1243,7 +1223,9 @@ void server_models::wait(std::unique_lock<std::mutex> & lk, const std::string &
return predicate(it->second.meta);
}
return false;
// model was removed from mapping by another code path (e.g. load_models()).
// nothing left to wait for - tell the caller to proceed.
return true;
});
}
@@ -1328,6 +1310,31 @@ void server_models::handle_child_state(const std::string & name, const std::stri
}
switch (state) {
case SERVER_STATE_DOWNLOADING:
{
std::string result = json_value(payload, "result", std::string());
std::string url = json_value(payload, "url", std::string());
auto request_exit = [&]() {
std::lock_guard<std::mutex> lk(mutex);
auto it = mapping.find(name);
if (it != mapping.end()) {
return it->second.subproc->request_exit();
}
};
if (result == "download_finished") {
update_download_progress(name, {}, true, true);
request_exit();
} else if (result == "download_failed") {
update_download_progress(name, {}, true, false);
request_exit();
} else if (!url.empty()) {
common_download_progress p;
p.url = url;
p.downloaded = json_value(payload, "downloaded", (size_t)0);
p.total = json_value(payload, "total", (size_t)0);
update_download_progress(name, p, false);
}
} break;
case SERVER_STATE_LOADING:
{
update_status(name, {
@@ -1366,6 +1373,90 @@ bool server_child::is_child() {
return router_port != nullptr;
}
server_child_mode server_child::get_mode() {
const char * mode = std::getenv("LLAMA_SERVER_CHILD_MODE");
std::string mode_str(mode ? mode : "");
if (mode_str == "download") {
return SERVER_CHILD_MODE_DOWNLOAD;
} else {
return SERVER_CHILD_MODE_NORMAL;
}
}
struct server_download_state : public common_download_callback {
server_child * self;
std::function<bool()> should_stop;
std::atomic<int64_t> last_progress_time{0}; // multiple files downloading in different threads
bool is_ok = false;
server_download_state(server_child * s) : self(s) {}
bool run(common_params & params) {
try {
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, this);
is_ok = true;
} catch (const std::exception & e) {
auto model_name = params.model.get_name();
SRV_ERR("download failed for model name=%s: %s\n", model_name.c_str(), e.what());
is_ok = false;
}
return is_ok;
}
void on_progress(const common_download_progress & p) {
json data = {
{"url", p.url},
{"downloaded", p.downloaded},
{"total", p.total},
};
self->notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), data);
}
void on_start(const common_download_progress & p) override {
on_progress(p);
}
void on_update(const common_download_progress & p) override {
int64_t now = ggml_time_ms();
// throttle progress updates to avoid flooding logs
if (now - last_progress_time.load(std::memory_order_relaxed) >= 100) {
on_progress(p);
last_progress_time.store(now, std::memory_order_relaxed);
}
}
void on_done(const common_download_progress & p, bool) override {
on_progress(p);
}
bool is_cancelled() const override {
return should_stop ? should_stop() : false;
}
};
int server_child::run_download(common_params & params) {
auto cancelled = std::make_shared<std::atomic<bool>>(false);
// monitor stdin for cancellation command from the router
std::thread signal_thread = setup([cancelled](int) {
cancelled->store(true, std::memory_order_relaxed);
});
server_download_state dl(this);
dl.should_stop = [cancelled]() {
return cancelled->load(std::memory_order_relaxed);
};
bool ok = dl.run(params);
notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), {
{"result", ok ? "download_finished" : "download_failed"},
});
// router should send CMD_ROUTER_TO_CHILD_EXIT after receiving the result
if (signal_thread.joinable()) {
signal_thread.join();
}
SRV_INF("download completed %s\n", ok ? "successfully" : "with errors");
return 0;
}
std::thread server_child::setup(const std::function<void(int)> & shutdown_handler) {
// setup thread for monitoring stdin
return std::thread([shutdown_handler]() {
@@ -1639,7 +1730,7 @@ void server_models_routes::init_routes() {
res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
return res;
}
if (!model->is_running()) {
if (!model->is_running() && model->status != SERVER_MODEL_STATUS_DOWNLOADING) {
res_err(res, format_error_response("model is not running", ERROR_TYPE_INVALID_REQUEST));
return res;
}
@@ -1680,8 +1771,9 @@ void server_models_routes::init_routes() {
model.hf_repo = name;
opts.bearer_token = params.hf_token;
opts.download_mmproj = true;
opts.download_mtp = true;
// note: we only check main model, no need sidecar here
opts.download_mmproj = false;
opts.download_mtp = false;
// first, only check if the model is valid and can be downloaded
opts.skip_download = true;
@@ -1702,10 +1794,21 @@ void server_models_routes::init_routes() {
throw std::invalid_argument("model validation failed, unable to download");
}
// reject if model already exists
if (models.has_model(name)) {
throw std::invalid_argument("model '" + name + "' already exists");
}
// then, proceed with the actual download
opts.skip_download = false;
SRV_INF("starting download for model '%s'\n", name.c_str());
models.download(std::move(model), std::move(opts));
{
server_models::load_options load_opts;
load_opts.mode = SERVER_CHILD_MODE_DOWNLOAD;
load_opts.custom_meta = server_model_meta{};
load_opts.custom_meta->source = SERVER_MODEL_SOURCE_CACHE;
load_opts.custom_meta->name = name;
models.load(name, load_opts);
}
res_ok(res, {{"success", true}});
return res;
@@ -1719,10 +1822,7 @@ void server_models_routes::init_routes() {
throw std::invalid_argument("model must be a non-empty string");
}
bool ok = models.remove(name);
if (!ok) {
throw std::runtime_error("failed to remove model '" + name + "'");
}
models.remove(name); // throws on error
res_ok(res, {{"success", true}});
return res;
+15 -5
View File
@@ -40,6 +40,11 @@ enum server_model_source {
SERVER_MODEL_SOURCE_CACHE,
};
enum server_child_mode {
SERVER_CHILD_MODE_NORMAL, // load the model and run normally
SERVER_CHILD_MODE_DOWNLOAD, // download the model and exit
};
static std::string server_model_status_to_string(server_model_status status) {
switch (status) {
case SERVER_MODEL_STATUS_DOWNLOADING: return "downloading";
@@ -105,7 +110,6 @@ private:
std::shared_ptr<server_subproc> subproc; // shared between main thread and monitoring thread
std::thread th;
server_model_meta meta;
FILE * stdin_file = nullptr;
};
std::mutex mutex;
@@ -161,16 +165,19 @@ public:
// return a copy of all model metadata (thread-safe)
std::vector<server_model_meta> get_all_meta();
struct load_options {
server_child_mode mode = SERVER_CHILD_MODE_NORMAL;
// used for spawning a downloading child process
std::optional<server_model_meta> custom_meta = std::nullopt;
};
// load and unload model instances
// these functions are thread-safe
void load(const std::string & name);
void load(const std::string & name, const load_options & opts);
void unload(const std::string & name);
void unload_all();
// download a new model, progress is reported via SSE
// to stop the download, call unload()
void download(common_params_model && model, common_download_opts && opts);
struct update_status_args {
server_model_status status;
int exit_code = 0; // only valid if status == UNLOADED
@@ -213,9 +220,12 @@ public:
struct server_child {
// serializes the notify_to_router writes
std::mutex mtx_stdout;
std::atomic<bool> is_finished_downloading = false; // set by run_download
// return true if the current process is a child server instance
bool is_child();
server_child_mode get_mode();
int run_download(common_params & params);
// register the shutdown_handler to be called by the router
// return the monitoring thread (to be joined by the caller)
+6 -3
View File
@@ -591,10 +591,11 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp() {
for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) {
output.push_back(json {
{"id", "fc_" + tool_call.id},
{"type", "function_call"},
{"status", "completed"},
{"arguments", tool_call.arguments},
{"call_id", "fc_" + tool_call.id},
{"call_id", "call_" + tool_call.id},
{"name", tool_call.name},
});
}
@@ -690,10 +691,11 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() {
for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) {
const json output_item = {
{"id", "fc_" + tool_call.id},
{"type", "function_call"},
{"status", "completed"},
{"arguments", tool_call.arguments},
{"call_id", "fc_" + tool_call.id},
{"call_id", "call_" + tool_call.id},
{"name", tool_call.name}
};
server_sent_events.push_back(json {
@@ -1277,8 +1279,9 @@ json server_task_result_cmpl_partial::to_json_oaicompat_resp() {
{"data", json {
{"type", "response.output_item.added"},
{"item", json {
{"id", "fc_" + diff.tool_call_delta.id},
{"arguments", ""},
{"call_id", "fc_" + diff.tool_call_delta.id},
{"call_id", "call_" + diff.tool_call_delta.id},
{"name", diff.tool_call_delta.name},
{"type", "function_call"},
{"status", "in_progress"},
+12 -1
View File
@@ -134,6 +134,7 @@ int llama_server(int argc, char ** argv) {
//
// register API routes
server_child child; // only used in non-router mode
server_routes routes(params, ctx_server);
server_tools tools;
@@ -254,11 +255,21 @@ int llama_server(int argc, char ** argv) {
ctx_http.post("/tools", ex_wrapper(tools.handle_post));
}
//
// Handle downloading model
//
if (child.is_child() && child.get_mode() == SERVER_CHILD_MODE_DOWNLOAD) {
return child.run_download(params);
} else if (!is_router_server) {
// single-model mode (NOT spawned by router)
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
}
//
// Start the server
//
server_child child; // only used in non-router mode
std::function<void()> clean_up;
if (is_router_server) {
+28 -7
View File
@@ -257,14 +257,25 @@ def test_router_reload_models():
MODEL_DOWNLOAD_ID = "ggml-org/test-model-router-download:F16"
MODEL_DOWNLOAD_TIMEOUT = 300
MODEL_DOWNLOAD_TIMEOUT = 30
def _listen_sse(server: ServerProcess, collected: list, stop: threading.Event):
"""Collect /models/sse events into `collected` until `stop` is set."""
def _listen_sse(
server: ServerProcess, collected: list, stop: threading.Event, ready: threading.Event | None = None
):
"""Collect /models/sse events into `collected` until `stop` is set.
When `ready` is provided, it is set once the streaming response is open,
i.e. the server has accepted the connection and registered us as a
subscriber. Callers that trigger one-shot events (e.g. download_finished)
must wait on `ready` before acting, otherwise the event can be broadcast
before this client is subscribed and be lost.
"""
url = f"http://{server.server_host}:{server.server_port}/models/sse"
try:
with requests.get(url, stream=True, timeout=MODEL_DOWNLOAD_TIMEOUT) as resp:
if ready is not None:
ready.set()
for line_bytes in resp.iter_lines():
if stop.is_set():
break
@@ -294,11 +305,17 @@ def test_router_download_model():
sse_events: list = []
stop = threading.Event()
sse_ready = threading.Event()
sse_thread = threading.Thread(
target=_listen_sse, args=(server, sse_events, stop), daemon=True
target=_listen_sse, args=(server, sse_events, stop, sse_ready), daemon=True
)
sse_thread.start()
# wait for the SSE client to be subscribed before triggering the download,
# otherwise the one-shot download_finished event can be broadcast before
# this client is registered and be lost
assert sse_ready.wait(10), "SSE client failed to connect"
# Trigger the download
res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID})
assert res.status_code == 200
@@ -328,13 +345,17 @@ def test_router_delete_model():
# Ensure the model exists (download it if needed)
if MODEL_DOWNLOAD_ID not in _get_model_ids(is_reload=False):
res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID})
assert res.status_code == 200
sse_events: list = []
stop = threading.Event()
sse_ready = threading.Event()
threading.Thread(
target=_listen_sse, args=(server, sse_events, stop), daemon=True
target=_listen_sse, args=(server, sse_events, stop, sse_ready), daemon=True
).start()
# subscribe before triggering the download so the one-shot
# download_finished event is not lost (see test_router_download_model)
assert sse_ready.wait(10), "SSE client failed to connect"
res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID})
assert res.status_code == 200
finished = _wait_for_sse_event(
sse_events, "download_finished", MODEL_DOWNLOAD_ID, MODEL_DOWNLOAD_TIMEOUT
)
+9 -1
View File
@@ -545,7 +545,8 @@ class ModelsStore {
* 1. Model from active conversation's last assistant response (if loaded)
* 2. Model from active conversation's last assistant response (if not loaded)
* 3. First loaded model (not from active conversation)
* 4. First available model
* 4. A favorite model
* 5. First available model
*/
async ensureFirstModelSelected(): Promise<void> {
if (this.selectedModelName) return;
@@ -574,6 +575,13 @@ class ModelsStore {
return;
}
// Try loading a favorite model
const favorite = this.favoriteModelIds.values().next()?.value
if (favorite) {
await this.selectModelById(favorite);
return;
}
// Fall back to the first available model
await this.selectModelById(availableModels[0].id);
}