mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-22 21:57:40 +02:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9c0ac887f3 | |||
| 721354fbdf | |||
| 6ee0f65793 | |||
| 099b579acb | |||
| f8cc15f163 | |||
| 37957e8531 | |||
| d0f9d2e5ac |
+10
-4
@@ -396,7 +396,7 @@ static bool parse_bool_value(const std::string & value) {
|
||||
// CLI argument parsing functions
|
||||
//
|
||||
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex) {
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex, common_download_callback * callback) {
|
||||
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
|
||||
params.speculative.types.end(),
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
|
||||
@@ -408,6 +408,10 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex)
|
||||
opts.download_mtp = spec_type_draft_mtp;
|
||||
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
|
||||
|
||||
if (callback) {
|
||||
opts.callback = callback;
|
||||
}
|
||||
|
||||
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
|
||||
// so we should not auto-discover mtp/mmproj siblings for them
|
||||
common_download_opts sub_opts = opts;
|
||||
@@ -584,8 +588,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
|
||||
}
|
||||
|
||||
// export_graph_ops loads only metadata
|
||||
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
|
||||
const bool skip_model_download =
|
||||
// server will call common_params_handle_models() later, so we skip it here
|
||||
ctx_arg.ex == LLAMA_EXAMPLE_SERVER ||
|
||||
// export_graph_ops loads only metadata
|
||||
ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
|
||||
|
||||
if (!skip_model_download) {
|
||||
// handle model and download
|
||||
@@ -594,7 +601,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
// model is required (except for server)
|
||||
// TODO @ngxson : maybe show a list of available models in CLI in this case
|
||||
if (params.model.path.empty()
|
||||
&& ctx_arg.ex != LLAMA_EXAMPLE_SERVER
|
||||
&& !params.usage
|
||||
&& !params.completion) {
|
||||
throw std::invalid_argument("error: --model is required\n");
|
||||
|
||||
+5
-1
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "download.h"
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
@@ -133,7 +134,10 @@ void common_params_add_preset_options(std::vector<common_arg> & args);
|
||||
// return true if the model is ready to use
|
||||
// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc)
|
||||
// if params.skip_download is true, no downloads will be attempted. return false if the model is invalid or missing (e.g. ETag check failed)
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex);
|
||||
bool common_params_handle_models(
|
||||
common_params & params,
|
||||
llama_example curr_ex,
|
||||
common_download_callback * callback = nullptr);
|
||||
|
||||
// initialize argument parser context - used by test-arg-parser and preset
|
||||
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
||||
|
||||
@@ -293,6 +293,11 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
|
||||
(sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2,
|
||||
ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0),
|
||||
ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_BF16) {
|
||||
op()((const sycl::ext::oneapi::bfloat16 *) src0->data, (const float *) src1->data,
|
||||
(sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2,
|
||||
ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0),
|
||||
ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);
|
||||
#endif
|
||||
} else {
|
||||
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type),
|
||||
|
||||
@@ -43,14 +43,44 @@ static __dpct_inline__ T op_sgn(T x) {
|
||||
return x > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_abs(T x) {
|
||||
return sycl::fabs(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::fabs(x); // or experimental namespace if needed
|
||||
} else {
|
||||
return sycl::fabs(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_expm1(T x) {
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return static_cast<sycl::ext::oneapi::bfloat16>(
|
||||
sycl::expm1(static_cast<float>(x))
|
||||
);
|
||||
} else {
|
||||
return sycl::expm1(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_elu(T x) {
|
||||
return (x > static_cast<T>(0.f)) ? x : sycl::expm1(x);
|
||||
return (x > static_cast<T>(0.f)) ? x : op_expm1(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_tanh(T x) {
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
constexpr int ver = __INTEL_LLVM_COMPILER;
|
||||
#if defined(__INTEL_LLVM_COMPILER) && (__INTEL_LLVM_COMPILER >= 20260000)
|
||||
return sycl::ext::oneapi::experimental::tanh(x);
|
||||
#else
|
||||
return static_cast<T>(sycl::tanh(static_cast<float>(x)));
|
||||
#endif
|
||||
} else {
|
||||
return sycl::tanh(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@@ -59,74 +89,106 @@ static __dpct_inline__ T op_gelu(T x) {
|
||||
const T SQRT_2_OVER_PI = static_cast<T>(0.79788456080286535587989211986876f);
|
||||
return static_cast<T>(0.5f) * x *
|
||||
(static_cast<T>(1.0f) +
|
||||
sycl::tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
|
||||
op_tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_exp(T x) {
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::exp(x);
|
||||
} else {
|
||||
return sycl::exp(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_silu(T x) {
|
||||
return x / (static_cast<T>(1.0f) + sycl::native::exp(-x));
|
||||
return x / (static_cast<T>(1.0f) + op_exp(-x));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_gelu_quick(T x) {
|
||||
const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
|
||||
return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x)));
|
||||
static __dpct_inline__ T op_erf(T x) {
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return static_cast<sycl::ext::oneapi::bfloat16>(
|
||||
sycl::erf(static_cast<float>(x))
|
||||
);
|
||||
} else {
|
||||
return sycl::erf(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_gelu_erf(T x) {
|
||||
const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
|
||||
return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + sycl::erf(x * SQRT_2_INV));
|
||||
return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + op_erf(x * SQRT_2_INV));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_tanh(T x) {
|
||||
return sycl::tanh(x);
|
||||
static __dpct_inline__ T op_gelu_quick(T x) {
|
||||
const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
|
||||
return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + op_exp(GELU_QUICK_COEF_LOCAL * x)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_relu(T x) {
|
||||
return sycl::fmax(x, static_cast<T>(0));
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::fmax(x, static_cast<T>(0));
|
||||
} else {
|
||||
return sycl::fmax(x, static_cast<T>(0));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_sigmoid(T x) {
|
||||
return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(-x));
|
||||
return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + op_exp(-x));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_sqrt(T x) {
|
||||
return sycl::sqrt(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::sqrt(x);
|
||||
} else {
|
||||
return sycl::sqrt(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_sin(T x) {
|
||||
return sycl::sin(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::sin(x);
|
||||
} else {
|
||||
return sycl::sin(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_cos(T x) {
|
||||
return sycl::cos(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::cos(x);
|
||||
} else {
|
||||
return sycl::cos(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_hardsigmoid(T x) {
|
||||
return sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::fmin(
|
||||
static_cast<T>(1.0f), sycl::ext::oneapi::experimental::fmax(
|
||||
static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
} else {
|
||||
return sycl::fmin(static_cast<T>(1.0f),
|
||||
sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_hardswish(T x) {
|
||||
return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_exp(T x) {
|
||||
return sycl::exp(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_expm1(T x) {
|
||||
return sycl::expm1(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return x * sycl::ext::oneapi::experimental::fmin(static_cast<T>(1.0f), sycl::ext::oneapi::experimental::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
} else {
|
||||
return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@@ -134,13 +196,17 @@ static __dpct_inline__ T op_log(T x) {
|
||||
if (x <= static_cast<T>(0)) {
|
||||
return neg_infinity<T>();
|
||||
}
|
||||
return sycl::log(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::log(x);
|
||||
} else {
|
||||
return sycl::log(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_softplus(T x) {
|
||||
const float xf = (float) x;
|
||||
const float ax = sycl::fabs(xf);
|
||||
const float ax = op_abs(xf);
|
||||
const float m = sycl::fmax(xf, 0.0f);
|
||||
const float y = m + sycl::log1p(sycl::exp(-ax));
|
||||
return (T) y;
|
||||
@@ -159,8 +225,14 @@ static __dpct_inline__ T op_step(T x) {
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_leaky_relu(T x, float negative_slope) {
|
||||
T neg_slope_T = static_cast<T>(negative_slope);
|
||||
return sycl::fmax(x, static_cast<T>(0)) +
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::fmax(x, static_cast<T>(0)) +
|
||||
sycl::ext::oneapi::experimental::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
|
||||
|
||||
} else {
|
||||
return sycl::fmax(x, static_cast<T>(0)) +
|
||||
sycl::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@@ -175,22 +247,40 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_floor(T x) {
|
||||
return sycl::floor(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::floor(x);
|
||||
} else {
|
||||
return sycl::floor(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_ceil(T x) {
|
||||
return sycl::ceil(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::ceil(x);
|
||||
} else {
|
||||
return sycl::ceil(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_round(T x) {
|
||||
return sycl::round(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return static_cast<sycl::ext::oneapi::bfloat16>(
|
||||
sycl::round(static_cast<float>(x))
|
||||
);
|
||||
} else {
|
||||
return sycl::round(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_trunc(T x) {
|
||||
return sycl::trunc(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::trunc(x);
|
||||
} else {
|
||||
return sycl::trunc(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename F>
|
||||
@@ -339,7 +429,7 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||
const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
|
||||
});
|
||||
}
|
||||
@@ -354,8 +444,8 @@ static void arange_kernel(T * dst, const int k, T start, T step,
|
||||
|
||||
template<typename KernelInvoker, typename... Args>
|
||||
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16 || dst->src[0]->type == GGML_TYPE_BF16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_BF16);
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
@@ -367,6 +457,14 @@ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx,
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
#ifdef GGML_SYCL_HAS_BF16
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
auto data_pts = cast_data<sycl::ext::oneapi::bfloat16>(dst);
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
auto data_pts = cast_data<float>(dst);
|
||||
@@ -480,7 +578,7 @@ static inline void ggml_sycl_op_unary(
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||
sycl::range<1>(256)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_generic_kernel(
|
||||
src, dst_ptr, k_elements,
|
||||
ne0, ne1, ne2, ne3,
|
||||
@@ -508,7 +606,7 @@ static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_ten
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
arange_kernel(dst_ptr, k, start, step, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -602,7 +700,7 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_log_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -640,7 +738,7 @@ static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tenso
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_sqrt_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -653,7 +751,7 @@ static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_sin_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -666,7 +764,7 @@ static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_cos_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -681,7 +779,7 @@ static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_leaky_relu_kernel(src, dst_ptr, k_elements, slope, item_ct1);
|
||||
});
|
||||
}, negative_slope);
|
||||
@@ -694,7 +792,7 @@ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SQR_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_sqr_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -711,7 +809,7 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1);
|
||||
});
|
||||
}, min_val, max_val);
|
||||
@@ -774,7 +872,8 @@ static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tens
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -785,7 +884,8 @@ static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tens
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -796,7 +896,8 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -811,7 +912,6 @@ __dpct_inline__ float ggml_sycl_op_swiglu_oai_single(float x, float g, float alp
|
||||
return out_glu;
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
static void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k,
|
||||
const int64_t n, const int64_t o0, const int64_t o1,
|
||||
@@ -845,7 +945,7 @@ static void swiglu_oai_sycl(const T * x,
|
||||
const int64_t num_blocks = (k + SYCL_GLU_BLOCK_SIZE - 1) / SYCL_GLU_BLOCK_SIZE;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
swiglu_oai_kernel(x, g, dst, k, n, o0, o1, alpha, limit, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -899,7 +999,8 @@ static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -910,7 +1011,8 @@ static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggm
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -2813,8 +2813,6 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
llama_sampler_softmax_impl(cur_p, true);
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
|
||||
|
||||
@@ -360,9 +360,9 @@ int main(void) {
|
||||
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.032727f, 0.241818f, 0.241818f}, 2.0f, 1.1f, 2, 5, {});
|
||||
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
|
||||
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1.00f);
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.0f, 0.0f, 0.428571f, 0.571429f}, 1.00f);
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 0.00f); // top_n_sigma == 0 now represents a no-op rather than greedy decoding as of PR#13345
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3.00f);
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 3.00f);
|
||||
|
||||
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
|
||||
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
|
||||
|
||||
@@ -204,9 +204,9 @@ Instead of building everything from the ground up (like what most AI agents will
|
||||
|
||||
The flow for downloading a new model:
|
||||
- POST request comes in --> `post_router_models` --> validation
|
||||
- `server_models::download()` is called
|
||||
- Sets up a new thread `inst.th` and runs the download inside
|
||||
- If a stop request comes in, set `stop_download` to `true`
|
||||
- A new `llama-server` subprocess will be spawned with special `SERVER_CHILD_MODE_DOWNLOAD`
|
||||
- Child process runs the download and report status back to router via stdin/out
|
||||
- If a stop request comes in, the router asks the child process to stop (same mechanism as running a model in child process)
|
||||
- Otherwise, upon completion, we call `load_models()` to refresh the list of models
|
||||
|
||||
### Notable Related PRs
|
||||
|
||||
+12
-5
@@ -1230,8 +1230,6 @@ print(completion.choices[0].text)
|
||||
|
||||
Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only models with a [supported chat template](https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, the ChatML template will be used.
|
||||
|
||||
If model supports multimodal, you can input the media file via `image_url` content part. We support both base64 and remote URL as input. See OAI documentation for more.
|
||||
|
||||
*Options:*
|
||||
|
||||
See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). llama.cpp `/completion`-specific features such as `mirostat` are also supported.
|
||||
@@ -1250,9 +1248,18 @@ The `response_format` parameter supports both plain JSON output (e.g. `{"type":
|
||||
|
||||
`parallel_tool_calls` : Whether to enable parallel/multiple tool calls (only supported on some models, verification is based on jinja template).
|
||||
|
||||
For multimodal input:
|
||||
- Content type `image_url` and `input_audio` are the same as OAI schema
|
||||
- Content type `input_video` is an extension from OAI schema. For now, it only accepts base64 input
|
||||
For multimodal input (typed content, `messages[i].content[j]`):
|
||||
- If `type == "image_url"`:
|
||||
- `image_url.url` can be a remote URL, base64 (raw or URI-encoded via `data:image/...;base64`) or path to local file
|
||||
- Accepts formats supported by `stb_image` (jpeg, png, tga, bmp, gif, ...)
|
||||
- If `type == "input_audio"`:
|
||||
- Either `input_audio.data` or `input_audio.url` can be specified, can be a remote URL, raw base64 or path to local file
|
||||
- Accepts formats supported by `miniaudio` (mp3, wav, flac)
|
||||
- `input_audio.format` will be ignored, the file format will be determined automatically
|
||||
- If `type == "input_video"`:
|
||||
- Either `input_video.data` or `input_video.url` can be specified, can be a remote URL, raw base64 or path to local file
|
||||
- Accepts formats supported by `ffmpeg`
|
||||
- Note: for local file, make sure to set `--media-path`. File path must be prefixed by `file://`
|
||||
|
||||
*Examples:*
|
||||
|
||||
|
||||
@@ -817,12 +817,21 @@ json oaicompat_completion_params_parse(const json & body) {
|
||||
return llama_params;
|
||||
}
|
||||
|
||||
// media_path always end with '/', see arg.cpp
|
||||
// url can be
|
||||
// - http(s):// for remote files
|
||||
// - file:// for local files (only allowed if media_path is set)
|
||||
// - data: for base64 encoded data with uri scheme (e.g. data:image/png;base64,...)
|
||||
// - raw base64 encoded data
|
||||
static void handle_media(
|
||||
std::vector<raw_buffer> & out_files,
|
||||
json & media_obj,
|
||||
const std::string & media_path) {
|
||||
std::string url = json_value(media_obj, "url", std::string());
|
||||
const std::string & url,
|
||||
const std::string & media_path,
|
||||
bool accept_base64_uri) {
|
||||
if (!media_path.empty()) {
|
||||
// should already be enforced by arg.cpp, but checking just in case
|
||||
GGML_ASSERT(media_path.back() == DIRECTORY_SEPARATOR);
|
||||
}
|
||||
|
||||
if (string_starts_with(url, "http")) {
|
||||
// download remote image
|
||||
// TODO @ngxson : maybe make these params configurable
|
||||
@@ -858,20 +867,28 @@ static void handle_media(
|
||||
data.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
out_files.push_back(data);
|
||||
|
||||
} else {
|
||||
} else if (accept_base64_uri && string_starts_with(url, "data:")) {
|
||||
// try to decode base64 image
|
||||
std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
|
||||
if (parts.size() != 2) {
|
||||
throw std::runtime_error("Invalid url value");
|
||||
throw std::runtime_error("Invalid uri-encoded base64 value");
|
||||
} else if (!string_starts_with(parts[0], "data:image/")) {
|
||||
throw std::runtime_error("Invalid url format: " + parts[0]);
|
||||
throw std::runtime_error("Invalid uri format: " + parts[0]);
|
||||
} else if (!string_ends_with(parts[0], "base64")) {
|
||||
throw std::runtime_error("url must be base64 encoded");
|
||||
throw std::runtime_error("uri must be base64 encoded");
|
||||
} else {
|
||||
auto base64_data = parts[1];
|
||||
auto decoded_data = base64_decode(base64_data);
|
||||
out_files.push_back(decoded_data);
|
||||
}
|
||||
|
||||
} else {
|
||||
// try as raw base64 string
|
||||
auto decoded_data = base64_decode(url);
|
||||
if (decoded_data.empty()) {
|
||||
throw std::runtime_error("Invalid base64 value");
|
||||
}
|
||||
out_files.push_back(decoded_data);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -957,14 +974,15 @@ json oaicompat_chat_params_parse(
|
||||
}
|
||||
|
||||
for (auto & p : content) {
|
||||
std::string type = json_value(p, "type", std::string());
|
||||
std::string type = json_value(p, "type", std::string());
|
||||
if (type == "image_url") {
|
||||
if (!opt.allow_image) {
|
||||
throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
|
||||
}
|
||||
|
||||
json image_url = json_value(p, "image_url", json::object());
|
||||
handle_media(out_files, image_url, opt.media_path);
|
||||
std::string url = json_value(image_url, "url", std::string());
|
||||
handle_media(out_files, url, opt.media_path, true);
|
||||
|
||||
p["type"] = "media_marker";
|
||||
p["text"] = get_media_marker();
|
||||
@@ -975,17 +993,11 @@ json oaicompat_chat_params_parse(
|
||||
throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
|
||||
}
|
||||
|
||||
json input_audio = json_value(p, "input_audio", json::object());
|
||||
std::string data = json_value(input_audio, "data", std::string());
|
||||
std::string format = json_value(input_audio, "format", std::string());
|
||||
// while we also support flac, we don't allow it here so we matches the OAI spec
|
||||
if (format != "wav" && format != "mp3") {
|
||||
throw std::invalid_argument("input_audio.format must be either 'wav' or 'mp3'");
|
||||
}
|
||||
auto decoded_data = base64_decode(data); // expected to be base64 encoded
|
||||
out_files.push_back(decoded_data);
|
||||
|
||||
// TODO: add audio_url support by reusing handle_media()
|
||||
// note: don't need to validate "format", it's redundant
|
||||
json input_audio = json_value(p, "input_audio", json::object());
|
||||
std::string url = json_value(input_audio, "data",
|
||||
json_value(input_audio, "url", std::string()));
|
||||
handle_media(out_files, url, opt.media_path, false);
|
||||
|
||||
p["type"] = "media_marker";
|
||||
p["text"] = get_media_marker();
|
||||
@@ -996,10 +1008,10 @@ json oaicompat_chat_params_parse(
|
||||
throw std::runtime_error("video input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
|
||||
}
|
||||
|
||||
json input_video = json_value(p, "input_video", json::object());
|
||||
std::string data = json_value(input_video, "data", std::string());
|
||||
auto decoded_data = base64_decode(data); // expected to be base64 encoded
|
||||
out_files.push_back(decoded_data);
|
||||
json input_video = json_value(p, "input_video", json::object());
|
||||
std::string url = json_value(input_video, "data",
|
||||
json_value(input_video, "url", std::string()));
|
||||
handle_media(out_files, url, opt.media_path, false);
|
||||
|
||||
p["type"] = "media_marker";
|
||||
p["text"] = get_media_marker();
|
||||
|
||||
@@ -931,6 +931,8 @@ private:
|
||||
|
||||
bool sleeping = false;
|
||||
|
||||
int64_t t_last_load_progress_ms = 0;
|
||||
|
||||
void destroy() {
|
||||
spec.reset();
|
||||
ctx_dft.reset();
|
||||
@@ -1244,6 +1246,10 @@ private:
|
||||
}
|
||||
|
||||
if (has_mmproj) {
|
||||
if (callback_state) {
|
||||
callback_state(SERVER_STATE_LOADING, {{"stage", "mmproj_model"}});
|
||||
}
|
||||
|
||||
if (!is_resume) {
|
||||
mtmd_helper_log_set(common_log_default_callback, nullptr);
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ struct server_context_meta {
|
||||
};
|
||||
|
||||
enum server_state {
|
||||
// SERVER_STATE_DOWNLOADING,
|
||||
SERVER_STATE_DOWNLOADING,
|
||||
SERVER_STATE_LOADING,
|
||||
SERVER_STATE_READY,
|
||||
SERVER_STATE_SLEEPING,
|
||||
@@ -61,6 +61,7 @@ enum server_state {
|
||||
|
||||
static std::string server_state_to_str(server_state state) {
|
||||
switch (state) {
|
||||
case SERVER_STATE_DOWNLOADING: return "downloading";
|
||||
case SERVER_STATE_LOADING: return "loading";
|
||||
case SERVER_STATE_READY: return "ready";
|
||||
case SERVER_STATE_SLEEPING: return "sleeping";
|
||||
@@ -69,6 +70,7 @@ static std::string server_state_to_str(server_state state) {
|
||||
}
|
||||
|
||||
static server_state server_state_from_str(const std::string & str) {
|
||||
if (str == "downloading") return SERVER_STATE_DOWNLOADING;
|
||||
if (str == "loading") return SERVER_STATE_LOADING;
|
||||
if (str == "ready") return SERVER_STATE_READY;
|
||||
if (str == "sleeping") return SERVER_STATE_SLEEPING;
|
||||
|
||||
+230
-130
@@ -64,6 +64,17 @@ struct server_subproc {
|
||||
return sproc.has_value() && subprocess_alive(&sproc.value());
|
||||
}
|
||||
|
||||
void request_exit() {
|
||||
if (sproc.has_value()) {
|
||||
FILE * stdin_file = subprocess_stdin(&sproc.value());
|
||||
if (stdin_file) {
|
||||
fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
|
||||
fflush(stdin_file);
|
||||
}
|
||||
}
|
||||
stopped.store(true, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
void terminate() {
|
||||
if (!sproc.has_value()) {
|
||||
return;
|
||||
@@ -323,7 +334,7 @@ void server_models::notify_sse(const std::string & event, const std::string & mo
|
||||
}
|
||||
|
||||
void server_models::load_models() {
|
||||
// Phase 1: load presets from all sources — pure I/O, no lock needed
|
||||
// Phase 1: load presets from all sources - pure I/O, no lock needed
|
||||
// 1. cached models
|
||||
common_presets cached_models = ctx_preset.load_from_cache();
|
||||
SRV_INF("Loaded %zu cached model presets\n", cached_models.size());
|
||||
@@ -376,7 +387,7 @@ void server_models::load_models() {
|
||||
return source_map.count(name) ? source_map.at(name) : SERVER_MODEL_SOURCE_PRESET;
|
||||
};
|
||||
|
||||
// Helpers that read `mapping` — must be called while holding the lock.
|
||||
// Helpers that read `mapping` - must be called while holding the lock.
|
||||
std::unordered_set<std::string> custom_names;
|
||||
for (const auto & [name, preset] : custom_presets) custom_names.insert(name);
|
||||
auto join_set = [](const std::set<std::string> & s) {
|
||||
@@ -523,7 +534,7 @@ void server_models::load_models() {
|
||||
}
|
||||
}
|
||||
|
||||
// join outside the lock — monitoring thread calls update_status (needs lock)
|
||||
// join outside the lock - monitoring thread calls update_status (needs lock)
|
||||
lk.unlock();
|
||||
for (auto & th : threads_to_join) th.join();
|
||||
lk.lock();
|
||||
@@ -622,7 +633,7 @@ void server_models::load_models() {
|
||||
|
||||
apply_stop_timeout();
|
||||
|
||||
// clear reload flag before unlocking for autoload — load() blocks on !is_reloading,
|
||||
// clear reload flag before unlocking for autoload - load() blocks on !is_reloading,
|
||||
// so clearing it here (while still locked) prevents a deadlock in the autoload calls below
|
||||
is_reloading = false;
|
||||
cv.notify_all();
|
||||
@@ -815,17 +826,23 @@ void server_models::unload_lru() {
|
||||
}
|
||||
|
||||
void server_models::load(const std::string & name) {
|
||||
if (!has_model(name)) {
|
||||
throw std::runtime_error("model name=" + name + " is not found");
|
||||
load(name, load_options{});
|
||||
}
|
||||
|
||||
void server_models::load(const std::string & name, const load_options & opts) {
|
||||
if (!opts.custom_meta.has_value()) {
|
||||
if (!has_model(name)) {
|
||||
throw std::runtime_error("model name=" + name + " is not found");
|
||||
}
|
||||
unload_lru();
|
||||
}
|
||||
unload_lru();
|
||||
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
// edge case: block until any in-progress reload has finished so we always load
|
||||
// against the freshest preset and a consistent mapping state
|
||||
cv.wait(lk, [this]() { return !is_reloading; });
|
||||
|
||||
auto meta = mapping[name].meta;
|
||||
auto meta = opts.custom_meta.has_value() ? *opts.custom_meta : mapping[name].meta;
|
||||
if (meta.status != SERVER_MODEL_STATUS_UNLOADED) {
|
||||
SRV_INF("model %s is not ready\n", name.c_str());
|
||||
return;
|
||||
@@ -869,6 +886,12 @@ void server_models::load(const std::string & name) {
|
||||
std::vector<std::string> child_env = base_env; // copy
|
||||
child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port));
|
||||
|
||||
if (opts.mode == SERVER_CHILD_MODE_DOWNLOAD) {
|
||||
inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING;
|
||||
child_env.push_back("LLAMA_SERVER_CHILD_MODE=download");
|
||||
child_env.push_back("LLAMA_ARG_HF_REPO=" + name);
|
||||
}
|
||||
|
||||
SRV_INF("%s", "spawning server instance with args:\n");
|
||||
for (const auto & arg : child_args) {
|
||||
SRV_INF(" %s\n", arg.c_str());
|
||||
@@ -886,13 +909,17 @@ void server_models::load(const std::string & name) {
|
||||
if (result != 0) {
|
||||
throw std::runtime_error("failed to spawn server instance");
|
||||
}
|
||||
|
||||
inst.stdin_file = subprocess_stdin(&inst.subproc->get());
|
||||
}
|
||||
|
||||
// start a thread to manage the child process
|
||||
// captured variables are guaranteed to be destroyed only after the thread is joined
|
||||
inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port, stop_timeout = inst.meta.stop_timeout]() {
|
||||
inst.th = std::thread([
|
||||
this, name,
|
||||
child_proc = inst.subproc,
|
||||
port = inst.meta.port,
|
||||
stop_timeout = inst.meta.stop_timeout,
|
||||
child_mode = opts.mode
|
||||
]() {
|
||||
FILE * stdin_file = subprocess_stdin(&child_proc->get());
|
||||
FILE * stdout_file = subprocess_stdout(&child_proc->get()); // combined stdout/stderr
|
||||
|
||||
@@ -925,7 +952,7 @@ void server_models::load(const std::string & name) {
|
||||
return is_stopping() || child_proc->stopped.load(std::memory_order_acquire);
|
||||
});
|
||||
}
|
||||
// child crashed or finished on its own — skip graceful shutdown sequence
|
||||
// child crashed or finished on its own, skip graceful shutdown sequence
|
||||
if (child_proc->stopped.load(std::memory_order_acquire)) {
|
||||
return;
|
||||
}
|
||||
@@ -973,10 +1000,14 @@ void server_models::load(const std::string & name) {
|
||||
subprocess_destroy(&child_proc->get());
|
||||
|
||||
// update status and exit code
|
||||
this->update_status(name, {
|
||||
SERVER_MODEL_STATUS_UNLOADED,
|
||||
exit_code
|
||||
});
|
||||
if (child_mode == SERVER_CHILD_MODE_DOWNLOAD) {
|
||||
// instance will be cleaned up on next load_models() call
|
||||
} else {
|
||||
this->update_status(name, {
|
||||
SERVER_MODEL_STATUS_UNLOADED,
|
||||
exit_code
|
||||
});
|
||||
}
|
||||
SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code);
|
||||
});
|
||||
|
||||
@@ -984,7 +1015,7 @@ void server_models::load(const std::string & name) {
|
||||
{
|
||||
auto & old_instance = mapping[name];
|
||||
// old process should have exited already, but just in case, we clean it up here
|
||||
if (old_instance.subproc->is_alive()) {
|
||||
if (old_instance.subproc && old_instance.subproc->is_alive()) {
|
||||
SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
|
||||
old_instance.subproc->terminate(); // force kill
|
||||
}
|
||||
@@ -1001,92 +1032,13 @@ void server_models::load(const std::string & name) {
|
||||
cv.notify_all();
|
||||
}
|
||||
|
||||
// callback for model downloading functionality
|
||||
struct server_models_download_res : public common_download_callback {
|
||||
common_params_model model;
|
||||
common_download_opts opts;
|
||||
|
||||
std::function<bool()> should_stop;
|
||||
std::function<void(const common_download_progress & p)> on_progress;
|
||||
|
||||
bool is_ok = false;
|
||||
|
||||
bool run() {
|
||||
try {
|
||||
common_download_model(model, opts);
|
||||
is_ok = true;
|
||||
} catch (const std::exception & e) {
|
||||
auto model_name = model.get_name();
|
||||
SRV_ERR("download failed for model name=%s: %s\n", model_name.c_str(), e.what());
|
||||
is_ok = false;
|
||||
}
|
||||
return is_ok;
|
||||
}
|
||||
void on_start(const common_download_progress & p) override {
|
||||
on_progress(p);
|
||||
}
|
||||
void on_update(const common_download_progress & p) override {
|
||||
on_progress(p);
|
||||
}
|
||||
void on_done(const common_download_progress &, bool ok) override {
|
||||
is_ok = ok;
|
||||
}
|
||||
bool is_cancelled() const override {
|
||||
return should_stop();
|
||||
}
|
||||
};
|
||||
|
||||
void server_models::download(common_params_model && model, common_download_opts && opts) {
|
||||
std::string name = model.get_name();
|
||||
GGML_ASSERT(name == model.hf_repo);
|
||||
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
if (mapping.find(name) != mapping.end()) {
|
||||
throw std::runtime_error("model name=" + name + " already exists");
|
||||
}
|
||||
|
||||
instance_t inst;
|
||||
inst.meta.name = name;
|
||||
inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING;
|
||||
inst.subproc = std::make_shared<server_subproc>();
|
||||
|
||||
auto dl = std::make_unique<server_models_download_res>();
|
||||
dl->model = model; // copy
|
||||
dl->opts = opts; // copy
|
||||
|
||||
dl->should_stop = [sp = inst.subproc]() {
|
||||
return sp->stopped.load(std::memory_order_relaxed);
|
||||
};
|
||||
|
||||
dl->on_progress = [this, name](const common_download_progress & p) {
|
||||
update_download_progress(name, p, false);
|
||||
};
|
||||
|
||||
inst.th = std::thread([this, dl = std::move(dl)]() {
|
||||
dl->opts.callback = dl.get();
|
||||
bool ok = dl->run();
|
||||
auto model_name = dl->model.get_name();
|
||||
SRV_INF("download finished for model name=%s with status=%s\n",
|
||||
model_name.c_str(), ok ? "success" : "failure");
|
||||
update_download_progress(model_name, {}, true, ok);
|
||||
// need_reload is set inside update_download_progress under the mutex;
|
||||
// the next load_models() call will clean up this instance
|
||||
});
|
||||
|
||||
mapping[name] = std::move(inst);
|
||||
notify_sse("status_update", name, {
|
||||
{"status", server_model_status_to_string(SERVER_MODEL_STATUS_DOWNLOADING)},
|
||||
});
|
||||
cv.notify_all();
|
||||
}
|
||||
|
||||
void server_models::unload(const std::string & name) {
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
auto it = mapping.find(name);
|
||||
if (it != mapping.end()) {
|
||||
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
|
||||
SRV_INF("cancelling download for model name=%s\n", name.c_str());
|
||||
it->second.subproc->stopped.store(true, std::memory_order_relaxed);
|
||||
it->second.subproc->request_exit();
|
||||
// for convenience, we wait the status change here
|
||||
wait(lk, name, [](const server_model_meta & new_meta) {
|
||||
return new_meta.status != SERVER_MODEL_STATUS_DOWNLOADING;
|
||||
@@ -1198,37 +1150,65 @@ void server_models::update_download_progress(const std::string & name, const com
|
||||
}
|
||||
|
||||
bool server_models::remove(const std::string & name) {
|
||||
auto meta = get_meta(name);
|
||||
// do everything under one lock acquisition; avoid get_meta() /
|
||||
// unload() because they can trigger load_models() which erases
|
||||
// transient DOWNLOADING / DOWNLOADED entries as a side-effect
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
|
||||
if (!meta.has_value()) {
|
||||
auto it = mapping.find(name);
|
||||
if (it == mapping.end()) {
|
||||
throw std::runtime_error("model name=" + name + " is not found");
|
||||
}
|
||||
if (meta->source != SERVER_MODEL_SOURCE_CACHE) {
|
||||
if (it->second.meta.source != SERVER_MODEL_SOURCE_CACHE) {
|
||||
throw std::runtime_error("model name=" + name + " is not removable (not from cache)");
|
||||
}
|
||||
|
||||
unload(name); // cancel download or stop running instance
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
// a cancelled download lands on DOWNLOADED; a stopped instance lands on UNLOADED
|
||||
wait(lk, name, [](const server_model_meta & new_meta) {
|
||||
return new_meta.status == SERVER_MODEL_STATUS_UNLOADED
|
||||
|| new_meta.status == SERVER_MODEL_STATUS_DOWNLOADED;
|
||||
});
|
||||
// join before erasing - after status reaches UNLOADED/DOWNLOADED the thread no
|
||||
// longer acquires this mutex, so joining while holding it is safe
|
||||
if (mapping[name].th.joinable()) {
|
||||
mapping[name].th.join();
|
||||
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
|
||||
// cancel in-flight download
|
||||
SRV_INF("cancelling download for model name=%s\n", name.c_str());
|
||||
it->second.subproc->request_exit();
|
||||
} else if (it->second.meta.is_running()) {
|
||||
// stop running instance
|
||||
SRV_INF("stopping model instance name=%s\n", name.c_str());
|
||||
stopping_models.insert(name);
|
||||
if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) {
|
||||
it->second.subproc->terminate();
|
||||
}
|
||||
// remove the model from disk (hold lock to prevent concurrent load)
|
||||
bool ok = common_download_remove(name);
|
||||
if (ok) {
|
||||
mapping.erase(name);
|
||||
}
|
||||
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "failed");
|
||||
notify_sse("model_remove", name, {});
|
||||
return ok;
|
||||
cv_stop.notify_all();
|
||||
}
|
||||
|
||||
// wait until the monitoring thread finishes
|
||||
wait(lk, name, [](const server_model_meta & meta) {
|
||||
return meta.status == SERVER_MODEL_STATUS_UNLOADED
|
||||
|| meta.status == SERVER_MODEL_STATUS_DOWNLOADED;
|
||||
});
|
||||
|
||||
// re-find after wait - load_models() may have erased the entry during the wait
|
||||
it = mapping.find(name);
|
||||
if (it == mapping.end()) {
|
||||
// load_models() already joined the thread and erased the entry;
|
||||
// we just need to clean up the cached files on disk
|
||||
lk.unlock();
|
||||
bool ok = common_download_remove(name);
|
||||
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial");
|
||||
notify_sse("model_remove", name, {});
|
||||
return true;
|
||||
}
|
||||
|
||||
// join before erasing - thread no longer acquires this mutex
|
||||
if (it->second.th.joinable()) {
|
||||
it->second.th.join();
|
||||
}
|
||||
|
||||
// remove from disk (best-effort: cancelled downloads may have no cached files)
|
||||
bool ok = common_download_remove(name);
|
||||
mapping.erase(name);
|
||||
if (!ok) {
|
||||
SRV_WRN("removing model name=%s from disk returned false (no cached files?)\n", name.c_str());
|
||||
}
|
||||
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial");
|
||||
notify_sse("model_remove", name, {});
|
||||
return true;
|
||||
}
|
||||
|
||||
void server_models::wait(const std::string & name, std::function<bool(const server_model_meta &)> predicate) {
|
||||
@@ -1243,7 +1223,9 @@ void server_models::wait(std::unique_lock<std::mutex> & lk, const std::string &
|
||||
return predicate(it->second.meta);
|
||||
|
||||
}
|
||||
return false;
|
||||
// model was removed from mapping by another code path (e.g. load_models()).
|
||||
// nothing left to wait for - tell the caller to proceed.
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1328,6 +1310,31 @@ void server_models::handle_child_state(const std::string & name, const std::stri
|
||||
}
|
||||
|
||||
switch (state) {
|
||||
case SERVER_STATE_DOWNLOADING:
|
||||
{
|
||||
std::string result = json_value(payload, "result", std::string());
|
||||
std::string url = json_value(payload, "url", std::string());
|
||||
auto request_exit = [&]() {
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
auto it = mapping.find(name);
|
||||
if (it != mapping.end()) {
|
||||
return it->second.subproc->request_exit();
|
||||
}
|
||||
};
|
||||
if (result == "download_finished") {
|
||||
update_download_progress(name, {}, true, true);
|
||||
request_exit();
|
||||
} else if (result == "download_failed") {
|
||||
update_download_progress(name, {}, true, false);
|
||||
request_exit();
|
||||
} else if (!url.empty()) {
|
||||
common_download_progress p;
|
||||
p.url = url;
|
||||
p.downloaded = json_value(payload, "downloaded", (size_t)0);
|
||||
p.total = json_value(payload, "total", (size_t)0);
|
||||
update_download_progress(name, p, false);
|
||||
}
|
||||
} break;
|
||||
case SERVER_STATE_LOADING:
|
||||
{
|
||||
update_status(name, {
|
||||
@@ -1366,6 +1373,90 @@ bool server_child::is_child() {
|
||||
return router_port != nullptr;
|
||||
}
|
||||
|
||||
server_child_mode server_child::get_mode() {
|
||||
const char * mode = std::getenv("LLAMA_SERVER_CHILD_MODE");
|
||||
std::string mode_str(mode ? mode : "");
|
||||
if (mode_str == "download") {
|
||||
return SERVER_CHILD_MODE_DOWNLOAD;
|
||||
} else {
|
||||
return SERVER_CHILD_MODE_NORMAL;
|
||||
}
|
||||
}
|
||||
|
||||
struct server_download_state : public common_download_callback {
|
||||
server_child * self;
|
||||
std::function<bool()> should_stop;
|
||||
std::atomic<int64_t> last_progress_time{0}; // multiple files downloading in different threads
|
||||
bool is_ok = false;
|
||||
|
||||
server_download_state(server_child * s) : self(s) {}
|
||||
|
||||
bool run(common_params & params) {
|
||||
try {
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, this);
|
||||
is_ok = true;
|
||||
} catch (const std::exception & e) {
|
||||
auto model_name = params.model.get_name();
|
||||
SRV_ERR("download failed for model name=%s: %s\n", model_name.c_str(), e.what());
|
||||
is_ok = false;
|
||||
}
|
||||
return is_ok;
|
||||
}
|
||||
void on_progress(const common_download_progress & p) {
|
||||
json data = {
|
||||
{"url", p.url},
|
||||
{"downloaded", p.downloaded},
|
||||
{"total", p.total},
|
||||
};
|
||||
self->notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), data);
|
||||
}
|
||||
void on_start(const common_download_progress & p) override {
|
||||
on_progress(p);
|
||||
}
|
||||
void on_update(const common_download_progress & p) override {
|
||||
int64_t now = ggml_time_ms();
|
||||
// throttle progress updates to avoid flooding logs
|
||||
if (now - last_progress_time.load(std::memory_order_relaxed) >= 100) {
|
||||
on_progress(p);
|
||||
last_progress_time.store(now, std::memory_order_relaxed);
|
||||
}
|
||||
}
|
||||
void on_done(const common_download_progress & p, bool) override {
|
||||
on_progress(p);
|
||||
}
|
||||
bool is_cancelled() const override {
|
||||
return should_stop ? should_stop() : false;
|
||||
}
|
||||
};
|
||||
|
||||
int server_child::run_download(common_params & params) {
|
||||
auto cancelled = std::make_shared<std::atomic<bool>>(false);
|
||||
|
||||
// monitor stdin for cancellation command from the router
|
||||
std::thread signal_thread = setup([cancelled](int) {
|
||||
cancelled->store(true, std::memory_order_relaxed);
|
||||
});
|
||||
|
||||
server_download_state dl(this);
|
||||
dl.should_stop = [cancelled]() {
|
||||
return cancelled->load(std::memory_order_relaxed);
|
||||
};
|
||||
|
||||
bool ok = dl.run(params);
|
||||
|
||||
notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), {
|
||||
{"result", ok ? "download_finished" : "download_failed"},
|
||||
});
|
||||
|
||||
// router should send CMD_ROUTER_TO_CHILD_EXIT after receiving the result
|
||||
if (signal_thread.joinable()) {
|
||||
signal_thread.join();
|
||||
}
|
||||
|
||||
SRV_INF("download completed %s\n", ok ? "successfully" : "with errors");
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::thread server_child::setup(const std::function<void(int)> & shutdown_handler) {
|
||||
// setup thread for monitoring stdin
|
||||
return std::thread([shutdown_handler]() {
|
||||
@@ -1639,7 +1730,7 @@ void server_models_routes::init_routes() {
|
||||
res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
if (!model->is_running()) {
|
||||
if (!model->is_running() && model->status != SERVER_MODEL_STATUS_DOWNLOADING) {
|
||||
res_err(res, format_error_response("model is not running", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
@@ -1680,8 +1771,9 @@ void server_models_routes::init_routes() {
|
||||
|
||||
model.hf_repo = name;
|
||||
opts.bearer_token = params.hf_token;
|
||||
opts.download_mmproj = true;
|
||||
opts.download_mtp = true;
|
||||
// note: we only check main model, no need sidecar here
|
||||
opts.download_mmproj = false;
|
||||
opts.download_mtp = false;
|
||||
|
||||
// first, only check if the model is valid and can be downloaded
|
||||
opts.skip_download = true;
|
||||
@@ -1702,10 +1794,21 @@ void server_models_routes::init_routes() {
|
||||
throw std::invalid_argument("model validation failed, unable to download");
|
||||
}
|
||||
|
||||
// reject if model already exists
|
||||
if (models.has_model(name)) {
|
||||
throw std::invalid_argument("model '" + name + "' already exists");
|
||||
}
|
||||
|
||||
// then, proceed with the actual download
|
||||
opts.skip_download = false;
|
||||
SRV_INF("starting download for model '%s'\n", name.c_str());
|
||||
models.download(std::move(model), std::move(opts));
|
||||
{
|
||||
server_models::load_options load_opts;
|
||||
load_opts.mode = SERVER_CHILD_MODE_DOWNLOAD;
|
||||
load_opts.custom_meta = server_model_meta{};
|
||||
load_opts.custom_meta->source = SERVER_MODEL_SOURCE_CACHE;
|
||||
load_opts.custom_meta->name = name;
|
||||
models.load(name, load_opts);
|
||||
}
|
||||
|
||||
res_ok(res, {{"success", true}});
|
||||
return res;
|
||||
@@ -1719,10 +1822,7 @@ void server_models_routes::init_routes() {
|
||||
throw std::invalid_argument("model must be a non-empty string");
|
||||
}
|
||||
|
||||
bool ok = models.remove(name);
|
||||
if (!ok) {
|
||||
throw std::runtime_error("failed to remove model '" + name + "'");
|
||||
}
|
||||
models.remove(name); // throws on error
|
||||
|
||||
res_ok(res, {{"success", true}});
|
||||
return res;
|
||||
|
||||
@@ -40,6 +40,11 @@ enum server_model_source {
|
||||
SERVER_MODEL_SOURCE_CACHE,
|
||||
};
|
||||
|
||||
enum server_child_mode {
|
||||
SERVER_CHILD_MODE_NORMAL, // load the model and run normally
|
||||
SERVER_CHILD_MODE_DOWNLOAD, // download the model and exit
|
||||
};
|
||||
|
||||
static std::string server_model_status_to_string(server_model_status status) {
|
||||
switch (status) {
|
||||
case SERVER_MODEL_STATUS_DOWNLOADING: return "downloading";
|
||||
@@ -105,7 +110,6 @@ private:
|
||||
std::shared_ptr<server_subproc> subproc; // shared between main thread and monitoring thread
|
||||
std::thread th;
|
||||
server_model_meta meta;
|
||||
FILE * stdin_file = nullptr;
|
||||
};
|
||||
|
||||
std::mutex mutex;
|
||||
@@ -161,16 +165,19 @@ public:
|
||||
// return a copy of all model metadata (thread-safe)
|
||||
std::vector<server_model_meta> get_all_meta();
|
||||
|
||||
struct load_options {
|
||||
server_child_mode mode = SERVER_CHILD_MODE_NORMAL;
|
||||
// used for spawning a downloading child process
|
||||
std::optional<server_model_meta> custom_meta = std::nullopt;
|
||||
};
|
||||
|
||||
// load and unload model instances
|
||||
// these functions are thread-safe
|
||||
void load(const std::string & name);
|
||||
void load(const std::string & name, const load_options & opts);
|
||||
void unload(const std::string & name);
|
||||
void unload_all();
|
||||
|
||||
// download a new model, progress is reported via SSE
|
||||
// to stop the download, call unload()
|
||||
void download(common_params_model && model, common_download_opts && opts);
|
||||
|
||||
struct update_status_args {
|
||||
server_model_status status;
|
||||
int exit_code = 0; // only valid if status == UNLOADED
|
||||
@@ -213,9 +220,12 @@ public:
|
||||
struct server_child {
|
||||
// serializes the notify_to_router writes
|
||||
std::mutex mtx_stdout;
|
||||
std::atomic<bool> is_finished_downloading = false; // set by run_download
|
||||
|
||||
// return true if the current process is a child server instance
|
||||
bool is_child();
|
||||
server_child_mode get_mode();
|
||||
int run_download(common_params & params);
|
||||
|
||||
// register the shutdown_handler to be called by the router
|
||||
// return the monitoring thread (to be joined by the caller)
|
||||
|
||||
@@ -569,9 +569,13 @@ struct server_tool_edit_file : server_tool {
|
||||
}
|
||||
int n = (int) lines.size();
|
||||
if (e.line_start == -1) {
|
||||
// -1 means end of file; line_end is ignored — normalize to point past last line
|
||||
e.line_start = n + 1;
|
||||
e.line_end = n + 1;
|
||||
// -1 targets end of file -> valid for append only; line_end is ignored
|
||||
if (e.mode != "append") {
|
||||
return {{"error", "line_start -1 (end of file) is only valid for append mode"}};
|
||||
}
|
||||
// append at end of file: insert position is the current line count
|
||||
e.line_start = n;
|
||||
e.line_end = n;
|
||||
} else {
|
||||
if (e.line_start < 1 || e.line_end < e.line_start) {
|
||||
return {{"error", string_format("invalid line range [%d, %d]", e.line_start, e.line_end)}};
|
||||
@@ -612,8 +616,8 @@ struct server_tool_edit_file : server_tool {
|
||||
} else if (e.mode == "delete") {
|
||||
lines.erase(lines.begin() + idx_start, lines.begin() + idx_end + 1);
|
||||
} else { // append
|
||||
// idx_end + 1 may equal lines.size() when line_start == -1 (end of file)
|
||||
lines.insert(lines.begin() + idx_end + 1, new_lines.begin(), new_lines.end());
|
||||
// insert after idx_end; idx_end + 1 == lines.size() for end-of-file append
|
||||
lines.insert(lines.begin() + (idx_end + 1), new_lines.begin(), new_lines.end());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+12
-1
@@ -134,6 +134,7 @@ int llama_server(int argc, char ** argv) {
|
||||
//
|
||||
|
||||
// register API routes
|
||||
server_child child; // only used in non-router mode
|
||||
server_routes routes(params, ctx_server);
|
||||
server_tools tools;
|
||||
|
||||
@@ -254,11 +255,21 @@ int llama_server(int argc, char ** argv) {
|
||||
ctx_http.post("/tools", ex_wrapper(tools.handle_post));
|
||||
}
|
||||
|
||||
//
|
||||
// Handle downloading model
|
||||
//
|
||||
|
||||
if (child.is_child() && child.get_mode() == SERVER_CHILD_MODE_DOWNLOAD) {
|
||||
return child.run_download(params);
|
||||
} else if (!is_router_server) {
|
||||
// single-model mode (NOT spawned by router)
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
|
||||
}
|
||||
|
||||
//
|
||||
// Start the server
|
||||
//
|
||||
|
||||
server_child child; // only used in non-router mode
|
||||
std::function<void()> clean_up;
|
||||
|
||||
if (is_router_server) {
|
||||
|
||||
@@ -257,14 +257,25 @@ def test_router_reload_models():
|
||||
|
||||
|
||||
MODEL_DOWNLOAD_ID = "ggml-org/test-model-router-download:F16"
|
||||
MODEL_DOWNLOAD_TIMEOUT = 300
|
||||
MODEL_DOWNLOAD_TIMEOUT = 30
|
||||
|
||||
|
||||
def _listen_sse(server: ServerProcess, collected: list, stop: threading.Event):
|
||||
"""Collect /models/sse events into `collected` until `stop` is set."""
|
||||
def _listen_sse(
|
||||
server: ServerProcess, collected: list, stop: threading.Event, ready: threading.Event | None = None
|
||||
):
|
||||
"""Collect /models/sse events into `collected` until `stop` is set.
|
||||
|
||||
When `ready` is provided, it is set once the streaming response is open,
|
||||
i.e. the server has accepted the connection and registered us as a
|
||||
subscriber. Callers that trigger one-shot events (e.g. download_finished)
|
||||
must wait on `ready` before acting, otherwise the event can be broadcast
|
||||
before this client is subscribed and be lost.
|
||||
"""
|
||||
url = f"http://{server.server_host}:{server.server_port}/models/sse"
|
||||
try:
|
||||
with requests.get(url, stream=True, timeout=MODEL_DOWNLOAD_TIMEOUT) as resp:
|
||||
if ready is not None:
|
||||
ready.set()
|
||||
for line_bytes in resp.iter_lines():
|
||||
if stop.is_set():
|
||||
break
|
||||
@@ -294,11 +305,17 @@ def test_router_download_model():
|
||||
|
||||
sse_events: list = []
|
||||
stop = threading.Event()
|
||||
sse_ready = threading.Event()
|
||||
sse_thread = threading.Thread(
|
||||
target=_listen_sse, args=(server, sse_events, stop), daemon=True
|
||||
target=_listen_sse, args=(server, sse_events, stop, sse_ready), daemon=True
|
||||
)
|
||||
sse_thread.start()
|
||||
|
||||
# wait for the SSE client to be subscribed before triggering the download,
|
||||
# otherwise the one-shot download_finished event can be broadcast before
|
||||
# this client is registered and be lost
|
||||
assert sse_ready.wait(10), "SSE client failed to connect"
|
||||
|
||||
# Trigger the download
|
||||
res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID})
|
||||
assert res.status_code == 200
|
||||
@@ -328,13 +345,17 @@ def test_router_delete_model():
|
||||
|
||||
# Ensure the model exists (download it if needed)
|
||||
if MODEL_DOWNLOAD_ID not in _get_model_ids(is_reload=False):
|
||||
res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID})
|
||||
assert res.status_code == 200
|
||||
sse_events: list = []
|
||||
stop = threading.Event()
|
||||
sse_ready = threading.Event()
|
||||
threading.Thread(
|
||||
target=_listen_sse, args=(server, sse_events, stop), daemon=True
|
||||
target=_listen_sse, args=(server, sse_events, stop, sse_ready), daemon=True
|
||||
).start()
|
||||
# subscribe before triggering the download so the one-shot
|
||||
# download_finished event is not lost (see test_router_download_model)
|
||||
assert sse_ready.wait(10), "SSE client failed to connect"
|
||||
res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID})
|
||||
assert res.status_code == 200
|
||||
finished = _wait_for_sse_event(
|
||||
sse_events, "download_finished", MODEL_DOWNLOAD_ID, MODEL_DOWNLOAD_TIMEOUT
|
||||
)
|
||||
|
||||
Vendored
+10
@@ -19,6 +19,10 @@ import type {
|
||||
ApiErrorResponse,
|
||||
ApiLlamaCppServerProps,
|
||||
ApiModelDataEntry,
|
||||
ApiModelLoadStage,
|
||||
ApiModelsSseProgress,
|
||||
ApiModelsSseData,
|
||||
ApiModelsSseEvent,
|
||||
ApiModelListResponse,
|
||||
ApiProcessingState,
|
||||
ApiRouterModelMeta,
|
||||
@@ -52,6 +56,7 @@ import type {
|
||||
// Model types
|
||||
ModelModalities,
|
||||
ModelOption,
|
||||
ModelLoadProgress,
|
||||
// Settings types
|
||||
SettingsChatServiceOptions,
|
||||
SettingsConfigValue,
|
||||
@@ -83,6 +88,10 @@ declare global {
|
||||
ApiErrorResponse,
|
||||
ApiLlamaCppServerProps,
|
||||
ApiModelDataEntry,
|
||||
ApiModelLoadStage,
|
||||
ApiModelsSseProgress,
|
||||
ApiModelsSseData,
|
||||
ApiModelsSseEvent,
|
||||
ApiModelListResponse,
|
||||
ApiProcessingState,
|
||||
ApiRouterModelMeta,
|
||||
@@ -120,6 +129,7 @@ declare global {
|
||||
// Model types
|
||||
ModelModalities,
|
||||
ModelOption,
|
||||
ModelLoadProgress,
|
||||
// Settings types
|
||||
SettingsChatServiceOptions,
|
||||
SettingsConfigValue,
|
||||
|
||||
+12
-3
@@ -10,7 +10,7 @@
|
||||
import { getMessageEditContext } from '$lib/contexts';
|
||||
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
|
||||
import { isLoading, isChatStreaming } from '$lib/stores/chat.svelte';
|
||||
import { copyToClipboard, deriveAgenticSections } from '$lib/utils';
|
||||
import { copyToClipboard, deriveAgenticSections, modelLoadProgressText } from '$lib/utils';
|
||||
import { AgenticSectionType } from '$lib/enums';
|
||||
import { REASONING_TAGS } from '$lib/constants/agentic';
|
||||
import { tick } from 'svelte';
|
||||
@@ -185,6 +185,13 @@
|
||||
let hasNoContent = $derived(!message?.content?.trim());
|
||||
let isActivelyProcessing = $derived(isCurrentlyLoading || isStreaming);
|
||||
|
||||
// during a router auto-load the message has no model yet, so target the selected one
|
||||
let loadTargetModel = $derived(message.model ?? modelsStore.selectedModelName);
|
||||
let modelLoadProgress = $derived(
|
||||
isRouter && loadTargetModel ? modelsStore.getLoadProgress(loadTargetModel) : null
|
||||
);
|
||||
let modelLoadingText = $derived(modelLoadProgressText(modelLoadProgress));
|
||||
|
||||
let showProcessingInfoTop = $derived(
|
||||
message?.role === MessageRole.ASSISTANT &&
|
||||
isActivelyProcessing &&
|
||||
@@ -220,7 +227,8 @@
|
||||
<div class="mt-6 w-full max-w-[48rem]" in:fade>
|
||||
<div class="processing-container">
|
||||
<span class="processing-text">
|
||||
{processingState.getPromptProgressText() ??
|
||||
{modelLoadingText ??
|
||||
processingState.getPromptProgressText() ??
|
||||
processingState.getProcessingMessage() ??
|
||||
'Processing...'}
|
||||
</span>
|
||||
@@ -252,7 +260,8 @@
|
||||
<div class="mt-4 w-full max-w-[48rem]" in:fade>
|
||||
<div class="processing-container">
|
||||
<span class="processing-text">
|
||||
{processingState.getPromptProgressText() ??
|
||||
{modelLoadingText ??
|
||||
processingState.getPromptProgressText() ??
|
||||
processingState.getProcessingMessage() ??
|
||||
'Processing...'}
|
||||
</span>
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
import type { ModelOption } from '$lib/types/models';
|
||||
import { ServerModelStatus } from '$lib/enums';
|
||||
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
|
||||
import { modelLoadFraction, modelLoadProgressText } from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
option: ModelOption;
|
||||
@@ -50,11 +51,15 @@
|
||||
(serverStatus === ServerModelStatus.LOADED || isSleeping) && !isOperationInProgress
|
||||
);
|
||||
let isLoading = $derived(serverStatus === ServerModelStatus.LOADING || isOperationInProgress);
|
||||
|
||||
let loadProgress = $derived(isLoading ? modelsStore.getLoadProgress(option.model) : null);
|
||||
let loadPercent = $derived(Math.round(modelLoadFraction(loadProgress) * 100));
|
||||
let loadTitle = $derived(modelLoadProgressText(loadProgress));
|
||||
</script>
|
||||
|
||||
<div
|
||||
class={[
|
||||
'group flex w-full items-center gap-2 rounded-sm p-2 text-left text-sm transition focus:outline-none',
|
||||
'group relative flex w-full items-center gap-2 rounded-sm p-2 text-left text-sm transition focus:outline-none',
|
||||
'cursor-pointer hover:bg-muted focus:bg-muted',
|
||||
(isSelected || isHighlighted) && 'bg-accent text-accent-foreground',
|
||||
!(isSelected || isHighlighted) && 'hover:bg-accent hover:text-accent-foreground',
|
||||
@@ -62,6 +67,7 @@
|
||||
]}
|
||||
role="option"
|
||||
aria-selected={isSelected || isHighlighted}
|
||||
title={loadTitle}
|
||||
tabindex="0"
|
||||
onclick={() => onSelect(option.id)}
|
||||
onmouseenter={onMouseEnter}
|
||||
@@ -188,4 +194,15 @@
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
{#if isLoading}
|
||||
<div
|
||||
class="pointer-events-none absolute inset-x-0 bottom-0 h-0.5 overflow-hidden rounded-b-sm bg-muted"
|
||||
>
|
||||
<div
|
||||
class="h-full bg-primary transition-[width] duration-200 ease-out"
|
||||
style="width: {loadPercent}%"
|
||||
></div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
export const API_MODELS = {
|
||||
LIST: '/v1/models',
|
||||
LOAD: '/models/load',
|
||||
UNLOAD: '/models/unload'
|
||||
UNLOAD: '/models/unload',
|
||||
SSE: '/models/sse'
|
||||
};
|
||||
|
||||
// chat completion routes, the control route drives realtime inference (e.g. end reasoning)
|
||||
|
||||
@@ -37,6 +37,8 @@ export * from './mcp-form';
|
||||
export * from './mcp-resource';
|
||||
export * from './message-export';
|
||||
export * from './model-id';
|
||||
export * from './model-loading';
|
||||
export * from './sse';
|
||||
export * from './precision';
|
||||
export * from './processing-info';
|
||||
export * from './pwa';
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
/**
|
||||
* Labels shown while a model loads, keyed by the stage reported on /models/sse.
|
||||
*/
|
||||
export const MODEL_LOAD_STAGE_LABELS: Record<ApiModelLoadStage, string> = {
|
||||
text_model: 'Loading weights',
|
||||
spec_model: 'Loading draft',
|
||||
mmproj_model: 'Loading projector'
|
||||
};
|
||||
|
||||
/**
|
||||
* Share of the bar reserved for each load phase after text_model.
|
||||
* text_model fills the rest, so a plain model reaches 100% on its own.
|
||||
*/
|
||||
export const MODEL_LOAD_TAIL_SHARE = 0.1;
|
||||
@@ -0,0 +1,16 @@
|
||||
/**
|
||||
* Server-sent events wire format, shared by the chat stream and the
|
||||
* /models/sse status feed (text/event-stream).
|
||||
*/
|
||||
|
||||
// blank line between two events
|
||||
export const SSE_RECORD_SEPARATOR = '\n\n';
|
||||
|
||||
// line break inside an event
|
||||
export const SSE_LINE_SEPARATOR = '\n';
|
||||
|
||||
// data field prefix, the value follows after an optional space
|
||||
export const SSE_DATA_PREFIX = 'data:';
|
||||
|
||||
// end-of-stream marker on the chat completion stream
|
||||
export const SSE_DONE_MARKER = '[DONE]';
|
||||
@@ -54,7 +54,7 @@ export {
|
||||
|
||||
export { ModelModality } from './model.enums';
|
||||
|
||||
export { ServerRole, ServerModelStatus } from './server.enums';
|
||||
export { ServerRole, ServerModelStatus, ServerModelsSseEventType } from './server.enums';
|
||||
|
||||
export { ParameterSource, SyncableParameterType, SettingsFieldType } from './settings.enums';
|
||||
|
||||
|
||||
@@ -19,3 +19,17 @@ export enum ServerModelStatus {
|
||||
SLEEPING = 'sleeping',
|
||||
FAILED = 'failed'
|
||||
}
|
||||
|
||||
/**
|
||||
* /models/sse event type enum - discriminates the records broadcast on the
|
||||
* model status feed in ROUTER mode. Matches the event names emitted by
|
||||
* tools/server/server-models.cpp from the C++ server.
|
||||
*/
|
||||
export enum ServerModelsSseEventType {
|
||||
STATUS_CHANGE = 'status_change',
|
||||
MODEL_STATUS = 'model_status',
|
||||
STATUS_UPDATE = 'status_update',
|
||||
MODELS_RELOAD = 'models_reload',
|
||||
MODEL_REMOVE = 'model_remove',
|
||||
DOWNLOAD_PROGRESS = 'download_progress'
|
||||
}
|
||||
|
||||
@@ -10,7 +10,10 @@ import {
|
||||
SETTINGS_KEYS,
|
||||
API_CHAT,
|
||||
API_SLOTS,
|
||||
CONTROL_ACTION
|
||||
CONTROL_ACTION,
|
||||
SSE_LINE_SEPARATOR,
|
||||
SSE_DATA_PREFIX,
|
||||
SSE_DONE_MARKER
|
||||
} from '$lib/constants';
|
||||
import {
|
||||
AttachmentType,
|
||||
@@ -18,8 +21,7 @@ import {
|
||||
FileTypeAudio,
|
||||
MessageRole,
|
||||
MimeTypeAudio,
|
||||
ReasoningFormat,
|
||||
UrlProtocol
|
||||
ReasoningFormat
|
||||
} from '$lib/enums';
|
||||
import type {
|
||||
ApiChatMessageContentPart,
|
||||
@@ -642,15 +644,15 @@ export class ChatService {
|
||||
if (abortSignal?.aborted) break;
|
||||
|
||||
chunk += decoder.decode(value, { stream: true });
|
||||
const lines = chunk.split('\n');
|
||||
const lines = chunk.split(SSE_LINE_SEPARATOR);
|
||||
chunk = lines.pop() || '';
|
||||
|
||||
for (const line of lines) {
|
||||
if (abortSignal?.aborted) break;
|
||||
|
||||
if (line.startsWith(UrlProtocol.DATA)) {
|
||||
const data = line.slice(6);
|
||||
if (data === '[DONE]') {
|
||||
if (line.startsWith(SSE_DATA_PREFIX)) {
|
||||
const data = line.slice(SSE_DATA_PREFIX.length).trim();
|
||||
if (data === SSE_DONE_MARKER) {
|
||||
streamFinished = true;
|
||||
|
||||
continue;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { base } from '$app/paths';
|
||||
import { SvelteMap, SvelteSet } from 'svelte/reactivity';
|
||||
import { toast } from 'svelte-sonner';
|
||||
import { ServerModelStatus, ModelModality } from '$lib/enums';
|
||||
import { ServerModelStatus, ServerModelsSseEventType, ModelModality } from '$lib/enums';
|
||||
import { ModelsService } from '$lib/services/models.service';
|
||||
import { PropsService } from '$lib/services/props.service';
|
||||
import { serverStore, isRouterMode } from '$lib/stores/server.svelte';
|
||||
@@ -8,11 +9,15 @@ import {
|
||||
detectThinkingSupport,
|
||||
detectThinkingSupportWithReason
|
||||
} from '$lib/utils/chat-template-thinking-detector';
|
||||
import { TTLCache } from '$lib/utils';
|
||||
import { TTLCache, getAuthHeaders } from '$lib/utils';
|
||||
import {
|
||||
MODEL_PROPS_CACHE_TTL_MS,
|
||||
MODEL_PROPS_CACHE_MAX_ENTRIES,
|
||||
FAVORITE_MODELS_LOCALSTORAGE_KEY
|
||||
FAVORITE_MODELS_LOCALSTORAGE_KEY,
|
||||
API_MODELS,
|
||||
SSE_RECORD_SEPARATOR,
|
||||
SSE_LINE_SEPARATOR,
|
||||
SSE_DATA_PREFIX
|
||||
} from '$lib/constants';
|
||||
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
@@ -55,6 +60,15 @@ class ModelsStore {
|
||||
private modelUsage = $state<Map<string, SvelteSet<string>>>(new Map());
|
||||
private modelLoadingStates = new SvelteMap<string, boolean>();
|
||||
|
||||
// /models/sse feed state, the single source of truth for status and load progress
|
||||
private statusAbort: AbortController | null = null;
|
||||
private statusReaderActive = false;
|
||||
private loadProgress = new SvelteMap<string, ModelLoadProgress>();
|
||||
private statusWaiters = new Map<
|
||||
string,
|
||||
{ target: ServerModelStatus; resolve: () => void; reject: (e: Error) => void }
|
||||
>();
|
||||
|
||||
favoriteModelIds = $state<Set<string>>(this.loadFavoritesFromStorage());
|
||||
|
||||
/**
|
||||
@@ -531,7 +545,8 @@ class ModelsStore {
|
||||
* 1. Model from active conversation's last assistant response (if loaded)
|
||||
* 2. Model from active conversation's last assistant response (if not loaded)
|
||||
* 3. First loaded model (not from active conversation)
|
||||
* 4. First available model
|
||||
* 4. A favorite model
|
||||
* 5. First available model
|
||||
*/
|
||||
async ensureFirstModelSelected(): Promise<void> {
|
||||
if (this.selectedModelName) return;
|
||||
@@ -560,6 +575,13 @@ class ModelsStore {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try loading a favorite model
|
||||
const favorite = this.favoriteModelIds.values().next()?.value
|
||||
if (favorite) {
|
||||
await this.selectModelById(favorite);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fall back to the first available model
|
||||
await this.selectModelById(availableModels[0].id);
|
||||
}
|
||||
@@ -626,49 +648,218 @@ class ModelsStore {
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* WORKAROUND: Polling for model status after load/unload operations.
|
||||
*
|
||||
* Currently, `/models/load` and `/models/unload` return success before
|
||||
* the operation actually completes on the server.
|
||||
*
|
||||
* TODO: Remove polling once llama-server properly waits for the operation
|
||||
* to complete before returning success.
|
||||
*/
|
||||
|
||||
private static readonly STATUS_POLL_INTERVAL = 500;
|
||||
// reconnect delay after the feed drops or the server is not ready yet
|
||||
private static readonly SSE_RECONNECT_MS = 1000;
|
||||
|
||||
/**
|
||||
* Poll for expected model status after load/unload operation.
|
||||
* Keeps polling until the model reaches the expected status or fails.
|
||||
* Open the /models/sse feed and keep it live with auto reconnect.
|
||||
* Idempotent and router mode only. The feed drives status and progress,
|
||||
* so it replaces any post-operation polling.
|
||||
*/
|
||||
private async pollForModelStatus(
|
||||
modelId: string,
|
||||
expectedStatus: ServerModelStatus
|
||||
): Promise<void> {
|
||||
let attempt = 0;
|
||||
while (true) {
|
||||
await this.fetchRouterModels();
|
||||
subscribeStatus(): void {
|
||||
if (this.statusReaderActive) return;
|
||||
if (!isRouterMode()) return;
|
||||
|
||||
const currentStatus = this.getModelStatus(modelId);
|
||||
if (currentStatus === expectedStatus) return;
|
||||
this.statusReaderActive = true;
|
||||
this.statusAbort = new AbortController();
|
||||
void this.runStatusReader(this.statusAbort.signal);
|
||||
}
|
||||
|
||||
if (currentStatus === ServerModelStatus.FAILED) {
|
||||
throw new Error(
|
||||
`Model failed to ${expectedStatus === ServerModelStatus.LOADED ? 'load' : 'unload'}`
|
||||
);
|
||||
/**
|
||||
* Close the /models/sse feed and drop transient progress.
|
||||
*/
|
||||
unsubscribeStatus(): void {
|
||||
this.statusReaderActive = false;
|
||||
this.statusAbort?.abort();
|
||||
this.statusAbort = null;
|
||||
this.loadProgress.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Current load progress for a model, or null when not loading.
|
||||
*/
|
||||
getLoadProgress(modelId: string): ModelLoadProgress | null {
|
||||
return this.loadProgress.get(modelId) ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Read the feed and reconnect until unsubscribed. Splits the byte stream
|
||||
* into SSE records on the blank line boundary.
|
||||
*/
|
||||
private async runStatusReader(signal: AbortSignal): Promise<void> {
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
while (!signal.aborted) {
|
||||
try {
|
||||
const response = await fetch(`${base}${API_MODELS.SSE}`, {
|
||||
headers: getAuthHeaders(),
|
||||
signal
|
||||
});
|
||||
|
||||
if (response.ok && response.body) {
|
||||
const reader = response.body.getReader();
|
||||
let buffer = '';
|
||||
|
||||
while (!signal.aborted) {
|
||||
const { value, done } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
|
||||
let boundary = buffer.indexOf(SSE_RECORD_SEPARATOR);
|
||||
while (boundary !== -1) {
|
||||
this.handleStatusRecord(buffer.slice(0, boundary));
|
||||
buffer = buffer.slice(boundary + SSE_RECORD_SEPARATOR.length);
|
||||
boundary = buffer.indexOf(SSE_RECORD_SEPARATOR);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// network drop or abort falls through to the reconnect delay
|
||||
}
|
||||
|
||||
if (
|
||||
expectedStatus === ServerModelStatus.LOADED &&
|
||||
currentStatus === ServerModelStatus.UNLOADED &&
|
||||
attempt > 2
|
||||
) {
|
||||
throw new Error('Model was unloaded unexpectedly during loading');
|
||||
}
|
||||
if (signal.aborted) return;
|
||||
|
||||
attempt++;
|
||||
await new Promise((resolve) => setTimeout(resolve, ModelsStore.STATUS_POLL_INTERVAL));
|
||||
await new Promise((resolve) => setTimeout(resolve, ModelsStore.SSE_RECONNECT_MS));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse one SSE record. The payload rides in the data lines as a JSON
|
||||
* envelope that carries its own model, event and data fields.
|
||||
*/
|
||||
private handleStatusRecord(record: string): void {
|
||||
const payload = record
|
||||
.split(SSE_LINE_SEPARATOR)
|
||||
.filter((line) => line.startsWith(SSE_DATA_PREFIX))
|
||||
.map((line) => line.slice(SSE_DATA_PREFIX.length).trim())
|
||||
.join(SSE_LINE_SEPARATOR);
|
||||
|
||||
if (payload.length === 0) return;
|
||||
|
||||
let envelope: ApiModelsSseEvent;
|
||||
try {
|
||||
envelope = JSON.parse(payload);
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
|
||||
this.applyStatusEvent(envelope);
|
||||
}
|
||||
|
||||
/**
|
||||
* Route one feed record by event kind. Only the status_* events carry a
|
||||
* status payload, models_reload triggers a list refresh, model_remove drops
|
||||
* the row, download_* belong to the download surface, not here.
|
||||
*/
|
||||
private applyStatusEvent(event: ApiModelsSseEvent): void {
|
||||
switch (event.event) {
|
||||
case ServerModelsSseEventType.STATUS_CHANGE:
|
||||
case ServerModelsSseEventType.MODEL_STATUS:
|
||||
case ServerModelsSseEventType.STATUS_UPDATE:
|
||||
this.applyModelStatus(event);
|
||||
break;
|
||||
case ServerModelsSseEventType.MODELS_RELOAD:
|
||||
void this.fetchRouterModels();
|
||||
break;
|
||||
case ServerModelsSseEventType.MODEL_REMOVE:
|
||||
this.removeRouterModel(event.model);
|
||||
break;
|
||||
case ServerModelsSseEventType.DOWNLOAD_PROGRESS:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply a status envelope: update the model row, track or clear progress,
|
||||
* settle any pending load or unload awaiter.
|
||||
*/
|
||||
private applyModelStatus(event: ApiModelsSseEvent): void {
|
||||
const model = event.model;
|
||||
const data = event.data;
|
||||
if (!model || !data?.status) return;
|
||||
|
||||
const status = data.status;
|
||||
|
||||
this.setRouterModelStatus(model, status);
|
||||
|
||||
if (status === ServerModelStatus.LOADING) {
|
||||
if (data.progress) this.loadProgress.set(model, data.progress);
|
||||
} else {
|
||||
this.loadProgress.delete(model);
|
||||
}
|
||||
|
||||
if (status === ServerModelStatus.LOADED) {
|
||||
void this.updateModelModalities(model);
|
||||
}
|
||||
|
||||
const failed =
|
||||
status === ServerModelStatus.FAILED ||
|
||||
(status === ServerModelStatus.UNLOADED && (data.exit_code ?? 0) !== 0);
|
||||
|
||||
if (failed) {
|
||||
this.rejectStatus(model, new Error(`Model failed: ${this.toDisplayName(model)}`));
|
||||
return;
|
||||
}
|
||||
|
||||
this.settleStatus(model, status);
|
||||
}
|
||||
|
||||
/**
|
||||
* Drop a model row reported gone by the feed and settle its awaiters.
|
||||
*/
|
||||
private removeRouterModel(modelId: string): void {
|
||||
if (this.routerModels.findIndex((m) => m.id === modelId) === -1) return;
|
||||
|
||||
this.routerModels = this.routerModels.filter((m) => m.id !== modelId);
|
||||
this.loadProgress.delete(modelId);
|
||||
this.rejectStatus(modelId, new Error(`Model removed: ${this.toDisplayName(modelId)}`));
|
||||
}
|
||||
|
||||
/**
|
||||
* Update one model row status in place, reassigning to trigger reactivity.
|
||||
*/
|
||||
private setRouterModelStatus(modelId: string, status: ServerModelStatus): void {
|
||||
const idx = this.routerModels.findIndex((m) => m.id === modelId);
|
||||
if (idx === -1) return;
|
||||
|
||||
const current = this.routerModels[idx];
|
||||
if (current.status.value === status) return;
|
||||
|
||||
const next = [...this.routerModels];
|
||||
next[idx] = { ...current, status: { ...current.status, value: status } };
|
||||
this.routerModels = next;
|
||||
}
|
||||
|
||||
/**
|
||||
* Register an awaiter that resolves when the feed reports target status.
|
||||
* One operation runs per model at a time, so one awaiter per model is kept.
|
||||
*/
|
||||
private waitForStatus(modelId: string, target: ServerModelStatus): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
this.statusWaiters.set(modelId, { target, resolve, reject });
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve and drop the awaiter when the model reaches its target status.
|
||||
*/
|
||||
private settleStatus(modelId: string, status: ServerModelStatus): void {
|
||||
const waiter = this.statusWaiters.get(modelId);
|
||||
if (waiter && waiter.target === status) {
|
||||
this.statusWaiters.delete(modelId);
|
||||
waiter.resolve();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reject and drop the awaiter for a model.
|
||||
*/
|
||||
private rejectStatus(modelId: string, error: Error): void {
|
||||
const waiter = this.statusWaiters.get(modelId);
|
||||
if (waiter) {
|
||||
this.statusWaiters.delete(modelId);
|
||||
waiter.reject(error);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -679,12 +870,18 @@ class ModelsStore {
|
||||
this.modelLoadingStates.set(modelId, true);
|
||||
this.error = null;
|
||||
|
||||
// the feed drives completion, so it must be live before the request
|
||||
this.subscribeStatus();
|
||||
|
||||
const reachedLoaded = this.waitForStatus(modelId, ServerModelStatus.LOADED);
|
||||
reachedLoaded.catch(() => {});
|
||||
|
||||
try {
|
||||
await ModelsService.load(modelId);
|
||||
await this.pollForModelStatus(modelId, ServerModelStatus.LOADED);
|
||||
await this.updateModelModalities(modelId);
|
||||
await reachedLoaded;
|
||||
toast.success(`Model loaded: ${this.toDisplayName(modelId)}`);
|
||||
} catch (error) {
|
||||
this.rejectStatus(modelId, error instanceof Error ? error : new Error('load failed'));
|
||||
this.error = error instanceof Error ? error.message : 'Failed to load model';
|
||||
toast.error(`Failed to load model: ${this.toDisplayName(modelId)}`);
|
||||
throw error;
|
||||
@@ -700,11 +897,17 @@ class ModelsStore {
|
||||
this.modelLoadingStates.set(modelId, true);
|
||||
this.error = null;
|
||||
|
||||
this.subscribeStatus();
|
||||
|
||||
const reachedUnloaded = this.waitForStatus(modelId, ServerModelStatus.UNLOADED);
|
||||
reachedUnloaded.catch(() => {});
|
||||
|
||||
try {
|
||||
await ModelsService.unload(modelId);
|
||||
await this.pollForModelStatus(modelId, ServerModelStatus.UNLOADED);
|
||||
await reachedUnloaded;
|
||||
toast.info(`Model unloaded: ${this.toDisplayName(modelId)}`);
|
||||
} catch (error) {
|
||||
this.rejectStatus(modelId, error instanceof Error ? error : new Error('unload failed'));
|
||||
this.error = error instanceof Error ? error.message : 'Failed to unload model';
|
||||
toast.error(`Failed to unload model: ${this.toDisplayName(modelId)}`);
|
||||
throw error;
|
||||
@@ -783,6 +986,9 @@ class ModelsStore {
|
||||
}
|
||||
|
||||
clear(): void {
|
||||
this.unsubscribeStatus();
|
||||
this.statusWaiters.forEach((waiter) => waiter.reject(new Error('Models store cleared')));
|
||||
this.statusWaiters.clear();
|
||||
this.models = [];
|
||||
this.routerModels = [];
|
||||
this.loading = false;
|
||||
|
||||
Vendored
+47
-1
@@ -1,4 +1,10 @@
|
||||
import type { ContentPartType, FileTypeAudio, ServerModelStatus, ServerRole } from '$lib/enums';
|
||||
import type {
|
||||
ContentPartType,
|
||||
FileTypeAudio,
|
||||
ServerModelStatus,
|
||||
ServerModelsSseEventType,
|
||||
ServerRole
|
||||
} from '$lib/enums';
|
||||
import type { ChatMessagePromptProgress, ChatRole } from './chat';
|
||||
|
||||
export type AudioInputFormat = FileTypeAudio.WAV | FileTypeAudio.MP3;
|
||||
@@ -96,6 +102,46 @@ export interface ApiModelDataEntry {
|
||||
meta?: Record<string, unknown> | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load stage reported by the /models/sse feed, in load order.
|
||||
*/
|
||||
export type ApiModelLoadStage = 'text_model' | 'spec_model' | 'mmproj_model';
|
||||
|
||||
/**
|
||||
* Load progress snapshot: the full ordered stage plan, the active stage,
|
||||
* and its fractional value (0.0 -> 1.0).
|
||||
*/
|
||||
export interface ApiModelsSseProgress {
|
||||
stages: ApiModelLoadStage[];
|
||||
current: ApiModelLoadStage;
|
||||
value: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Status payload carried by a /models/sse envelope.
|
||||
* exit_code appears on unload.
|
||||
*/
|
||||
export interface ApiModelsSseData {
|
||||
status: ServerModelStatus;
|
||||
progress?: ApiModelsSseProgress;
|
||||
exit_code?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Event kind multiplexed on the /models/sse feed.
|
||||
* Only the status_* events carry a status payload, models_reload signals a
|
||||
* full list refresh, model_remove drops a row, download_* drive download UI.
|
||||
*/
|
||||
/**
|
||||
* One /models/sse record. event discriminates the kind, model names the
|
||||
* target instance, data carries the status payload when present.
|
||||
*/
|
||||
export interface ApiModelsSseEvent {
|
||||
model: string;
|
||||
event: ServerModelsSseEventType;
|
||||
data: ApiModelsSseData;
|
||||
}
|
||||
|
||||
export interface ApiModelDetails {
|
||||
name: string;
|
||||
model: string;
|
||||
|
||||
@@ -11,6 +11,10 @@ export type {
|
||||
ApiChatMessageData,
|
||||
ApiModelStatus,
|
||||
ApiModelDataEntry,
|
||||
ApiModelLoadStage,
|
||||
ApiModelsSseProgress,
|
||||
ApiModelsSseData,
|
||||
ApiModelsSseEvent,
|
||||
ApiModelDetails,
|
||||
ApiModelListResponse,
|
||||
ApiLlamaCppServerProps,
|
||||
@@ -70,7 +74,12 @@ export type {
|
||||
} from './database';
|
||||
|
||||
// Model types
|
||||
export type { ModelModalities, ModelOption, ModalityCapabilities } from './models';
|
||||
export type {
|
||||
ModelModalities,
|
||||
ModelOption,
|
||||
ModelLoadProgress,
|
||||
ModalityCapabilities
|
||||
} from './models';
|
||||
|
||||
// Settings types
|
||||
export type {
|
||||
|
||||
Vendored
+12
-1
@@ -1,4 +1,4 @@
|
||||
import type { ApiModelDataEntry, ApiModelDetails } from '$lib/types/api';
|
||||
import type { ApiModelDataEntry, ApiModelDetails, ApiModelLoadStage } from '$lib/types/api';
|
||||
|
||||
export interface ModelModalities {
|
||||
vision: boolean;
|
||||
@@ -20,6 +20,17 @@ export interface ModelOption {
|
||||
tags?: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Ephemeral UI-only load progress for one model instance.
|
||||
* Lives only while a load runs, driven by the /models/sse feed.
|
||||
* stage is absent until the feed reports its first stage.
|
||||
*/
|
||||
export interface ModelLoadProgress {
|
||||
stages: ApiModelLoadStage[];
|
||||
current: ApiModelLoadStage;
|
||||
value: number;
|
||||
}
|
||||
|
||||
export interface ParsedModelId {
|
||||
raw: string;
|
||||
orgName: string | null;
|
||||
|
||||
@@ -44,6 +44,9 @@ export { buildProxiedUrl, buildProxiedHeaders } from './cors-proxy';
|
||||
// URL utilities
|
||||
export { extractRootDomain, sanitizeExternalUrl } from './url';
|
||||
|
||||
// Progress helpers
|
||||
export { modelLoadFraction, modelLoadProgressText } from './progress';
|
||||
|
||||
// Conversation utilities
|
||||
export { createMessageCountMap, getMessageCount } from './conversation-utils';
|
||||
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
/**
|
||||
* Model load progress helpers for the /models/sse surfaces
|
||||
* (selector row and chat message).
|
||||
*/
|
||||
|
||||
import { MODEL_LOAD_STAGE_LABELS, MODEL_LOAD_TAIL_SHARE } from '$lib/constants';
|
||||
|
||||
/**
|
||||
* Human label for a model load stage.
|
||||
*/
|
||||
export function modelLoadStageLabel(stage: ApiModelLoadStage): string {
|
||||
return MODEL_LOAD_STAGE_LABELS[stage];
|
||||
}
|
||||
|
||||
/**
|
||||
* Overall load fraction (0.0 -> 1.0) across the declared stage plan.
|
||||
* text_model fills [0, 1 - tail], each later phase owns one tail slice.
|
||||
*/
|
||||
export function modelLoadFraction(progress: ModelLoadProgress | null): number {
|
||||
if (!progress) return 0;
|
||||
|
||||
const { stages, current, value } = progress;
|
||||
const tailCount = Math.max(stages.length - 1, 0);
|
||||
const textCeiling = 1 - tailCount * MODEL_LOAD_TAIL_SHARE;
|
||||
const idx = stages.indexOf(current);
|
||||
|
||||
if (idx <= 0) {
|
||||
return value * textCeiling;
|
||||
}
|
||||
|
||||
return textCeiling + (idx - 1 + value) * MODEL_LOAD_TAIL_SHARE;
|
||||
}
|
||||
|
||||
/**
|
||||
* Single line describing load progress: active stage label and overall percent.
|
||||
* Returns null when there is no progress to show.
|
||||
*/
|
||||
export function modelLoadProgressText(progress: ModelLoadProgress | null): string | null {
|
||||
if (!progress) return null;
|
||||
|
||||
const label = modelLoadStageLabel(progress.current);
|
||||
return `${label} ${Math.round(modelLoadFraction(progress) * 100)}%`;
|
||||
}
|
||||
@@ -230,6 +230,20 @@
|
||||
}
|
||||
});
|
||||
|
||||
// Live model status and load progress via the /models/sse feed (router mode)
|
||||
$effect(() => {
|
||||
if (!browser) return;
|
||||
if (!isRouterMode()) return;
|
||||
|
||||
untrack(() => {
|
||||
modelsStore.subscribeStatus();
|
||||
});
|
||||
|
||||
return () => {
|
||||
modelsStore.unsubscribeStatus();
|
||||
};
|
||||
});
|
||||
|
||||
// Background MCP server health checks on app load
|
||||
// Fetch enabled servers from settings and run health checks in background
|
||||
$effect(() => {
|
||||
|
||||
Reference in New Issue
Block a user