mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-18 19:57:46 +02:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f3e1828164 | |||
| 2e88c49c90 | |||
| 0843245cb1 | |||
| 8d2e580632 | |||
| 4b4d13ae72 |
@@ -997,3 +997,87 @@ std::vector<common_cached_model_info> common_list_cached_models() {
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
bool common_download_remove(const std::string & hf_repo_with_tag) {
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
auto [repo_id, tag] = common_download_split_repo_tag(hf_repo_with_tag);
|
||||
|
||||
if (tag.empty()) {
|
||||
return hf_cache::remove_cached_repo(repo_id);
|
||||
}
|
||||
|
||||
std::string tag_upper = tag;
|
||||
for (char & c : tag_upper) {
|
||||
c = (char) std::toupper((unsigned char) c);
|
||||
}
|
||||
|
||||
auto files = hf_cache::get_cached_files(repo_id);
|
||||
if (files.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// collect snapshot entries whose tag matches
|
||||
std::vector<fs::path> to_remove;
|
||||
for (const auto & f : files) {
|
||||
auto split = get_gguf_split_info(f.path);
|
||||
if (split.tag == tag_upper) {
|
||||
to_remove.emplace_back(f.local_path);
|
||||
}
|
||||
}
|
||||
|
||||
if (to_remove.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// resolve blob paths from symlinks before deleting snapshot entries
|
||||
std::vector<fs::path> blobs_to_check;
|
||||
for (const auto & p : to_remove) {
|
||||
std::error_code ec;
|
||||
if (fs::is_symlink(p, ec)) {
|
||||
auto target = fs::read_symlink(p, ec);
|
||||
if (!ec) {
|
||||
blobs_to_check.push_back((p.parent_path() / target).lexically_normal());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// remove snapshot entries
|
||||
for (const auto & p : to_remove) {
|
||||
std::error_code ec;
|
||||
fs::remove(p, ec);
|
||||
if (ec) {
|
||||
LOG_WRN("%s: failed to remove %s: %s\n", __func__, p.string().c_str(), ec.message().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (blobs_to_check.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// collect blobs still referenced by remaining snapshot entries
|
||||
std::unordered_set<std::string> still_referenced;
|
||||
for (const auto & f : hf_cache::get_cached_files(repo_id)) {
|
||||
fs::path p(f.local_path);
|
||||
std::error_code ec;
|
||||
if (fs::is_symlink(p, ec)) {
|
||||
auto target = fs::read_symlink(p, ec);
|
||||
if (!ec) {
|
||||
still_referenced.insert((p.parent_path() / target).lexically_normal().string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// remove orphaned blobs
|
||||
for (const auto & blob : blobs_to_check) {
|
||||
if (still_referenced.find(blob.string()) == still_referenced.end()) {
|
||||
std::error_code ec;
|
||||
fs::remove(blob, ec);
|
||||
if (ec) {
|
||||
LOG_WRN("%s: failed to remove blob %s: %s\n", __func__, blob.string().c_str(), ec.message().c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -115,3 +115,10 @@ int common_download_file_single(const std::string & url,
|
||||
// resolve and download model from Docker registry
|
||||
// return local path to downloaded model file
|
||||
std::string common_docker_resolve_model(const std::string & docker);
|
||||
|
||||
// Remove a cached model from disk
|
||||
// input format: "user/model" or "user/model:tag"
|
||||
// - if tag is omitted, removes the entire repo cache directory
|
||||
// - if tag is present, removes only files matching that tag (and orphaned blobs)
|
||||
// returns true if anything was removed
|
||||
bool common_download_remove(const std::string & hf_repo_with_tag);
|
||||
|
||||
@@ -495,4 +495,19 @@ std::string finalize_file(const hf_file & file) {
|
||||
return file.final_path;
|
||||
}
|
||||
|
||||
bool remove_cached_repo(const std::string & repo_id) {
|
||||
if (!is_valid_repo_id(repo_id)) {
|
||||
LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str());
|
||||
return false;
|
||||
}
|
||||
fs::path repo_path = get_repo_path(repo_id);
|
||||
std::error_code ec;
|
||||
auto removed = fs::remove_all(repo_path, ec);
|
||||
if (ec) {
|
||||
LOG_ERR("%s: failed to remove repo cache %s: %s\n", __func__, repo_path.string().c_str(), ec.message().c_str());
|
||||
return false;
|
||||
}
|
||||
return removed > 0;
|
||||
}
|
||||
|
||||
} // namespace hf_cache
|
||||
|
||||
@@ -29,4 +29,7 @@ hf_files get_cached_files(const std::string & repo_id = {});
|
||||
// Create snapshot path (link or move/copy) and return it
|
||||
std::string finalize_file(const hf_file & file);
|
||||
|
||||
// Remove the entire cached directory for a repo, returns true if removed
|
||||
bool remove_cached_repo(const std::string & repo_id);
|
||||
|
||||
} // namespace hf_cache
|
||||
|
||||
@@ -438,7 +438,14 @@ if (GGML_CPU_ALL_VARIANTS)
|
||||
ggml_add_cpu_backend_variant(power8_2 POWER8 VSX)
|
||||
ggml_add_cpu_backend_variant(power9 POWER9 VSX)
|
||||
ggml_add_cpu_backend_variant(power10 POWER10 VSX)
|
||||
ggml_add_cpu_backend_variant(power11 POWER11 VSX)
|
||||
# POWER11 backend: only if compiler supports -mcpu=power11
|
||||
check_cxx_compiler_flag("-mcpu=power11" GGML_CXX_SUPPORTS_POWER11)
|
||||
if (GGML_CXX_SUPPORTS_POWER11)
|
||||
message(STATUS "Compiler supports -mcpu=power11, enabling POWER11 backend")
|
||||
ggml_add_cpu_backend_variant(power11 POWER11 VSX)
|
||||
else()
|
||||
message(STATUS "Skipping POWER11 backend: compiler does not support -mcpu=power11")
|
||||
endif()
|
||||
else()
|
||||
message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
|
||||
@@ -389,7 +389,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M_UPPER}")
|
||||
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
|
||||
|
||||
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
|
||||
if (EXTRACTED_NUMBER EQUAL 10 OR EXTRACTED_NUMBER EQUAL 11)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power10)
|
||||
elseif (EXTRACTED_NUMBER EQUAL 9)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power9)
|
||||
|
||||
@@ -66,7 +66,6 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml
|
||||
const char * op_str = "undefined";
|
||||
switch (op) {
|
||||
case GGML_OP_ADD_ID: op_str = "add_id"; break;
|
||||
case GGML_OP_CONCAT: op_str = "concat"; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
@@ -211,6 +210,21 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_meta
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_concat(ggml_metal_library_t lib, ggml_type tsrc) {
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_concat_%s", ggml_type_name(tsrc));
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
char base[256];
|
||||
char name[256];
|
||||
@@ -1689,7 +1703,9 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm(ggml_metal_
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_ROPE);
|
||||
assert(op->op == GGML_OP_ROPE || op->op == GGML_OP_ROPE_BACK);
|
||||
|
||||
const bool is_back = op->op == GGML_OP_ROPE_BACK;
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
@@ -1713,13 +1729,14 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_
|
||||
snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type));
|
||||
}
|
||||
|
||||
snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
|
||||
snprintf(name, 256, "%s_imrope=%d_is_back=%d", base, is_imrope ? 1 : 0, is_back ? 1 : 0);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
|
||||
ggml_metal_cv_set_bool(cv, is_back, FC_ROPE + 1);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
|
||||
@@ -115,6 +115,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_concat (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
@@ -1123,13 +1123,24 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return true;
|
||||
case GGML_OP_CONCAT:
|
||||
{
|
||||
// kernel_concat copies one float-sized value per element.
|
||||
// Other scalar types need a type-generic copy kernel first.
|
||||
const enum ggml_type src0_type = op->src[0]->type;
|
||||
const enum ggml_type src1_type = op->src[1]->type;
|
||||
return src0_type == src1_type &&
|
||||
src0_type == op->type &&
|
||||
(src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_I32);
|
||||
if (src0_type != src1_type || src0_type != op->type) {
|
||||
return false;
|
||||
}
|
||||
switch (src0_type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
case GGML_TYPE_I64:
|
||||
return true;
|
||||
case GGML_TYPE_BF16:
|
||||
return has_bfloat;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
@@ -1173,6 +1184,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_OP_RMS_NORM:
|
||||
return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0]));
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK:
|
||||
return true;
|
||||
case GGML_OP_IM2COL:
|
||||
return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
|
||||
|
||||
@@ -375,6 +375,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
n_fuse = ggml_metal_op_norm(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK:
|
||||
{
|
||||
n_fuse = ggml_metal_op_rope(ctx, idx);
|
||||
} break;
|
||||
@@ -556,7 +557,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
|
||||
/*.dim =*/ dim,
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_concat(lib, op->type);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
|
||||
@@ -4358,6 +4358,7 @@ template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_
|
||||
#endif
|
||||
|
||||
constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
|
||||
constant bool FC_rope_is_back [[function_constant(FC_ROPE + 1)]];
|
||||
|
||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||
@@ -4381,6 +4382,9 @@ static void rope_yarn(
|
||||
}
|
||||
*cos_theta = cos(theta) * mscale;
|
||||
*sin_theta = sin(theta) * mscale;
|
||||
if (FC_rope_is_back) {
|
||||
*sin_theta *= -1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
||||
@@ -7513,14 +7517,15 @@ template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<
|
||||
template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_concat(
|
||||
constant ggml_metal_kargs_concat & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
constant ggml_metal_kargs_concat & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int i3 = tgpig.z;
|
||||
const int i2 = tgpig.y;
|
||||
@@ -7533,21 +7538,31 @@ kernel void kernel_concat(
|
||||
int o[4] = {0, 0, 0, 0};
|
||||
o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
|
||||
|
||||
device const float * x;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
device const T * x;
|
||||
|
||||
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
|
||||
x = (device const float *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00);
|
||||
x = (device const T *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00);
|
||||
} else {
|
||||
x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);
|
||||
x = (device const T *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);
|
||||
}
|
||||
|
||||
device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
device T * y = (device T *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
*y = *x;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_concat<float>) kernel_concat_t;
|
||||
|
||||
template [[host_name("kernel_concat_f32")]] kernel kernel_concat_t kernel_concat<float>;
|
||||
template [[host_name("kernel_concat_f16")]] kernel kernel_concat_t kernel_concat<half>;
|
||||
template [[host_name("kernel_concat_bf16")]] kernel kernel_concat_t kernel_concat<bfloat>;
|
||||
template [[host_name("kernel_concat_i8")]] kernel kernel_concat_t kernel_concat<char>;
|
||||
template [[host_name("kernel_concat_i16")]] kernel kernel_concat_t kernel_concat<short>;
|
||||
template [[host_name("kernel_concat_i32")]] kernel kernel_concat_t kernel_concat<int>;
|
||||
template [[host_name("kernel_concat_i64")]] kernel kernel_concat_t kernel_concat<long>;
|
||||
|
||||
template<int nr0, typename args_t>
|
||||
void kernel_mul_mv_q2_K_f32_impl(
|
||||
args_t args,
|
||||
|
||||
@@ -1105,6 +1105,8 @@ bool mtmd_image_preprocessor_internvl::preprocess(const clip_image_u8 & img, cli
|
||||
img_u8_to_f32(*imgs[i], *res, hparams.image_mean, hparams.image_std);
|
||||
output.entries.push_back(std::move(res));
|
||||
}
|
||||
output.grid_x = inst.grid_size.width;
|
||||
output.grid_y = inst.grid_size.height;
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1558,3 +1560,22 @@ bool mtmd_image_preprocessor_youtuvl::preprocess(const clip_image_u8 & img, clip
|
||||
output.entries.push_back(std::move(img_f32));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool mtmd_image_preprocessor_granite::preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) {
|
||||
// call super class preprocessor
|
||||
bool ok = mtmd_image_preprocessor_llava_uhd::preprocess(img, output);
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
if (output.entries.size() == 1) {
|
||||
// Single-tile (overview only): append one newline row.
|
||||
output.entries[0]->add_newline = true;
|
||||
} else {
|
||||
// Multi-tile: overview gets no newline, grid tiles get one.
|
||||
output.entries[0]->add_newline = false;
|
||||
for (size_t i = 1; i < output.entries.size(); ++i) {
|
||||
output.entries[i]->add_newline = true;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -197,3 +197,9 @@ struct mtmd_image_preprocessor_youtuvl : mtmd_image_preprocessor {
|
||||
mtmd_image_preprocessor_youtuvl(const clip_ctx * ctx) : mtmd_image_preprocessor(ctx) {}
|
||||
bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) override;
|
||||
};
|
||||
|
||||
// similar to llava_uhd, but has add_newline
|
||||
struct mtmd_image_preprocessor_granite : mtmd_image_preprocessor_llava_uhd {
|
||||
mtmd_image_preprocessor_granite(const clip_ctx * ctx) : mtmd_image_preprocessor_llava_uhd(ctx) {}
|
||||
bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) override;
|
||||
};
|
||||
|
||||
+6
-25
@@ -639,7 +639,7 @@ struct mtmd_context {
|
||||
{
|
||||
img_beg = "<image>";
|
||||
img_end = "";
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_llava_uhd>(ctx_v);
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_granite>(ctx_v);
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error(string_format("%s: unexpected vision projector type %d\n", __func__, proj));
|
||||
@@ -1033,7 +1033,10 @@ struct mtmd_tokenizer {
|
||||
int32_t add_media(std::vector<const mtmd_bitmap *> & bitmaps) {
|
||||
GGML_ASSERT(!bitmaps.empty());
|
||||
|
||||
if (!bitmaps[0]->is_audio) {
|
||||
// note: only one type of media is supported per call, caller should enforce this
|
||||
const bool is_vision = !bitmaps[0]->is_audio;
|
||||
|
||||
if (is_vision) {
|
||||
// handle image
|
||||
|
||||
if (!ctx->ctx_v) {
|
||||
@@ -1085,31 +1088,9 @@ struct mtmd_tokenizer {
|
||||
batch_f32.grid_y = tmp_batch.grid_y;
|
||||
}
|
||||
|
||||
// Annotate llava-next style tiles so clip_n_output_tokens accounts
|
||||
// for per-tile newline injection.
|
||||
if (ctx->proj_type_v() == PROJECTOR_TYPE_GRANITE4_VISION) {
|
||||
if (batch_f32.entries.size() == 1) {
|
||||
// Single-tile (overview only): append one newline row.
|
||||
batch_f32.entries[0]->add_newline = true;
|
||||
} else {
|
||||
// Multi-tile: overview gets no newline, grid tiles get one.
|
||||
batch_f32.entries[0]->add_newline = false;
|
||||
for (size_t i = 1; i < batch_f32.entries.size(); ++i) {
|
||||
batch_f32.entries[i]->add_newline = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handle llava-uhd style preprocessing
|
||||
const bool has_tiling_grid = batch_f32.grid_x > 0 && batch_f32.grid_y > 0;
|
||||
if (
|
||||
ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5
|
||||
|| ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6
|
||||
|| ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4
|
||||
|| ctx->slice_tmpl == MTMD_SLICE_TMPL_IDEFICS3
|
||||
|| ctx->slice_tmpl == MTMD_SLICE_TMPL_STEP3VL
|
||||
|| (ctx->slice_tmpl == MTMD_SLICE_TMPL_LFM2 && has_tiling_grid)
|
||||
) {
|
||||
if (has_tiling_grid) {
|
||||
// [QWEN_VIDEO] we do not support "frame merging" for llama-uhd style, so no batching for now
|
||||
GGML_ASSERT(bitmaps.size() == 1);
|
||||
|
||||
|
||||
@@ -180,6 +180,24 @@ That requires `JSON.stringify` when formatted to message content:
|
||||
}
|
||||
```
|
||||
|
||||
### Model management API (router mode)
|
||||
|
||||
Model management API was added via PR [#23976](https://github.com/ggml-org/llama.cpp/pull/23976)
|
||||
|
||||
The main goal of this API is to allow downloading models and/or removing models from the web UI. It relies on the model cache infrastructure under the hood to manage the list of models dynamically.
|
||||
|
||||
Instead of building everything from the ground up (like what most AI agents will do when you ask them to implement a similar feature), we built on top of existing, already well-engineered components inside the codebase:
|
||||
- Model cache infrastructure as mentioned above (`common/download.h`)
|
||||
- Server response queue (`server-queue.h`). We use this feature to broadcast events to SSE clients.
|
||||
- Server router thread management (`server-models.h`). We re-use the same thread model that is used for managing subprocess life cycle, except that we don't create a new subprocess, but launch the download right inside the thread.
|
||||
|
||||
The flow for downloading a new model:
|
||||
- POST request comes in --> `post_router_models` --> validation
|
||||
- `server_models::download()` is called
|
||||
- Sets up a new thread `inst.th` and runs the download inside
|
||||
- If a stop request comes in, set `stop_download` to `true`
|
||||
- Otherwise, upon completion, we call `load_models()` to refresh the list of models
|
||||
|
||||
### Notable Related PRs
|
||||
|
||||
- Initial server implementation: https://github.com/ggml-org/llama.cpp/pull/1443
|
||||
|
||||
@@ -1778,6 +1778,20 @@ The `status` object can be:
|
||||
}
|
||||
```
|
||||
|
||||
Note: for "downloading" state, there can be multiple files be downloading in parallel
|
||||
|
||||
```json
|
||||
"status": {
|
||||
"value": "downloading",
|
||||
"progress": {
|
||||
"https://...model.gguf": {
|
||||
"done": 195963406,
|
||||
"total": 219307424
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### POST `/models/load`: Load a model
|
||||
|
||||
Load a model
|
||||
@@ -1820,6 +1834,107 @@ Response:
|
||||
}
|
||||
```
|
||||
|
||||
### GET `/models/sse`: Real-time events
|
||||
|
||||
Example events:
|
||||
|
||||
```js
|
||||
{
|
||||
"model": "...",
|
||||
"event": "model_status",
|
||||
"data": {
|
||||
"status": "loading"
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
"model": "...",
|
||||
"event": "download_progress",
|
||||
"data": {
|
||||
// note: there can be multiple files being downloaded in parallel
|
||||
"https://...model.gguf": {
|
||||
"done": 195963406,
|
||||
"total": 219307424
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
"model": "...",
|
||||
"event": "download_finished",
|
||||
"data": {
|
||||
"status": "loading"
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
"model": "...",
|
||||
"event": "model_remove"
|
||||
}
|
||||
|
||||
// special event: reload of the list of all models
|
||||
{
|
||||
"model": "*",
|
||||
"event": "models_reload"
|
||||
}
|
||||
```
|
||||
|
||||
### POST `/models`: Download new model
|
||||
|
||||
Trigger a new download (non-blocking), the progress can be tracked via SSE endpoint `/models/sse`
|
||||
|
||||
To cancel model downloading, send an event to `/models/unload`
|
||||
|
||||
Download procedure:
|
||||
- Send POST request to `/models`
|
||||
- Subscribe to `/models/sse` for updates
|
||||
- On downloading completed, you will receive either `download_finished` or `download_failed` event
|
||||
- Call GET `/models` to trigger model list update. If the download success, you should see the new model in the list
|
||||
|
||||
Payload:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M",
|
||||
}
|
||||
```
|
||||
|
||||
Response (download is started in the background):
|
||||
|
||||
```json
|
||||
{
|
||||
"success": true
|
||||
}
|
||||
```
|
||||
|
||||
Response (error, cannot start the download):
|
||||
|
||||
```json
|
||||
{
|
||||
"error": {
|
||||
"code": 400,
|
||||
"message": "model validation failed, unable to download",
|
||||
"type": "invalid_request_error"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### DELETE `/models`: Delete a model from cache
|
||||
|
||||
IMPORTANT: only model stored in cache can be deleted. You cannot delete models in a preset.
|
||||
|
||||
Model name must be passed via query param: `?model={name}`
|
||||
|
||||
If delete success, it will send an SSE event of type `model_remove`
|
||||
|
||||
Response:
|
||||
|
||||
```json
|
||||
{
|
||||
"success": true
|
||||
}
|
||||
```
|
||||
|
||||
## API errors
|
||||
|
||||
`llama-server` returns errors in the same format as OAI: https://github.com/openai/openai-openapi
|
||||
|
||||
@@ -588,6 +588,23 @@ void server_http_context::post(const std::string & path, const server_http_conte
|
||||
});
|
||||
}
|
||||
|
||||
void server_http_context::del(const std::string & path, const server_http_context::handler_t & handler) const {
|
||||
handlers.emplace(path, handler);
|
||||
pimpl->srv->Delete(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
|
||||
server_http_req_ptr request = std::make_unique<server_http_req>(server_http_req{
|
||||
get_params(req),
|
||||
get_headers(req),
|
||||
req.path,
|
||||
build_query_string(req),
|
||||
req.body,
|
||||
{},
|
||||
req.is_connection_closed
|
||||
});
|
||||
server_http_res_ptr response = handler(*request);
|
||||
process_handler_response(std::move(request), response, res);
|
||||
});
|
||||
}
|
||||
|
||||
//
|
||||
// Vertex AI Prediction protocol (AIP_PREDICT_ROUTE)
|
||||
// https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements
|
||||
|
||||
@@ -86,6 +86,7 @@ struct server_http_context {
|
||||
|
||||
void get(const std::string & path, const handler_t & handler) const;
|
||||
void post(const std::string & path, const handler_t & handler) const;
|
||||
void del(const std::string & path, const handler_t & handler) const;
|
||||
|
||||
// Register the Google Cloud Platform (Vertex AI) compat (AIP_PREDICT_ROUTE env var, or /predict)
|
||||
// Must be called AFTER all other API routes are registered
|
||||
|
||||
+398
-35
@@ -9,6 +9,7 @@
|
||||
#include <sheredom/subprocess.h>
|
||||
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <algorithm>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
@@ -51,6 +52,21 @@ extern char **environ;
|
||||
// ref: https://github.com/ggml-org/llama.cpp/issues/17862
|
||||
#define CHILD_ADDR "127.0.0.1"
|
||||
|
||||
struct server_subproc {
|
||||
std::optional<subprocess_s> sproc; // empty while in DOWNLOADING state
|
||||
std::atomic<bool> stop_download{false}; // flag to signal download cancellation
|
||||
|
||||
subprocess_s & get() {
|
||||
GGML_ASSERT(sproc.has_value() && "subprocess not initialized");
|
||||
return sproc.value();
|
||||
}
|
||||
|
||||
bool is_alive() {
|
||||
return sproc.has_value() && subprocess_alive(&sproc.value());
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
static std::filesystem::path get_server_exec_path() {
|
||||
#if defined(_WIN32)
|
||||
wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths
|
||||
@@ -272,12 +288,25 @@ void server_models::add_model(server_model_meta && meta) {
|
||||
meta.update_caps();
|
||||
std::string name = meta.name;
|
||||
mapping[name] = instance_t{
|
||||
/* subproc */ std::make_shared<subprocess_s>(),
|
||||
/* subproc */ std::make_shared<server_subproc>(),
|
||||
/* th */ std::thread(),
|
||||
/* meta */ std::move(meta)
|
||||
};
|
||||
}
|
||||
|
||||
void server_models::notify_sse(const std::string & event, const std::string & model_id, const json & data) {
|
||||
std::unique_ptr<server_task_result_router> result = std::make_unique<server_task_result_router>();
|
||||
result->data = {
|
||||
{"model", model_id},
|
||||
{"event", event},
|
||||
};
|
||||
if (!data.is_null()) {
|
||||
result->data["data"] = data;
|
||||
}
|
||||
SRV_DBG("notifying SSE clients about event '%s' for model '%s': %s\n", event.c_str(), model_id.c_str(), safe_json_to_str(result->data).c_str());
|
||||
sse.broadcast(std::move(result));
|
||||
}
|
||||
|
||||
void server_models::load_models() {
|
||||
// Phase 1: load presets from all sources — pure I/O, no lock needed
|
||||
// 1. cached models
|
||||
@@ -304,19 +333,27 @@ void server_models::load_models() {
|
||||
|
||||
// note: if a model exists in both cached and local, local takes precedence
|
||||
common_presets final_presets;
|
||||
for (const auto & [name, preset] : cached_models) final_presets[name] = preset;
|
||||
for (const auto & [name, preset] : local_models) final_presets[name] = preset;
|
||||
std::unordered_map<std::string, server_model_source> source_map;
|
||||
for (const auto & [name, preset] : cached_models) {
|
||||
final_presets[name] = preset;
|
||||
source_map[name] = SERVER_MODEL_SOURCE_CACHE;
|
||||
}
|
||||
for (const auto & [name, preset] : local_models) {
|
||||
final_presets[name] = preset;
|
||||
source_map[name] = SERVER_MODEL_SOURCE_MODELS_DIR;
|
||||
}
|
||||
for (const auto & [name, custom] : custom_presets) {
|
||||
if (final_presets.find(name) != final_presets.end()) {
|
||||
final_presets[name].merge(custom);
|
||||
} else {
|
||||
final_presets[name] = custom;
|
||||
}
|
||||
source_map[name] = SERVER_MODEL_SOURCE_PRESET;
|
||||
}
|
||||
// server base preset from CLI args takes highest precedence
|
||||
for (auto & [name, preset] : final_presets) {
|
||||
preset.merge(base_preset);
|
||||
}
|
||||
|
||||
auto get_source = [&](const std::string & name) {
|
||||
return source_map.count(name) ? source_map.at(name) : SERVER_MODEL_SOURCE_PRESET;
|
||||
};
|
||||
|
||||
// Helpers that read `mapping` — must be called while holding the lock.
|
||||
std::unordered_set<std::string> custom_names;
|
||||
@@ -366,12 +403,15 @@ void server_models::load_models() {
|
||||
// (unload, load) or when joining threads (the monitoring thread calls update_status
|
||||
// which locks the mutex, so joining while holding it would deadlock).
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
|
||||
need_reload = false;
|
||||
bool is_first_load = mapping.empty();
|
||||
|
||||
if (is_first_load) {
|
||||
// FIRST LOAD: add all models, then unlock for autoloading
|
||||
for (const auto & [name, preset] : final_presets) {
|
||||
server_model_meta meta{
|
||||
/* source */ get_source(name),
|
||||
/* preset */ preset,
|
||||
/* name */ name,
|
||||
/* aliases */ {},
|
||||
@@ -384,7 +424,7 @@ void server_models::load_models() {
|
||||
/* exit_code */ 0,
|
||||
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
|
||||
/* multimodal */ mtmd_caps{false, false},
|
||||
/* need_download */ false,
|
||||
// /* need_download */ false,
|
||||
};
|
||||
add_model(std::move(meta));
|
||||
}
|
||||
@@ -453,6 +493,9 @@ void server_models::load_models() {
|
||||
}
|
||||
}
|
||||
for (auto & [name, inst] : mapping) {
|
||||
if (inst.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
|
||||
continue; // downloading models are not from config sources, leave them alone
|
||||
}
|
||||
if (final_presets.find(name) == final_presets.end() && !inst.meta.is_running() && inst.th.joinable()) {
|
||||
threads_to_join.push_back(std::move(inst.th));
|
||||
}
|
||||
@@ -465,7 +508,15 @@ void server_models::load_models() {
|
||||
|
||||
// erase models no longer in any source
|
||||
for (auto it = mapping.begin(); it != mapping.end(); ) {
|
||||
if (final_presets.find(it->first) == final_presets.end()) {
|
||||
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
|
||||
++it; // download thread is still busy, skip
|
||||
} else if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADED) {
|
||||
// download finished, safe to erase
|
||||
if (it->second.th.joinable()) {
|
||||
it->second.th.join();
|
||||
}
|
||||
it = mapping.erase(it);
|
||||
} else if (final_presets.find(it->first) == final_presets.end()) {
|
||||
SRV_INF("(reload) removing model name=%s (no longer in source)\n", it->first.c_str());
|
||||
GGML_ASSERT(!it->second.th.joinable()); // must have been joined above
|
||||
it = mapping.erase(it);
|
||||
@@ -526,6 +577,7 @@ void server_models::load_models() {
|
||||
for (const auto & [name, preset] : final_presets) {
|
||||
if (mapping.find(name) == mapping.end()) {
|
||||
server_model_meta meta{
|
||||
/* source */ get_source(name),
|
||||
/* preset */ preset,
|
||||
/* name */ name,
|
||||
/* aliases */ {},
|
||||
@@ -538,7 +590,7 @@ void server_models::load_models() {
|
||||
/* exit_code */ 0,
|
||||
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
|
||||
/* multimodal */ mtmd_caps{false, false},
|
||||
/* need_download */ false,
|
||||
// /* need_download */ false,
|
||||
};
|
||||
add_model(std::move(meta));
|
||||
newly_added.push_back(name);
|
||||
@@ -571,6 +623,8 @@ void server_models::load_models() {
|
||||
SRV_INF("(reload) loading new model %s\n", name.c_str());
|
||||
load(name);
|
||||
}
|
||||
|
||||
notify_sse("models_reload", "*");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -597,7 +651,13 @@ bool server_models::has_model(const std::string & name) {
|
||||
}
|
||||
|
||||
std::optional<server_model_meta> server_models::get_meta(const std::string & name) {
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
if (need_reload) {
|
||||
lk.unlock();
|
||||
load_models();
|
||||
lk.lock();
|
||||
}
|
||||
|
||||
auto it = mapping.find(name);
|
||||
if (it != mapping.end()) {
|
||||
return it->second.meta;
|
||||
@@ -683,7 +743,13 @@ static std::vector<char *> to_char_ptr_array(const std::vector<std::string> & ve
|
||||
}
|
||||
|
||||
std::vector<server_model_meta> server_models::get_all_meta() {
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
if (need_reload) {
|
||||
lk.unlock();
|
||||
load_models();
|
||||
lk.lock();
|
||||
}
|
||||
|
||||
std::vector<server_model_meta> result;
|
||||
result.reserve(mapping.size());
|
||||
for (const auto & [name, inst] : mapping) {
|
||||
@@ -770,7 +836,7 @@ void server_models::load(const std::string & name) {
|
||||
throw std::runtime_error("failed to get a port number");
|
||||
}
|
||||
|
||||
inst.subproc = std::make_shared<subprocess_s>();
|
||||
inst.subproc = std::make_shared<server_subproc>();
|
||||
{
|
||||
SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port);
|
||||
|
||||
@@ -792,19 +858,20 @@ void server_models::load(const std::string & name) {
|
||||
// TODO @ngxson : maybe separate stdout and stderr in the future
|
||||
// so that we can use stdout for commands and stderr for logging
|
||||
int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr;
|
||||
int result = subprocess_create_ex(argv.data(), options, envp.data(), inst.subproc.get());
|
||||
inst.subproc->sproc.emplace();
|
||||
int result = subprocess_create_ex(argv.data(), options, envp.data(), &inst.subproc->get());
|
||||
if (result != 0) {
|
||||
throw std::runtime_error("failed to spawn server instance");
|
||||
}
|
||||
|
||||
inst.stdin_file = subprocess_stdin(inst.subproc.get());
|
||||
inst.stdin_file = subprocess_stdin(&inst.subproc->get());
|
||||
}
|
||||
|
||||
// start a thread to manage the child process
|
||||
// captured variables are guaranteed to be destroyed only after the thread is joined
|
||||
inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port, stop_timeout = inst.meta.stop_timeout]() {
|
||||
FILE * stdin_file = subprocess_stdin(child_proc.get());
|
||||
FILE * stdout_file = subprocess_stdout(child_proc.get()); // combined stdout/stderr
|
||||
FILE * stdin_file = subprocess_stdin(&child_proc->get());
|
||||
FILE * stdout_file = subprocess_stdout(&child_proc->get()); // combined stdout/stderr
|
||||
|
||||
std::thread log_thread([&]() {
|
||||
// read stdout/stderr and forward to main server log
|
||||
@@ -834,14 +901,14 @@ void server_models::load(const std::string & name) {
|
||||
return this->stopping_models.find(name) != this->stopping_models.end();
|
||||
};
|
||||
auto should_wake = [&]() {
|
||||
return is_stopping() || !subprocess_alive(child_proc.get());
|
||||
return is_stopping() || !child_proc->is_alive();
|
||||
};
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(this->mutex);
|
||||
this->cv_stop.wait(lk, should_wake);
|
||||
}
|
||||
// child may have already exited (e.g. crashed) — skip shutdown sequence
|
||||
if (!subprocess_alive(child_proc.get())) {
|
||||
if (!child_proc->is_alive()) {
|
||||
return;
|
||||
}
|
||||
SRV_INF("stopping model instance name=%s\n", name.c_str());
|
||||
@@ -859,7 +926,7 @@ void server_models::load(const std::string & name) {
|
||||
if (elapsed >= stop_timeout * 1000) {
|
||||
// timeout, force kill
|
||||
SRV_WRN("force-killing model instance name=%s after %d seconds timeout\n", name.c_str(), stop_timeout);
|
||||
subprocess_terminate(child_proc.get());
|
||||
subprocess_terminate(&child_proc->get());
|
||||
return;
|
||||
}
|
||||
this->cv_stop.wait_for(lk, std::chrono::seconds(1));
|
||||
@@ -884,8 +951,8 @@ void server_models::load(const std::string & name) {
|
||||
|
||||
// get the exit code
|
||||
int exit_code = 0;
|
||||
subprocess_join(child_proc.get(), &exit_code);
|
||||
subprocess_destroy(child_proc.get());
|
||||
subprocess_join(&child_proc->get(), &exit_code);
|
||||
subprocess_destroy(&child_proc->get());
|
||||
|
||||
// update status and exit code
|
||||
this->update_status(name, SERVER_MODEL_STATUS_UNLOADED, exit_code);
|
||||
@@ -896,30 +963,118 @@ void server_models::load(const std::string & name) {
|
||||
{
|
||||
auto & old_instance = mapping[name];
|
||||
// old process should have exited already, but just in case, we clean it up here
|
||||
if (subprocess_alive(old_instance.subproc.get())) {
|
||||
if (old_instance.subproc->is_alive()) {
|
||||
SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
|
||||
subprocess_terminate(old_instance.subproc.get()); // force kill
|
||||
subprocess_terminate(&old_instance.subproc->get()); // force kill
|
||||
}
|
||||
if (old_instance.th.joinable()) {
|
||||
old_instance.th.join();
|
||||
}
|
||||
}
|
||||
|
||||
notify_sse("model_status", name, {
|
||||
{"status", server_model_status_to_string(inst.meta.status)},
|
||||
});
|
||||
|
||||
mapping[name] = std::move(inst);
|
||||
cv.notify_all();
|
||||
}
|
||||
|
||||
// callback for model downloading functionality
|
||||
struct server_models_download_res : public common_download_callback {
|
||||
common_params_model model;
|
||||
common_download_opts opts;
|
||||
|
||||
std::function<bool()> should_stop;
|
||||
std::function<void(const common_download_progress & p)> on_progress;
|
||||
|
||||
bool is_ok = false;
|
||||
|
||||
bool run() {
|
||||
try {
|
||||
common_download_model(model, opts);
|
||||
is_ok = true;
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("download failed for model name=%s: %s\n", model.name.c_str(), e.what());
|
||||
is_ok = false;
|
||||
}
|
||||
return is_ok;
|
||||
}
|
||||
void on_start(const common_download_progress & p) override {
|
||||
on_progress(p);
|
||||
}
|
||||
void on_update(const common_download_progress & p) override {
|
||||
on_progress(p);
|
||||
}
|
||||
void on_done(const common_download_progress &, bool ok) override {
|
||||
is_ok = ok;
|
||||
}
|
||||
bool is_cancelled() const override {
|
||||
return should_stop();
|
||||
}
|
||||
};
|
||||
|
||||
void server_models::download(common_params_model && model, common_download_opts && opts) {
|
||||
std::string name = model.name;
|
||||
GGML_ASSERT(name == model.hf_repo);
|
||||
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
if (mapping.find(name) != mapping.end()) {
|
||||
throw std::runtime_error("model name=" + name + " already exists");
|
||||
}
|
||||
|
||||
instance_t inst;
|
||||
inst.meta.name = name;
|
||||
inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING;
|
||||
inst.subproc = std::make_shared<server_subproc>();
|
||||
|
||||
auto dl = std::make_unique<server_models_download_res>();
|
||||
dl->model = model; // copy
|
||||
dl->opts = opts; // copy
|
||||
|
||||
dl->should_stop = [sp = inst.subproc]() {
|
||||
return sp->stop_download.load(std::memory_order_relaxed);
|
||||
};
|
||||
|
||||
dl->on_progress = [this, name](const common_download_progress & p) {
|
||||
update_download_progress(name, p, false);
|
||||
};
|
||||
|
||||
inst.th = std::thread([this, dl = std::move(dl)]() {
|
||||
dl->opts.callback = dl.get();
|
||||
bool ok = dl->run();
|
||||
SRV_INF("download finished for model name=%s with status=%s\n",
|
||||
dl->model.name.c_str(), ok ? "success" : "failure");
|
||||
update_download_progress(dl->model.name, {}, true, ok);
|
||||
// need_reload is set inside update_download_progress under the mutex;
|
||||
// the next load_models() call will clean up this instance
|
||||
});
|
||||
|
||||
mapping[name] = std::move(inst);
|
||||
notify_sse("status_update", name, {
|
||||
{"status", server_model_status_to_string(SERVER_MODEL_STATUS_DOWNLOADING)},
|
||||
});
|
||||
cv.notify_all();
|
||||
}
|
||||
|
||||
void server_models::unload(const std::string & name) {
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
auto it = mapping.find(name);
|
||||
if (it != mapping.end()) {
|
||||
if (it->second.meta.is_running()) {
|
||||
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
|
||||
SRV_INF("cancelling download for model name=%s\n", name.c_str());
|
||||
it->second.subproc->stop_download.store(true, std::memory_order_relaxed);
|
||||
// for convenience, we wait the status change here
|
||||
wait(lk, name, [](const server_model_meta & new_meta) {
|
||||
return new_meta.status != SERVER_MODEL_STATUS_DOWNLOADING;
|
||||
});
|
||||
} else if (it->second.meta.is_running()) {
|
||||
SRV_INF("stopping model instance name=%s\n", name.c_str());
|
||||
stopping_models.insert(name);
|
||||
if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) {
|
||||
// special case: if model is in loading state, unloading means force-killing it
|
||||
SRV_WRN("model name=%s is still loading, force-killing\n", name.c_str());
|
||||
subprocess_terminate(it->second.subproc.get());
|
||||
subprocess_terminate(&it->second.subproc->get());
|
||||
}
|
||||
cv_stop.notify_all();
|
||||
// status change will be handled by the managing thread
|
||||
@@ -934,7 +1089,10 @@ void server_models::unload_all() {
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
for (auto & [name, inst] : mapping) {
|
||||
if (inst.meta.is_running()) {
|
||||
if (inst.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
|
||||
SRV_INF("cancelling download for model name=%s\n", name.c_str());
|
||||
inst.subproc->stop_download.store(true, std::memory_order_relaxed);
|
||||
} else if (inst.meta.is_running()) {
|
||||
SRV_INF("stopping model instance name=%s\n", name.c_str());
|
||||
stopping_models.insert(name);
|
||||
cv_stop.notify_all();
|
||||
@@ -959,6 +1117,17 @@ void server_models::update_status(const std::string & name, server_model_status
|
||||
meta.status = status;
|
||||
meta.exit_code = exit_code;
|
||||
}
|
||||
// broadcast status change to SSE
|
||||
{
|
||||
json data = {
|
||||
{"status", server_model_status_to_string(status)},
|
||||
};
|
||||
if (status == SERVER_MODEL_STATUS_UNLOADED) {
|
||||
data["exit_code"] = exit_code;
|
||||
}
|
||||
// note: notify_sse doesn't acquire the lock, so no deadlock here
|
||||
notify_sse("status_change", name, data);
|
||||
}
|
||||
cv.notify_all();
|
||||
}
|
||||
|
||||
@@ -985,12 +1154,82 @@ void server_models::update_loaded_info(const std::string & name, std::string & r
|
||||
cv.notify_all();
|
||||
}
|
||||
|
||||
void server_models::wait_until_loading_finished(const std::string & name) {
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
cv.wait(lk, [this, &name]() {
|
||||
void server_models::update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok) {
|
||||
json curr;
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
auto it = mapping.find(name);
|
||||
if (it != mapping.end()) {
|
||||
return it->second.meta.status != SERVER_MODEL_STATUS_LOADING;
|
||||
if (done) {
|
||||
// mark the instance to be erased on next load_models() call
|
||||
it->second.meta.status = SERVER_MODEL_STATUS_DOWNLOADED;
|
||||
need_reload = true;
|
||||
} else {
|
||||
json & info = it->second.meta.loaded_info;
|
||||
if (!info.contains("progress")) {
|
||||
info["progress"] = json{};
|
||||
}
|
||||
info["progress"][progress.url] = {
|
||||
{"done", progress.downloaded},
|
||||
{"total", progress.total},
|
||||
};
|
||||
curr = it->second.meta.loaded_info; // copy
|
||||
}
|
||||
}
|
||||
}
|
||||
if (done) {
|
||||
cv.notify_all(); // notify in case unload() is waiting for download to be cancelled
|
||||
notify_sse(ok ? "download_finished" : "download_failed", name, {});
|
||||
} else {
|
||||
notify_sse("download_progress", name, curr);
|
||||
}
|
||||
}
|
||||
|
||||
bool server_models::remove(const std::string & name) {
|
||||
auto meta = get_meta(name);
|
||||
|
||||
if (!meta.has_value()) {
|
||||
throw std::runtime_error("model name=" + name + " is not found");
|
||||
}
|
||||
if (meta->source != SERVER_MODEL_SOURCE_CACHE) {
|
||||
throw std::runtime_error("model name=" + name + " is not removable (not from cache)");
|
||||
}
|
||||
|
||||
unload(name); // cancel download or stop running instance
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
// a cancelled download lands on DOWNLOADED; a stopped instance lands on UNLOADED
|
||||
wait(lk, name, [](const server_model_meta & new_meta) {
|
||||
return new_meta.status == SERVER_MODEL_STATUS_UNLOADED
|
||||
|| new_meta.status == SERVER_MODEL_STATUS_DOWNLOADED;
|
||||
});
|
||||
// join before erasing - after status reaches UNLOADED/DOWNLOADED the thread no
|
||||
// longer acquires this mutex, so joining while holding it is safe
|
||||
if (mapping[name].th.joinable()) {
|
||||
mapping[name].th.join();
|
||||
}
|
||||
// remove the model from disk (hold lock to prevent concurrent load)
|
||||
bool ok = common_download_remove(name);
|
||||
if (ok) {
|
||||
mapping.erase(name);
|
||||
}
|
||||
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "failed");
|
||||
notify_sse("model_remove", name, {});
|
||||
return ok;
|
||||
}
|
||||
}
|
||||
|
||||
void server_models::wait(const std::string & name, std::function<bool(const server_model_meta &)> predicate) {
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
wait(lk, name, predicate);
|
||||
}
|
||||
|
||||
void server_models::wait(std::unique_lock<std::mutex> & lk, const std::string & name, std::function<bool(const server_model_meta &)> predicate) {
|
||||
cv.wait(lk, [this, &name, &predicate]() {
|
||||
auto it = mapping.find(name);
|
||||
if (it != mapping.end()) {
|
||||
return predicate(it->second.meta);
|
||||
|
||||
}
|
||||
return false;
|
||||
});
|
||||
@@ -1014,10 +1253,15 @@ bool server_models::ensure_model_ready(const std::string & name) {
|
||||
|
||||
// wait for loading to complete
|
||||
SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str());
|
||||
wait_until_loading_finished(name);
|
||||
wait(name, [&meta](const server_model_meta & new_meta) {
|
||||
if (new_meta.status != SERVER_MODEL_STATUS_LOADING) {
|
||||
meta = new_meta; // update meta for final check after wait
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
// check final status
|
||||
meta = get_meta(name);
|
||||
if (!meta.has_value() || meta->is_failed()) {
|
||||
throw std::runtime_error("model name=" + name + " failed to load");
|
||||
}
|
||||
@@ -1111,6 +1355,42 @@ void server_models::notify_router_sleeping_state(bool is_sleeping) {
|
||||
// server_models_routes
|
||||
//
|
||||
|
||||
// RAII wrapper similar to server_response_reader, but doesn't use server_queue
|
||||
static std::atomic<int> sse_client_id_counter = 0;
|
||||
struct server_models_sse_client {
|
||||
server_response & queue_results;
|
||||
int client_id;
|
||||
server_models_sse_client(server_response & q)
|
||||
: queue_results(q), client_id(sse_client_id_counter.fetch_add(1, std::memory_order_relaxed)) {
|
||||
SRV_DBG("new SSE client connected, assigned client_id=%d\n", client_id);
|
||||
queue_results.add_waiting_task_id(client_id);
|
||||
}
|
||||
~server_models_sse_client() {
|
||||
SRV_DBG("SSE client disconnected, removing client_id=%d\n", client_id);
|
||||
queue_results.remove_waiting_task_id(client_id);
|
||||
}
|
||||
|
||||
// return nullptr if should_stop() is true before receiving a result
|
||||
// note: if one error is received, it will stop further processing and return error result
|
||||
server_task_result_ptr next(const std::function<bool()> & should_stop) {
|
||||
while (true) {
|
||||
static const int http_polling_seconds = 1; // check should_stop every 1 second
|
||||
server_task_result_ptr result = queue_results.recv_with_timeout({client_id}, http_polling_seconds);
|
||||
if (result == nullptr) {
|
||||
// timeout, check stop condition
|
||||
if (should_stop()) {
|
||||
return nullptr;
|
||||
}
|
||||
// continue waiting otherwise
|
||||
} else {
|
||||
SRV_DBG("recv result for client_id=%d: %s\n", client_id, safe_json_to_str(result->to_json()).c_str());
|
||||
return result;
|
||||
}
|
||||
}
|
||||
// should not reach here
|
||||
}
|
||||
};
|
||||
|
||||
static void res_ok(std::unique_ptr<server_http_res> & res, const json & response_data) {
|
||||
res->status = 200;
|
||||
res->data = safe_json_to_str(response_data);
|
||||
@@ -1274,7 +1554,9 @@ void server_models_routes::init_routes() {
|
||||
{"created", t}, // for OAI-compat
|
||||
{"status", status},
|
||||
{"architecture", architecture},
|
||||
{"need_download", meta.need_download},
|
||||
{"source", server_model_source_to_string(meta.source)},
|
||||
{"can_remove", meta.source == SERVER_MODEL_SOURCE_CACHE},
|
||||
// {"need_download", meta.need_download},
|
||||
// TODO: add other fields, may require reading GGUF metadata
|
||||
};
|
||||
|
||||
@@ -1312,6 +1594,87 @@ void server_models_routes::init_routes() {
|
||||
res_ok(res, {{"success", true}});
|
||||
return res;
|
||||
};
|
||||
|
||||
this->get_router_models_sse = [this](const server_http_req & req) {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->status = 200;
|
||||
res->content_type = "text/event-stream";
|
||||
auto sse_client = std::make_shared<server_models_sse_client>(models.sse);
|
||||
res->next = [this, sse_client, &req](std::string & output) -> bool {
|
||||
auto result = sse_client->next([&]() {
|
||||
return stopping.load(std::memory_order_relaxed) || req.should_stop();
|
||||
});
|
||||
if (result == nullptr) {
|
||||
return false; // client disconnected or should_stop
|
||||
}
|
||||
output = "data: " + safe_json_to_str(result->to_json()) + "\n\n";
|
||||
return true; // listen for the next event
|
||||
};
|
||||
return res;
|
||||
};
|
||||
|
||||
this->post_router_models = [this](const server_http_req & req) {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
|
||||
json body = json::parse(req.body);
|
||||
std::string name = json_value(body, "model", std::string());
|
||||
if (name.empty()) {
|
||||
throw std::invalid_argument("model must be a non-empty string");
|
||||
}
|
||||
|
||||
common_params_model model;
|
||||
common_download_opts opts;
|
||||
|
||||
model.name = name;
|
||||
model.hf_repo = name;
|
||||
opts.bearer_token = params.hf_token;
|
||||
opts.download_mmproj = true;
|
||||
opts.download_mtp = true;
|
||||
|
||||
// first, only check if the model is valid and can be downloaded
|
||||
opts.skip_download = true;
|
||||
bool ok = false;
|
||||
try {
|
||||
auto validation = common_download_model(model, opts);
|
||||
ok = !validation.model_path.empty();
|
||||
} catch (const common_skip_download_exception &) {
|
||||
// model is valid and will be downloaded
|
||||
ok = true;
|
||||
} catch (...) {
|
||||
SRV_ERR("unknown error while validating model '%s'\n", name.c_str());
|
||||
// other exceptions will be handled by the outer ex_wrapper()
|
||||
throw;
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
throw std::invalid_argument("model validation failed, unable to download");
|
||||
}
|
||||
|
||||
// then, proceed with the actual download
|
||||
opts.skip_download = false;
|
||||
SRV_INF("starting download for model '%s'\n", name.c_str());
|
||||
models.download(std::move(model), std::move(opts));
|
||||
|
||||
res_ok(res, {{"success", true}});
|
||||
return res;
|
||||
};
|
||||
|
||||
this->del_router_models = [this](const server_http_req & req) {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
|
||||
std::string name = req.get_param("model");
|
||||
if (name.empty()) {
|
||||
throw std::invalid_argument("model must be a non-empty string");
|
||||
}
|
||||
|
||||
bool ok = models.remove(name);
|
||||
if (!ok) {
|
||||
throw std::runtime_error("failed to remove model '" + name + "'");
|
||||
}
|
||||
|
||||
res_ok(res, {{"success", true}});
|
||||
return res;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "download.h"
|
||||
#include "preset.h"
|
||||
#include "server-common.h"
|
||||
#include "server-http.h"
|
||||
#include "server-queue.h"
|
||||
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
@@ -14,6 +16,8 @@
|
||||
/**
|
||||
* state diagram:
|
||||
*
|
||||
* DOWNLOADING ──► DOWNLOADED ──► (replaced by new instance)
|
||||
*
|
||||
* UNLOADED ──► LOADING ──► LOADED ◄──── SLEEPING
|
||||
* ▲ │ │ ▲
|
||||
* └───failed───┘ │ │
|
||||
@@ -22,39 +26,43 @@
|
||||
*/
|
||||
enum server_model_status {
|
||||
// TODO: also add downloading state when the logic is added
|
||||
SERVER_MODEL_STATUS_DOWNLOADING,
|
||||
SERVER_MODEL_STATUS_DOWNLOADED,
|
||||
SERVER_MODEL_STATUS_UNLOADED,
|
||||
SERVER_MODEL_STATUS_LOADING,
|
||||
SERVER_MODEL_STATUS_LOADED,
|
||||
SERVER_MODEL_STATUS_SLEEPING
|
||||
};
|
||||
|
||||
static server_model_status server_model_status_from_string(const std::string & status_str) {
|
||||
if (status_str == "unloaded") {
|
||||
return SERVER_MODEL_STATUS_UNLOADED;
|
||||
}
|
||||
if (status_str == "loading") {
|
||||
return SERVER_MODEL_STATUS_LOADING;
|
||||
}
|
||||
if (status_str == "loaded") {
|
||||
return SERVER_MODEL_STATUS_LOADED;
|
||||
}
|
||||
if (status_str == "sleeping") {
|
||||
return SERVER_MODEL_STATUS_SLEEPING;
|
||||
}
|
||||
throw std::runtime_error("invalid server model status");
|
||||
}
|
||||
enum server_model_source {
|
||||
SERVER_MODEL_SOURCE_PRESET,
|
||||
SERVER_MODEL_SOURCE_MODELS_DIR,
|
||||
SERVER_MODEL_SOURCE_CACHE,
|
||||
};
|
||||
|
||||
static std::string server_model_status_to_string(server_model_status status) {
|
||||
switch (status) {
|
||||
case SERVER_MODEL_STATUS_UNLOADED: return "unloaded";
|
||||
case SERVER_MODEL_STATUS_LOADING: return "loading";
|
||||
case SERVER_MODEL_STATUS_LOADED: return "loaded";
|
||||
case SERVER_MODEL_STATUS_SLEEPING: return "sleeping";
|
||||
default: return "unknown";
|
||||
case SERVER_MODEL_STATUS_DOWNLOADING: return "downloading";
|
||||
case SERVER_MODEL_STATUS_DOWNLOADED: return "downloaded";
|
||||
case SERVER_MODEL_STATUS_UNLOADED: return "unloaded";
|
||||
case SERVER_MODEL_STATUS_LOADING: return "loading";
|
||||
case SERVER_MODEL_STATUS_LOADED: return "loaded";
|
||||
case SERVER_MODEL_STATUS_SLEEPING: return "sleeping";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
static std::string server_model_source_to_string(server_model_source source) {
|
||||
switch (source) {
|
||||
case SERVER_MODEL_SOURCE_PRESET: return "preset";
|
||||
case SERVER_MODEL_SOURCE_MODELS_DIR: return "models_dir";
|
||||
case SERVER_MODEL_SOURCE_CACHE: return "cache";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
struct server_model_meta {
|
||||
server_model_source source = SERVER_MODEL_SOURCE_CACHE;
|
||||
common_preset preset;
|
||||
std::string name;
|
||||
std::set<std::string> aliases; // additional names that resolve to this model
|
||||
@@ -63,11 +71,11 @@ struct server_model_meta {
|
||||
server_model_status status = SERVER_MODEL_STATUS_UNLOADED;
|
||||
int64_t last_used = 0; // for LRU unloading
|
||||
std::vector<std::string> args; // args passed to the model instance, will be populated by render_args()
|
||||
json loaded_info; // info to be reflected via /v1/models endpoint
|
||||
json loaded_info; // info to be reflected via /v1/models endpoint ; if in DOWNLOADING state, it should contain download progress info
|
||||
int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED)
|
||||
int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown
|
||||
mtmd_caps multimodal; // multimodal capabilities
|
||||
bool need_download = false; // whether the model needs to be downloaded before loading
|
||||
// bool need_download = false; // whether the model needs to be downloaded before loading // TODO @ngxson: implement this
|
||||
|
||||
bool is_ready() const {
|
||||
return status == SERVER_MODEL_STATUS_LOADED;
|
||||
@@ -85,12 +93,15 @@ struct server_model_meta {
|
||||
void update_caps();
|
||||
};
|
||||
|
||||
struct subprocess_s;
|
||||
struct server_models_routes;
|
||||
struct server_subproc; // defined in server-models.cpp
|
||||
|
||||
struct server_models {
|
||||
friend struct server_models_routes;
|
||||
|
||||
private:
|
||||
struct instance_t {
|
||||
std::shared_ptr<subprocess_s> subproc; // shared between main thread and monitoring thread
|
||||
std::shared_ptr<server_subproc> subproc; // shared between main thread and monitoring thread
|
||||
std::thread th;
|
||||
server_model_meta meta;
|
||||
FILE * stdin_file = nullptr;
|
||||
@@ -107,6 +118,9 @@ private:
|
||||
// set to true while load_models() is executing a reload; load() will wait until clear
|
||||
bool is_reloading = false;
|
||||
|
||||
// if true, the next get_meta() will trigger a reload of model list
|
||||
bool need_reload = false;
|
||||
|
||||
common_preset_context ctx_preset;
|
||||
|
||||
common_params base_params;
|
||||
@@ -122,9 +136,14 @@ private:
|
||||
// not thread-safe, caller must hold mutex
|
||||
void add_model(server_model_meta && meta);
|
||||
|
||||
// notify SSE clients
|
||||
void notify_sse(const std::string & event, const std::string & model_id, const json & data = nullptr);
|
||||
|
||||
public:
|
||||
server_models(const common_params & params, int argc, char ** argv);
|
||||
|
||||
server_response sse; // for real-time updates via SSE endpoint
|
||||
|
||||
// (re-)load the list of models from various sources and prepare the metadata mapping
|
||||
// - if this is called the first time, simply populate the metadata
|
||||
// - if this is called subsequently (e.g. when refreshing from disk):
|
||||
@@ -147,13 +166,24 @@ public:
|
||||
void unload(const std::string & name);
|
||||
void unload_all();
|
||||
|
||||
// download a new model, progress is reported via SSE
|
||||
// to stop the download, call unload()
|
||||
void download(common_params_model && model, common_download_opts && opts);
|
||||
|
||||
// update the status of a model instance (thread-safe)
|
||||
void update_status(const std::string & name, server_model_status status, int exit_code);
|
||||
void update_loaded_info(const std::string & name, std::string & raw_info);
|
||||
void update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok = true);
|
||||
|
||||
// remove a cache model from disk and update the list (thread-safe)
|
||||
// note: only cache models can be removed; returns false if the model doesn't exist or is not a cache model
|
||||
bool remove(const std::string & name);
|
||||
|
||||
// wait until the model instance is fully loaded (thread-safe)
|
||||
// note: predicate is called while holding the lock
|
||||
// return when the model no longer in "loading" state
|
||||
void wait_until_loading_finished(const std::string & name);
|
||||
void wait(const std::string & name, std::function<bool(const server_model_meta &)> predicate);
|
||||
void wait(std::unique_lock<std::mutex> & lk, const std::string & name, std::function<bool(const server_model_meta &)> predicate);
|
||||
|
||||
// ensure the model is in ready state (thread-safe)
|
||||
// return false if model is ready
|
||||
@@ -176,8 +206,9 @@ public:
|
||||
|
||||
struct server_models_routes {
|
||||
common_params params;
|
||||
json ui_settings = json::object(); // Primary: new name
|
||||
json webui_settings = json::object(); // Deprecated: use ui_settings (kept for compat)
|
||||
json ui_settings = json::object(); // Primary: new name
|
||||
json webui_settings = json::object(); // Deprecated: use ui_settings (kept for compat)
|
||||
std::atomic<bool> stopping = false; // for graceful disconnecting SSE clients during shutdown
|
||||
server_models models;
|
||||
server_models_routes(const common_params & params, int argc, char ** argv)
|
||||
: params(params), models(params, argc, argv) {
|
||||
@@ -206,6 +237,10 @@ struct server_models_routes {
|
||||
server_http_context::handler_t get_router_models;
|
||||
server_http_context::handler_t post_router_models_load;
|
||||
server_http_context::handler_t post_router_models_unload;
|
||||
// management API
|
||||
server_http_context::handler_t get_router_models_sse;
|
||||
server_http_context::handler_t post_router_models;
|
||||
server_http_context::handler_t del_router_models;
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -331,6 +331,17 @@ void server_response::send(server_task_result_ptr && result) {
|
||||
}
|
||||
}
|
||||
|
||||
void server_response::broadcast(server_task_result_ptr && result) {
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
for (const auto & id_task : waiting_task_ids) {
|
||||
RES_DBG("task id = %d pushed to result queue\n", id_task);
|
||||
server_task_result_ptr res_copy(result->clone());
|
||||
res_copy->id = id_task; // override id with target task id
|
||||
queue_results.emplace_back(std::move(res_copy));
|
||||
}
|
||||
condition_results.notify_all();
|
||||
}
|
||||
|
||||
void server_response::terminate() {
|
||||
running = false;
|
||||
condition_results.notify_all();
|
||||
|
||||
@@ -154,11 +154,15 @@ public:
|
||||
// Send a new result to a waiting id_task
|
||||
void send(server_task_result_ptr && result);
|
||||
|
||||
// broadcast a new result to all waiting tasks
|
||||
// (used by router mode)
|
||||
void broadcast(server_task_result_ptr && result);
|
||||
|
||||
// terminate the waiting loop
|
||||
void terminate();
|
||||
};
|
||||
|
||||
// utility class to make working with server_queue and server_response easier
|
||||
// RAII wrapper to make working with server_queue and server_response easier
|
||||
// it provides a generator-like API for server responses
|
||||
// support pooling connection state and aggregating multiple results
|
||||
struct server_response_reader {
|
||||
|
||||
@@ -312,6 +312,9 @@ struct server_task_result {
|
||||
}
|
||||
virtual json to_json() = 0;
|
||||
virtual ~server_task_result() = default;
|
||||
virtual server_task_result * clone() const {
|
||||
GGML_ABORT("not implemented for this task type");
|
||||
}
|
||||
};
|
||||
|
||||
// using shared_ptr for polymorphism of server_task_result
|
||||
@@ -649,3 +652,12 @@ struct server_prompt_cache {
|
||||
|
||||
void update();
|
||||
};
|
||||
|
||||
// used exclusively by router mode
|
||||
struct server_task_result_router : server_task_result {
|
||||
json data;
|
||||
virtual json to_json() override { return data; }
|
||||
virtual server_task_result * clone() const override {
|
||||
return new server_task_result_router(*this);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -174,8 +174,11 @@ int llama_server(int argc, char ** argv) {
|
||||
routes.get_props = models_routes->get_router_props;
|
||||
routes.get_models = models_routes->get_router_models;
|
||||
|
||||
ctx_http.post("/models", ex_wrapper(models_routes->post_router_models));
|
||||
ctx_http.post("/models/load", ex_wrapper(models_routes->post_router_models_load));
|
||||
ctx_http.post("/models/unload", ex_wrapper(models_routes->post_router_models_unload));
|
||||
ctx_http.get ("/models/sse", ex_wrapper(models_routes->get_router_models_sse));
|
||||
ctx_http.del ("/models", ex_wrapper(models_routes->del_router_models));
|
||||
}
|
||||
|
||||
ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
|
||||
@@ -261,6 +264,7 @@ int llama_server(int argc, char ** argv) {
|
||||
clean_up = [&models_routes]() {
|
||||
SRV_INF("%s: cleaning up before exit...\n", __func__);
|
||||
if (models_routes.has_value()) {
|
||||
models_routes->stopping.store(true); // maybe redundant, but just to be safe
|
||||
models_routes->models.unload_all();
|
||||
}
|
||||
llama_backend_free();
|
||||
@@ -274,6 +278,10 @@ int llama_server(int argc, char ** argv) {
|
||||
ctx_http.is_ready.store(true);
|
||||
|
||||
shutdown_handler = [&](int) {
|
||||
if (models_routes.has_value()) {
|
||||
// important to disconnect any SSE clients
|
||||
models_routes->stopping.store(true);
|
||||
}
|
||||
ctx_http.stop();
|
||||
};
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import threading
|
||||
import pytest
|
||||
from utils import *
|
||||
|
||||
@@ -253,3 +254,98 @@ def test_router_reload_models():
|
||||
assert "model-reload-c" in ids, "newly added model should appear"
|
||||
finally:
|
||||
os.remove(preset_path)
|
||||
|
||||
|
||||
MODEL_DOWNLOAD_ID = "ggml-org/test-model-router-download:F16"
|
||||
MODEL_DOWNLOAD_TIMEOUT = 300
|
||||
|
||||
|
||||
def _listen_sse(server: ServerProcess, collected: list, stop: threading.Event):
|
||||
"""Collect /models/sse events into `collected` until `stop` is set."""
|
||||
url = f"http://{server.server_host}:{server.server_port}/models/sse"
|
||||
try:
|
||||
with requests.get(url, stream=True, timeout=MODEL_DOWNLOAD_TIMEOUT) as resp:
|
||||
for line_bytes in resp.iter_lines():
|
||||
if stop.is_set():
|
||||
break
|
||||
line = line_bytes.decode("utf-8")
|
||||
if line.startswith("data: "):
|
||||
collected.append(json.loads(line[6:]))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _wait_for_sse_event(collected: list, event_type: str, model: str, timeout: int) -> bool:
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
if any(e.get("event") == event_type and e.get("model") == model for e in collected):
|
||||
return True
|
||||
time.sleep(0.5)
|
||||
return False
|
||||
|
||||
|
||||
def test_router_download_model():
|
||||
"""Case 1: download a model, verify SSE events and GET /models."""
|
||||
global server
|
||||
server.start()
|
||||
|
||||
# Ensure the model is not present before we start
|
||||
server.make_request("DELETE", f"/models?model={MODEL_DOWNLOAD_ID}")
|
||||
|
||||
sse_events: list = []
|
||||
stop = threading.Event()
|
||||
sse_thread = threading.Thread(
|
||||
target=_listen_sse, args=(server, sse_events, stop), daemon=True
|
||||
)
|
||||
sse_thread.start()
|
||||
|
||||
# Trigger the download
|
||||
res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID})
|
||||
assert res.status_code == 200
|
||||
assert res.body.get("success") is True
|
||||
|
||||
# Wait for download_finished SSE event
|
||||
finished = _wait_for_sse_event(
|
||||
sse_events, "download_finished", MODEL_DOWNLOAD_ID, MODEL_DOWNLOAD_TIMEOUT
|
||||
)
|
||||
stop.set()
|
||||
|
||||
assert finished, "Never received download_finished SSE event"
|
||||
assert any(
|
||||
e.get("event") == "download_progress" and e.get("model") == MODEL_DOWNLOAD_ID
|
||||
for e in sse_events
|
||||
), "No download_progress events received"
|
||||
|
||||
# Model should now appear in GET /models
|
||||
ids = _get_model_ids(is_reload=False)
|
||||
assert MODEL_DOWNLOAD_ID in ids, f"{MODEL_DOWNLOAD_ID} not found in /models after download"
|
||||
|
||||
|
||||
def test_router_delete_model():
|
||||
"""Case 2: delete the downloaded model, verify it disappears from GET /models."""
|
||||
global server
|
||||
server.start()
|
||||
|
||||
# Ensure the model exists (download it if needed)
|
||||
if MODEL_DOWNLOAD_ID not in _get_model_ids(is_reload=False):
|
||||
res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID})
|
||||
assert res.status_code == 200
|
||||
sse_events: list = []
|
||||
stop = threading.Event()
|
||||
threading.Thread(
|
||||
target=_listen_sse, args=(server, sse_events, stop), daemon=True
|
||||
).start()
|
||||
finished = _wait_for_sse_event(
|
||||
sse_events, "download_finished", MODEL_DOWNLOAD_ID, MODEL_DOWNLOAD_TIMEOUT
|
||||
)
|
||||
stop.set()
|
||||
assert finished, "Model did not finish downloading before delete test"
|
||||
|
||||
# Delete the model
|
||||
del_res = server.make_request("DELETE", f"/models?model={MODEL_DOWNLOAD_ID}")
|
||||
assert del_res.status_code == 200
|
||||
assert del_res.body.get("success") is True
|
||||
|
||||
# Model should no longer appear in GET /models
|
||||
ids = _get_model_ids(is_reload=False)
|
||||
assert MODEL_DOWNLOAD_ID not in ids, f"{MODEL_DOWNLOAD_ID} still present after deletion"
|
||||
|
||||
@@ -340,6 +340,9 @@ class ServerProcess:
|
||||
elif method == "POST":
|
||||
response = requests.post(url, headers=headers, json=data, timeout=timeout)
|
||||
parse_body = True
|
||||
elif method == "DELETE":
|
||||
response = requests.delete(url, headers=headers, timeout=timeout)
|
||||
parse_body = True
|
||||
elif method == "OPTIONS":
|
||||
response = requests.options(url, headers=headers, timeout=timeout)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user