mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-09 07:16:44 +02:00
mtmd: fit_params now take into account mmproj (#21489)
* mtmd: fit_params now take into account mmproj * rename alloc_compute_meta to reserve_compute_meta * rm unused functions * add ggml_backend_dev_t support * add debug log
This commit is contained in:
@@ -746,6 +746,46 @@ private:
|
||||
|
||||
params_base = params;
|
||||
|
||||
std::string & mmproj_path = params_base.mmproj.path;
|
||||
bool has_mmproj = !mmproj_path.empty();
|
||||
mtmd_context_params mparams = mtmd_context_params_default();
|
||||
if (has_mmproj) {
|
||||
mparams.use_gpu = params_base.mmproj_use_gpu;
|
||||
mparams.print_timings = false;
|
||||
mparams.n_threads = params_base.cpuparams.n_threads;
|
||||
mparams.flash_attn_type = params_base.flash_attn_type;
|
||||
mparams.warmup = params_base.warmup;
|
||||
mparams.image_min_tokens = params_base.image_min_tokens;
|
||||
mparams.image_max_tokens = params_base.image_max_tokens;
|
||||
mparams.media_marker = get_media_marker();
|
||||
}
|
||||
|
||||
// optionally get the memory usage of mmproj
|
||||
if (has_mmproj && params_base.fit_params) {
|
||||
auto mmproj_mem = mtmd_get_memory_usage(mmproj_path.c_str(), mparams);
|
||||
if (!mmproj_mem.empty()) {
|
||||
size_t total = 0;
|
||||
for (auto & [dev, size] : mmproj_mem) {
|
||||
total += size;
|
||||
}
|
||||
SRV_INF("[mtmd] estimated memory usage of mmproj is %.2f MiB\n", total / (1024.0 * 1024.0));
|
||||
GGML_ASSERT(!params_base.fit_params_target.empty());
|
||||
for (auto & [dev, size] : mmproj_mem) {
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
|
||||
if (ggml_backend_dev_get(i) == dev) {
|
||||
if (i < params_base.fit_params_target.size()) {
|
||||
SRV_DBG("[mtmd] adding %.2f MiB to fit_params_target for device %s\n", size / (1024.0 * 1024.0), ggml_backend_dev_name(dev));
|
||||
params_base.fit_params_target[i] += size;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
SRV_ERR("%s", "[mtmd] failed to get memory usage of mmproj\n");
|
||||
}
|
||||
}
|
||||
|
||||
llama_init = common_init_from_params(params_base);
|
||||
|
||||
model_tgt = llama_init->model();
|
||||
@@ -830,18 +870,10 @@ private:
|
||||
params_base.speculative.draft.ctx_dft = ctx_dft.get();
|
||||
}
|
||||
|
||||
std::string & mmproj_path = params_base.mmproj.path;
|
||||
if (!mmproj_path.empty()) {
|
||||
mtmd_context_params mparams = mtmd_context_params_default();
|
||||
|
||||
mparams.use_gpu = params_base.mmproj_use_gpu;
|
||||
mparams.print_timings = false;
|
||||
mparams.n_threads = params_base.cpuparams.n_threads;
|
||||
mparams.flash_attn_type = params_base.flash_attn_type;
|
||||
mparams.warmup = params_base.warmup;
|
||||
mparams.image_min_tokens = params_base.image_min_tokens;
|
||||
mparams.image_max_tokens = params_base.image_max_tokens;
|
||||
mparams.media_marker = get_media_marker();
|
||||
if (has_mmproj) {
|
||||
if (!is_resume) {
|
||||
mtmd_helper_log_set(common_log_default_callback, nullptr);
|
||||
}
|
||||
|
||||
mctx = mtmd_init_from_file(mmproj_path.c_str(), model_tgt, mparams);
|
||||
if (mctx == nullptr) {
|
||||
|
||||
Reference in New Issue
Block a user