Compare commits

..

1 Commits

Author SHA1 Message Date
Xuan-Son Nguyen 7c082bc417 server: fix report progress for loading spec models, add "stages" list (#24870)
* server: fix report progress for loading spec models, add "stages" list

* improve

* nits

* nits 2
2026-06-21 17:36:52 +02:00
2 changed files with 50 additions and 29 deletions
+6 -2
View File
@@ -1863,11 +1863,15 @@ Example events:
"data": {
"status": "loading",
"progress": {
"stage": "fit_params",
"value": 0.5 // from 0.0 to 1.0 ; note: not all stages have this "value"
"stages": ["text_model", "spec_model", "mmproj_model"],
"current": "text_model",
"value": 0.5
}
}
}
// note for "loading" status:
// - subsequent events will follow the same order of "stages" list
// - mmap is may report incorrect progress on some platforms; if you need exact progress, use --no-mmap
{
"model": "...",
+44 -27
View File
@@ -962,6 +962,7 @@ private:
struct load_progress_data {
server_context_impl * ctx;
std::string stage;
std::vector<std::string> stages;
int64_t t_last_load_progress_ms = 0;
load_progress_data(server_context_impl * ctx, const std::string & stage) : ctx(ctx), stage(stage) {}
};
@@ -982,7 +983,8 @@ private:
}
if (d->ctx->callback_state) {
d->ctx->callback_state(SERVER_STATE_LOADING, {
{"stage", d->stage},
{"stages", d->stages},
{"current", d->stage},
{"value", progress},
});
}
@@ -992,18 +994,42 @@ private:
// load the model and initialize llama_context
// this may also be called to resume from sleeping state
bool load_model(common_params & params) {
load_progress_data load_progress_text(this, "text_model");
load_progress_data load_progress_text (this, "text_model");
load_progress_data load_progress_mmproj(this, "mmproj_model");
load_progress_data load_progress_spec (this, "spec_model");
bool is_resume = sleeping;
SRV_INF("loading model '%s'\n", params.model.path.c_str());
const bool is_resume = sleeping;
params_base = params;
params_base.n_outputs_max = server_n_outputs_max(params_base);
const bool has_mmproj = !params.mmproj.path.empty();
const bool has_draft = params.speculative.has_dft();
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
params_base.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
const bool has_spec = has_draft || spec_mtp;
if (callback_state) {
std::vector<std::string> stages = {"text_model"};
if (has_spec) {
stages.push_back("spec_model");
}
if (has_mmproj) {
stages.push_back("mmproj_model");
}
load_progress_text.stages = stages;
load_progress_mmproj.stages = stages;
load_progress_spec.stages = stages;
// trigger 0% progress
load_progress_callback(0.0f, &load_progress_text);
}
SRV_INF("loading model '%s'\n", params.model.path.c_str());
std::string & mmproj_path = params_base.mmproj.path;
bool has_mmproj = !mmproj_path.empty();
mtmd_context_params mparams = mtmd_context_params_default();
if (has_mmproj) {
mparams.use_gpu = params_base.mmproj_use_gpu;
@@ -1050,16 +1076,7 @@ private:
// optionally reserve VRAM for the draft / MTP context before fitting the target model
if (params_base.fit_params) {
if (callback_state) {
callback_state(SERVER_STATE_LOADING, {{"stage", "fit_params"}});
}
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
params_base.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
const bool has_draft = params_base.speculative.has_dft();
if (has_draft || spec_mtp) {
if (has_spec) {
common_params params_dft = params_base;
bool measure_model_bytes = true;
@@ -1151,11 +1168,7 @@ private:
add_bos_token = llama_vocab_get_add_bos(vocab);
if (params_base.speculative.has_dft()) {
if (callback_state) {
callback_state(SERVER_STATE_LOADING, {{"stage", "spec_model"}});
}
if (has_draft) {
// TODO speculative: move to common/speculative.cpp?
const auto & params_spec = params_base.speculative.draft;
@@ -1178,6 +1191,10 @@ private:
auto mparams_dft = common_model_params_to_llama(params_dft);
// progress callback
mparams_dft.progress_callback = load_progress_callback;
mparams_dft.progress_callback_user_data = &load_progress_spec;
model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
if (model_dft == nullptr) {
SRV_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
@@ -1186,10 +1203,6 @@ private:
auto cparams = common_context_params_to_llama(params_dft);
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
params_base.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
if (spec_mtp) {
cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
}
@@ -1203,8 +1216,10 @@ private:
params_base.speculative.draft.ctx_tgt = ctx_tgt;
params_base.speculative.draft.ctx_dft = ctx_dft.get();
} else if (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end()) {
} else if (spec_mtp) {
// no new model load, so we simply report 0.0 and 1.0 progress
load_progress_callback(0.0f, &load_progress_spec);
SRV_INF("creating MTP draft context against the target model '%s'\n",
params_base.model.path.c_str());
@@ -1224,6 +1239,8 @@ private:
params_base.speculative.draft.ctx_tgt = ctx_tgt;
params_base.speculative.draft.ctx_dft = ctx_dft.get();
load_progress_callback(1.0f, &load_progress_spec);
}
if (has_mmproj) {