Compare commits

..

12 Commits

Author SHA1 Message Date
o7si 32eddaf2ea cmake : fix ui build with read-only source (#24752) 2026-06-18 18:59:18 +02:00
Xuan-Son Nguyen 060ce1bf72 mtmd: refactor llava-uhd overview image handling (always use ov_img_first) (#24769)
* add dedicated "overview" for mtmd_image_preproc_out

* corrections

* correct (again)

* nits

* nits (2)
2026-06-18 18:53:49 +02:00
Max Krasnyansky d2c67959b3 hexagon: support for op-trace (fine-grain tracing of HVX/HMX/DMA events) (#24592)
* hex-optrace: add support for optrace and instrument matmul and flash-atten code

* hex-trace: improve trace event and prefetto generator

* hex-trace: add new script dedicated to handling traces, specifically perfetto traces

* hex-trace: add --head/--tail options to profile and trace tools

* hex-trace: fix whitespaces

* hex-trace: fix flake8 warnings

* hex-trace: fix flake8 warnings

* hmx-fa: restore q_tiles clearing

* hex-profile: remove circular dep in includes

* hex-trace: simplify trace sizing check

* hex-profile: sort events in the summary by name
2026-06-18 08:35:02 -07:00
Kangjia Gao 7b6c5a2aed docs: fix export-lora --lora-scaled syntax [no release] (#24703)
Assisted-by: Codex
2026-06-18 16:46:17 +02:00
Xuan-Son Nguyen fe7c8b2414 server: (router) fix stopping_thread potentially hang (#24728)
* server: (router) fix stopping_thread potentially hang

* fix windows build
2026-06-18 15:41:09 +02:00
Xuan-Son Nguyen e1efd0991d server: add "schema" and validation (#24150)
* wip

* working

* correct some limits

* add field name to error message
2026-06-18 15:40:58 +02:00
Aarni Koskela 08023072ef server : add last-5-seconds generation speed display (#24291)
* server : add last-5-seconds generation speed display

* cont : clean-up

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-06-18 14:02:20 +02:00
Amos Wong 20832179e2 ui: provide touch accessible model selection UI (#24604)
* ui : add model selector storybook stories

Covers list, favorites, single-model, all status states
(loading/loaded/sleeping/failed/idle), and selection states.

* ui : improve model selector mobile UX with hover media queries

Use @media (hover:none) to show action buttons directly on touch
devices and color-code them by model status (amber=sleeping,
green=loaded, muted=idle). Status dots hidden on touch. Desktop
hover behavior unchanged.
2026-06-18 13:14:20 +02:00
Anuj Attri 10786217e9 server : return HTTP 400 on invalid grammar (#24144) (#24154)
Throw on grammar parse failure so the server returns HTTP 400
instead of silently dropping the constraint.
Add a regression test for the invalid-grammar response.

Fixes #24144
2026-06-18 12:49:14 +02:00
Xuan-Son Nguyen 552258c535 server: (router) rework -hf preset repo (#24739)
* server: temporary remove HF remote preset

* rework remove preset.ini support

* rm unused get_remote_preset_whitelist()

* print warning

* add docs

* rm stray file
2026-06-18 12:45:23 +02:00
Xuan-Son Nguyen 968c43891a server: fix router args not being forwarded to child instances (#24760) 2026-06-18 12:15:46 +02:00
Xuan-Son Nguyen 24bba7b98e mtmd: refactor preprocessor, add mtmd_image_preproc_out (#24736)
* add mtmd_image_preproc_out

* add dev docs

* remove unused clip API

* rm unused clip_image_f32_batch::grid

* change preprocess() call signature
2026-06-18 12:04:39 +02:00
47 changed files with 2750 additions and 1299 deletions
+34 -78
View File
@@ -285,58 +285,15 @@ static std::string clean_file_name(const std::string & fname) {
return clean_fname;
}
static bool common_params_handle_remote_preset(common_params & params, llama_example ex) {
GGML_ASSERT(!params.model.hf_repo.empty());
// the returned hf_repo is without tag
auto [hf_repo, hf_tag] = common_download_split_repo_tag(params.model.hf_repo);
// "latest" tag (default if not specified) is translated to "default" preset
if (hf_tag == "latest") {
hf_tag = "default";
}
std::string model_endpoint = common_get_model_endpoint();
auto preset_url = model_endpoint + hf_repo + "/resolve/main/preset.ini";
// prepare local path for caching
auto preset_fname = clean_file_name(hf_repo + "_preset.ini");
auto preset_path = fs_get_cache_file(preset_fname);
common_download_opts opts;
opts.bearer_token = params.hf_token;
opts.offline = params.offline;
LOG_TRC("%s: looking for remote preset at %s\n", __func__, preset_url.c_str());
const int status = common_download_file_single(preset_url, preset_path, opts);
const bool has_preset = status >= 200 && status < 400;
// remote preset is optional, so we don't error out if not found
if (has_preset) {
LOG_TRC("%s: applying remote preset from %s\n", __func__, preset_url.c_str());
common_preset_context ctx(ex, /* only_remote_allowed */ true);
common_preset global;
auto remote_presets = ctx.load_from_ini(preset_path, global);
remote_presets = ctx.cascade(global, remote_presets);
if (remote_presets.find(hf_tag) != remote_presets.end()) {
common_preset preset = remote_presets.at(hf_tag);
LOG_INF("\n%s", preset.to_ini().c_str()); // to_ini already added trailing newline
preset.apply_to_params(params);
} else {
throw std::runtime_error("Remote preset.ini does not contain [" + std::string(hf_tag) + "] section");
}
} else {
LOG_TRC("%s: no remote preset found, skipping\n", __func__);
}
return has_preset;
}
struct handle_model_result {
bool found_mmproj = false;
common_params_model mmproj;
bool found_mtp = false;
common_params_model mtp;
bool found_preset = false;
std::string preset_path;
};
static handle_model_result common_params_handle_model(struct common_params_model & model,
@@ -355,6 +312,12 @@ static handle_model_result common_params_handle_model(struct common_params_model
common_download_opts hf_opts = opts;
auto download_result = common_download_model(model, hf_opts);
if (!download_result.preset_path.empty()) {
result.found_preset = true;
result.preset_path = download_result.preset_path;
return result; // skip everything else if preset.ini is used
}
if (download_result.model_path.empty()) {
throw std::runtime_error("failed to download model from Hugging Face");
}
@@ -454,6 +417,17 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex)
try {
auto res = common_params_handle_model(params.model, opts);
if (res.found_preset) {
if (!params.models_preset.empty()) {
throw std::invalid_argument("cannot use both --models-preset and -hf with a preset.ini file");
}
// if HF repo is a preset repo, we simply run server in router mode with the preset.ini file
params.models_preset_hf = params.model.hf_repo; // only for showing a warning
params.models_preset = res.preset_path;
params.model = common_params_model{}; // make sure to clear model, so server starts in router mode
return true;
}
if (params.no_mmproj) {
params.mmproj = {};
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
@@ -601,30 +575,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
// parse the first time to get -hf option (used for remote preset)
parse_cli_args();
// export_graph_ops loads only metadata
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
// maybe handle remote preset
if (!params.model.hf_repo.empty() && !skip_model_download) {
std::string cli_hf_repo = params.model.hf_repo;
bool has_preset = common_params_handle_remote_preset(params, ctx_arg.ex);
// special case: if hf_repo explicitly set by preset, we need to preserve it (ignore CLI value)
// this is useful when we have one HF repo pointing to other HF repos (one model - multiple GGUFs)
std::string preset_hf_repo = params.model.hf_repo;
bool preset_has_hf_repo = preset_hf_repo != cli_hf_repo;
if (has_preset) {
// re-parse CLI args to override preset values
parse_cli_args();
}
// preserve hf_repo from preset if needed
if (preset_has_hf_repo) {
params.model.hf_repo = preset_hf_repo;
}
}
postprocess_cpu_params(params.cpuparams, nullptr);
postprocess_cpu_params(params.cpuparams_batch, &params.cpuparams);
@@ -635,15 +585,21 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
}
// handle model and download
if (!skip_model_download) {
common_params_handle_models(params, ctx_arg.ex);
}
// export_graph_ops loads only metadata
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
// model is required (except for server)
// TODO @ngxson : maybe show a list of available models in CLI in this case
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !skip_model_download && !params.usage && !params.completion) {
throw std::invalid_argument("error: --model is required\n");
if (!skip_model_download) {
// handle model and download
common_params_handle_models(params, ctx_arg.ex);
// model is required (except for server)
// TODO @ngxson : maybe show a list of available models in CLI in this case
if (params.model.path.empty()
&& ctx_arg.ex != LLAMA_EXAMPLE_SERVER
&& !params.usage
&& !params.completion) {
throw std::invalid_argument("error: --model is required\n");
}
}
if (params.escape) {
+5 -4
View File
@@ -642,10 +642,11 @@ struct common_params {
std::vector<std::string> server_tools;
// router server configs
std::string models_dir = ""; // directory containing models for the router server
std::string models_preset = ""; // directory containing model presets for the router server
int models_max = 4; // maximum number of models to load simultaneously
bool models_autoload = true; // automatically load models when requested via the router server
std::string models_dir = ""; // directory containing models for the router server
std::string models_preset = ""; // directory containing model presets for the router server
int models_max = 4; // maximum number of models to load simultaneously
bool models_autoload = true; // automatically load models when requested via the router server
std::string models_preset_hf = ""; // show a warning about remote presets on router loaded (if not empty)
bool log_json = false;
+36 -17
View File
@@ -696,6 +696,7 @@ struct hf_plan {
hf_cache::hf_files model_files;
hf_cache::hf_file mmproj;
hf_cache::hf_file mtp;
hf_cache::hf_file preset; // if set, only this file is downloaded
};
static hf_plan get_hf_plan(const common_params_model & model,
@@ -717,6 +718,14 @@ static hf_plan get_hf_plan(const common_params_model & model,
return plan;
}
// if preset.ini exists in the repo root, download only that file
for (const auto & f : all) {
if (f.path == "preset.ini") {
plan.preset = f;
return plan;
}
}
hf_cache::hf_file primary;
if (!model.hf_file.empty()) {
@@ -794,14 +803,19 @@ common_download_model_result common_download_model(const common_params_model &
if (is_hf) {
hf = get_hf_plan(model, opts, download_mmproj, download_mtp);
for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path});
}
if (!hf.mmproj.path.empty()) {
tasks.push_back({hf.mmproj.url, hf.mmproj.local_path});
}
if (!hf.mtp.path.empty()) {
tasks.push_back({hf.mtp.url, hf.mtp.local_path});
if (!hf.preset.path.empty()) {
// if preset.ini exists, only download that file alone
tasks.push_back({hf.preset.url, hf.preset.local_path});
} else {
for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path});
}
if (!hf.mmproj.path.empty()) {
tasks.push_back({hf.mmproj.url, hf.mmproj.local_path});
}
if (!hf.mtp.path.empty()) {
tasks.push_back({hf.mtp.url, hf.mtp.local_path});
}
}
} else if (!model.url.empty()) {
tasks = get_url_tasks(model);
@@ -835,17 +849,22 @@ common_download_model_result common_download_model(const common_params_model &
}
if (is_hf) {
for (const auto & f : hf.model_files) {
hf_cache::finalize_file(f);
}
result.model_path = hf.primary.final_path;
if (!hf.preset.path.empty()) {
// if preset.ini is used, do not set other paths
result.preset_path = hf_cache::finalize_file(hf.preset);
} else {
for (const auto & f : hf.model_files) {
hf_cache::finalize_file(f);
}
result.model_path = hf.primary.final_path;
if (!hf.mmproj.path.empty()) {
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
}
if (!hf.mmproj.path.empty()) {
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
}
if (!hf.mtp.path.empty()) {
result.mtp_path = hf_cache::finalize_file(hf.mtp);
if (!hf.mtp.path.empty()) {
result.mtp_path = hf_cache::finalize_file(hf.mtp);
}
}
} else {
result.model_path = model.path;
+1
View File
@@ -63,6 +63,7 @@ struct common_download_model_result {
std::string model_path;
std::string mmproj_path;
std::string mtp_path;
std::string preset_path;
};
// throw if the file is missing or invalid (e.g. ETag check failed)
+1 -49
View File
@@ -16,48 +16,6 @@ static std::string rm_leading_dashes(const std::string & str) {
return str.substr(pos);
}
// only allow a subset of args for remote presets for security reasons
// do not add more args unless absolutely necessary
// args that output to files are strictly prohibited
static std::set<std::string> get_remote_preset_whitelist(const std::map<std::string, common_arg> & key_to_opt) {
static const std::set<std::string> allowed_options = {
"model-url",
"hf-repo",
"hf-repo-draft",
"hf-repo-v", // vocoder
"hf-file-v", // vocoder
"mmproj-url",
"pooling",
"jinja",
"batch-size",
"ubatch-size",
"cache-reuse",
"chat-template-kwargs",
"mmap",
// note: sampling params are automatically allowed by default
// negated args will be added automatically if the positive arg is specified above
};
std::set<std::string> allowed_keys;
for (const auto & it : key_to_opt) {
const std::string & key = it.first;
const common_arg & opt = it.second;
if (allowed_options.find(key) != allowed_options.end() || opt.is_sampling) {
allowed_keys.insert(key);
// also add variant keys (args without leading dashes and env vars)
for (const auto & arg : opt.get_args()) {
allowed_keys.insert(rm_leading_dashes(arg));
}
for (const auto & env : opt.get_env()) {
allowed_keys.insert(env);
}
}
}
return allowed_keys;
}
std::vector<std::string> common_preset::to_args(const std::string & bin_path) const {
std::vector<std::string> args;
@@ -300,16 +258,10 @@ static std::string parse_bool_arg(const common_arg & arg, const std::string & ke
return value;
}
common_preset_context::common_preset_context(llama_example ex, bool only_remote_allowed)
common_preset_context::common_preset_context(llama_example ex)
: ctx_params(common_params_parser_init(default_params, ex)) {
common_params_add_preset_options(ctx_params.options);
key_to_opt = get_map_key_opt(ctx_params);
// setup allowed keys if only_remote_allowed is true
if (only_remote_allowed) {
filter_allowed_keys = true;
allowed_keys = get_remote_preset_whitelist(key_to_opt);
}
}
common_presets common_preset_context::load_from_ini(const std::string & path, common_preset & global) const {
+1 -1
View File
@@ -60,7 +60,7 @@ struct common_preset_context {
std::set<std::string> allowed_keys;
// if only_remote_allowed is true, only accept whitelisted keys
common_preset_context(llama_example ex, bool only_remote_allowed = false);
common_preset_context(llama_example ex);
// load presets from INI file
common_presets load_from_ini(const std::string & path, common_preset & global) const;
+3
View File
@@ -259,6 +259,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}
}
}
if (!grmr && !grammar_str.empty()) {
throw std::runtime_error("failed to parse grammar");
}
// Compute prefill tokens from the generation prompt
std::vector<llama_token> prefill_tokens;
+3 -2
View File
@@ -1,10 +1,11 @@
# Multimodal
llama.cpp supports multimodal input via `libmtmd`. Currently, there are 2 tools support this feature:
- [llama-mtmd-cli](../tools/mtmd/README.md)
- [llama-cli](../tools/cli/README.md)
- [llama-server](../tools/server/README.md) via OpenAI-compatible `/chat/completions` API
- [llama-mtmd-cli](../tools/mtmd/README.md), for testing and development
Currently, we support **image** and **audio** input. Audio is highly experimental and may have reduced quality.
Currently, we support **image**, **audio** and **video** input.
To enable it, you can use one of the 2 methods below:
+36 -38
View File
@@ -8,55 +8,53 @@ The INI preset feature, introduced in [PR#17859](https://github.com/ggml-org/lla
When running multiple models on the server (router mode), INI preset files can be used to configure model-specific parameters. Please refer to the [server documentation](../tools/server/README.md) for more details.
### Using a Remote Preset
### Using a Hugging Face Preset
> [!NOTE]
> [!IMPORTANT]
>
> This feature is currently only supported via the `-hf` option.
> Please only use presets that you can trust! Unknown presets may be unsafe
For GGUF models hosted on Hugging Face, you can include a `preset.ini` file in the root directory of the repository to define specific configurations for that model.
You can push your preset to Hugging Face Hub and share with other users by:
1. Creating an empty model repository on Hugging Face
2. Creating a `preset.ini` file in the root directory of the repository
Example:
Example of a `preset.ini`:
```ini
hf-repo-draft = username/my-draft-model-GGUF
temp = 0.5
top-k = 20
top-p = 0.95
[*]
ctx-size = 0
mmap = 1
kv-unified = 1
parallel = 4
spec-default = 1
[Qwen3.5-4B]
hf = unsloth/Qwen3.5-4B-GGUF:Q4_K_M
ctx-size = 262144
batch-size = 2048
ubatch-size = 2048
top-p = 1.0
top-k = 0
min-p = 0.01
temp = 1.0
[gpt-oss-120b-hf]
hf = ggml-org/gpt-oss-120b-GGUF
ctx-size = 262144
batch-size = 2048
ubatch-size = 2048
top-p = 1.0
top-k = 0
min-p = 0.01
temp = 1.0
chat-template-kwargs = {"reasoning_effort": "high"}
```
For security reasons, only certain options are allowed. Please refer to [preset.cpp](../common/preset.cpp) for the complete list of permitted options.
Example usage:
Assuming your repository `username/my-model-with-preset` contains a `preset.ini` with the configuration above:
```sh
llama-cli -hf username/my-model-with-preset
# This is equivalent to:
llama-cli -hf username/my-model-with-preset \
--hf-repo-draft username/my-draft-model-GGUF \
--temp 0.5 \
--top-k 20 \
--top-p 0.95
```
You can also override preset arguments by specifying them on the command line:
The preset will be loaded similarly to the `--models-preset` option. Therefore, you can also override certain params via CLI arguments:
```sh
# Force temp = 0.1, overriding the preset value
llama-cli -hf username/my-model-with-preset --temp 0.1
```
If you want to define multiple preset configurations for one or more GGUF models, you can create a blank HF repo for each preset. Each HF repo should contain a `preset.ini` file that references the actual model(s):
```ini
hf-repo = user/my-model-main
hf-repo-draft = user/my-model-draft
temp = 0.8
ctx-size = 1024
; (and other configurations)
llama-cli -hf username/my-preset --temp 0.1
```
### Named presets
+104 -12
View File
@@ -69,6 +69,7 @@ static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE;
static int opt_opbatch = 1024; // max number of ops in a batch
static int opt_opqueue = 16; // max number of pending batches
static int opt_oppoll = 0; // polling for batch completions
static int opt_optrace = 0; // trace buffer size per thread (0 means default)
static std::regex* opt_opfilter = NULL; // regex of ops to not claim
@@ -118,20 +119,39 @@ static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct
ggml_op_desc(op), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, supp ? "yes" : "no");
}
static const char * htp_event_name(uint16_t id) {
switch (id) {
case HTP_TRACE_EVT_DMA: return "DMA";
case HTP_TRACE_EVT_HVX_COMP: return "HVX_COMP";
case HTP_TRACE_EVT_HVX_A_QUANT: return "HVX_A_QUANT";
case HTP_TRACE_EVT_HVX_A_PREP: return "HVX_A_PREP";
case HTP_TRACE_EVT_HVX_W_DEQUANT: return "HVX_W_DEQUANT";
case HTP_TRACE_EVT_HVX_W_PREP: return "HVX_W_PREP";
case HTP_TRACE_EVT_HVX_O_PROC: return "HVX_O_PROC";
case HTP_TRACE_EVT_HMX_COMP: return "HMX_COMP";
default: return "UNKNOWN";
}
}
static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const htp_opnode & node,
uint32_t op_usec, uint32_t op_cycles, const uint32_t pmu[]) {
const htp_prof_desc & pd) {
if (!opt_profile) return;
uint32_t op_usec = pd.usecs;
uint32_t op_cycles = pd.cycles_stop - pd.cycles_start;
const uint32_t * pmu = pd.pmu;
char pmu_str[256] = "";
if (opt_profile > 1) {
if (opt_profile == 2) {
static_assert(HTP_PROF_PMU_NCNT == 8, "current implementation assumes 8 PMU counters");
sprintf(pmu_str, " pmu [%u,%u,%u,%u,%u,%u,%u,%u]",
pmu[0], pmu[1], pmu[2], pmu[3], pmu[4], pmu[5], pmu[6], pmu[7]);
}
htp_opformat fmt(node);
GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u%s\n", sess_name.c_str(),
node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, op_usec, op_cycles, pmu_str);
float mhz = op_usec > 0 ? (float) op_cycles / op_usec : 0.0f;
GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u start %u mhz %.1f%s\n", sess_name.c_str(),
node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, op_usec, op_cycles, pd.cycles_start, mhz, pmu_str);
}
// ** backend sessions
@@ -1995,10 +2015,16 @@ struct ggml_hexagon_opqueue {
size_t n_ops = batch_size;
size_t n_tensors = n_ops + n_ops * HTP_OP_MAX_INPUTS;
size_t tr_size = 0;
if (opt_profile == 3) {
tr_size = (HTP_MAX_NTHREADS + 1) * opt_optrace * sizeof(htp_trace_desc);
}
shm_blk_size = sizeof(htp_buf_desc) * n_bufs +
sizeof(htp_tensor) * n_tensors +
sizeof(htp_op_desc) * n_ops +
sizeof(htp_prof_desc) * n_ops;
sizeof(htp_prof_desc) * n_ops +
tr_size;
shm_buf = new ggml_hexagon_shared_buffer(sess, shm_blk_size * depth, true /* pinned */);
@@ -2042,11 +2068,19 @@ struct ggml_hexagon_opqueue {
const size_t o_size = sizeof(htp_op_desc) * req.n_ops;
const size_t p_size = sizeof(htp_prof_desc) * req.n_ops;
size_t tr_size = 0;
if (opt_profile == 3) {
req.n_traces = opt_optrace;
tr_size = (HTP_MAX_NTHREADS + 1) * req.n_traces * sizeof(htp_trace_desc);
} else {
req.n_traces = 0;
}
dbuf.ptr = shm_buf->base + (req.id * shm_blk_size);
dbuf.fd = shm_buf->fd;
dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT;
dbuf.offset = (uint8_t*) dbuf.ptr - (uint8_t*) shm_buf->base;
dbuf.size = b_size + t_size + o_size + p_size;
dbuf.size = b_size + t_size + o_size + p_size + tr_size;
GGML_ASSERT(dbuf.size <= shm_blk_size);
@@ -2092,7 +2126,14 @@ struct ggml_hexagon_opqueue {
const size_t o_size = sizeof(htp_op_desc) * rsp.n_ops;
const size_t p_size = sizeof(htp_prof_desc) * rsp.n_ops;
const size_t m_size = b_size + t_size + o_size + p_size;
size_t tr_size = 0;
uint32_t n_traces = 0;
if (opt_profile == 3) {
n_traces = opt_optrace;
tr_size = (HTP_MAX_NTHREADS + 1) * n_traces * sizeof(htp_trace_desc);
}
const size_t m_size = b_size + t_size + o_size + p_size + tr_size;
GGML_ASSERT(m_size <= shm_blk_size);
HEX_VERBOSE("ggml-hex: %s op-queue pop batch #%u : n-bufs %u n-tensors %u n-ops %u : m-size %zu b-size %zu t-size %zu o-size %zu\n",
@@ -2111,13 +2152,62 @@ struct ggml_hexagon_opqueue {
GGML_ASSERT(rsp.n_ops <= ops.size());
const htp_prof_desc * pd = (const htp_prof_desc *) p_ptr;
for (uint32_t i = 0; i < rsp.n_ops; i++) {
htp_usec += pd[i].usecs;
ggml_hexagon_dump_op_prof(shm_buf->sess->name, ops[i], pd[i].usecs, pd[i].cycles, pd[i].pmu);
const htp_trace_desc * trace_events = nullptr;
if (opt_profile == 3) {
trace_events = (const htp_trace_desc *) (p_ptr + p_size);
}
GGML_LOG_DEBUG("ggml-hex: %s profile-batch n-ops %u batch-dur-usec %lld htp-ops-usec %u\n",
shm_buf->sess->c_name(), rsp.n_ops, (long long) batch_usec, htp_usec);
uint32_t trace_idx[HTP_MAX_NTHREADS + 1] = {0};
uint32_t valid_cnt[HTP_MAX_NTHREADS + 1] = {0};
if (opt_profile == 3) {
for (uint32_t t = 0; t <= HTP_MAX_NTHREADS; t++) {
uint32_t count = rsp.n_traces[t];
valid_cnt[t] = count > n_traces ? n_traces : count;
}
}
for (uint32_t i = 0; i < rsp.n_ops; i++) {
htp_usec += pd[i].usecs;
ggml_hexagon_dump_op_prof(shm_buf->sess->name, ops[i], pd[i]);
if (opt_profile == 3) {
uint32_t op_duration = pd[i].cycles_stop - pd[i].cycles_start;
for (uint32_t t = 0; t <= HTP_MAX_NTHREADS; t++) {
while (trace_idx[t] < valid_cnt[t]) {
const auto & e = trace_events[t * n_traces + trace_idx[t]];
uint32_t offset = e.cycles - pd[i].cycles_start;
if (offset >= 0x80000000) {
trace_idx[t]++;
continue;
}
if (offset > op_duration) {
break;
}
bool is_stop = (e.info & 0x8000) != 0;
uint16_t info = e.info & 0x7FFF;
GGML_LOG_DEBUG("ggml-hex: %s trace-op %s: thread %u event %s info %u %s %u\n",
shm_buf->sess->c_name(), ops[i].op_name().c_str(), t, htp_event_name(e.id), info, is_stop ? "stop" : "start", e.cycles);
trace_idx[t]++;
}
}
}
}
char evt_str[256] = "";
if (opt_profile == 3) {
sprintf(evt_str, " evt [%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u]",
rsp.n_traces[0], rsp.n_traces[1], rsp.n_traces[2], rsp.n_traces[3],
rsp.n_traces[4], rsp.n_traces[5], rsp.n_traces[6], rsp.n_traces[7],
rsp.n_traces[8], rsp.n_traces[9], rsp.n_traces[10]);
}
GGML_LOG_DEBUG("ggml-hex: %s profile-batch n-ops %u batch-dur-usec %lld htp-ops-usec %u%s\n",
shm_buf->sess->c_name(), rsp.n_ops, (long long) batch_usec, htp_usec, evt_str);
}
}
};
@@ -3901,6 +3991,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH");
const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE");
const char * str_oppoll = getenv("GGML_HEXAGON_OPPOLL");
const char * str_optrace = getenv("GGML_HEXAGON_OPTRACE");
const char * str_opfilter = getenv("GGML_HEXAGON_OPFILTER");
const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
const char * str_etm = getenv("GGML_HEXAGON_ETM");
@@ -3939,6 +4030,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch;
opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue;
opt_oppoll = str_oppoll ? strtoul(str_oppoll, NULL, 0) : opt_oppoll;
opt_optrace = str_optrace ? strtoul(str_optrace, NULL, 0) : (opt_opbatch * 128);
opt_profile = str_profile ? atoi(str_profile) : 0;
opt_etm = str_etm ? atoi(str_etm) : 0;
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
+1 -1
View File
@@ -37,8 +37,8 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
if (_hmx_idx GREATER_EQUAL 0)
target_sources(${HTP_LIB} PRIVATE
hmx-matmul-ops.c
hmx-flash-attn-ops.c
hmx-matmul-ops.c
hmx-queue.c
)
@@ -339,6 +339,9 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
if (ir0 >= ir1) return;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
dma_queue * dma = octx->ctx->dma[ith];
const uint32_t DK = nek0;
@@ -615,6 +618,7 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
hvx_copy_f16_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
}
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
}
int op_flash_attn_ext(struct htp_ops_context * octx) {
+10 -3
View File
@@ -6,6 +6,8 @@
#include <stdbool.h>
#include <stdint.h>
#include "hex-profile.h"
#ifdef __cplusplus
extern "C" {
#endif
@@ -88,6 +90,7 @@ typedef struct {
uint32_t pop_idx;
uint32_t capacity;
uint32_t idx_mask;
struct htp_thread_trace * trace;
} dma_queue;
dma_queue * dma_queue_create(size_t capacity);
@@ -152,6 +155,7 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t
q->dptr[q->push_idx] = dptr;
if (size) {
htp_trace_event_start(q->trace, HTP_TRACE_EVT_DMA, q->push_idx);
dmlink(q->tail, desc);
q->tail = (dma_descriptor_2d *) desc;
} else {
@@ -202,6 +206,7 @@ static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t
q->dptr[q->push_idx] = dptr;
if (nrows) {
htp_trace_event_start(q->trace, HTP_TRACE_EVT_DMA, q->push_idx);
dmlink(q->tail, desc);
q->tail = desc;
} else {
@@ -223,10 +228,12 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) {
dma_descriptor_2d * desc = &q->desc[q->pop_idx];
// Wait for desc to complete
while (!desc->done) {
// FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx);
dmpoll();
if (!desc->done) {
while (!desc->done) {
dmpoll();
}
}
htp_trace_event_stop(q->trace, HTP_TRACE_EVT_DMA, q->pop_idx);
dptr = q->dptr[q->pop_idx];
+64
View File
@@ -0,0 +1,64 @@
#ifndef HEX_PROFILE_H
#define HEX_PROFILE_H
#include <stdbool.h>
#include <stdint.h>
#include <qurt.h>
#include "hex-utils.h"
#include "htp-ops.h"
#define HTP_TRACE_EVT_START 0
#define HTP_TRACE_EVT_STOP 1
#ifndef HEX_NUM_PMU_COUNTERS
#define HEX_NUM_PMU_COUNTERS 8
#endif
static inline void hex_get_pmu(uint32_t counters[]) {
#if __HVX_ARCH__ >= 79
asm volatile("%0 = upmucnt0" : "=r"(counters[0]));
asm volatile("%0 = upmucnt1" : "=r"(counters[1]));
asm volatile("%0 = upmucnt2" : "=r"(counters[2]));
asm volatile("%0 = upmucnt3" : "=r"(counters[3]));
asm volatile("%0 = upmucnt4" : "=r"(counters[4]));
asm volatile("%0 = upmucnt5" : "=r"(counters[5]));
asm volatile("%0 = upmucnt6" : "=r"(counters[6]));
asm volatile("%0 = upmucnt7" : "=r"(counters[7]));
#else
counters[0] = qurt_pmu_get(QURT_PMUCNT0);
counters[1] = qurt_pmu_get(QURT_PMUCNT1);
counters[2] = qurt_pmu_get(QURT_PMUCNT2);
counters[3] = qurt_pmu_get(QURT_PMUCNT3);
counters[4] = qurt_pmu_get(QURT_PMUCNT4);
counters[5] = qurt_pmu_get(QURT_PMUCNT5);
counters[6] = qurt_pmu_get(QURT_PMUCNT6);
counters[7] = qurt_pmu_get(QURT_PMUCNT7);
#endif
}
struct htp_thread_trace {
uint32_t count;
uint32_t max_events;
struct htp_trace_desc * events;
};
static inline void htp_trace_event(struct htp_thread_trace * tr, uint16_t id, uint16_t info, uint32_t type) {
if (tr && tr->events && tr->count < tr->max_events) {
uint32_t idx = tr->count;
tr->events[idx].id = id;
tr->events[idx].info = info | (type == HTP_TRACE_EVT_STOP ? 0x8000 : 0);
tr->events[idx].cycles = (uint32_t) hex_get_cycles();
tr->count++;
}
}
static inline void htp_trace_event_start(struct htp_thread_trace * tr, uint16_t id, uint16_t info) {
htp_trace_event(tr, id, info, HTP_TRACE_EVT_START);
}
static inline void htp_trace_event_stop(struct htp_thread_trace * tr, uint16_t id, uint16_t info) {
htp_trace_event(tr, id, info, HTP_TRACE_EVT_STOP);
}
#endif /* HEX_PROFILE_H */
-27
View File
@@ -107,31 +107,4 @@ static inline void hex_pause() {
asm volatile(" pause(#255)\n");
}
#ifndef HEX_NUM_PMU_COUNTERS
#define HEX_NUM_PMU_COUNTERS 8
#endif
static inline void hex_get_pmu(uint32_t counters[]) {
#if __HVX_ARCH__ >= 79
asm volatile("%0 = upmucnt0" : "=r"(counters[0]));
asm volatile("%0 = upmucnt1" : "=r"(counters[1]));
asm volatile("%0 = upmucnt2" : "=r"(counters[2]));
asm volatile("%0 = upmucnt3" : "=r"(counters[3]));
asm volatile("%0 = upmucnt4" : "=r"(counters[4]));
asm volatile("%0 = upmucnt5" : "=r"(counters[5]));
asm volatile("%0 = upmucnt6" : "=r"(counters[6]));
asm volatile("%0 = upmucnt7" : "=r"(counters[7]));
#else
counters[0] = qurt_pmu_get(QURT_PMUCNT0);
counters[1] = qurt_pmu_get(QURT_PMUCNT1);
counters[2] = qurt_pmu_get(QURT_PMUCNT2);
counters[3] = qurt_pmu_get(QURT_PMUCNT3);
counters[4] = qurt_pmu_get(QURT_PMUCNT4);
counters[5] = qurt_pmu_get(QURT_PMUCNT5);
counters[6] = qurt_pmu_get(QURT_PMUCNT6);
counters[7] = qurt_pmu_get(QURT_PMUCNT7);
// qurt_pmu_get_pmucnt(counters);
#endif
}
#endif /* HEX_UTILS_H */
+26 -66
View File
@@ -18,7 +18,7 @@
#include "ggml-common.h"
#include "hex-dma.h"
#include "hex-fastdiv.h"
#include "hmx-profile.h"
#include "hex-profile.h"
#include "hmx-queue.h"
#include "hmx-utils.h"
#include "htp-ctx.h"
@@ -367,8 +367,11 @@ static void fa_k_interleave_thread(unsigned int n, unsigned int i, void * data)
return;
}
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
hmx_interleave_rows_to_tiles(factx->vtcm_k_tiles, factx->vtcm_k_fp16[args->buf_idx], total_rows, (int) factx->DK,
(int) args->src_stride, start, end);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, start);
}
static void fa_phase_k_interleave(struct hmx_fa_context * factx, int kv_rows, size_t src_stride, size_t buf_idx) {
@@ -408,8 +411,11 @@ static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data)
__fp16 * v_tiles_dest = factx->use_pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0];
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
hmx_interleave_cols_to_tiles(v_tiles_dest, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV,
(int) args->src_stride, (int) args->n_col_tiles, start, end);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, start);
}
static void fa_phase_v_interleave(struct hmx_fa_context * factx,
@@ -462,6 +468,9 @@ static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) {
return;
}
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
const struct htp_tensor * q = args->q;
const uint32_t q_start = args->q_start;
const uint32_t kv_head = args->kv_head;
@@ -515,6 +524,7 @@ static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) {
}
}
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, start);
}
static void fa_phase_q_load(struct hmx_fa_context * factx,
@@ -566,6 +576,9 @@ static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) {
return;
}
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
const struct htp_tensor * dst = args->dst;
const __fp16 * o_tile_src = args->o_tile_src;
const uint32_t q_start = args->q_start;
@@ -611,6 +624,7 @@ static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) {
}
}
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, start);
}
static void fa_phase_o_store(struct hmx_fa_context * factx,
@@ -680,6 +694,9 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
return;
}
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, vec_start);
// Per-thread row scratch: thread i uses bufs at offset i * 2 * stride
const size_t row_buf_stride = factx->row_buf_stride;
HVX_Vector * my_row_buf0 = factx->vtcm_row_bufs + i * 2 * row_buf_stride;
@@ -950,6 +967,7 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
factx->vtcm_s_rowmax[r_vec_idx] = rowmax_acc_v;
factx->vtcm_p_rowsum[r_vec_idx] = rowsum_acc_v;
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, vec_start);
}
// Serial m/l update + build_D. Must run after softmax barrier (s_rowmax written by all threads).
@@ -1245,6 +1263,7 @@ static __attribute__((noinline)) void fa_compute_slopes(
// ============================================================================
int hmx_flash_attn_ext(struct htp_ops_context * octx) {
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[HTP_MAX_NTHREADS] : NULL;
const struct htp_tensor * q = octx->src[0];
const struct htp_tensor * k = octx->src[1];
const struct htp_tensor * v = octx->src[2];
@@ -1422,19 +1441,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
return HTP_STATUS_OK;
}
// Profiling timers
TIMER_DEFINE(total);
TIMER_DEFINE(q_load);
TIMER_DEFINE(kv_dma);
TIMER_DEFINE(k_interleave);
TIMER_DEFINE(v_interleave);
TIMER_DEFINE(qk_dot);
TIMER_DEFINE(softmax);
TIMER_DEFINE(o_update);
TIMER_DEFINE(o_norm);
TIMER_DEFINE(o_store);
TIMER_START(total);
// ======== DMA setup ========
dma_queue * const dma = ctx->dma[0];
@@ -1474,12 +1480,10 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
const size_t n_row_tiles = g_br_actual / HMX_FP16_TILE_N_ROWS;
// ---- Load Q block [g_br, D] -> tiles, interleaving G heads ----
TIMER_START(q_load);
if (n_rows_g < g_br) {
hvx_splat_u8_a(factx.vtcm_q_tiles, 0, q_tile_bytes);
}
fa_phase_q_load(&factx, q, q_start, kv_head, ib3, n_rows_g);
TIMER_STOP(q_load);
// ---- Initialize per-block state ----
hvx_splat_u8_a(factx.vtcm_l_vec, 0, col_vec_bytes);
@@ -1558,10 +1562,8 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS);
// Wait for current KV DMA
TIMER_START(kv_dma);
dma_queue_pop(dma); // K
dma_queue_pop(dma); // V
TIMER_STOP(kv_dma);
// Push mask DMA for this block (single 2D DMA when broadcast)
bool has_mask_dma = false;
@@ -1583,10 +1585,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
ou_job.DV = DV;
hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job));
}
TIMER_START(k_interleave);
fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx);
TIMER_STOP(k_interleave);
// ---- Phase 2: qk_dot(blk) on HMX ‖ V_int(blk) + DMA prefetch on HVX ----
qk_job.q_tiles = factx.vtcm_q_tiles;
@@ -1597,15 +1596,11 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
qk_job.n_dot_tiles = DK / 32;
qk_job.n_tiles_per_bc = n_tiles_per_bc;
qk_job.hmx_scales = factx.vtcm_hmx_scales_qk;
TIMER_START(qk_dot);
hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_qk_dot_worker, &qk_job));
// DMA push next block (non-blocking, before worker_pool)
DMA_PREFETCH_KV(kv_blk + 1);
TIMER_START(v_interleave);
fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc);
TIMER_STOP(v_interleave);
// Pop and swap previous block's output update (deferred HMX pop)
if (kv_blk > 0) {
@@ -1615,7 +1610,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
// Pop current block's dot product job
hmx_queue_pop(hmx_q);
TIMER_STOP(qk_dot);
// ---- Phase 3: softmax(blk) + build_D(blk) | HMX idle ----
// Pop mask DMA before softmax (ensures VTCM buffer is ready)
@@ -1641,10 +1635,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL;
sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride;
sargs.slopes = factx.vtcm_slopes;
TIMER_START(softmax);
fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br);
TIMER_STOP(softmax);
buf_idx = 1 - buf_idx;
} // end KV block loop (pipeline)
@@ -1664,11 +1655,8 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
ou_job.n_row_tiles_g_br = n_row_tiles_g_br;
ou_job.n_tiles_per_bc = n_tiles_per_bc;
ou_job.DV = DV;
TIMER_START(o_update);
hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job));
hmx_queue_pop(hmx_q);
TIMER_STOP(o_update);
hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev);
}
@@ -1683,23 +1671,14 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
const uint32_t kv_start = kv_blk * Bc;
const uint32_t kv_rows = hex_smin(Bc, nek1 - kv_start);
const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS);
TIMER_START(kv_dma);
dma_queue_pop(dma); // K
dma_queue_pop(dma); // V
TIMER_STOP(kv_dma);
bool has_mask_dma = false;
MASK_DMA_PUSH(kv_start, kv_rows, has_mask_dma);
DMA_PREFETCH_KV(kv_blk + 1);
// K interleave (multi-thread HVX)
TIMER_START(k_interleave);
fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx);
TIMER_STOP(k_interleave);
// QK dot (inline HMX on main thread)
TIMER_START(qk_dot);
{
const size_t n_dot_tiles = (size_t) (DK / 32);
const __fp16 * restrict q_base = factx.vtcm_q_tiles;
@@ -1709,6 +1688,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
__builtin_assume(n_col_tiles > 0);
__builtin_assume(n_dot_tiles > 0);
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_qk);
for (size_t r = 0; r < n_row_tiles; ++r) {
for (size_t c = 0; c < n_col_tiles; ++c) {
@@ -1724,8 +1704,8 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
Q6_mxmem_AR_after_hf(out_tile, 0);
}
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
}
TIMER_STOP(qk_dot);
// Pop mask DMA
MASK_DMA_POP(has_mask_dma);
@@ -1751,21 +1731,9 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL;
sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride;
sargs.slopes = factx.vtcm_slopes;
TIMER_START(softmax);
fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br);
TIMER_STOP(softmax);
// V interleave (multi-thread HVX)
TIMER_START(v_interleave);
// FIX(v-stride): use n_tiles_per_bc (block-invariant) as V tile layout
// stride to match o_update's v_tile access. Using per-block n_col_tiles
// misplaces DV_tile 1..3 in the last partial KV block.
fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc);
TIMER_STOP(v_interleave);
// O update (inline HMX on main thread)
TIMER_START(o_update);
{
const size_t DV_tiles = (size_t) (DV / 32);
const __fp16 * restrict d_base = factx.vtcm_d_tiles;
@@ -1777,6 +1745,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
__builtin_assume(n_col_tiles > 0);
__builtin_assume(DV_tiles > 0);
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id);
for (size_t r = 0; r < n_row_tiles; ++r) {
for (size_t c = 0; c < DV_tiles; ++c) {
@@ -1798,16 +1767,15 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
Q6_mxmem_AR_after_hf(o_tile_out, 0);
}
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev);
}
TIMER_STOP(o_update);
buf_idx = 1 - buf_idx;
} // end KV block loop (fallback)
}
// ---- Final normalization: O = diag(1/l) @ O ----
TIMER_START(o_norm);
{
fa_build_d_diag_inv_l(&factx, n_row_tiles, n_row_tiles_g_br);
@@ -1830,6 +1798,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
__builtin_assume(n_row_tiles > 0);
__builtin_assume(DV_tiles > 0);
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id);
for (size_t r = 0; r < n_row_tiles; ++r) {
for (size_t c = 0; c < DV_tiles; ++c) {
@@ -1842,14 +1811,12 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
Q6_mxmem_AR_after_hf(o_out, 0);
}
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
}
}
TIMER_STOP(o_norm);
// ---- Store O block ----
TIMER_START(o_store);
fa_phase_o_store(&factx, dst, o_tile_curr, q_start, kv_head, ib3, n_rows_g);
TIMER_STOP(o_store);
#undef MASK_DMA_PUSH
#undef MASK_DMA_POP
@@ -1865,14 +1832,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
}
TIMER_STOP(total);
#if defined(ENABLE_PROFILE_TIMERS)
FARF(HIGH, "hmx-fa: %lld us, q_load=%lld kv_dma=%lld k_interleave=%lld v_interleave=%lld", TIMER_US(total),
TIMER_US(q_load), TIMER_US(kv_dma), TIMER_US(k_interleave), TIMER_US(v_interleave));
FARF(HIGH, " qk_dot=%lld softmax=%lld o_update=%lld o_norm=%lld o_store=%lld", TIMER_US(qk_dot), TIMER_US(softmax),
TIMER_US(o_update), TIMER_US(o_norm), TIMER_US(o_store));
#endif
return HTP_STATUS_OK;
}
+55 -41
View File
@@ -27,7 +27,7 @@
#include "hmx-ops.h"
#include "hmx-utils.h"
#include "hmx-queue.h"
#include "hmx-profile.h"
#include "hex-profile.h"
#include "vtcm-utils.h"
@@ -430,6 +430,7 @@ typedef struct {
int n_tasks;
int n_k_tiles;
struct fastdiv_values n_k_tiles_div;
struct htp_thread_trace * traces;
} x4x2_dequantize_state_t;
// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16.
@@ -533,11 +534,14 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task_##suffix(
\
static void dequantize_x4x2_worker_loop_##suffix(unsigned int n, unsigned int i, void *data) { \
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; \
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; \
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { \
int start = task_id * state->n_tiles_per_task; \
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); \
dequantize_x4x2_weight_to_fp16_tiles_task_##suffix(state, start, end); \
} \
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \
}
DEFINE_DEQUANTIZE_Q4_TASK(q4_0, q4_0_to_fp16_lut, q4_0, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16))
@@ -657,11 +661,14 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(
static void dequantize_x4x2_worker_loop_mxfp4(unsigned int n, unsigned int i, void *data) {
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
int start = task_id * state->n_tiles_per_task;
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(state, start, end);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
}
static void dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(
@@ -717,11 +724,14 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(
static void dequantize_x4x2_worker_loop_q8_0(unsigned int n, unsigned int i, void *data) {
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
int start = task_id * state->n_tiles_per_task;
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(state, start, end);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
}
static void convert_f16_weight_to_fp16_tiles_task(
@@ -773,11 +783,14 @@ static void convert_f16_weight_to_fp16_tiles_task(
static void convert_f16_worker_loop(unsigned int n, unsigned int i, void *data) {
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
int start = task_id * state->n_tiles_per_task;
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
convert_f16_weight_to_fp16_tiles_task(state, start, end);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
}
static void quantize_f32_weight_to_fp16_tiles_task(
@@ -833,11 +846,14 @@ static void quantize_f32_weight_to_fp16_tiles_task(
static void quantize_f32_worker_loop(unsigned int n, unsigned int i, void *data) {
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
int start = task_id * state->n_tiles_per_task;
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
quantize_f32_weight_to_fp16_tiles_task(state, start, end);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
}
@@ -868,6 +884,7 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
state.weight_type = weight_type;
state.n_k_tiles = n_k_tiles;
state.n_k_tiles_div = n_k_tiles_div;
state.traces = ctx ? ctx->trace : NULL;
if (state.n_tasks == 1 || n_threads == 1) {
dequant_worker_fn(1, 0, &state);
@@ -985,10 +1002,13 @@ typedef struct {
int n_chunks_per_task;
int n_cols;
int n; // DDR row stride (total output columns)
struct htp_thread_trace * traces;
} output_transfer_task_state_t;
static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void *data) {
output_transfer_task_state_t *st = (output_transfer_task_state_t *) data;
struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, i);
for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) {
int chunk_idx = task_id * st->n_chunks_per_task;
@@ -998,6 +1018,7 @@ static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void
const __fp16 *vtcm_src = st->vtcm_src + chunk_idx * st->n_cols;
transfer_output_chunk_fp16_to_fp32(dst, vtcm_src, chunk_size, st->n_cols, st->n);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, i);
}
static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src,
@@ -1015,6 +1036,7 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst,
state.vtcm_src = vtcm_src;
state.n_cols = n_cols;
state.n = n;
state.traces = ctx ? ctx->trace : NULL;
if (state.n_tasks == 1 || n_threads == 1) {
transfer_output_chunk_worker_fn(1, 0, &state);
@@ -1086,10 +1108,13 @@ typedef struct {
int n_chunks_per_task;
int k_block;
int k_stride;
struct htp_thread_trace * traces;
} activation_transfer_task_state_t;
static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) {
activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data;
struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, i);
for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) {
// one chunk: one row
@@ -1100,6 +1125,7 @@ static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i,
const float *src = st->src + chunk_idx * st->k_stride;
transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, i);
}
static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride, int n_threads) {
@@ -1117,6 +1143,7 @@ static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *
state.src = src;
state.k_block = k_block;
state.k_stride = k_stride;
state.traces = ctx ? ctx->trace : NULL;
if (state.n_tasks == 1 || n_threads == 1) {
transfer_activation_chunk_worker_fn(1, 0, &state);
@@ -1245,13 +1272,7 @@ int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float
FARF(HIGH, "hmx-mm-2d: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu",
m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget);
TIMER_DEFINE(activation_load);
TIMER_DEFINE(weight_load);
TIMER_DEFINE(hmx_core);
TIMER_DEFINE(output_store);
TIMER_DEFINE(total);
TIMER_START(total);
int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols);
@@ -1370,7 +1391,12 @@ int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads);
// C: HMX Compute (Synchronous)
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS);
{
struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
}
// D: Output Store
float *output_chunk = dst + (mr * n + nc);
@@ -1380,18 +1406,7 @@ int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
}
TIMER_STOP(total);
#if defined(ENABLE_PROFILE_TIMERS)
FARF(HIGH, "hex-mm-2d: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n);
if (!use_pipeline) {
FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us",
TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store));
size_t weight_size = (size_t)n * row_stride;
float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load);
FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth);
}
#endif
return 0;
}
@@ -1523,13 +1538,7 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
m_chunk_n_rows, n_chunk_n_cols,
(size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget);
TIMER_DEFINE(activation_load);
TIMER_DEFINE(weight_load);
TIMER_DEFINE(hmx_core);
TIMER_DEFINE(output_store);
TIMER_DEFINE(total);
TIMER_START(total);
const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16);
const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16);
@@ -1549,7 +1558,6 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
// contiguous rows into a VTCM scratch buffer first, then HVX
// converts from the contiguous VTCM buffer. This avoids L2 cache
// thrashing from HVX loads at large strides.
TIMER_START(activation_load);
for (int g = 0; g < group_size; ++g) {
const float *activation_chunk = hmx_matmul_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride;
__fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride;
@@ -1569,7 +1577,6 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
params->k, params->act_stride, ctx->n_threads);
}
}
TIMER_STOP(activation_load);
void *buf_curr = vtcm_scratch0;
void *buf_next = vtcm_scratch1;
@@ -1584,7 +1591,6 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols);
const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS);
TIMER_START(weight_load);
{
dma_queue_pop(ctx->dma[0]);
@@ -1601,24 +1607,22 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
0, n_cols);
hex_swap_ptr(&buf_curr, &buf_next);
}
TIMER_STOP(weight_load);
// Reuse the interleaved weight for every q_head in this GQA group
for (int g = 0; g < group_size; ++g) {
TIMER_START(hmx_core);
{
const __fp16 * vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride;
struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles,
params->k / 32);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
}
TIMER_STOP(hmx_core);
TIMER_START(output_store);
{
float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc;
transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride, ctx->n_threads);
}
TIMER_STOP(output_store);
}
}
}
@@ -1627,14 +1631,7 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
TIMER_STOP(total);
#if defined(ENABLE_PROFILE_TIMERS)
FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d group=%d", __func__, TIMER_US(total),
params->m, params->k, params->n, group_size);
FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us",
TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store));
#endif
return 0;
}
@@ -1668,6 +1665,7 @@ typedef struct {
size_t nb12;
int start_row;
int cne1;
struct htp_thread_trace *traces;
} activation_transfer_gathered_task_state_t;
typedef struct {
@@ -1684,6 +1682,7 @@ typedef struct {
size_t dst_nb2;
int start_row;
int cne1;
struct htp_thread_trace *traces;
} output_transfer_scattered_task_state_t;
static void transfer_activation_chunk_fp32_to_fp16_gathered(
@@ -1780,6 +1779,9 @@ static void transfer_activation_chunk_fp32_to_fp16_gathered(
static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigned int i, void *data) {
activation_transfer_gathered_task_state_t *st = data;
struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, i);
int chunk_idx = i;
int chunk_size = st->n_chunks_per_task;
int start_row = st->start_row + chunk_idx * chunk_size;
@@ -1791,6 +1793,7 @@ static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigne
st->matrix_rows, st->cur_a, st->mapping_stride,
st->ne11, &st->ne11_div, st->nb11, st->nb12, st->cne1);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, i);
}
static void transfer_activation_chunk_gathered_threaded(
@@ -1830,6 +1833,7 @@ static void transfer_activation_chunk_gathered_threaded(
.nb12 = nb12,
.start_row = start_row,
.cne1 = cne1,
.traces = ctx ? ctx->trace : NULL,
};
if (actual_threads <= 1) {
@@ -1895,6 +1899,9 @@ static void transfer_output_chunk_fp16_to_fp32_scattered(
static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned int i, void *data) {
output_transfer_scattered_task_state_t *st = data;
struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, i);
int chunk_idx = i;
int chunk_size = st->n_chunks_per_task;
int start_row = st->start_row + chunk_idx * chunk_size;
@@ -1906,6 +1913,7 @@ static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned i
st->matrix_rows, st->cur_a, st->mapping_stride,
st->dst_nb1, st->dst_nb2, st->cne1);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, i);
}
static void transfer_output_chunk_scattered_threaded(
@@ -1942,6 +1950,7 @@ static void transfer_output_chunk_scattered_threaded(
.dst_nb2 = dst_nb2,
.start_row = start_row,
.cne1 = cne1,
.traces = ctx ? ctx->trace : NULL,
};
if (actual_threads <= 1) {
@@ -2053,7 +2062,12 @@ int hmx_matmul_id_2d_f32(struct htp_context *ctx,
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads);
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS);
{
struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
}
transfer_output_chunk_scattered_threaded(
ctx, dst, vtcm_output, (int) mr, (int) n_rows, (int) n_cols,
-34
View File
@@ -1,34 +0,0 @@
// Conditional fine-grained profiling macros for HMX operations.
//
// Define ENABLE_PROFILE_TIMERS (via compiler flag or before including this
// header) to instrument sub-operation latencies with HAP qtimer. When the
// macro is not defined the TIMER_* helpers expand to nothing so there is zero
// overhead.
//
// Usage:
// TIMER_DEFINE(my_phase); // declare accumulator variable
// TIMER_START(my_phase); // snapshot start time
// ... work ...
// TIMER_STOP(my_phase); // accumulate elapsed ticks
// FARF(ALWAYS, "my_phase: %lld us", TIMER_US(my_phase));
#ifndef HMX_PROFILE_H
#define HMX_PROFILE_H
#include <HAP_perf.h>
// #define ENABLE_PROFILE_TIMERS
#if defined(ENABLE_PROFILE_TIMERS)
# define TIMER_DEFINE(name) int64_t name##_ticks = 0
# define TIMER_START(name) int64_t name##_t0 = HAP_perf_get_qtimer_count()
# define TIMER_STOP(name) name##_ticks += HAP_perf_get_qtimer_count() - name##_t0
# define TIMER_US(name) HAP_perf_qtimer_count_to_us(name##_ticks)
#else
# define TIMER_DEFINE(name)
# define TIMER_START(name)
# define TIMER_STOP(name)
# define TIMER_US(name) 0LL
#endif
#endif // HMX_PROFILE_H
+2
View File
@@ -44,7 +44,9 @@ static inline void hmx_queue_process(struct hmx_queue *q, bool* killed) {
case HMX_QUEUE_SUSPEND: hmx_unlock(q); break;
default:
hmx_lock(q);
htp_trace_event_start(q->trace, HTP_TRACE_EVT_HMX_COMP, ir);
d->func(d->data);
htp_trace_event_stop(q->trace, HTP_TRACE_EVT_HMX_COMP, ir);
break;
}
+2
View File
@@ -11,6 +11,7 @@
#include <HAP_farf.h>
#include "hex-utils.h"
#include "hex-profile.h"
#ifdef __cplusplus
extern "C" {
@@ -47,6 +48,7 @@ struct hmx_queue {
void * stack;
uint32_t hap_rctx;
bool hmx_locked;
struct htp_thread_trace * trace;
};
struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx);
+2
View File
@@ -4,6 +4,7 @@
#include "hex-dma.h"
#include "hmx-queue.h"
#include "htp-ops.h"
#include "hex-profile.h"
#include "worker-pool.h"
#include <assert.h>
@@ -70,6 +71,7 @@ struct htp_context {
bool hmx_enabled;
bool etm;
uint32_t profiler;
struct htp_thread_trace trace[HTP_MAX_NTHREADS + 1];
uint8_t * vtcm_base;
size_t vtcm_size;
+31 -4
View File
@@ -146,10 +146,36 @@ struct htp_op_desc {
uint16_t dst; // Output tensor index
};
#ifndef HTP_MAX_NTHREADS
#define HTP_MAX_NTHREADS 10
#endif
#define HTP_TRACE_MAX_EVENTS 256
enum htp_profiler_mode {
HTP_PROF_DISABLED = 0,
HTP_PROF_BASIC = 1,
HTP_PROF_PMU = 2,
HTP_PROF_TRACE = 3,
};
enum htp_trace_event_id {
HTP_TRACE_EVT_DMA = 0,
HTP_TRACE_EVT_HVX_COMP = 20,
HTP_TRACE_EVT_HVX_A_QUANT = 21,
HTP_TRACE_EVT_HVX_A_PREP = 22,
HTP_TRACE_EVT_HVX_W_DEQUANT = 23,
HTP_TRACE_EVT_HVX_W_PREP = 24,
HTP_TRACE_EVT_HVX_O_PROC = 25,
HTP_TRACE_EVT_HMX_COMP = 40,
};
struct htp_trace_desc {
uint32_t cycles; // lower 32-bits of cycle counter
uint16_t id; // Event ID
uint16_t info; // bit 15: is_stop. bits 14-0: tile/chunk index or other metadata.
};
#define HTP_PROF_PMU_NCNT 8
@@ -158,8 +184,8 @@ enum htp_profiler_mode {
struct htp_prof_desc {
uint32_t opcode; // GGML/HTP Op
uint32_t usecs; // Number of usec
uint32_t cycles; // Number of cycles
uint32_t pad; // Unused
uint32_t cycles_start; // Start cycle counter
uint32_t cycles_stop; // Stop cycle counter
uint32_t pmu[HTP_PROF_PMU_NCNT]; // PMU counters
};
@@ -168,7 +194,7 @@ struct htp_opbatch_req {
uint32_t n_bufs; // Number of buffers
uint32_t n_tensors; // Number of tensors
uint32_t n_ops; // Number of ops
uint32_t flags; // unused
uint32_t n_traces; // Number of trace descriptors per thread
uint32_t pad; // unused
// struct htp_buf_desc bufs[]; -- dspqueue buf 0
// struct htp_tensor tensors[]; -- dspqueue buf 0
@@ -181,7 +207,8 @@ struct htp_opbatch_rsp {
uint32_t n_bufs; // Number of buffers
uint32_t n_tensors; // Number of tensors
uint32_t n_ops; // Number of op profile descriptors
uint32_t pad; // unused
uint32_t n_traces[HTP_MAX_NTHREADS + 1];
uint8_t pad[8]; // align to 8 bytes
// struct htp_prof_desc profs[]; -- dspqueue buf 0
};
+41 -9
View File
@@ -400,7 +400,9 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
ctx->hmx_queue = NULL;
if (use_hmx) {
ctx->hmx_queue = hmx_queue_create(16, ctx->vtcm_rctx);
if (!ctx->hmx_queue) {
if (ctx->hmx_queue) {
ctx->hmx_queue->trace = &ctx->trace[HTP_MAX_NTHREADS];
} else {
FARF(ERROR, "hmx-queue-create failed");
ctx->hmx_enabled = false;
}
@@ -425,6 +427,9 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
ctx->n_threads = n_hvx;
for (int i = 0; i < ctx->n_threads; i++) {
ctx->dma[i] = dma_queue_create(256); // queue depth
if (ctx->dma[i]) {
ctx->dma[i]->trace = &ctx->trace[i];
}
}
ctx->ddr_spad_size = 512 * 1024; // 512 KB
@@ -502,7 +507,8 @@ static void htp_error_callback(dspqueue_t queue, int error, void * context) {
struct profile_data {
uint64_t usecs;
uint64_t cycles;
uint64_t cycles_start;
uint64_t cycles_stop;
uint32_t pmu_counters[HEX_NUM_PMU_COUNTERS];
};
@@ -512,8 +518,9 @@ static inline void profile_start(uint32_t mode, struct profile_data * d) {
hex_get_pmu(d->pmu_counters);
// fallthrough
case HTP_PROF_BASIC:
case HTP_PROF_TRACE:
d->usecs = HAP_perf_get_qtimer_count();
d->cycles = hex_get_cycles();
d->cycles_start = hex_get_cycles();
break;
default:
break;
@@ -530,8 +537,9 @@ static inline void profile_stop(uint32_t mode, struct profile_data * d) {
}
// fallthrough
case HTP_PROF_BASIC:
case HTP_PROF_TRACE:
d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs);
d->cycles = hex_get_cycles() - d->cycles;
d->cycles_stop = hex_get_cycles();
break;
default:
break;
@@ -845,14 +853,15 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
const uint32_t t_size = sizeof(struct htp_tensor) * n_tens;
const uint32_t o_size = sizeof(struct htp_op_desc) * n_ops;
const uint32_t p_size = sizeof(struct htp_prof_desc) * n_ops;
const uint32_t tr_size = (HTP_MAX_NTHREADS + 1) * req.n_traces * sizeof(struct htp_trace_desc);
if (dbuf.size < b_size + t_size + o_size + p_size) {
FARF(ERROR, "invalid opbatch memory block size %u", dbuf.size);
if (dbuf.size < b_size + t_size + o_size + p_size + tr_size) {
FARF(ERROR, "invalid opbatch memory block size %u (req %u)", dbuf.size, b_size + t_size + o_size + p_size + tr_size);
break;
}
FARF(HIGH, "processing opbatch #%u: n-bufs %u n-tensors %u n-ops %u : m-size %u b-size %u t-size %u o-size %u", req.id,
n_bufs, n_tens, n_ops, dbuf.size, b_size, t_size, o_size);
FARF(HIGH, "processing opbatch #%u: n-bufs %u n-tensors %u n-ops %u n-traces %u : m-size %u b-size %u t-size %u o-size %u", req.id,
n_bufs, n_tens, n_ops, req.n_traces, dbuf.size, b_size, t_size, o_size);
// Setup descriptor pointers
uint8_t * m_ptr = dbuf.ptr;
@@ -869,6 +878,20 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
octx->n_threads = ctx->n_threads;
octx->ctx = ctx;
if (ctx->profiler == HTP_PROF_TRACE) {
memset(ctx->trace, 0, sizeof(ctx->trace));
struct htp_trace_desc * trace_events = (struct htp_trace_desc *) (m_ptr + p_size);
for (int t = 0; t <= HTP_MAX_NTHREADS; t++) {
ctx->trace[t].events = &trace_events[t * req.n_traces];
ctx->trace[t].max_events = req.n_traces;
}
} else {
for (int t = 0; t <= HTP_MAX_NTHREADS; t++) {
ctx->trace[t].events = NULL;
ctx->trace[t].max_events = 0;
}
}
for (uint32_t i=0; i < n_ops; i++) {
struct profile_data prof;
@@ -886,7 +909,8 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
if (ctx->profiler) {
pds[i].opcode = ops[i].opcode;
pds[i].usecs = prof.usecs;
pds[i].cycles = prof.cycles;
pds[i].cycles_start = prof.cycles_start;
pds[i].cycles_stop = prof.cycles_stop;
for (int j = 0; j < HEX_NUM_PMU_COUNTERS; j++) {
pds[i].pmu[j] = prof.pmu_counters[j];
}
@@ -899,6 +923,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
rsp.n_bufs = n_bufs;
rsp.n_tensors = n_tens;
rsp.n_ops = n_ops;
memset(rsp.pad, 0, sizeof(rsp.pad));
if (ctx->profiler == HTP_PROF_TRACE) {
for (int t = 0; t <= HTP_MAX_NTHREADS; t++) {
rsp.n_traces[t] = ctx->trace[t].count;
}
} else {
memset(rsp.n_traces, 0, sizeof(rsp.n_traces));
}
dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT;
+46
View File
@@ -3350,6 +3350,7 @@ static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void *
static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
htp_matmul_preamble;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
@@ -3411,10 +3412,12 @@ static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, iir0);
for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, iir0);
}
}
}
@@ -3430,6 +3433,7 @@ static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
// src1 tensor is already in VTCM spad
static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
htp_matmul_preamble;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows
@@ -3477,6 +3481,8 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
// Process src1 columns in pairs (2×2 tiling)
uint32_t ir1 = 0;
for (; ir1 + 1 < src1_nrows; ir1 += 2) {
@@ -3494,6 +3500,8 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
// Prefetch next (n + spad_nrows) row
const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
@@ -3511,12 +3519,14 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
src0_stride, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
#pragma unroll(2)
for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
}
t2 = HAP_perf_get_qtimer_count();
@@ -3530,6 +3540,7 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
// q8x4x2 src1 tensor is already in VTCM spad
static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
htp_matmul_preamble;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const uint32_t src0_nrows = ne01;
@@ -3581,7 +3592,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
// Process src0 rows
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
mmctx->vec_dot_4x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, ss0 + 2 * src0_stride, ss0 + 3 * src0_stride, src1_col);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
// Prefetch next (n + spad_nrows) row
const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
@@ -3599,7 +3612,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
src0_stride, src0_row_size, 2);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
ir0 += 2;
}
if (ir0 < src0_end_row) {
@@ -3607,7 +3622,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
src0_stride, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
ir0 += 1;
}
} else {
@@ -3627,7 +3644,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
// Process src0 rows
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
// Prefetch next (n + spad_nrows) row
const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
@@ -3645,7 +3664,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
src0_stride, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
}
}
@@ -3669,6 +3690,7 @@ struct mmid_row_mapping {
// src1 tensor is already in VTCM spad
static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
htp_matmul_preamble;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const struct htp_tensor * restrict ids = octx->src[2];
struct htp_spad * restrict src2_spad = &octx->src2_spad;
@@ -3735,6 +3757,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
for (uint32_t cid = 0; cid < cne1; ++cid) {
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
const int rm1 = row_mapping.i1; // expert idx
@@ -3746,6 +3769,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
// Prefetch next (n + spad_nrows) row
const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
@@ -3764,6 +3788,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
src0_row_size_padded, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
for (uint32_t cid = 0; cid < cne1; ++cid) {
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
const int rm1 = row_mapping.i1; // expert idx
@@ -3775,6 +3800,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
}
}
@@ -3789,6 +3815,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
// src1 tensor is already in VTCM spad
static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
htp_matmul_preamble;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const struct htp_tensor * restrict ids = octx->src[2];
struct htp_spad * restrict src2_spad = &octx->src2_spad;
@@ -3847,7 +3874,9 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
// Process src0 rows
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
// Prefetch next (n + spad_nrows) row
const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
@@ -3865,7 +3894,9 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
src0_row_size_padded, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
}
}
@@ -4147,6 +4178,7 @@ static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, ui
static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) {
struct htp_matmul_context * mmctx = data;
struct htp_ops_context * octx = mmctx->octx;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const struct htp_tensor * src = octx->src[1];
uint8_t * restrict dst = octx->src1_spad.data;
@@ -4163,6 +4195,7 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data)
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
const uint32_t ir_first = nrows_per_thread * ith; // first row
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
const size_t src_row_size = src->nb[1];
@@ -4189,6 +4222,7 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data)
FARF(HIGH, "quantize-f32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
}
static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) {
@@ -4219,6 +4253,7 @@ static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y,
static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * data) {
struct htp_matmul_context * mmctx = data;
struct htp_ops_context * octx = mmctx->octx;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const struct htp_tensor * src = octx->src[1];
uint8_t * restrict dst = octx->src1_spad.data;
@@ -4235,6 +4270,7 @@ static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * dat
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
const uint32_t ir_first = nrows_per_thread * ith; // first row
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
const size_t src_row_size = src->nb[1];
@@ -4260,11 +4296,13 @@ static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * dat
FARF(HIGH, "quantize-f32-q8_1x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
}
static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) {
struct htp_matmul_context * mmctx = data;
struct htp_ops_context * octx = mmctx->octx;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const struct htp_tensor * src = octx->src[1];
uint8_t * restrict dst = octx->src1_spad.data;
@@ -4281,6 +4319,7 @@ static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) {
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
const uint32_t ir_first = nrows_per_thread * ith; // first row
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
const size_t src_row_size = ne0 * sizeof(float);
@@ -4301,11 +4340,13 @@ static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) {
FARF(HIGH, "quantize-f32-f32: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
}
static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
struct htp_matmul_context * mmctx = data;
struct htp_ops_context * octx = mmctx->octx;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const struct htp_tensor * src = octx->src[1];
uint8_t * restrict dst = octx->src1_spad.data;
@@ -4322,6 +4363,7 @@ static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
const uint32_t ir_first = nrows_per_thread * ith; // first row
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
const size_t src_row_size = ne0 * sizeof(float);
@@ -4342,12 +4384,14 @@ static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
FARF(HIGH, "quantize-f32-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
}
// TODO just a plain copy that should be done via the DMA during the Op setup
static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
struct htp_matmul_context * mmctx = data;
struct htp_ops_context * octx = mmctx->octx;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const struct htp_tensor * src = octx->src[1];
uint8_t * restrict dst = octx->src1_spad.data;
@@ -4364,6 +4408,7 @@ static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
const uint32_t ir_first = nrows_per_thread * ith; // first row
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
const size_t src_row_size = ne0 * sizeof(float);
@@ -4384,6 +4429,7 @@ static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
FARF(HIGH, "quantize-f16-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
}
+270 -27
View File
@@ -6,6 +6,7 @@ import re
import argparse
import statistics
import logging
from typing import Any, Dict, List, Optional
from collections import defaultdict
@@ -25,12 +26,47 @@ COL_MAP = {
}
op_pattern = re.compile(
r"profile-op\s+(?P<op_name>[A-Z_0-9+]+):\s+.*?\s+:\s+(?P<dims>[\d:x\s\->!]+)\s+:\s+(?P<types>[a-z\d_\s\->x]+)\s+:\s+.*?\s+(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?"
r"profile-op\s+(?P<op_name>[A-Z_0-9+]+):\s+.*?\s+:\s+(?P<dims>[\d:x\s\->!]+)\s+:\s+(?P<types>[a-z\d_\s\->x]+)\s+:\s+.*?\s+(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?"
)
trace_pattern = re.compile(
r"trace-op\s+(?P<op_name>[A-Z_0-9+]+):\s+thread\s+(?P<thread>\d+)\s+event\s+(?P<event>[A-Z_0-9\-]+)\s+info\s+(?P<info>\d+)\s+(?P<state>start|stop)\s+(?P<cycles>\d+)"
)
logger = logging.getLogger("ggml-hexagon-profile")
def normalize_event_name(evt_type):
if evt_type == "HVX_COMP":
return "V-COMP"
if evt_type == "HMX_COMP":
return "M-COMP"
# Strip HVX_ or HMX_ prefixes
name = evt_type
if name.startswith("HVX_") or name.startswith("HMX_"):
name = name[4:]
return name.replace("_", "-")
class CycleUnwrapper:
def __init__(self):
self.last_raw = None
self.high_part = 0
def unwrap(self, raw):
if self.last_raw is None:
self.last_raw = raw
return raw
diff = raw - self.last_raw
if diff < -0x80000000:
self.high_part += 0x100000000
elif diff > 0x80000000:
self.high_part -= 0x100000000
self.last_raw = raw
return raw + self.high_part
def parse_log(file_path, pmu_index=None):
try:
if file_path != "-":
@@ -41,35 +77,211 @@ def parse_log(file_path, pmu_index=None):
logger.error(f"file '{file_path}' not found.")
sys.exit(1)
all_ops = []
all_ops: List[Dict[str, Any]] = []
current_op: Optional[Dict[str, Any]] = None
timestamp_pattern = re.compile(r"^(?P<min>\d+)\.(?P<sec>\d+)\.(?P<ms>\d+)\.(?P<us>\d+)\s+[A-Z]\s+")
unwrapper = CycleUnwrapper()
for line in f:
match = op_pattern.search(line)
if not match: continue
ts_match = timestamp_pattern.match(line)
abs_usec = 0
if ts_match:
abs_usec = (
(int(ts_match.group('min')) * 60 + int(ts_match.group('sec'))) * 1000000
+ int(ts_match.group('ms')) * 1000
+ int(ts_match.group('us'))
)
pmu_raw = match.group('pmu')
pmu_val = None
if pmu_raw and pmu_index is not None:
try:
pmu_list = [int(x.strip()) for x in pmu_raw.split(',')]
if len(pmu_list) > pmu_index:
pmu_val = pmu_list[pmu_index]
except (ValueError, IndexError):
pmu_val = None
op_match = op_pattern.search(line)
if op_match:
pmu_raw = op_match.group('pmu')
pmu_val = None
if pmu_raw and pmu_index is not None:
try:
pmu_list = [int(x.strip()) for x in pmu_raw.split(',')]
if len(pmu_list) > pmu_index:
pmu_val = pmu_list[pmu_index]
except (ValueError, IndexError):
pmu_val = None
all_ops.append({
'name': match.group('op_name'),
'dims': match.group('dims').strip(),
'types': match.group('types').strip(),
'usec': int(match.group('usec')),
'cycles': int(match.group('cycles')),
'pmu_val': pmu_val
})
evt_raw = op_match.group('evt')
evt_val = None
if evt_raw:
try:
evt_val = [int(x.strip()) for x in evt_raw.split(',')]
except ValueError:
evt_val = None
cycles_start_raw = op_match.group('start')
unwrapped_cycles_start = None
if cycles_start_raw:
unwrapped_cycles_start = unwrapper.unwrap(int(cycles_start_raw))
idx = line.find("profile-op ")
op_text = line[idx + 11:].strip() if idx != -1 else line.strip()
current_op = {
'name': op_match.group('op_name'),
'dims': op_match.group('dims').strip(),
'types': op_match.group('types').strip(),
'op_text': op_text,
'usec': int(op_match.group('usec')),
'cycles': int(op_match.group('cycles')),
'cycles_start': int(cycles_start_raw) if cycles_start_raw else None,
'unwrapped_cycles_start': unwrapped_cycles_start,
'pmu_val': pmu_val,
'evt_val': evt_val,
'abs_usec': abs_usec,
'trace_events': []
}
all_ops.append(current_op)
continue
trace_match = trace_pattern.search(line)
if trace_match and current_op:
if trace_match.group('op_name') == current_op['name']:
raw_cyc = int(trace_match.group('cycles'))
current_op['trace_events'].append({
'thread': int(trace_match.group('thread')),
'event': trace_match.group('event'),
'info': int(trace_match.group('info')),
'cycles': raw_cyc,
'unwrapped_cycles': unwrapper.unwrap(raw_cyc),
'state': trace_match.group('state')
})
f.close()
return all_ops
def print_ascii_timeline(op_name, dims, types, usec, cycles, events, evt_val=None):
evt_str = ""
if evt_val:
evt_str = " - evt [" + ",".join(str(x) for x in evt_val) + "]"
logger.info("=" * 100)
logger.info(f"{op_name} ({dims} : {types}) - {usec} usec {cycles} cycles{evt_str}")
logger.info("=" * 100)
events = sorted(events, key=lambda e: e['cycles'])
if not events:
logger.info(" No trace events recorded.")
return
min_cycles = events[0]['cycles']
logger.info("Cycles %-30s" % "EventDetails" + " ".join(f"T{i:<2}" for i in range(10)) + " HMX")
logger.info("-" * 100)
thread_stacks = [[] for _ in range(11)]
for e in events:
t = e['thread']
if t < 0 or t > 10:
continue
if e['cycles'] >= min_cycles:
rel_cycles = e['cycles'] - min_cycles
else:
rel_cycles = (e['cycles'] + 0x100000000) - min_cycles
state = e['state']
evt_type = e['event']
# Determine char representing the event
norm_evt = normalize_event_name(evt_type)
char = '?'
if norm_evt == 'V-COMP':
char = 'V'
elif norm_evt == 'M-COMP':
char = 'H'
elif norm_evt == 'A-QUANT':
char = 'Q'
elif norm_evt == 'A-PREP':
char = 'A'
elif norm_evt == 'W-DEQUANT':
char = 'D'
elif norm_evt == 'O-PROC':
char = 'O'
elif norm_evt == 'W-PREP':
char = 'P'
elif norm_evt == 'DMA':
char = 'M'
if state == 'start':
thread_stacks[t].append(char)
elif state == 'stop':
if thread_stacks[t]:
if thread_stacks[t][-1] == char:
thread_stacks[t].pop()
elif char in thread_stacks[t]:
thread_stacks[t].remove(char)
else:
thread_stacks[t].pop()
cols = []
for i in range(11):
if thread_stacks[i]:
cols.append(f"[{thread_stacks[i][-1]}]")
else:
cols.append(" | ")
evt_desc = f"T{t}: {evt_type} {state} ({e['info']})"
logger.info(f"{rel_cycles:10d} %-30s" % evt_desc + " ".join(cols[:10]) + " " + cols[10])
logger.info("-" * 100)
def print_ascii_summary(op_name, dims, types, usec, cycles, events, evt_val=None):
evt_str = ""
if evt_val:
evt_str = " - evt [" + ",".join(str(x) for x in evt_val) + "]"
logger.info("=" * 100)
logger.info(f"{op_name} ({dims} : {types}) - {usec} usec {cycles} cycles{evt_str}")
logger.info("=" * 100)
events = sorted(events, key=lambda e: e['cycles'])
if not events:
logger.info(" No trace events recorded.")
return
active_starts = {}
thread_totals = defaultdict(lambda: defaultdict(int))
for e in events:
t = e['thread']
evt = e['event']
info = e['info']
cyc = e['cycles']
state = e['state']
key = (t, evt, info)
if state == 'start':
active_starts[key] = cyc
elif state == 'stop':
if key in active_starts:
start_cyc = active_starts[key]
del active_starts[key]
if cyc >= start_cyc:
dur = cyc - start_cyc
else:
dur = (cyc + 0x100000000) - start_cyc
norm_evt = normalize_event_name(evt)
thread_totals[t][norm_evt] += dur
for t in sorted(thread_totals.keys()):
thread_name = f"Thread {t} (HVX)" if t != 10 else "Thread 10 (HMX)"
sorted_evts = sorted(thread_totals[t].items(), key=lambda item: item[0])
evt_strs = []
for evt, dur in sorted_evts:
pct = (dur / cycles * 100) if cycles > 0 else 0
evt_strs.append(f"{evt} {dur} ({pct:.1f}%)")
logger.info(f" {thread_name:<16}: " + " | ".join(evt_strs))
def generate_report(ops, top_n, width_overrides, sort_col, pmu_name=None):
if not ops:
logger.info("No valid records found.")
@@ -115,7 +327,6 @@ def generate_report(ops, top_n, width_overrides, sort_col, pmu_name=None):
# Sorting logic
actual_sort_key = COL_MAP[sort_col][2]
# We sort numeric fields descending, strings (op/dims) ascending
is_numeric = actual_sort_key.startswith("_") or actual_sort_key == "count"
sorted_groups = sorted(group_stats, key=lambda x: x[actual_sort_key], reverse=is_numeric)[:top_n]
@@ -132,7 +343,7 @@ def generate_report(ops, top_n, width_overrides, sort_col, pmu_name=None):
if "pmu" in col_name and pmu_name:
header_text = header_text.replace("PMU", pmu_name)
natural_width = max([len(row[data_key]) for row in sorted_groups] + [len(header_text)])
natural_width = max([len(str(row[data_key])) for row in sorted_groups] + [len(header_text)])
target_width = width_overrides.get(col_name, natural_width)
if target_width == 0:
@@ -152,7 +363,7 @@ def generate_report(ops, top_n, width_overrides, sort_col, pmu_name=None):
for group in sorted_groups:
row_vals = []
for i, key in enumerate(final_keys):
val = group[key]
val = str(group[key])
if len(val) > final_widths[i]:
val = val[:final_widths[i] - 3] + "..."
row_vals.append(f"{val:<{final_widths[i]}}")
@@ -167,12 +378,18 @@ def main():
parser.add_argument("--pmu-index", type=int)
parser.add_argument("--pmu-name", type=str)
parser.add_argument("--width", action='append', default=['dims:40'], help="Override column width, e.g. --width dims:50")
parser.add_argument("--timeline", type=str, nargs='?', const='summary', choices=["summary", "diagram"],
help="Output ASCII art event summary or timing diagram (default: summary)")
parser.add_argument("--filter", type=str, help="Regex filter matching against the original profile-op line")
group = parser.add_mutually_exclusive_group()
group.add_argument("--head", type=int, help="Limit to first N ops")
group.add_argument("--tail", type=int, help="Limit to last N ops")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format='%(message)s')
# Sort validation: can't sort by PMU if index isn't provided
if "pmu" in args.sort and args.pmu_index is None:
logger.error(f"Cannot sort by '{args.sort}' without --pmu-index.")
sys.exit(1)
@@ -188,7 +405,33 @@ def main():
final_pmu_name = (args.pmu_name or f"#{args.pmu_index}") if args.pmu_index is not None else None
ops = parse_log(args.logfile, pmu_index=args.pmu_index)
generate_report(ops, args.top, overrides, args.sort, pmu_name=final_pmu_name)
if args.filter:
try:
filter_re = re.compile(args.filter)
except re.error as e:
logger.error(f"Invalid regex filter: {e}")
sys.exit(1)
ops = [op for op in ops if filter_re.search(op['op_text'])]
if args.head is not None:
ops = ops[:args.head]
elif args.tail is not None:
ops = ops[-args.tail:]
if args.timeline:
logger.info(f"\n# ASCII Timing {args.timeline.capitalize()}\n")
printed_cnt = 0
for op in ops:
if args.timeline == "summary":
print_ascii_summary(op['name'], op['dims'], op['types'], op['usec'], op['cycles'], op['trace_events'], op.get('evt_val'))
elif args.timeline == "diagram":
print_ascii_timeline(op['name'], op['dims'], op['types'], op['usec'], op['cycles'], op['trace_events'], op.get('evt_val'))
printed_cnt += 1
if printed_cnt >= args.top:
break
else:
generate_report(ops, args.top, overrides, args.sort, pmu_name=final_pmu_name)
if __name__ == "__main__":
+463
View File
@@ -0,0 +1,463 @@
#!/usr/bin/env python3
import sys
import os
import re
import argparse
import statistics
import logging
from typing import Any, Dict, List, Optional
from collections import defaultdict
logger = logging.getLogger("ggml-hexagon-trace")
op_pattern = re.compile(
r"profile-op\s+(?P<op_name>[A-Z_0-9+]+):\s+.*?\s+:\s+(?P<dims>[\d:x\s\->!]+)\s+:\s+(?P<types>[a-z\d_\s\->x]+)\s+:\s+(?P<strides>[\d:x\s\->!]+)\s+:\s+(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?"
)
trace_pattern = re.compile(
r"trace-op\s+(?P<op_name>[A-Z_0-9+]+):\s+thread\s+(?P<thread>\d+)\s+event\s+(?P<event>[A-Z_0-9\-]+)\s+info\s+(?P<info>\d+)\s+(?P<state>start|stop)\s+(?P<cycles>\d+)"
)
def normalize_event_name(evt_type):
if evt_type == "HVX_COMP":
return "V-COMP"
if evt_type == "HMX_COMP":
return "M-COMP"
name = evt_type
if name.startswith("HVX_") or name.startswith("HMX_"):
name = name[4:]
return name.replace("_", "-")
class CycleUnwrapper:
def __init__(self):
self.last_raw = None
self.high_part = 0
def unwrap(self, raw):
if self.last_raw is None:
self.last_raw = raw
return raw
diff = raw - self.last_raw
if diff < -0x80000000:
self.high_part += 0x100000000
elif diff > 0x80000000:
self.high_part -= 0x100000000
self.last_raw = raw
return raw + self.high_part
def parse_log(file_path):
try:
if file_path != "-":
f = open(file_path, 'r', encoding='utf-8', errors='ignore')
else:
f = os.fdopen(0, 'r', encoding='utf-8', errors='ignore')
except FileNotFoundError:
logger.error(f"file '{file_path}' not found.")
sys.exit(1)
all_ops: List[Dict[str, Any]] = []
current_op: Optional[Dict[str, Any]] = None
unwrapper = CycleUnwrapper()
line_idx = 0
for line in f:
line_idx += 1
op_match = op_pattern.search(line)
if op_match:
cycles_start_raw = op_match.group('start')
unwrapped_cycles_start = None
if cycles_start_raw:
unwrapped_cycles_start = unwrapper.unwrap(int(cycles_start_raw))
idx = line.find("profile-op ")
op_text = line[idx + 11:].strip() if idx != -1 else line.strip()
current_op = {
'name': op_match.group('op_name'),
'dims': op_match.group('dims').strip() if op_match.group('dims') else '',
'types': op_match.group('types').strip() if op_match.group('types') else '',
'strides': op_match.group('strides').strip() if op_match.group('strides') else '',
'op_text': op_text,
'usec': int(op_match.group('usec')),
'cycles': int(op_match.group('cycles')),
'cycles_start': int(cycles_start_raw) if cycles_start_raw else None,
'unwrapped_cycles_start': unwrapped_cycles_start,
'trace_events': [],
'line_num': line_idx
}
all_ops.append(current_op)
continue
trace_match = trace_pattern.search(line)
if trace_match and current_op:
if trace_match.group('op_name') == current_op['name']:
raw_cyc = int(trace_match.group('cycles'))
current_op['trace_events'].append({
'thread': int(trace_match.group('thread')),
'event': trace_match.group('event'),
'info': int(trace_match.group('info')),
'cycles': raw_cyc,
'unwrapped_cycles': unwrapper.unwrap(raw_cyc),
'state': trace_match.group('state')
})
f.close()
return all_ops
# --- Simple protobuf encoder ---
def write_varint(val):
if val < 0:
val = (1 << 64) + val
res = bytearray()
while True:
towrite = val & 0x7f
val >>= 7
if val > 0:
res.append(towrite | 0x80)
else:
res.append(towrite)
break
return bytes(res)
def pb_field(num, wire, data):
return write_varint((num << 3) | wire) + data
def pb_varint(num, val):
return pb_field(num, 0, write_varint(val))
def pb_length_delimited(num, data):
return pb_field(num, 2, write_varint(len(data)) + data)
def pb_string(num, text):
return pb_length_delimited(num, text.encode('utf-8'))
# Message Encoders
def make_process_descriptor(pid, name):
return pb_varint(1, pid) + pb_string(6, name)
def make_thread_descriptor(pid, tid, name, sort_index=None):
payload = pb_varint(1, pid) + pb_varint(2, tid) + pb_string(5, name)
if sort_index is not None:
payload += pb_varint(3, sort_index)
return payload
def make_track_descriptor(uuid, name=None, parent_uuid=None, thread=None, process=None, sibling_merge_behavior=None, child_ordering=None, sibling_order_rank=None):
payload = pb_varint(1, uuid)
if name is not None:
payload += pb_string(2, name)
if parent_uuid is not None:
payload += pb_varint(5, parent_uuid)
if process is not None:
payload += pb_length_delimited(3, process)
if thread is not None:
payload += pb_length_delimited(4, thread)
if sibling_merge_behavior is not None:
payload += pb_varint(15, sibling_merge_behavior)
if child_ordering is not None:
payload += pb_varint(11, child_ordering)
if sibling_order_rank is not None:
payload += pb_varint(12, sibling_order_rank)
return payload
def make_debug_annotation(name, string_val=None, int_val=None):
payload = pb_string(10, name)
if string_val is not None:
payload += pb_string(6, string_val)
elif int_val is not None:
payload += pb_varint(4, int_val)
return payload
def make_track_event(event_type, track_uuid, name=None, category=None, debug_annotations=None):
payload = pb_varint(9, event_type)
payload += pb_varint(11, track_uuid)
if name is not None:
payload += pb_string(23, name)
if category is not None:
payload += pb_string(22, category)
if debug_annotations is not None:
for da in debug_annotations:
payload += pb_length_delimited(4, da)
return payload
def make_trace_packet(timestamp, track_event=None, track_descriptor=None, seq_id=1):
payload = pb_varint(8, timestamp)
payload += pb_varint(10, seq_id)
if track_event is not None:
payload += pb_length_delimited(11, track_event)
if track_descriptor is not None:
payload += pb_length_delimited(60, track_descriptor)
return payload
def write_trace_packet_to_file(f, packet_bytes):
# Write as field 1 of top-level Trace message
f.write(pb_length_delimited(1, packet_bytes))
# --- End Protobuf Encoder ---
def generate_perfetto_trace(filtered_ops, output_path):
if not filtered_ops:
logger.warning("No operators found after filtering.")
return
# Compute average frequency
frequencies = []
for op in filtered_ops:
if op['usec'] > 0 and op['cycles'] > 0:
frequencies.append(op['cycles'] / op['usec'])
avg_freq_mhz = statistics.mean(frequencies) if frequencies else 1000.0
if avg_freq_mhz <= 0:
avg_freq_mhz = 1000.0
# Assign start and end cycles to each operator
for op in filtered_ops:
op['start_cycles'] = op['unwrapped_cycles_start']
op['end_cycles'] = op['start_cycles'] + op['cycles']
global_min_cyc = min(op['start_cycles'] for op in filtered_ops if op['start_cycles'] is not None)
# Process events
completed_events = []
for op in filtered_ops:
events = op['trace_events']
if not events:
continue
events = sorted(events, key=lambda e: e['unwrapped_cycles'])
active_starts = {}
for e in events:
t = e['thread']
evt = e['event']
info = e['info']
state = e['state']
cyc = e['unwrapped_cycles']
key = (t, evt, info)
if state == 'start':
active_starts[key] = cyc
elif state == 'stop':
if key in active_starts:
start_cyc = active_starts[key]
del active_starts[key]
completed_events.append({
'thread': t,
'event': evt,
'info': info,
'start_cyc': start_cyc,
'end_cyc': cyc,
'op_name': op['name']
})
completed_events.sort(key=lambda e: e['start_cyc'])
# Convert event times to microseconds and apply clamp rounded to 1ns resolution (3 decimals)
for e in completed_events:
start_us = (e['start_cyc'] - global_min_cyc) / avg_freq_mhz
dur_us = (e['end_cyc'] - e['start_cyc']) / avg_freq_mhz
e['ts_ns'] = int(round(start_us * 1000))
e['dur_ns'] = int(round(max(dur_us, 0.1) * 1000))
# Allocate slots (sub-tracks) to prevent overlaps on same virtual track
active_slots = defaultdict(list)
for e in completed_events:
t = e['thread']
evt = e['event']
ts = e['ts_ns']
dur = e['dur_ns']
norm_evt = normalize_event_name(evt)
if norm_evt == "DMA":
track_key = (t, "DMA")
elif t == 10:
track_key = (t, "HMX")
else:
track_key = (t, "HVX")
slots = active_slots[track_key]
allocated_slot = -1
for idx, slot_end_ns in enumerate(slots):
if ts >= slot_end_ns:
slots[idx] = ts + dur
allocated_slot = idx
break
if allocated_slot == -1:
slots.append(ts + dur)
allocated_slot = len(slots) - 1
e['slot'] = allocated_slot
# Generate Track IDs and track definitions
used_tracks = {}
for e in completed_events:
t = e['thread']
evt = e['event']
slot = e['slot']
norm_evt = normalize_event_name(evt)
if norm_evt == "DMA":
track_evt = "DMA"
evt_id = 1
elif t == 10:
track_evt = "HMX"
evt_id = 3
else:
track_evt = "HVX"
evt_id = 2
t_sort = 1 if t == 10 else t + 2
# Unique UUID for each sub-track
if t == 10:
uuid = 20 # HMX thread track UUID
else:
uuid = int(t_sort * 1000000 + evt_id * 1000 + slot)
e['uuid'] = uuid
used_tracks[uuid] = (t, track_evt, slot)
with open(output_path, "wb") as f:
# Define Process with EXPLICIT child sorting
proc_desc = make_process_descriptor(1, "HTP NPU")
proc_packet = make_trace_packet(0, track_descriptor=make_track_descriptor(1, process=proc_desc, child_ordering=3))
write_trace_packet_to_file(f, proc_packet)
# Define Operators Track (UUID = 2) as a thread track at rank 1, tid 8
op_thread_desc = make_thread_descriptor(1, 8, "Ops", sort_index=1)
op_packet = make_trace_packet(0, track_descriptor=make_track_descriptor(2, parent_uuid=1, thread=op_thread_desc))
write_trace_packet_to_file(f, op_packet)
# Define HMX Thread Track (UUID = 20) at rank 2, tid 9
hmx_thread_desc = make_thread_descriptor(1, 9, "HMX", sort_index=2)
hmx_packet = make_trace_packet(0, track_descriptor=make_track_descriptor(20, parent_uuid=1, thread=hmx_thread_desc))
write_trace_packet_to_file(f, hmx_packet)
# Define Thread Tracks (T0, T1, ..., T9)
unique_threads = sorted(list(set(t for (t, _, _) in used_tracks.values() if t != 10)))
for t in unique_threads:
thread_uuid = 10 + t
thread_name = f"T{t}"
# Sort order starts from index 3 (T0 -> 3, T1 -> 4, etc.)
sort_index = 3 + t
tid = 10 + t
thread_desc = make_thread_descriptor(1, tid, thread_name, sort_index=sort_index)
thread_packet = make_trace_packet(0, track_descriptor=make_track_descriptor(
thread_uuid,
parent_uuid=1,
thread=thread_desc,
sibling_order_rank=sort_index,
child_ordering=3 # Explicit child sorting for sub-tracks
))
write_trace_packet_to_file(f, thread_packet)
# Define Track descriptors for sub-tracks parented to thread tracks
for uuid in sorted(used_tracks.keys()):
if uuid == 20:
continue
t, evt, slot = used_tracks[uuid]
name = f"T{t} {evt}"
rank = 0 if evt == "HVX" else 1
parent_thread_uuid = 10 + t
# Sibling merge behavior: 1 (SIBLING_MERGE_BEHAVIOR_BY_TRACK_NAME)
track_desc = make_track_descriptor(
uuid=uuid,
name=name,
parent_uuid=parent_thread_uuid,
sibling_merge_behavior=1,
sibling_order_rank=rank
)
track_packet = make_trace_packet(0, track_descriptor=track_desc)
write_trace_packet_to_file(f, track_packet)
# Emit Operators
last_op_end_ns = 0
for op in filtered_ops:
op_start_ns = int(round(((op['start_cycles'] - global_min_cyc) / avg_freq_mhz) * 1000))
op_dur_ns = int(round((op['cycles'] / avg_freq_mhz) * 1000))
if op_start_ns < last_op_end_ns:
op_start_ns = last_op_end_ns
clamped_dur = max(op_dur_ns, 100) # Clamp to 100ns (0.1us)
# Debug annotations for Ops
debug_annots = []
if 'line_num' in op:
debug_annots.append(make_debug_annotation("line", int_val=op['line_num']))
if 'strides' in op and op['strides']:
debug_annots.append(make_debug_annotation("strides", string_val=op['strides']))
# Slice Begin
evt_begin = make_track_event(1, 2, name=f"{op['name']} ({op['dims']})", category="operator", debug_annotations=debug_annots)
packet_begin = make_trace_packet(op_start_ns, track_event=evt_begin)
write_trace_packet_to_file(f, packet_begin)
# Slice End
evt_end = make_track_event(2, 2)
packet_end = make_trace_packet(op_start_ns + clamped_dur, track_event=evt_end)
write_trace_packet_to_file(f, packet_end)
last_op_end_ns = op_start_ns + clamped_dur
# Emit Thread Trace Events
for e in completed_events:
norm_name = normalize_event_name(e['event'])
name = f"DMA {e['info']}" if norm_name == "DMA" else norm_name
# Slice Begin
evt_begin = make_track_event(1, e['uuid'], name=name, category="trace")
packet_begin = make_trace_packet(e['ts_ns'], track_event=evt_begin)
write_trace_packet_to_file(f, packet_begin)
# Slice End
evt_end = make_track_event(2, e['uuid'])
packet_end = make_trace_packet(e['ts_ns'] + e['dur_ns'], track_event=evt_end)
write_trace_packet_to_file(f, packet_end)
logger.info(f"Successfully generated Perfetto trace at {output_path}")
def main():
parser = argparse.ArgumentParser(description="Convert Hexagon Op profile logs to native Perfetto Protobuf traces.")
parser.add_argument("logfile", help="Path to hex-log profile file")
parser.add_argument("-o", "--output", default="optrace.perfetto-trace", help="Output trace file path (default: optrace.perfetto-trace)")
parser.add_argument("--filter", type=str, help="Regex filter matching against the original profile-op line")
group = parser.add_mutually_exclusive_group()
group.add_argument("--head", type=int, help="Limit to first N ops")
group.add_argument("--tail", type=int, help="Limit to last N ops")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format='%(message)s')
ops = parse_log(args.logfile)
if args.filter:
try:
filter_re = re.compile(args.filter)
except re.error as e:
logger.error(f"Invalid regex filter: {e}")
sys.exit(1)
ops = [op for op in ops if filter_re.search(op['op_text'])]
if args.head is not None:
ops = ops[:args.head]
elif args.tail is not None:
ops = ops[-args.tail:]
generate_perfetto_trace(ops, args.output)
if __name__ == "__main__":
main()
+23 -4
View File
@@ -20,6 +20,7 @@ set(LLAMA_UI_GZIP "" CACHE STRING "Apply gzip compress to assets to save ban
set(DIST_DIR "${UI_BINARY_DIR}/dist")
set(SRC_DIST_DIR "${UI_SOURCE_DIR}/dist")
set(WORK_DIR "${UI_BINARY_DIR}/ui-src")
set(STAMP_FILE "${UI_BINARY_DIR}/.ui-stamp")
set(UI_CPP "${UI_BINARY_DIR}/ui.cpp")
set(UI_H "${UI_BINARY_DIR}/ui.h")
@@ -64,6 +65,22 @@ function(npm_build_should_skip out_var)
set(${out_var} TRUE PARENT_SCOPE)
endfunction()
function(stage_sources)
if(EXISTS "${WORK_DIR}")
file(GLOB staged RELATIVE "${WORK_DIR}" "${WORK_DIR}/*")
list(REMOVE_ITEM staged "node_modules")
foreach(entry ${staged})
file(REMOVE_RECURSE "${WORK_DIR}/${entry}")
endforeach()
endif()
file(COPY "${UI_SOURCE_DIR}/"
DESTINATION "${WORK_DIR}"
NO_SOURCE_PERMISSIONS
PATTERN "node_modules" EXCLUDE
)
endfunction()
function(npm_build out_var)
set(${out_var} FALSE PARENT_SCOPE)
@@ -89,14 +106,16 @@ function(npm_build out_var)
return()
endif()
stage_sources()
# npm writes node_modules/.package-lock.json on every successful install,
# so a package-lock.json newer than this marker means node_modules is stale
set(NPM_MARKER "${UI_SOURCE_DIR}/node_modules/.package-lock.json")
set(NPM_MARKER "${WORK_DIR}/node_modules/.package-lock.json")
set(need_install FALSE)
if(NOT EXISTS "${NPM_MARKER}")
set(need_install TRUE)
else()
file(TIMESTAMP "${UI_SOURCE_DIR}/package-lock.json" lock_ts)
file(TIMESTAMP "${WORK_DIR}/package-lock.json" lock_ts)
file(TIMESTAMP "${NPM_MARKER}" marker_ts)
if(lock_ts STRGREATER marker_ts)
set(need_install TRUE)
@@ -107,7 +126,7 @@ function(npm_build out_var)
message(STATUS "UI: running npm install")
execute_process(
COMMAND ${NPM_EXECUTABLE} install
WORKING_DIRECTORY "${UI_SOURCE_DIR}"
WORKING_DIRECTORY "${WORK_DIR}"
RESULT_VARIABLE rc
ERROR_VARIABLE err
)
@@ -124,7 +143,7 @@ function(npm_build out_var)
execute_process(
COMMAND ${CMAKE_COMMAND} -E env "LLAMA_UI_OUT_DIR=${DIST_DIR}" "LLAMA_UI_VERSION=${HF_VERSION}" "LLAMA_BUILD_NUMBER=${LLAMA_BUILD_NUMBER}"
${NPM_EXECUTABLE} run build
WORKING_DIRECTORY "${UI_SOURCE_DIR}"
WORKING_DIRECTORY "${WORK_DIR}"
RESULT_VARIABLE rc
ERROR_VARIABLE err
)
+6 -8
View File
@@ -6,11 +6,10 @@ Apply LORA adapters to base model and export the resulting model.
usage: llama-export-lora [options]
options:
-m, --model model path from which to load base model (default '')
--lora FNAME path to LoRA adapter (can be repeated to use multiple adapters)
--lora-scaled FNAME S path to LoRA adapter with user defined scaling S (can be repeated to use multiple adapters)
-t, --threads N number of threads to use during computation (default: 4)
-o, --output FNAME output file (default: 'ggml-lora-merged-f16.gguf')
-m, --model FNAME model path from which to load base model
--lora FNAME path to LoRA adapter (use comma-separated values to load multiple adapters)
--lora-scaled FNAME:SCALE,... path to LoRA adapter with user defined scaling (format: FNAME:SCALE,...)
-o, --output, --output-file FNAME output file (default: 'ggml-lora-merged-f16.gguf')
```
For example:
@@ -22,12 +21,11 @@ For example:
--lora lora-open-llama-3b-v2-english2tokipona-chat-LATEST.gguf
```
Multiple LORA adapters can be applied by passing multiple `--lora FNAME` or `--lora-scaled FNAME S` command line parameters:
Multiple LORA adapters can be applied by passing comma-separated values to `--lora FNAME` or `--lora-scaled FNAME:SCALE,...`:
```bash
./bin/llama-export-lora \
-m your_base_model.gguf \
-o your_merged_model.gguf \
--lora-scaled lora_task_A.gguf 0.5 \
--lora-scaled lora_task_B.gguf 0.5
--lora-scaled lora_task_A.gguf:0.5,lora_task_B.gguf:0.5
```
+35
View File
@@ -0,0 +1,35 @@
# libmtmd dev guide
## History
Please refer to [multimodal.md](../../docs/multimodal.md) for a broader context.
In short:
- `libmtmd` started as a wrapper around `libllava` / `clip.cpp`
- Various components that used to be in `clip.cpp` are moved progressively to mtmd. For example, preprocessor is now part of mtmd
## Terminologies
- mtmd: **M**ul**T**i**M**o**D**al
- bitmap: representing a raw input data, for example: RGB image, PCM audio
- tiles / slices: for llava-uhd-style models, the preprocessor breaks a large input into smaller square images called tiles or slices
- chunk: a mtmd_input_chunk represents a preprocessed input that can then be passed through `mtmd_encode()`
## Pipeline
A typical pipeline of the core libmtmd is as follows:
- A bitmap (RGB image or PCM audio) is created
- Bitmap and the text prompt is provided to `mtmd_tokenize()` that breaks the input into chunks
- The tokenizer function first expands a "lazy" bitmap if it finds one. Typically, this is used by video, so that one media token corresponds to one input bitmap
- For models that support "fused" temporal frames like Qwen-VL, the tokenizer tries to merge pair of consecutive frames into one batch
- The preprocessor will then be called, which produces a list of chunks
- Depending on the model itself, special tokens will be injected to separate image chunks (i.e. llava-uhd-style models)
- Multiple bitmaps may be batched together to form a larger `mtmd_batch()`
- Single image or batch is encoded, via `mtmd_encode()` or `mtmd_batch_encode()`
- Get the output embeddings
## Helper
We provide a set of helper functions via `mtmd_helper` to make using libmtmd easier. The helper provides:
- Image, audio and video file decoding (for example, decode raw JPEG into RGB bitmap)
- Manage `llama_batch` and calls to `llama_decode`
+52 -81
View File
@@ -367,56 +367,56 @@ enum projector_type {
};
static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_MLP, "mlp" },
{ PROJECTOR_TYPE_LDP, "ldp" },
{ PROJECTOR_TYPE_LDPV2, "ldpv2"},
{ PROJECTOR_TYPE_MINICPMV, "resampler"},
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
{ PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger"},
{ PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"},
{ PROJECTOR_TYPE_QWEN3VL, "qwen3vl_merger"},
{ PROJECTOR_TYPE_STEP3VL, "step3vl"},
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
{ PROJECTOR_TYPE_GEMMA3NV, "gemma3nv"},
{ PROJECTOR_TYPE_GEMMA3NA, "gemma3na"},
{ PROJECTOR_TYPE_GEMMA4V, "gemma4v"},
{ PROJECTOR_TYPE_GEMMA4A, "gemma4a"},
{ PROJECTOR_TYPE_GEMMA4UV, "gemma4uv"},
{ PROJECTOR_TYPE_GEMMA4UA, "gemma4ua"},
{ PROJECTOR_TYPE_PHI4, "phi4"},
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
{ PROJECTOR_TYPE_ULTRAVOX, "ultravox"},
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
{ PROJECTOR_TYPE_QWEN2A, "qwen2a"},
{ PROJECTOR_TYPE_QWEN3A, "qwen3a"},
{ PROJECTOR_TYPE_GLMA, "glma"},
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
{ PROJECTOR_TYPE_MERALION, "meralion"},
{ PROJECTOR_TYPE_MUSIC_FLAMINGO, "musicflamingo"},
{ PROJECTOR_TYPE_LFM2, "lfm2"},
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
{ PROJECTOR_TYPE_PADDLEOCR, "paddleocr"},
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
{ PROJECTOR_TYPE_DOTS_OCR, "dots_ocr"},
{ PROJECTOR_TYPE_DEEPSEEKOCR,"deepseekocr"},
{ PROJECTOR_TYPE_DEEPSEEKOCR2,"deepseekocr2"},
{ PROJECTOR_TYPE_LFM2A, "lfm2a"},
{ PROJECTOR_TYPE_GLM4V, "glm4v"},
{ PROJECTOR_TYPE_YOUTUVL, "youtuvl"},
{ PROJECTOR_TYPE_YASA2, "yasa2"},
{ PROJECTOR_TYPE_KIMIK25, "kimik25"},
{ PROJECTOR_TYPE_NEMOTRON_V2_VL, "nemotron_v2_vl"},
{ PROJECTOR_TYPE_EXAONE4_5, "exaone4_5"},
{ PROJECTOR_TYPE_HUNYUANVL, "hunyuanvl"},
{ PROJECTOR_TYPE_MINICPMV4_6, "minicpmv4_6"},
{ PROJECTOR_TYPE_GRANITE_SPEECH, "granite_speech"},
{ PROJECTOR_TYPE_MIMOVL, "mimovl"},
{ PROJECTOR_TYPE_GRANITE4_VISION, "granite4_vision"},
{ PROJECTOR_TYPE_MLP, "mlp" },
{ PROJECTOR_TYPE_LDP, "ldp" },
{ PROJECTOR_TYPE_LDPV2, "ldpv2"},
{ PROJECTOR_TYPE_MINICPMV, "resampler"},
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
{ PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger"},
{ PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"},
{ PROJECTOR_TYPE_QWEN3VL, "qwen3vl_merger"},
{ PROJECTOR_TYPE_STEP3VL, "step3vl"},
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
{ PROJECTOR_TYPE_GEMMA3NV, "gemma3nv"},
{ PROJECTOR_TYPE_GEMMA3NA, "gemma3na"},
{ PROJECTOR_TYPE_GEMMA4V, "gemma4v"},
{ PROJECTOR_TYPE_GEMMA4A, "gemma4a"},
{ PROJECTOR_TYPE_GEMMA4UV, "gemma4uv"},
{ PROJECTOR_TYPE_GEMMA4UA, "gemma4ua"},
{ PROJECTOR_TYPE_PHI4, "phi4"},
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
{ PROJECTOR_TYPE_ULTRAVOX, "ultravox"},
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
{ PROJECTOR_TYPE_QWEN2A, "qwen2a"},
{ PROJECTOR_TYPE_QWEN3A, "qwen3a"},
{ PROJECTOR_TYPE_GLMA, "glma"},
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
{ PROJECTOR_TYPE_MERALION, "meralion"},
{ PROJECTOR_TYPE_MUSIC_FLAMINGO, "musicflamingo"},
{ PROJECTOR_TYPE_LFM2, "lfm2"},
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
{ PROJECTOR_TYPE_PADDLEOCR, "paddleocr"},
{ PROJECTOR_TYPE_LIGHTONOCR, "lightonocr"},
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
{ PROJECTOR_TYPE_DOTS_OCR, "dots_ocr"},
{ PROJECTOR_TYPE_DEEPSEEKOCR, "deepseekocr"},
{ PROJECTOR_TYPE_DEEPSEEKOCR2, "deepseekocr2"},
{ PROJECTOR_TYPE_LFM2A, "lfm2a"},
{ PROJECTOR_TYPE_GLM4V, "glm4v"},
{ PROJECTOR_TYPE_YOUTUVL, "youtuvl"},
{ PROJECTOR_TYPE_YASA2, "yasa2"},
{ PROJECTOR_TYPE_KIMIK25, "kimik25"},
{ PROJECTOR_TYPE_NEMOTRON_V2_VL, "nemotron_v2_vl"},
{ PROJECTOR_TYPE_EXAONE4_5, "exaone4_5"},
{ PROJECTOR_TYPE_HUNYUANVL, "hunyuanvl"},
{ PROJECTOR_TYPE_MINICPMV4_6, "minicpmv4_6"},
{ PROJECTOR_TYPE_GRANITE_SPEECH, "granite_speech"},
{ PROJECTOR_TYPE_MIMOVL, "mimovl"},
{ PROJECTOR_TYPE_GRANITE4_VISION, "granite4_vision"},
};
static projector_type clip_projector_type_from_string(const std::string & str) {
@@ -640,47 +640,18 @@ static void clip_log_internal(enum ggml_log_level level, const char * format, ..
// cpp wrappers
//
// wrapper for clip_image_size
struct clip_image_size_deleter {
void operator()(clip_image_size * val) { clip_image_size_free(val); }
};
typedef std::unique_ptr<clip_image_size, clip_image_size_deleter> clip_image_size_ptr;
// wrapper for clip_image_u8
struct clip_image_u8_deleter {
void operator()(clip_image_u8 * val) { clip_image_u8_free(val); }
};
typedef std::unique_ptr<clip_image_u8, clip_image_u8_deleter> clip_image_u8_ptr;
// wrapper for clip_image_f32
struct clip_image_f32_deleter {
void operator()(clip_image_f32 * val) { clip_image_f32_free(val); }
};
typedef std::unique_ptr<clip_image_f32, clip_image_f32_deleter> clip_image_f32_ptr;
struct clip_image_u8_batch {
std::vector<clip_image_u8_ptr> entries;
};
struct clip_image_f32_batch {
std::vector<clip_image_f32_ptr> entries;
std::vector<clip_image_f32> entries;
bool is_audio = false;
// for llava-uhd style models, we need to know the grid size
// note: entries.size() == grid_x * grid_y + 1 (one overview image)
int grid_x = 0;
int grid_y = 0;
clip_image_f32_batch clone() const {
clip_image_f32_batch new_batch{
/* entries */ {},
/* is_audio */ is_audio,
/* grid_x */ grid_x,
/* grid_y */ grid_y,
};
new_batch.entries.reserve(entries.size());
for (const auto & entry : entries) {
new_batch.entries.emplace_back(new clip_image_f32(*entry));
new_batch.entries.emplace_back(entry); // copy
}
return new_batch;
}
+24 -95
View File
@@ -865,7 +865,7 @@ ggml_tensor * clip_graph::build_patch_merge_permute(ggml_tensor * cur, int scale
}
static std::unique_ptr<clip_graph> clip_get_graph_builder(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
const clip_image_f32 & img = *imgs.entries[0];
const clip_image_f32 & img = imgs.entries[0];
std::unique_ptr<clip_graph> builder;
switch (ctx->proj_type()) {
@@ -2825,16 +2825,16 @@ struct clip_model_loader {
// create a fake batch
const auto & hparams = ctx_clip.model.hparams;
clip_image_f32_batch batch;
clip_image_f32_ptr img(clip_image_f32_init());
clip_image_f32 img;
if (ctx_clip.model.modality == CLIP_MODALITY_VISION) {
const int sz = hparams.warmup_image_size;
img->set_size({sz, sz}, false, false);
img.set_size({sz, sz}, false, false);
LOG_INF("%s: warmup with image size = %d x %d\n", __func__, sz, sz);
} else {
img->set_size({hparams.warmup_audio_size, hparams.n_mel_bins}, false, false);
img.set_size({hparams.warmup_audio_size, hparams.n_mel_bins}, false, false);
LOG_INF("%s: warmup with audio size = %d\n", __func__, hparams.warmup_audio_size);
}
batch.entries.push_back(std::move(img));
batch.entries.push_back(img);
return batch;
}
@@ -3124,64 +3124,6 @@ struct clip_cap clip_get_cap(const char * fname) {
return res;
}
struct clip_image_size * clip_image_size_init() {
struct clip_image_size * load_image_size = new struct clip_image_size();
load_image_size->width = 448;
load_image_size->height = 448;
return load_image_size;
}
struct clip_image_u8 * clip_image_u8_init() {
return new clip_image_u8();
}
struct clip_image_f32 * clip_image_f32_init() {
return new clip_image_f32();
}
struct clip_image_f32_batch * clip_image_f32_batch_init() {
return new clip_image_f32_batch();
}
void clip_image_size_free(struct clip_image_size * load_image_size) {
if (load_image_size == nullptr) {
return;
}
delete load_image_size;
}
void clip_image_u8_free(struct clip_image_u8 * img) { delete img; }
void clip_image_f32_free(struct clip_image_f32 * img) { delete img; }
void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { delete batch; }
void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { delete batch; }
size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch) {
return batch->entries.size();
}
size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx) {
if (idx < 0 || idx >= (int)batch->entries.size()) {
LOG_ERR("%s: invalid index %d\n", __func__, idx);
return 0;
}
return batch->entries[idx]->nx();
}
size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx) {
if (idx < 0 || idx >= (int)batch->entries.size()) {
LOG_ERR("%s: invalid index %d\n", __func__, idx);
return 0;
}
return batch->entries[idx]->ny();
}
clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx) {
if (idx < 0 || idx >= (int)batch->entries.size()) {
LOG_ERR("%s: invalid index %d\n", __func__, idx);
return nullptr;
}
return batch->entries[idx].get();
}
void clip_free(clip_ctx * ctx) {
if (ctx == nullptr) {
return;
@@ -3189,23 +3131,11 @@ void clip_free(clip_ctx * ctx) {
delete ctx;
}
int32_t clip_get_image_size(const struct clip_ctx * ctx) {
return ctx->model.hparams.image_size;
}
int32_t clip_get_patch_size(const struct clip_ctx * ctx) {
return ctx->model.hparams.patch_size;
}
int32_t clip_get_hidden_size(const struct clip_ctx * ctx) {
return ctx->model.hparams.n_embd;
}
const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
return ctx->model.hparams.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD ? "spatial_unpad" : "flat";
}
int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
int clip_n_output_tokens_x(const clip_ctx * ctx, const clip_image_f32 * img) {
const auto & params = ctx->model.hparams;
const int n_total = clip_n_output_tokens(ctx, img);
const auto & proj = ctx->proj_type();
@@ -3228,7 +3158,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 *
return n_total;
}
int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
int clip_n_output_tokens_y(const clip_ctx * ctx, const clip_image_f32 * img) {
const auto & params = ctx->model.hparams;
const auto & proj = ctx->proj_type();
switch (proj) {
@@ -3250,7 +3180,7 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 *
return 1;
}
int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
int clip_n_output_tokens(const clip_ctx * ctx, const clip_image_f32 * img) {
const auto & params = ctx->model.hparams;
// for models with fixed size image, the input image is already pre-processed and resized to square
@@ -3500,16 +3430,15 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
return n_patches;
}
bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, std::vector<float> & out_vec) {
bool clip_image_encode(struct clip_ctx * ctx, int n_threads, const clip_image_f32 * img, std::vector<float> & out_vec) {
clip_image_f32_batch imgs;
clip_image_f32_ptr img_copy(clip_image_f32_init());
*img_copy = *img;
clip_image_f32 img_copy = *img;
imgs.entries.push_back(std::move(img_copy));
return clip_image_batch_encode(ctx, n_threads, &imgs, out_vec);
}
bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, std::vector<float> & out_batch_embd) {
bool clip_image_batch_encode(clip_ctx * ctx, int n_threads, const clip_image_f32_batch * imgs_c_ptr, std::vector<float> & out_batch_embd) {
const clip_image_f32_batch & imgs = *imgs_c_ptr;
int n_batch_cur = imgs.entries.size();
@@ -3533,8 +3462,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
const auto & model = ctx->model;
const auto & hparams = model.hparams;
const int image_size_width = imgs.entries[0]->nx();
const int image_size_height = imgs.entries[0]->ny();
const int image_size_width = imgs.entries[0].nx();
const int image_size_height = imgs.entries[0].ny();
const int patch_size = hparams.patch_size;
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
@@ -3572,7 +3501,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
if (!imgs.is_audio) {
size_t nelem = 0;
for (const auto & img : imgs.entries) {
nelem += img->nx() * img->ny() * 3;
nelem += img.nx() * img.ny() * 3;
}
std::vector<float> inp_raw(nelem);
@@ -3590,13 +3519,13 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
// IMPORTANT: [QWEN_VIDEO] the batch dim is currently used for temporal dim in Qwen-VL models
// All entries must have the same spatial size (enforced by can_batch_with() during merging)
{
const int nx = imgs.entries[0]->nx();
const int ny = imgs.entries[0]->ny();
const int nx = imgs.entries[0].nx();
const int ny = imgs.entries[0].ny();
const int n = nx * ny;
for (int b = 0; b < n_batch_cur; b++) {
LOG_DBG("%s: copying image %d/%d to input buffer (nx=%d, ny=%d)\n", __func__, b+1, n_batch_cur, nx, ny);
const auto & buf = imgs.entries[b]->get_ro_buf();
const auto & buf = imgs.entries[b].get_ro_buf();
float * batch_entry = inp_raw.data() + b * (3*n);
for (int y = 0; y < ny; y++) {
for (int x = 0; x < nx; x++) {
@@ -3616,9 +3545,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
GGML_ASSERT(imgs.entries.size() == 1);
const auto & mel_inp = imgs.entries[0];
const auto & buf = mel_inp->get_ro_buf();
const int n_step = mel_inp->nx();
const int n_mel = mel_inp->ny();
const auto & buf = mel_inp.get_ro_buf();
const int n_step = mel_inp.nx();
const int n_mel = mel_inp.ny();
GGML_ASSERT((size_t)n_step * n_mel == buf.size());
set_input_f32("inp_raw", buf);
@@ -4232,7 +4161,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
GGML_ASSERT(imgs.entries.size() == 1);
const auto & img0 = imgs.entries.front();
// Compute n_pos matching SSCP output: two stride-2 convs
int n_pos = img0->nx();
int n_pos = img0.nx();
for (int i = 0; i < 2; i++) { n_pos = (n_pos - 1) / 2 + 1; }
// Chunked local attention: blocked causal mask and RPE
@@ -4280,7 +4209,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
case PROJECTOR_TYPE_LFM2A:
{
GGML_ASSERT(imgs.entries.size() == 1);
const auto n_frames = clip_n_output_tokens(ctx, imgs.entries.front().get());
const auto n_frames = clip_n_output_tokens(ctx, &imgs.entries.front());
auto d_model = 512;
auto seq_len = n_frames * 2 - 1;
@@ -4338,7 +4267,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
// reshapes as ggml_get_rows gathers. The names are set
// by g4v_gather() in models/granite4-vision.cpp.
const int patch_size = model.hparams.patch_size;
const int image_side = imgs.entries.front()->nx() / patch_size;
const int image_side = imgs.entries.front().nx() / patch_size;
const int window_side = hparams.downsample_window_side;
const int query_side = hparams.downsample_query_side;
const int n = image_side / window_side;
@@ -4432,7 +4361,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
// sanity check (assuming that all images in batch have the same number of tokens, so we only check the first one)
const int n_tokens_out = embeddings->ne[1];
const int expected_n_tokens_out = clip_n_output_tokens(ctx, imgs.entries[0].get());
const int expected_n_tokens_out = clip_n_output_tokens(ctx, &imgs.entries[0]);
if (n_tokens_out != expected_n_tokens_out) {
LOG_ERR("%s: expected output %d tokens, got %d\n", __func__, expected_n_tokens_out, n_tokens_out);
GGML_ABORT("Invalid number of output tokens");
+5 -26
View File
@@ -29,7 +29,6 @@ struct clip_image_size {
};
struct clip_image_f32;
struct clip_image_u8_batch;
struct clip_image_f32_batch;
enum clip_modality {
@@ -63,41 +62,21 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
void clip_free(struct clip_ctx * ctx);
int32_t clip_get_image_size (const struct clip_ctx * ctx);
int32_t clip_get_patch_size (const struct clip_ctx * ctx);
int32_t clip_get_hidden_size(const struct clip_ctx * ctx);
// TODO: should be enum, not string
const char * clip_patch_merge_type(const struct clip_ctx * ctx);
int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img);
int clip_n_output_tokens(const clip_ctx * ctx, const clip_image_f32 * img);
// for M-RoPE, this will be the number of token positions in X and Y directions
// for other models, X will be the total number of tokens and Y will be 1
int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img);
int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img);
int clip_n_output_tokens_x(const clip_ctx * ctx, const clip_image_f32 * img);
int clip_n_output_tokens_y(const clip_ctx * ctx, const clip_image_f32 * img);
// this should be equal to the embedding dimension of the text model
int clip_n_mmproj_embd(const struct clip_ctx * ctx);
struct clip_image_size * clip_image_size_init(void);
struct clip_image_u8 * clip_image_u8_init (void);
struct clip_image_f32 * clip_image_f32_init(void);
struct clip_image_f32_batch * clip_image_f32_batch_init(void); // only used by libllava
void clip_image_size_free (struct clip_image_size * img_size);
void clip_image_u8_free (struct clip_image_u8 * img);
void clip_image_f32_free(struct clip_image_f32 * img);
void clip_image_u8_batch_free (struct clip_image_u8_batch * batch);
void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
// use for accessing underlay data of clip_image_f32_batch
size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch); // equivalent to batch->size()
size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->nx
size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny
struct clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
bool clip_image_encode (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, std::vector<float> & out_vec);
// TODO: remove clip_image_encode() and always use batched version
bool clip_image_encode (struct clip_ctx * ctx, int n_threads, const clip_image_f32 * img, std::vector<float> & out_vec);
bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, std::vector<float> & out_batch_embd);
bool clip_is_llava(const struct clip_ctx * ctx);
+102 -123
View File
@@ -4,17 +4,33 @@
#include <cmath>
#include <vector>
//
// base implementation
//
void mtmd_image_preprocessor::img_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst, const float mean[3], const float std[3]) {
dst.from_u8(src);
dst.normalize(mean, std);
void mtmd_image_preproc_out::append(const clip_hparams & hparams, const clip_image_u8 & img, bool normalized) {
clip_image_f32 dst;
dst.from_u8(img);
if (normalized) {
dst.normalize(hparams.image_mean, hparams.image_std);
}
entries.push_back(std::move(dst));
}
void mtmd_image_preprocessor::img_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst) {
dst.from_u8(src);
void mtmd_image_preproc_out::append(const clip_hparams & hparams, const std::vector<clip_image_u8> & imgs, bool normalized) {
for (const auto & img : imgs) {
append(hparams, img, normalized);
}
}
void mtmd_image_preproc_out::append(const clip_hparams & hparams, clip_image_f32 & img, bool normalized) {
if (normalized) {
img.normalize(hparams.image_mean, hparams.image_std);
}
entries.push_back(std::move(img));
}
void mtmd_image_preproc_out::append_overview(const clip_hparams & hparams, const clip_image_u8 & img, bool normalized) {
overview.from_u8(img);
if (normalized) {
overview.normalize(hparams.image_mean, hparams.image_std);
}
}
// set of tools to manipulate images
@@ -595,21 +611,18 @@ private:
// mtmd_image_preprocessor_llava_uhd
//
bool mtmd_image_preprocessor_llava_uhd::preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) {
mtmd_image_preproc_out mtmd_image_preprocessor_llava_uhd::preprocess(const clip_image_u8 & img) {
const clip_image_size original_size = img.get_size();
auto const inst = get_slice_instructions(original_size);
std::vector<clip_image_u8_ptr> imgs = slice_image(img, inst);
for (size_t i = 0; i < imgs.size(); ++i) {
// clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
clip_image_f32_ptr res(clip_image_f32_init());
img_u8_to_f32(*imgs[i], *res, hparams.image_mean, hparams.image_std);
output.entries.push_back(std::move(res));
}
auto sliced = slice_image(img, inst);
mtmd_image_preproc_out output;
output.append_overview(hparams, sliced.overview, true);
output.append(hparams, sliced.slices, true);
output.grid_x = inst.grid_size.width;
output.grid_y = inst.grid_size.height;
return true;
return output;
}
mtmd_image_preprocessor_llava_uhd::slice_instructions mtmd_image_preprocessor_llava_uhd::get_slice_instructions(const clip_image_size & original_size) {
@@ -717,28 +730,21 @@ mtmd_image_preprocessor_llava_uhd::slice_instructions mtmd_image_preprocessor_ll
return res;
}
std::vector<clip_image_u8_ptr> mtmd_image_preprocessor_llava_uhd::slice_image(const clip_image_u8 & img, const mtmd_image_preprocessor_llava_uhd::slice_instructions & inst, bool overview_first) {
std::vector<clip_image_u8_ptr> output;
mtmd_image_preprocessor_llava_uhd::slice_output mtmd_image_preprocessor_llava_uhd::slice_image(const clip_image_u8 & img, const mtmd_image_preprocessor_llava_uhd::slice_instructions & inst) {
slice_output output;
// resize to overview size
clip_image_u8_ptr resized_img(clip_image_u8_init());
img_tool::resize(img, *resized_img, inst.overview_size, hparams.image_resize_algo_ov,
img_tool::resize(img, output.overview, inst.overview_size, hparams.image_resize_algo_ov,
hparams.image_pad_ov, hparams.image_pad_color_ov);
if (overview_first) {
output.push_back(std::move(resized_img));
}
if (inst.slices.empty()) {
// no slices, just return the resized image
if (!overview_first) {
output.push_back(std::move(resized_img));
}
// no slices, just return the overview image
return output;
}
// resize to refined size
clip_image_u8_ptr refined_img(clip_image_u8_init());
img_tool::resize(img, *refined_img, inst.refined_size, hparams.image_resize_algo_rf,
clip_image_u8 refined_img;
img_tool::resize(img, refined_img, inst.refined_size, hparams.image_resize_algo_rf,
hparams.image_pad_rf, hparams.image_pad_color_rf);
// create slices
@@ -748,13 +754,9 @@ std::vector<clip_image_u8_ptr> mtmd_image_preprocessor_llava_uhd::slice_image(co
int w = slice.size.width;
int h = slice.size.height;
clip_image_u8_ptr img_slice(clip_image_u8_init());
img_tool::crop(*refined_img, *img_slice, x, y, w, h);
output.push_back(std::move(img_slice));
}
if (!overview_first) {
output.push_back(std::move(resized_img));
clip_image_u8 img_slice;
img_tool::crop(refined_img, img_slice, x, y, w, h);
output.slices.push_back(std::move(img_slice));
}
return output;
@@ -871,24 +873,23 @@ clip_image_size mtmd_image_preprocessor_llava_uhd::get_best_grid(const int max_s
// mtmd_image_preprocessor_fixed_size
//
bool mtmd_image_preprocessor_fixed_size::preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) {
mtmd_image_preproc_out mtmd_image_preprocessor_fixed_size::preprocess(const clip_image_u8 & img) {
clip_image_u8 resized_image;
int sz = hparams.image_size;
img_tool::resize(img, resized_image, {sz, sz},
hparams.image_resize_algo,
hparams.image_resize_pad,
hparams.image_pad_color);
clip_image_f32_ptr img_f32(clip_image_f32_init());
img_u8_to_f32(resized_image, *img_f32, hparams.image_mean, hparams.image_std);
output.entries.push_back(std::move(img_f32));
return true;
mtmd_image_preproc_out output;
output.append(hparams, resized_image, true);
return output;
}
//
// mtmd_image_preprocessor_dyn_size
//
bool mtmd_image_preprocessor_dyn_size::preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) {
mtmd_image_preproc_out mtmd_image_preprocessor_dyn_size::preprocess(const clip_image_u8 & img) {
GGML_ASSERT(hparams.image_min_pixels > 0 && hparams.image_max_pixels > 0);
clip_image_u8 resized_image;
const clip_image_size original_size = img.get_size();
@@ -903,17 +904,16 @@ bool mtmd_image_preprocessor_dyn_size::preprocess(const clip_image_u8 & img, cli
hparams.image_resize_algo,
hparams.image_resize_pad,
hparams.image_pad_color);
clip_image_f32_ptr img_f32(clip_image_f32_init());
img_u8_to_f32(resized_image, *img_f32, hparams.image_mean, hparams.image_std);
output.entries.push_back(std::move(img_f32));
return true;
mtmd_image_preproc_out output;
output.append(hparams, resized_image, true);
return output;
}
//
// mtmd_image_preprocessor_longest_edge
//
bool mtmd_image_preprocessor_longest_edge::preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) {
mtmd_image_preproc_out mtmd_image_preprocessor_longest_edge::preprocess(const clip_image_u8 & img) {
GGML_ASSERT(hparams.image_longest_edge > 0);
clip_image_u8 resized_image;
const clip_image_size original_size = img.get_size();
@@ -927,10 +927,9 @@ bool mtmd_image_preprocessor_longest_edge::preprocess(const clip_image_u8 & img,
hparams.image_resize_algo,
hparams.image_resize_pad,
hparams.image_pad_color);
clip_image_f32_ptr img_f32(clip_image_f32_init());
img_u8_to_f32(resized_image, *img_f32, hparams.image_mean, hparams.image_std);
output.entries.push_back(std::move(img_f32));
return true;
mtmd_image_preproc_out output;
output.append(hparams, resized_image, true);
return output;
}
//
@@ -1040,7 +1039,7 @@ clip_image_size mtmd_image_preprocessor_lfm2::get_grid_layout(int height, int wi
// mtmd_image_preprocessor_idefics3
//
bool mtmd_image_preprocessor_idefics3::preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) {
mtmd_image_preproc_out mtmd_image_preprocessor_idefics3::preprocess(const clip_image_u8 & img) {
// The refined size has two steps:
// 1. Resize w/ aspect-ratio preserving such that the longer side is
// the preprocessor longest size
@@ -1075,46 +1074,40 @@ bool mtmd_image_preprocessor_idefics3::preprocess(const clip_image_u8 & img, cli
});
}
}
auto imgs = slice_image(img, instructions);
// cast and normalize to f32
for (size_t i = 0; i < imgs.size(); ++i) {
// clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
clip_image_f32_ptr res(clip_image_f32_init());
img_u8_to_f32(*imgs[i], *res, hparams.image_mean, hparams.image_std);
output.entries.push_back(std::move(res));
}
auto sliced = slice_image(img, instructions);
mtmd_image_preproc_out output;
output.append_overview(hparams, sliced.overview, true);
output.append(hparams, sliced.slices, true);
output.grid_x = instructions.grid_size.width;
output.grid_y = instructions.grid_size.height;
return true;
return output;
}
//
// mtmd_image_preprocessor_internvl
//
bool mtmd_image_preprocessor_internvl::preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) {
mtmd_image_preproc_out mtmd_image_preprocessor_internvl::preprocess(const clip_image_u8 & img) {
GGML_ASSERT(!hparams.image_res_candidates.empty());
const clip_image_size original_size = img.get_size();
auto const inst = get_slice_instructions(original_size);
std::vector<clip_image_u8_ptr> imgs = slice_image(img, inst, false);
auto sliced = slice_image(img, inst);
for (size_t i = 0; i < imgs.size(); ++i) {
clip_image_f32_ptr res(clip_image_f32_init());
img_u8_to_f32(*imgs[i], *res, hparams.image_mean, hparams.image_std);
output.entries.push_back(std::move(res));
}
mtmd_image_preproc_out output;
// InternVL: slices first, then overview
output.append(hparams, sliced.slices, true);
output.append_overview(hparams, sliced.overview, true);
output.grid_x = inst.grid_size.width;
output.grid_y = inst.grid_size.height;
return true;
return output;
}
//
// mtmd_image_preprocessor_deepseekocr
//
bool mtmd_image_preprocessor_deepseekocr::preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) {
mtmd_image_preproc_out mtmd_image_preprocessor_deepseekocr::preprocess(const clip_image_u8 & img) {
static constexpr int native_resolutions[] = { 1024 /* base */, 1280 /* large */ };
// TODO: support 512 (tiny) and 640 (small) once we have eval data for them
@@ -1137,14 +1130,12 @@ bool mtmd_image_preprocessor_deepseekocr::preprocess(const clip_image_u8 & img,
clip_image_u8 padded;
img_tool::resize(img, padded, {image_size, image_size}, RESIZE_ALGO_BICUBIC_PILLOW,
PAD_NEAREST, hparams.image_pad_color);
clip_image_f32_ptr res(clip_image_f32_init());
img_u8_to_f32(padded, *res, hparams.image_mean, hparams.image_std);
output.entries.push_back(std::move(res));
output.grid_x = 1;
output.grid_y = 1;
return true;
mtmd_image_preproc_out output;
output.append_overview(hparams, padded, true);
output.grid_x = 0;
output.grid_y = 0;
// TODO @ngxson : support slicing for DeepSeek-OCR, to do in another PR
return output;
}
//
@@ -1207,10 +1198,11 @@ clip_image_size mtmd_image_preprocessor_deepseekocr2::find_closest_aspect_ratio(
return best_ratio;
}
bool mtmd_image_preprocessor_deepseekocr2::preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) {
mtmd_image_preproc_out mtmd_image_preprocessor_deepseekocr2::preprocess(const clip_image_u8 & img) {
// emit 768x768 local tiles when the image is larger than a tile in either
// dimension, then always a 1024x1024 global view. order: [tiles..., global].
mtmd_image_preproc_out output;
const auto img_size = img.get_size();
if (img_size.width > tile_size || img_size.height > tile_size) {
const float aspect_ratio = static_cast<float>(img_size.width) / img_size.height;
@@ -1226,9 +1218,7 @@ bool mtmd_image_preprocessor_deepseekocr2::preprocess(const clip_image_u8 & img,
for (int col = 0; col < grid.width; col++) {
clip_image_u8 tile;
img_tool::crop(refined, tile, col * tile_size, row * tile_size, tile_size, tile_size);
clip_image_f32_ptr res(clip_image_f32_init());
img_u8_to_f32(tile, *res, hparams.image_mean, hparams.image_std);
output.entries.push_back(std::move(res));
output.append(hparams, tile, true);
}
}
}
@@ -1237,14 +1227,9 @@ bool mtmd_image_preprocessor_deepseekocr2::preprocess(const clip_image_u8 & img,
clip_image_u8 padded;
img_tool::resize(img, padded, { base_size, base_size }, RESIZE_ALGO_BICUBIC_PILLOW,
PAD_NEAREST, hparams.image_pad_color);
clip_image_f32_ptr global(clip_image_f32_init());
img_u8_to_f32(padded, *global, hparams.image_mean, hparams.image_std);
global->add_viewsep = true;
output.entries.push_back(std::move(global));
output.grid_x = 1;
output.grid_y = 1;
return true;
output.append_overview(hparams, padded, true);
output.overview.add_viewsep = true;
return output;
}
//
@@ -1260,7 +1245,8 @@ void mtmd_image_preprocessor_step3vl::img_u8_resize_bilinear_to_f32(
const float std[3]) {
const auto src_size = src.get_size();
if (src_size.width == target_width && src_size.height == target_height) {
img_u8_to_f32(src, dst, mean, std);
dst.from_u8(src);
dst.normalize(mean, std);
return;
}
@@ -1455,24 +1441,24 @@ mtmd_image_preprocessor_step3vl::slice_instructions mtmd_image_preprocessor_step
return instructions;
}
bool mtmd_image_preprocessor_step3vl::preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) {
mtmd_image_preproc_out mtmd_image_preprocessor_step3vl::preprocess(const clip_image_u8 & img) {
clip_image_u8 prepared = prepare_image(img, hparams);
const auto instructions = build_slice_instructions(hparams, prepared.get_size());
clip_image_f32_ptr overview_f32(clip_image_f32_init());
mtmd_image_preproc_out output;
// overview (normalized f32, already includes mean/std)
img_u8_resize_bilinear_to_f32(
prepared,
*overview_f32,
output.overview,
hparams.image_size,
hparams.image_size,
hparams.image_mean,
hparams.image_std);
output.entries.push_back(std::move(overview_f32));
if (instructions.slices.empty()) {
output.grid_x = 0;
output.grid_y = 0;
return true;
return output;
}
clip_image_u8 img_for_crop = prepared;
@@ -1488,28 +1474,28 @@ bool mtmd_image_preprocessor_step3vl::preprocess(const clip_image_u8 & img, clip
// If the requested patch extends past the source image, pad the out-of-bounds area with black.
clip_image_u8 patch = crop_with_black_padding(img_for_crop, slice.x, slice.y, slice.size.width, slice.size.height);
clip_image_f32_ptr patch_f32(clip_image_f32_init());
clip_image_f32 patch_f32;
img_u8_resize_bilinear_to_f32(
patch,
*patch_f32,
patch_f32,
crop_size,
crop_size,
hparams.image_mean,
hparams.image_std);
output.entries.push_back(std::move(patch_f32));
output.append(hparams, patch_f32, false);
}
output.grid_x = instructions.grid_size.width;
output.grid_y = instructions.grid_size.height;
return true;
return output;
}
//
// mtmd_image_preprocessor_youtuvl
//
bool mtmd_image_preprocessor_youtuvl::preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) {
mtmd_image_preproc_out mtmd_image_preprocessor_youtuvl::preprocess(const clip_image_u8 & img) {
const int patch_size = hparams.patch_size; // typically 16
const int merge_size = hparams.n_merge; // typically 2
const int align_size = patch_size * merge_size; // 32
@@ -1553,29 +1539,22 @@ bool mtmd_image_preprocessor_youtuvl::preprocess(const clip_image_u8 & img, clip
clip_image_u8 resized;
img_tool::resize(img, resized, new_size, hparams.image_resize_algo, hparams.image_resize_pad);
// Normalize to float32
clip_image_f32_ptr img_f32(clip_image_f32_init());
img_u8_to_f32(resized, *img_f32, hparams.image_mean, hparams.image_std);
// Add to results
output.entries.push_back(std::move(img_f32));
return true;
mtmd_image_preproc_out output;
output.append(hparams, resized, true);
return output;
}
bool mtmd_image_preprocessor_granite::preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) {
// call super class preprocessor
bool ok = mtmd_image_preprocessor_llava_uhd::preprocess(img, output);
if (!ok) {
return false;
}
if (output.entries.size() == 1) {
mtmd_image_preproc_out mtmd_image_preprocessor_granite::preprocess(const clip_image_u8 & img) {
auto output = mtmd_image_preprocessor_llava_uhd::preprocess(img);
if (output.entries.size() == 0) {
// Single-tile (overview only): append one newline row.
output.entries[0]->add_newline = true;
output.overview.add_newline = true;
} else {
// Multi-tile: overview gets no newline, grid tiles get one.
output.entries[0]->add_newline = false;
for (size_t i = 1; i < output.entries.size(); ++i) {
output.entries[i]->add_newline = true;
output.overview.add_newline = false;
for (size_t i = 0; i < output.entries.size(); ++i) {
output.entries[i].add_newline = true;
}
}
return true;
return output;
}
+37 -16
View File
@@ -8,6 +8,24 @@
#define MTMD_INTERNAL_HEADER
struct mtmd_image_preproc_out {
std::vector<clip_image_f32> entries;
// grid size is required for llava-uhd style models
clip_image_f32 overview; // overview image (downscaled image)
int grid_x = 0;
int grid_y = 0;
void append(const clip_hparams & hparams, const clip_image_u8 & img, bool normalized = true);
void append(const clip_hparams & hparams, const std::vector<clip_image_u8> & imgs, bool normalized = true);
void append(const clip_hparams & hparams, clip_image_f32 & img, bool normalized = true);
void append_overview(const clip_hparams & hparams, const clip_image_u8 & img, bool normalized = true);
bool has_overview() const {
return overview.nx() > 0 || overview.ny() > 0;
}
};
// base class, models must inherit from this class
struct mtmd_image_preprocessor {
const clip_hparams & hparams;
@@ -15,10 +33,7 @@ struct mtmd_image_preprocessor {
mtmd_image_preprocessor(const clip_ctx * ctx): hparams(*clip_get_hparams(ctx)) {}
virtual ~mtmd_image_preprocessor() = default;
virtual bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) = 0;
void img_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst, const float mean[3], const float std[3]);
void img_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst);
virtual mtmd_image_preproc_out preprocess(const clip_image_u8 & img) = 0;
};
/**
@@ -39,10 +54,12 @@ struct mtmd_image_preprocessor {
* [overview] --> [slice 1] --> [slice 2]
* | |
* +--> [slice 3] --> [slice 4]
*
* NOTE: for the ordering of overview, set "ov_img_first" on the mtmd_context
*/
struct mtmd_image_preprocessor_llava_uhd : mtmd_image_preprocessor {
mtmd_image_preprocessor_llava_uhd(const clip_ctx * ctx) : mtmd_image_preprocessor(ctx) {}
bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) override;
mtmd_image_preproc_out preprocess(const clip_image_u8 & img) override;
struct slice_coordinates {
int x;
@@ -60,7 +77,11 @@ struct mtmd_image_preprocessor_llava_uhd : mtmd_image_preprocessor {
// LFM2 override this function to implement its custom slicing logic
virtual slice_instructions get_slice_instructions(const clip_image_size & original_size);
std::vector<clip_image_u8_ptr> slice_image(const clip_image_u8 & img, const slice_instructions & inst, bool overview_first = true);
struct slice_output {
clip_image_u8 overview;
std::vector<clip_image_u8> slices;
};
slice_output slice_image(const clip_image_u8 & img, const slice_instructions & inst);
private:
clip_image_size get_best_resize(const clip_image_size & original_size, int scale_resolution, int patch_size, bool allow_upscale = false);
@@ -91,7 +112,7 @@ private:
// downscale or upscale the input image to fixed size
struct mtmd_image_preprocessor_fixed_size : mtmd_image_preprocessor {
mtmd_image_preprocessor_fixed_size(const clip_ctx * ctx) : mtmd_image_preprocessor(ctx) {}
bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) override;
mtmd_image_preproc_out preprocess(const clip_image_u8 & img) override;
};
// resize image to multiple of patch_size*n_merge, while preserving aspect ratio
@@ -99,13 +120,13 @@ struct mtmd_image_preprocessor_fixed_size : mtmd_image_preprocessor {
// this is used by models with native support for dynamic image size, for example: Qwen-VL, Pixtral, Kimi-VL, etc
struct mtmd_image_preprocessor_dyn_size : mtmd_image_preprocessor {
mtmd_image_preprocessor_dyn_size(const clip_ctx * ctx) : mtmd_image_preprocessor(ctx) {}
bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) override;
mtmd_image_preproc_out preprocess(const clip_image_u8 & img) override;
};
// similar to mtmd_image_preprocessor_dyn_size, but resize the image to have longest edge equal to hparams.image_longest_edge, while preserving aspect ratio
struct mtmd_image_preprocessor_longest_edge : mtmd_image_preprocessor {
mtmd_image_preprocessor_longest_edge(const clip_ctx * ctx) : mtmd_image_preprocessor(ctx) {}
bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) override;
mtmd_image_preproc_out preprocess(const clip_image_u8 & img) override;
};
// custom llava-uhd slicing logic for LFM2
@@ -131,17 +152,17 @@ private:
struct mtmd_image_preprocessor_idefics3 : mtmd_image_preprocessor_llava_uhd {
mtmd_image_preprocessor_idefics3(const clip_ctx * ctx) : mtmd_image_preprocessor_llava_uhd(ctx) {}
bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) override;
mtmd_image_preproc_out preprocess(const clip_image_u8 & img) override;
};
struct mtmd_image_preprocessor_internvl : mtmd_image_preprocessor_llava_uhd {
mtmd_image_preprocessor_internvl(const clip_ctx * ctx) : mtmd_image_preprocessor_llava_uhd(ctx) {}
bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) override;
mtmd_image_preproc_out preprocess(const clip_image_u8 & img) override;
};
struct mtmd_image_preprocessor_deepseekocr : mtmd_image_preprocessor {
mtmd_image_preprocessor_deepseekocr(const clip_ctx * ctx) : mtmd_image_preprocessor(ctx) {}
bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) override;
mtmd_image_preproc_out preprocess(const clip_image_u8 & img) override;
};
// DeepSeek-OCR-2: a 1024x1024 global view, plus InternVL-style 768x768 local
@@ -153,7 +174,7 @@ struct mtmd_image_preprocessor_deepseekocr2 : mtmd_image_preprocessor {
static constexpr int max_tiles = 6;
mtmd_image_preprocessor_deepseekocr2(const clip_ctx * ctx) : mtmd_image_preprocessor(ctx) {}
bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) override;
mtmd_image_preproc_out preprocess(const clip_image_u8 & img) override;
private:
static std::vector<clip_image_size> get_target_ratios();
@@ -168,7 +189,7 @@ private:
// ref: https://huggingface.co/stepfun-ai/Step3-VL-10B/blob/main/processing_step3.py
struct mtmd_image_preprocessor_step3vl : mtmd_image_preprocessor_llava_uhd {
mtmd_image_preprocessor_step3vl(const clip_ctx * ctx) : mtmd_image_preprocessor_llava_uhd(ctx) {}
bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) override;
mtmd_image_preproc_out preprocess(const clip_image_u8 & img) override;
static slice_instructions build_slice_instructions(const clip_hparams & params, const clip_image_size & prepared_size);
private:
@@ -195,11 +216,11 @@ private:
struct mtmd_image_preprocessor_youtuvl : mtmd_image_preprocessor {
mtmd_image_preprocessor_youtuvl(const clip_ctx * ctx) : mtmd_image_preprocessor(ctx) {}
bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) override;
mtmd_image_preproc_out preprocess(const clip_image_u8 & img) override;
};
// similar to llava_uhd, but has add_newline
struct mtmd_image_preprocessor_granite : mtmd_image_preprocessor_llava_uhd {
mtmd_image_preprocessor_granite(const clip_ctx * ctx) : mtmd_image_preprocessor_llava_uhd(ctx) {}
bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) override;
mtmd_image_preproc_out preprocess(const clip_image_u8 & img) override;
};
+105 -86
View File
@@ -114,7 +114,7 @@ struct mtmd_image_tokens {
// true if one of entries in batch_f32 is a placeholder
bool is_placeholder() const {
for (const auto & entry : batch_f32.entries) {
if (entry->is_placeholder()) {
if (entry.is_placeholder()) {
return true;
}
}
@@ -147,7 +147,7 @@ struct mtmd_audio_tokens {
// true if one of entries in batch_f32 is a placeholder
bool is_placeholder() const {
for (const auto & entry : batch_f32.entries) {
if (entry->is_placeholder()) {
if (entry.is_placeholder()) {
return true;
}
}
@@ -516,6 +516,7 @@ struct mtmd_context {
LOG_WRN("%s: llama 4 vision is known to have degraded quality:\n"
" https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
image_preproc = std::make_unique<mtmd_image_preprocessor_llava_uhd>(ctx_v);
ov_img_first = false;
} break;
case PROJECTOR_TYPE_STEP3VL:
{
@@ -539,6 +540,7 @@ struct mtmd_context {
img_beg = "<img>";
img_end = "</img>";
image_preproc = std::make_unique<mtmd_image_preprocessor_internvl>(ctx_v);
ov_img_first = false;
} break;
case PROJECTOR_TYPE_KIMIVL:
{
@@ -615,11 +617,13 @@ struct mtmd_context {
{
img_end = "\n"; // prevent empty batch on llama-server
image_preproc = std::make_unique<mtmd_image_preprocessor_deepseekocr>(ctx_v);
ov_img_first = false;
} break;
case PROJECTOR_TYPE_DEEPSEEKOCR2:
{
img_end = "\n"; // prevent empty batch on llama-server
image_preproc = std::make_unique<mtmd_image_preprocessor_deepseekocr2>(ctx_v);
ov_img_first = false;
} break;
case PROJECTOR_TYPE_HUNYUANVL:
{
@@ -640,6 +644,7 @@ struct mtmd_context {
img_beg = "<image>";
img_end = "";
image_preproc = std::make_unique<mtmd_image_preprocessor_granite>(ctx_v);
ov_img_first = true;
} break;
default:
throw std::runtime_error(string_format("%s: unexpected vision projector type %d\n", __func__, proj));
@@ -1050,7 +1055,7 @@ struct mtmd_tokenizer {
// TODO @ngxson : this is quite hacky because preprocessor only support batch with one single element, that need to be fixed in the future (e.g. by changing the preprocessor interface always take single input)
clip_image_f32_batch batch_f32;
mtmd_image_preproc_out preproc_out;
for (const auto * bmp : bitmaps) {
// sanity check
@@ -1063,44 +1068,54 @@ struct mtmd_tokenizer {
}
// convert mtmd_bitmap to clip_image_u8
clip_image_u8_ptr img_u8(clip_image_u8_init());
img_u8->set_size(
clip_image_u8 img_u8;
img_u8.set_size(
{(int)bmp->nx, (int)bmp->ny},
bmp->is_placeholder());
img_u8->cpy_buf(bmp->get_ro_buf());
img_u8.cpy_buf(bmp->get_ro_buf());
// preprocess image
clip_image_f32_batch tmp_batch;
bool ok = ctx->image_preproc->preprocess(*img_u8, tmp_batch);
if (!ok) {
LOG_ERR("Unable to preprocess image\n");
return 2;
}
mtmd_image_preproc_out tmp_preproc_out = ctx->image_preproc->preprocess(img_u8);
// move entries and grid dimensions to the "global" batch_f32
for (auto & entry : tmp_batch.entries) {
batch_f32.entries.emplace_back(std::move(entry));
// move entries and grid dimensions to the "global" preproc_out
for (auto & entry : tmp_preproc_out.entries) {
preproc_out.entries.emplace_back(std::move(entry));
}
// for llava-uhd style, we need to handle grid too
// we don't care about overwriting these values for now because llama-uhd doesn't support batching anyway
batch_f32.grid_x = tmp_batch.grid_x;
batch_f32.grid_y = tmp_batch.grid_y;
// we don't care about overwriting these values for now because the case where bitmaps.size() > 1 is only for frame merging (qwen-vl), not supported by llava-uhd
if ((tmp_preproc_out.grid_x > 0 && tmp_preproc_out.grid_y > 0)
|| tmp_preproc_out.has_overview()) {
GGML_ASSERT(bitmaps.size() == 1);
preproc_out.grid_x = tmp_preproc_out.grid_x;
preproc_out.grid_y = tmp_preproc_out.grid_y;
preproc_out.overview = std::move(tmp_preproc_out.overview);
}
}
LOG_DBG("%s: preproc_out has %zu entries, grid_x = %d, grid_y = %d, has_overview = %d\n",
__func__, preproc_out.entries.size(), preproc_out.grid_x, preproc_out.grid_y,
preproc_out.has_overview() ? 1 : 0);
// handle llava-uhd style preprocessing
const bool has_tiling_grid = batch_f32.grid_x > 0 && batch_f32.grid_y > 0;
// (output either a grid, or overview-only)
const bool has_tiling_grid = (preproc_out.grid_x > 0 && preproc_out.grid_y > 0)
|| preproc_out.has_overview();
if (has_tiling_grid) {
// [QWEN_VIDEO] we do not support "frame merging" for llama-uhd style, so no batching for now
GGML_ASSERT(bitmaps.size() == 1);
const int n_col = batch_f32.grid_x;
const int n_row = batch_f32.grid_y;
const int n_col = preproc_out.grid_x;
const int n_row = preproc_out.grid_y;
// split batch into chunks of single images
// NOTE: batch_f32 will be invalidated after this call
auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[0]->id);
auto chunks = split_batch_to_chunk(std::move(preproc_out), bitmaps[0]->id);
GGML_ASSERT(chunks.size() > 0);
// NOTE: preproc_out is invalidated after this point, do not use it anymore
// split_batch_to_chunk must always put the overview image first
auto ov_chunk = std::move(chunks.front());
chunks.erase(chunks.begin());
@@ -1127,7 +1142,16 @@ struct mtmd_tokenizer {
std::snprintf(buf.get(), sz, ctx->sli_img_start_tmpl.c_str(), y+1, x+1);
add_text(std::string(buf.get(), buf.get() + sz - 1), true);
}
cur.entries.emplace_back(std::move(chunks[y * n_col + x]));
auto & curr_chunk = chunks[y * n_col + x];
auto & curr_batch = curr_chunk.tokens_image->batch_f32;
if (curr_batch.entries.size() != 1) {
throw std::runtime_error(string_format("%s: expect 1 image in batch_f32", __func__));
}
LOG_DBG("%s: adding slice image at row %d col %d\n", __func__, y, x);
cur.entries.emplace_back(std::move(curr_chunk));
add_text(ctx->tok_sli_img_end);
if (!is_last_in_row) {
add_text(ctx->tok_sli_img_mid);
@@ -1149,9 +1173,14 @@ struct mtmd_tokenizer {
} else {
if (preproc_out.entries.size() == 0) {
LOG_ERR("%s: no image tokens produced by preprocessor (ref: https://github.com/ggml-org/llama.cpp/pull/24769)\n", __func__);
return 2;
}
size_t n_tokens = 0;
for (const auto & e : batch_f32.entries) {
n_tokens += clip_n_output_tokens(ctx->ctx_v, e.get());
for (auto & e : preproc_out.entries) {
n_tokens += clip_n_output_tokens(ctx->ctx_v, &e);
if (clip_model_n_temporal_merge(ctx->ctx_v) == 2) {
// [QWEN_VIDEO] pair input is merged to the same embd, so only count as one image
break;
@@ -1165,8 +1194,8 @@ struct mtmd_tokenizer {
if (mtmd_decode_use_mrope(ctx)) {
// for Qwen2VL, we need this information for M-RoPE decoding positions
image_tokens->nx = clip_n_output_tokens_x(ctx->ctx_v, batch_f32.entries[0].get());
image_tokens->ny = clip_n_output_tokens_y(ctx->ctx_v, batch_f32.entries[0].get());
image_tokens->nx = clip_n_output_tokens_x(ctx->ctx_v, &preproc_out.entries[0]);
image_tokens->ny = clip_n_output_tokens_y(ctx->ctx_v, &preproc_out.entries[0]);
} else {
// other models, we only need the total number of tokens
image_tokens->nx = n_tokens;
@@ -1181,6 +1210,12 @@ struct mtmd_tokenizer {
image_tokens->image_idx = n_images_added;
GGML_ASSERT(n_tokens == (size_t)image_tokens->n_tokens());
}
clip_image_f32_batch batch_f32;
batch_f32.is_audio = false;
batch_f32.entries = std::move(preproc_out.entries);
// do NOT use preproc_out from this point on, it's moved
image_tokens->batch_f32 = std::move(batch_f32);
image_tokens->id = bitmaps[0]->id; // optional
@@ -1260,13 +1295,13 @@ struct mtmd_tokenizer {
for (auto & mel_spec : mel_spec_chunks) {
const bool is_placeholder = mel_spec.data.empty();
clip_image_f32_ptr mel_f32(clip_image_f32_init());
mel_f32->set_size(
clip_image_f32 mel_f32;
mel_f32.set_size(
{mel_spec.n_len, mel_spec.n_mel},
is_placeholder, /* is_audio */ true);
mel_f32->cpy_buf(mel_spec.data);
mel_f32.cpy_buf(mel_spec.data);
size_t n_tokens = clip_n_output_tokens(ctx->ctx_a, mel_f32.get());
size_t n_tokens = clip_n_output_tokens(ctx->ctx_a, &mel_f32);
clip_image_f32_batch batch_f32;
batch_f32.is_audio = true;
@@ -1296,16 +1331,18 @@ struct mtmd_tokenizer {
return 0;
}
std::vector<mtmd_input_chunk> split_batch_to_chunk(clip_image_f32_batch && batch_f32, const std::string & id) {
std::vector<mtmd_input_chunk> split_batch_to_chunk(mtmd_image_preproc_out && preproc_out, const std::string & id) {
std::vector<mtmd_input_chunk> chunks;
for (auto & entry : batch_f32.entries) {
auto process_chunk = [&](clip_image_f32 && img) {
mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
image_tokens->nx = clip_n_output_tokens(ctx->ctx_v, entry.get());
image_tokens->nx = clip_n_output_tokens(ctx->ctx_v, &img);
image_tokens->ny = 1;
image_tokens->batch_f32.entries.push_back(std::move(entry));
image_tokens->batch_f32.entries.push_back(std::move(img));
image_tokens->id = id;
GGML_ASSERT(image_tokens->nx > 0);
mtmd_input_chunk chunk{
MTMD_INPUT_CHUNK_TYPE_IMAGE,
{}, // text tokens
@@ -1313,6 +1350,21 @@ struct mtmd_tokenizer {
nullptr, // audio tokens
};
chunks.emplace_back(std::move(chunk));
};
// overview image first
auto & overview = preproc_out.overview;
if (overview.nx() == 0 || overview.ny() == 0) {
throw std::runtime_error(string_format("%s: invalid overview image for llava-uhd style preprocessing\n", __func__));
}
process_chunk(std::move(preproc_out.overview));
// then, process slices
for (auto & entry : preproc_out.entries) {
if (entry.nx() == 0 || entry.ny() == 0) {
throw std::runtime_error(string_format("%s: invalid image slice for llava-uhd style preprocessing\n", __func__));
}
process_chunk(std::move(entry));
}
return chunks;
@@ -1386,57 +1438,22 @@ static int32_t mtmd_encode_impl(mtmd_context * ctx, const mtmd_image_tokens * im
LOG_ERR("%s: this API does not support non-vision input, please use mtmd_encode_chunk instead\n", __func__);
return 1;
}
auto proj_type = clip_get_projector_type(ctx_clip);
int n_embd_out = ctx->n_embd_out();
auto n_tokens_out = image_tokens->n_tokens();
out_embd.resize((size_t)n_embd_out * n_tokens_out);
bool ok = false;
if (clip_is_llava(ctx_clip)
|| proj_type == PROJECTOR_TYPE_MINICPMV
|| proj_type == PROJECTOR_TYPE_GLM_EDGE
|| proj_type == PROJECTOR_TYPE_INTERNVL
|| proj_type == PROJECTOR_TYPE_DEEPSEEKOCR2
|| proj_type == PROJECTOR_TYPE_GRANITE4_VISION) {
// TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
const auto & entries = image_tokens->batch_f32.entries;
// entries may have different token counts
// e.g., DeepSeek-OCR-2: 144 per tile views, 257 for the global view
size_t offset = 0;
for (size_t i = 0; i < entries.size(); i++) {
if (entries[i]->is_placeholder()) {
LOG_ERR("%s: image tokens batch entry %zu is placeholder\n", __func__, i);
return 1;
}
int n_tokens_per_image = clip_n_output_tokens(ctx_clip, entries[i].get());
std::vector<float> tmp_embd((size_t)n_tokens_per_image * n_embd_out);
bool ok_i = clip_image_encode(
ctx_clip,
ctx->n_threads,
entries[i].get(),
tmp_embd);
if (!ok_i) {
LOG_ERR("%s: failed to encode image %zu\n", __func__, i);
return 1;
}
ok = true;
std::copy(tmp_embd.begin(), tmp_embd.end(), out_embd.begin() + offset);
offset += static_cast<size_t>(n_embd_out) * n_tokens_per_image;
}
} else {
if (image_tokens->is_placeholder()) {
LOG_ERR("%s: image tokens batch is placeholder\n", __func__);
return 1;
}
ok = clip_image_batch_encode(
ctx_clip,
ctx->n_threads,
&image_tokens->batch_f32,
out_embd);
if (image_tokens->is_placeholder()) {
LOG_ERR("%s: image tokens batch is placeholder\n", __func__);
return 1;
}
bool ok = clip_image_batch_encode(
ctx_clip,
ctx->n_threads,
&image_tokens->batch_f32,
out_embd);
return ok ? 0 : 1;
}
@@ -2063,16 +2080,18 @@ void mtmd_debug_preprocess_image(mtmd_context * ctx, const std::vector<uint8_t>
clip_image_u8 img_u8;
img_u8.set_size({nx, ny}, false);
img_u8.cpy_buf(rgb_values);
clip_image_f32_batch batch_f32;
GGML_ASSERT(ctx->image_preproc != nullptr);
bool ok = ctx->image_preproc->preprocess(img_u8, batch_f32);
if (!ok) {
LOG_ERR("%s: failed to preprocess image\n", __func__);
return;
mtmd_image_preproc_out preproc_out = ctx->image_preproc->preprocess(img_u8);
clip_image_f32_batch batch_f32;
batch_f32.is_audio = false;
for (auto & entry : preproc_out.entries) {
batch_f32.entries.push_back(std::move(entry));
}
LOG_INF("%s: preprocessed image to batch_f32 with %d entries\n", __func__, (int)batch_f32.entries.size());
for (size_t i = 0; i < batch_f32.entries.size(); i++) {
LOG_INF("%s: entry %zu has nx=%d, ny=%d\n", __func__, i, batch_f32.entries[i]->nx(), batch_f32.entries[i]->ny());
LOG_INF("%s: entry %zu has nx=%d, ny=%d\n", __func__, i, batch_f32.entries[i].nx(), batch_f32.entries[i].ny());
// TODO: better way to dump entry content?
}
}
+2
View File
@@ -17,6 +17,8 @@ add_library(${TARGET} STATIC
server-context.h
server-tools.cpp
server-tools.h
server-schema.cpp
server-schema.h
)
if (BUILD_SHARED_LIBS)
+18 -12
View File
@@ -4,6 +4,7 @@
#include "server-http.h"
#include "server-task.h"
#include "server-queue.h"
#include "server-schema.h"
#include "build-info.h"
#include "common.h"
@@ -189,9 +190,10 @@ struct server_slot {
// stats
size_t n_sent_text = 0; // number of sent text character
int64_t t_print_last = 0;
int64_t t_start_process_prompt;
int64_t t_start_generation;
int64_t t_print_last = 0;
int32_t n_decoded_last = 0;
double t_prompt_processing = 0.0; // ms
double t_token_generation = 0.0; // ms
@@ -470,11 +472,13 @@ struct server_slot {
return;
}
const double n_gen_second = 1e3 / (t_token_generation) * (n_decoded);
const double n_gen_second_win = 1e6 / (t_now - t_print_last) * (n_decoded - n_decoded_last);
t_print_last = t_now;
n_decoded_last = n_decoded;
const double n_gen_second = 1e3 / t_token_generation * n_decoded;
SLT_INF(*this, "n_decoded = %6d, tg = %6.2f t/s\n", n_decoded, n_gen_second);
SLT_INF(*this, "n_decoded = %6d, tg = %6.2f t/s, tg_3s = %6.2f t/s\n", n_decoded, n_gen_second, n_gen_second_win);
}
void print_timings_pp() const {
@@ -3038,8 +3042,8 @@ private:
}
}
const int64_t t_current = ggml_time_us();
slot.t_prompt_processing = (t_current - slot.t_start_process_prompt) / 1e3;
const int64_t t_now = ggml_time_us();
slot.t_prompt_processing = (t_now - slot.t_start_process_prompt) / 1e3;
slot.print_timings_pp();
// truncate any tokens that are beyond n_past for this slot
@@ -3447,17 +3451,19 @@ private:
common_sampler_accept(slot.smpl.get(), id, true);
// here we have synchronized the llama_context (due to the sampling above), so we can do time measurement
const int64_t t_current = ggml_time_us();
const int64_t t_now = ggml_time_us();
slot.n_decoded += 1;
if (slot.n_decoded == 1) {
slot.t_start_generation = t_current;
slot.t_start_generation = t_now;
slot.t_print_last = t_now;
slot.n_decoded_last = 0;
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot);
}
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
slot.t_token_generation = std::max<int64_t>(1, t_now - slot.t_start_generation) / 1e3;
completion_token_output result;
result.tok = id;
@@ -3551,11 +3557,11 @@ private:
slot.spec_draft = std::move(accepted);
}
const int64_t t_current = ggml_time_us();
const int64_t t_now = ggml_time_us();
const auto ids = std::move(slot.spec_draft);
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
slot.t_token_generation = std::max<int64_t>(1, t_now - slot.t_start_generation) / 1e3;
// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;
@@ -3820,7 +3826,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
task.id = rd.get_new_id();
task.tokens = std::move(inputs[i]);
task.params = server_task::params_from_json_cmpl(
task.params = server_schema::eval_llama_cmpl_schema(
ctx_server.vocab,
params,
meta->slot_n_ctx,
+43 -22
View File
@@ -54,7 +54,7 @@ extern char **environ;
struct server_subproc {
std::optional<subprocess_s> sproc; // empty while in DOWNLOADING state
std::atomic<bool> stop_download{false}; // flag to signal download cancellation
std::atomic<bool> stopped{false}; // set to cancel a download or signal child process exit
subprocess_s & get() {
GGML_ASSERT(sproc.has_value() && "subprocess not initialized");
@@ -64,6 +64,22 @@ struct server_subproc {
bool is_alive() {
return sproc.has_value() && subprocess_alive(&sproc.value());
}
void terminate() {
if (!sproc.has_value()) {
return;
}
#if defined(_WIN32)
if (sproc->hProcess == NULL) {
return;
}
#else
if (sproc->child <= 0) {
return;
}
#endif
subprocess_terminate(&sproc.value());
}
};
@@ -351,6 +367,12 @@ void server_models::load_models() {
source_map[name] = SERVER_MODEL_SOURCE_PRESET;
}
// overlay router's own CLI args on top of every model preset so that
// e.g. `llama-server --temp 0` is honoured by all child processes
for (auto & [name, preset] : final_presets) {
preset.merge(base_preset);
}
auto get_source = [&](const std::string & name) {
return source_map.count(name) ? source_map.at(name) : SERVER_MODEL_SOURCE_PRESET;
};
@@ -896,50 +918,49 @@ void server_models::load(const std::string & name) {
});
std::thread stopping_thread([&]() {
// thread to monitor stopping signal OR child crash
// thread to monitor explicit stop requests; child crash is signalled via child_proc->stopped
auto is_stopping = [this, &name]() {
return this->stopping_models.find(name) != this->stopping_models.end();
};
auto should_wake = [&]() {
return is_stopping() || !child_proc->is_alive();
};
{
std::unique_lock<std::mutex> lk(this->mutex);
this->cv_stop.wait(lk, should_wake);
this->cv_stop.wait(lk, [&]() {
return is_stopping() || child_proc->stopped.load(std::memory_order_acquire);
});
}
// child may have already exited (e.g. crashed) — skip shutdown sequence
if (!child_proc->is_alive()) {
// child crashed or finished on its own — skip graceful shutdown sequence
if (child_proc->stopped.load(std::memory_order_acquire)) {
return;
}
SRV_INF("stopping model instance name=%s\n", name.c_str());
// send interrupt to child process
fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
fflush(stdin_file);
// wait to stop gracefully or timeout
int64_t start_time = ggml_time_ms();
while (true) {
std::unique_lock<std::mutex> lk(this->mutex);
if (!is_stopping()) {
return; // already stopped
if (!is_stopping() || child_proc->stopped.load(std::memory_order_acquire)) {
return;
}
int64_t elapsed = ggml_time_ms() - start_time;
if (elapsed >= stop_timeout * 1000) {
// timeout, force kill
lk.unlock();
SRV_WRN("force-killing model instance name=%s after %d seconds timeout\n", name.c_str(), stop_timeout);
subprocess_terminate(&child_proc->get());
child_proc->terminate();
return;
}
this->cv_stop.wait_for(lk, std::chrono::seconds(1));
this->cv_stop.wait_for(lk, std::chrono::seconds(1), [&]() {
return !is_stopping() || child_proc->stopped.load(std::memory_order_acquire);
});
}
});
// we reach here when the child process exits
// we reach here when the child process exits (stdout EOF)
// note: we cannot join() prior to this point because it will close stdin_file
if (log_thread.joinable()) {
log_thread.join();
}
// stop the timeout monitoring thread
child_proc->stopped.store(true, std::memory_order_release);
{
std::lock_guard<std::mutex> lk(this->mutex);
stopping_models.erase(name);
@@ -965,7 +986,7 @@ void server_models::load(const std::string & name) {
// old process should have exited already, but just in case, we clean it up here
if (old_instance.subproc->is_alive()) {
SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
subprocess_terminate(&old_instance.subproc->get()); // force kill
old_instance.subproc->terminate(); // force kill
}
if (old_instance.th.joinable()) {
old_instance.th.join();
@@ -1033,7 +1054,7 @@ void server_models::download(common_params_model && model, common_download_opts
dl->opts = opts; // copy
dl->should_stop = [sp = inst.subproc]() {
return sp->stop_download.load(std::memory_order_relaxed);
return sp->stopped.load(std::memory_order_relaxed);
};
dl->on_progress = [this, name](const common_download_progress & p) {
@@ -1063,7 +1084,7 @@ void server_models::unload(const std::string & name) {
if (it != mapping.end()) {
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
SRV_INF("cancelling download for model name=%s\n", name.c_str());
it->second.subproc->stop_download.store(true, std::memory_order_relaxed);
it->second.subproc->stopped.store(true, std::memory_order_relaxed);
// for convenience, we wait the status change here
wait(lk, name, [](const server_model_meta & new_meta) {
return new_meta.status != SERVER_MODEL_STATUS_DOWNLOADING;
@@ -1074,7 +1095,7 @@ void server_models::unload(const std::string & name) {
if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) {
// special case: if model is in loading state, unloading means force-killing it
SRV_WRN("model name=%s is still loading, force-killing\n", name.c_str());
subprocess_terminate(&it->second.subproc->get());
it->second.subproc->terminate();
}
cv_stop.notify_all();
// status change will be handled by the managing thread
@@ -1091,7 +1112,7 @@ void server_models::unload_all() {
for (auto & [name, inst] : mapping) {
if (inst.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
SRV_INF("cancelling download for model name=%s\n", name.c_str());
inst.subproc->stop_download.store(true, std::memory_order_relaxed);
inst.subproc->stopped.store(true, std::memory_order_relaxed);
} else if (inst.meta.is_running()) {
SRV_INF("stopping model instance name=%s\n", name.c_str());
stopping_models.insert(name);
+635
View File
@@ -0,0 +1,635 @@
#include "server-schema.h"
#include "json-schema-to-grammar.h"
namespace server_schema {
//
// llama.cpp-specific completion schema
//
std::vector<std::unique_ptr<field>> make_llama_cmpl_schema(const common_params & params_base, task_params & params) {
std::vector<std::unique_ptr<field>> fields;
auto add = [&](field * f) {
fields.emplace_back(f);
};
add((new field_bool("timings_per_token", params.timings_per_token))
->set_desc("Include prompt processing and text generation speed information in each response"));
add((new field_bool("stream", params.stream))
->set_desc("Allows receiving each predicted token in real-time instead of waiting for the completion to finish"));
add((new field_nested("stream_options"))
->add_subfield((new field_bool("include_usage", params.include_usage))
->set_desc("Whether to include usage information in the stream"))
->set_desc("Additional options for streaming responses"));
add((new field_bool("cache_prompt", params.cache_prompt))
->set_desc("Re-use KV cache from a previous request if possible. This way the common prefix does not have to be re-processed, only the suffix that differs between the requests"));
add((new field_bool("return_tokens", params.return_tokens))
->set_desc("Return the raw generated token ids in the `tokens` field"));
add((new field_bool("return_progress", params.return_progress))
->set_desc("Include prompt processing progress events in stream mode"));
add((new field_num("n_predict", params.n_predict))
->set_hard_limits(-1, INT32_MAX)
->add_alias("max_completion_tokens")
->add_alias("max_tokens")
->set_desc("Set the maximum number of tokens to predict. When 0, no tokens will be generated but the prompt is evaluated into the cache"));
add((new field_num("n_indent", params.n_indent))
->set_hard_limits(0, INT32_MAX)
->set_desc("Specify the minimum line indentation for the generated text in number of whitespace characters. Useful for code completion tasks"));
add((new field_num("n_keep", params.n_keep))
->set_hard_limits(-1, INT32_MAX)
->set_desc("Specify the number of tokens from the initial prompt to retain when context size is exceeded. Use -1 to retain all tokens from the prompt"));
add((new field_num("n_discard", params.n_discard))
->set_hard_limits(0, INT32_MAX)
->set_desc("Number of tokens after n_keep that may be discarded when shifting context (0 = half context)"));
add((new field_num("n_cmpl", params.n_cmpl))
->set_hard_limits(1, params_base.n_parallel)
->add_alias("n") // alias "n" as fallback (OpenAI completions API)
->set_desc("Number of completions to generate. If the input has multiple prompts, total outputs will be N prompts times n_cmpl"));
add((new field_num("n_cache_reuse", params.n_cache_reuse))
->set_hard_limits(0, INT32_MAX)
->set_desc("Min chunk size to attempt reusing from the cache via KV shifting. See --cache-reuse arg"));
// TODO: implement t_max_prompt_ms
// add((new field_num("t_max_prompt_ms", params.t_max_prompt_ms))
add((new field_num("t_max_predict_ms", params.t_max_predict_ms))
->set_hard_limits(-1, std::numeric_limits<int64_t>::max())
->set_desc("Set a time limit in milliseconds for the prediction phase. The timeout triggers if generation exceeds this time (measured since the first token) and a newline has been generated. Useful for FIM applications"));
add((new field_json("response_fields"))
->set_desc("A list of response fields to return. Missing fields are omitted without error. Fields with a slash are unnested (e.g. generation_settings/n_predict moves n_predict to the root)")
->set_handler([&](field_eval_context & ctx, const json & data) {
ctx.params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
}));
//
// Sampling params
//
add((new field_num("top_k", params.sampling.top_k))
->set_limits(0, INT32_MAX)
->set_desc("Limit the next token selection to the K most probable tokens (0 = disabled)"));
add((new field_num("top_p", params.sampling.top_p))
->set_limits(0.0f, 1.0f)
->set_desc("Limit the next token selection to a subset of tokens with cumulative probability above threshold P (1.0 = disabled)"));
add((new field_num("min_p", params.sampling.min_p))
->set_limits(0.0f, 1.0f)
->set_desc("The minimum probability for a token to be considered, relative to the probability of the most likely token (0 = disabled)"));
add((new field_num("top_n_sigma", params.sampling.top_n_sigma))
->set_desc("Keep tokens within n standard deviations of the top token logit (< 0 = disabled)"));
add((new field_num("xtc_probability", params.sampling.xtc_probability))
->set_limits(0.0f, 1.0f)
->set_desc("Set the chance for token removal via XTC sampler (0 = disabled)"));
add((new field_num("xtc_threshold", params.sampling.xtc_threshold))
->set_limits(0.0f, 1.0f)
->set_desc("Set a minimum probability threshold for tokens to be removed via XTC sampler (> 0.5 disables XTC)"));
add((new field_num("typical_p", params.sampling.typ_p))
// ->set_limits(0.0f, 1.0f) // what's the valid range?
->set_desc("Enable locally typical sampling with parameter p (1.0 = disabled)"));
add((new field_num("temperature", params.sampling.temp))
->set_limits(0.0f, std::numeric_limits<float>::infinity())
->set_desc("Adjust the randomness of the generated text (0 = greedy)"));
add((new field_num("dynatemp_range", params.sampling.dynatemp_range))
->set_desc("Dynamic temperature range. The final temperature will be in [temperature - range, temperature + range] (0 = disabled)"));
add((new field_num("dynatemp_exponent", params.sampling.dynatemp_exponent))
->set_desc("Dynamic temperature exponent, controls how entropy maps to temperature"));
add((new field_num("repeat_last_n", params.sampling.penalty_last_n))
->set_hard_limits(-1, INT32_MAX)
->set_desc("Last n tokens to consider for penalizing repetition (0 = disabled, -1 = ctx-size)"));
add((new field_num("repeat_penalty", params.sampling.penalty_repeat))
->set_desc("Control the repetition of token sequences in the generated text (1.0 = disabled)"));
add((new field_num("frequency_penalty", params.sampling.penalty_freq))
->set_desc("Repeat alpha frequency penalty (0 = disabled)"));
add((new field_num("presence_penalty", params.sampling.penalty_present))
->set_desc("Repeat alpha presence penalty (0 = disabled)"));
add((new field_num("dry_multiplier", params.sampling.dry_multiplier))
->set_desc("Set the DRY (Don't Repeat Yourself) repetition penalty multiplier (0 = disabled)"));
add((new field_num("dry_base", params.sampling.dry_base))
->set_desc("Set the DRY repetition penalty base value (must be >= 1.0, any values < 1.0 will be replaced with the default value)")
->set_handler([&](field_eval_context & ctx, const json & data) {
float v = data.at("dry_base").get<float>();
ctx.params.sampling.dry_base = (v < 1.0f) ? params_base.sampling.dry_base : v;
}));
add((new field_num("dry_allowed_length", params.sampling.dry_allowed_length))
->set_hard_limits(0, INT32_MAX)
->set_desc("Tokens that extend repetition beyond this length receive exponentially increasing penalty: multiplier * base ^ (sequence_length - allowed_length)"));
add((new field_num("dry_penalty_last_n", params.sampling.dry_penalty_last_n))
->set_hard_limits(-1, INT32_MAX)
->set_desc("How many tokens to scan for repetitions (0 = disabled, -1 = context size)"));
add((new field_num("mirostat", params.sampling.mirostat))
->set_limits(0, 2)
->set_desc("Enable Mirostat sampling, controlling perplexity during text generation (0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"));
add((new field_num("mirostat_tau", params.sampling.mirostat_tau))
->set_desc("Set the Mirostat target entropy, parameter tau"));
add((new field_num("mirostat_eta", params.sampling.mirostat_eta))
->set_desc("Set the Mirostat learning rate, parameter eta"));
add((new field_num("adaptive_target", params.sampling.adaptive_target))
->set_limits(-std::numeric_limits<float>::max(), 1.0f)
->set_desc("Adaptive sampling target entropy (valid range 0.0 to 1.0; negative = disabled)"));
add((new field_num("adaptive_decay", params.sampling.adaptive_decay))
->set_hard_limits(0.0f, 0.99f)
->set_desc("EMA decay for adaptive sampling; history approximates 1/(1-decay) tokens"));
// seed is uint32_t; field_num uses int32_t so use a handler
add((new field_num("seed", params.sampling.seed))
->set_desc("Set the random number generator (RNG) seed (-1 = random)"));
add((new field_num("n_probs", params.sampling.n_probs))
->add_alias("logprobs") // use "logprobs" if "n_probs" wasn't provided
->set_desc("If greater than 0, output the probabilities of top N tokens for each generated token"));
add((new field_num("min_keep", params.sampling.min_keep))
->set_hard_limits(0, INT32_MAX)
->set_desc("If greater than 0, force samplers to return at least N possible tokens"));
add((new field_bool("backend_sampling", params.sampling.backend_sampling))
->set_desc("Use backend sampling instead of llama.cpp sampling"));
add((new field_bool("post_sampling_probs", params.post_sampling_probs))
->set_desc("Return probabilities of top n_probs tokens after applying the sampling chain"));
//
// Speculative decoding params
//
// TODO: to keep things simple, we disable speculative parameter adjustments for now
#if 0
// TODO: for now, be able to adjust only the draft-model based speculative parameters
add((new field_num("speculative.n_max", params.speculative.draft.n_max))
->set_hard_limits(0, INT32_MAX)
->set_desc("Maximum number of tokens to draft during speculative decoding"));
add((new field_num("speculative.n_min", params.speculative.draft.n_min))
->set_hard_limits(0, INT32_MAX)
->set_desc("Minimum number of draft tokens to use for speculative decoding");
add((new field_num("speculative.p_min", params.speculative.draft.p_min))
->set_hard_limits(0.0f, 1.0f)
->set_desc("Minimum speculative decoding probability for draft tokens (0 = greedy)"));
add((new field_str("speculative.type"))
->set_desc("Speculative decoding method (for debugging and research purposes)")
->set_handler([&](field_eval_context & ctx, const json & data) {
ctx.params.speculative.types = { common_speculative_type_from_name(data.at("speculative.type").get<std::string>()) };
}));
add((new field_num("speculative.ngram_size_n", params.speculative.ngram_simple.size_n))
->set_desc("Ngram size for lookup in ngram-based speculative decoding"));
add((new field_num("speculative.ngram_size_m", params.speculative.ngram_simple.size_m))
->set_desc("Mgram size for speculative tokens in ngram-based speculative decoding"));
add((new field_num("speculative.ngram_min_hits", params.speculative.ngram_simple.min_hits))
->set_desc("Minimum hits at ngram lookup for mgram to be proposed"));
#endif
add((new field_json("lora"))
->set_desc("A list of LoRA adapters to apply to this request. Each entry must have `id` and `scale` fields. Adapters not listed default to scale 0.0")
->set_handler([&](field_eval_context & ctx, const json & data) {
const auto & lora = data.at("lora");
if (!lora.is_array()) {
throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
}
ctx.params.lora = parse_lora_request(lora);
}));
// sequence breakers for DRY
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
add((new field_json("dry_sequence_breakers"))
->set_desc("Specify an array of sequence breakers for DRY sampling. Only a JSON array of strings is accepted")
->set_handler([&](field_eval_context & ctx, const json & data) {
ctx.params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
if (ctx.params.sampling.dry_sequence_breakers.empty()) {
throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings");
}
}));
// handle both "json_schema" and "grammar"
add((new field_json("json_schema"))
->add_alias("grammar")
->set_desc("Set a JSON schema (json_schema) or GBNF grammar string (grammar) for constrained generation. json_schema takes precedence if both are provided")
->set_handler([&](field_eval_context & ctx, const json & data) {
auto & params = ctx.params;
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
auto schema = json_value(data, "json_schema", json::object());
SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str());
std::string grammar_str = json_schema_to_grammar(schema);
SRV_DBG("Converted grammar: %s\n", grammar_str.c_str());
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, std::move(grammar_str)};
} catch (const std::exception & e) {
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
}
} else {
std::string grammar_str = json_value(data, "grammar", std::string());
if (!grammar_str.empty()) {
// grammar_type key is set by the server when converting chat template grammars
std::string grammar_type = json_value(data, "grammar_type", std::string());
if (grammar_type == "tool_calls") {
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_TOOL_CALLS, std::move(grammar_str)};
} else {
// explicit grammar from the user (API field "grammar")
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, std::move(grammar_str)};
}
SRV_DBG("Grammar (%s): %s\n", grammar_type.c_str(), common_grammar_value(params.sampling.grammar).c_str());
}
}
}));
add((new field_bool("grammar_lazy", params.sampling.grammar_lazy))
->set_desc("Whether to apply grammar constraints lazily, only when triggered (instead of at every step)"));
//
// Chat parser params
//
// TODO: change this to string field instead
add((new field_json("chat_format"))
->set_desc("Chat format used internally by the server")
->set_handler([&](field_eval_context & ctx, const json & data) {
ctx.params.chat_parser_params.format = static_cast<common_chat_format>(data.at("chat_format").get<int>());
SRV_INF("Chat format: %s\n", common_chat_format_name(ctx.params.chat_parser_params.format));
}));
add((new field_str("reasoning_format"))
->set_desc("Reasoning format for chain-of-thought models")
->set_handler([&](field_eval_context & ctx, const json & data) {
auto reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get<std::string>());
ctx.params.chat_parser_params.reasoning_format = reasoning_format;
ctx.params.chat_parser_params.reasoning_in_content = ctx.params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
}));
add((new field_str("generation_prompt"))
->set_desc("Generation prompt appended to the chat template output")
->set_handler([&](field_eval_context & ctx, const json & data) {
std::string s = data.at("generation_prompt").get<std::string>();
ctx.params.chat_parser_params.generation_prompt = s;
ctx.params.sampling.generation_prompt = s;
}));
add((new field_bool("parse_tool_calls", params.chat_parser_params.parse_tool_calls))
->set_desc("Whether to parse tool calls from the generated output"));
add((new field_str("chat_parser"))
->set_desc("Chat parser configuration string")
->set_handler([&](field_eval_context & ctx, const json & data) {
ctx.params.chat_parser_params.parser.load(data.at("chat_parser").get<std::string>());
}));
add((new field_json("continue_final_message"))
->set_desc("Whether to continue the final message of the chat template")
->set_handler([&](field_eval_context & ctx, const json & data) {
auto continuation = common_chat_continuation_parse(data.at("continue_final_message"));
ctx.params.chat_parser_params.is_continuation = continuation != COMMON_CHAT_CONTINUATION_NONE;
}));
add((new field_bool("echo", params.chat_parser_params.echo))
->set_desc("Whether to echo the input tokens in the output"));
//
// Token-level fields (require vocab)
//
add((new field_json("preserved_tokens"))
->set_desc("List of token strings that must not be split during tokenization")
->set_handler([&](field_eval_context & ctx, const json & data) {
GGML_ASSERT(ctx.vocab != nullptr);
for (const auto & t : data.at("preserved_tokens")) {
auto ids = common_tokenize(ctx.vocab, t.get<std::string>(), false, true);
if (ids.size() == 1) {
ctx.params.sampling.preserved_tokens.insert(ids[0]);
}
}
}));
add((new field_json("grammar_triggers"))
->set_desc("List of strings or patterns that trigger grammar-constrained generation")
->set_handler([&](field_eval_context & ctx, const json & data) {
GGML_ASSERT(ctx.vocab != nullptr);
for (const auto & t : data.at("grammar_triggers")) {
server_grammar_trigger ct(t);
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
const auto & word = ct.value.value;
auto ids = common_tokenize(ctx.vocab, word, false, true);
if (ids.size() == 1) {
auto token = ids[0];
if (std::find(ctx.params.sampling.preserved_tokens.begin(), ctx.params.sampling.preserved_tokens.end(), (llama_token) token) == ctx.params.sampling.preserved_tokens.end()) {
throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
}
common_grammar_trigger trigger;
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
trigger.value = word;
trigger.token = token;
ctx.params.sampling.grammar_triggers.push_back(std::move(trigger));
} else {
ctx.params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
}
} else {
ctx.params.sampling.grammar_triggers.emplace_back(std::move(ct.value));
}
}
if (ctx.params.sampling.grammar_lazy && ctx.params.sampling.grammar_triggers.empty()) {
throw std::runtime_error("Error: no triggers set for lazy grammar!");
}
}));
add((new field_bool("reasoning_control", params.sampling.reasoning_control))
->set_desc("Create the budget sampler on demand so reasoning can be ended at runtime"));
add((new field_num("reasoning_budget_tokens", params.sampling.reasoning_budget_tokens))
->set_hard_limits(-1, INT32_MAX)
->set_desc("Number of tokens in the reasoning budget (-1 = disabled)"));
add((new field_str("reasoning_budget_start_tag"))
->set_desc("Token string marking the start of the reasoning budget section")
->set_handler([&](field_eval_context & ctx, const json & data) {
GGML_ASSERT(ctx.vocab != nullptr);
ctx.params.sampling.reasoning_budget_start = common_tokenize(ctx.vocab, data.at("reasoning_budget_start_tag").get<std::string>(), false, true);
}));
add((new field_str("reasoning_budget_end_tag"))
->set_desc("Token string marking the end of the reasoning budget section")
->set_handler([&](field_eval_context & ctx, const json & data) {
GGML_ASSERT(ctx.vocab != nullptr);
std::string end_tag = data.at("reasoning_budget_end_tag").get<std::string>();
ctx.params.sampling.reasoning_budget_end = common_tokenize(ctx.vocab, end_tag, false, true);
}));
add((new field_str("reasoning_budget_message"))
->set_desc("Message to prepend to the reasoning budget end tag when forcing it")
->set_handler([&](field_eval_context & ctx, const json & data) {
GGML_ASSERT(ctx.vocab != nullptr);
std::string end_tag = json_value(data, "reasoning_budget_end_tag", std::string());
std::string message = data.at("reasoning_budget_message").get<std::string>();
ctx.params.sampling.reasoning_budget_forced = common_tokenize(ctx.vocab, message + end_tag, false, true);
}));
add((new field_json("logit_bias"))
->set_desc("Modify the likelihood of specific tokens. Accepts an array of [token, bias] pairs or an object mapping token to bias. Use false as bias to ban a token")
->set_handler([&](field_eval_context & ctx, const json & data) {
GGML_ASSERT(ctx.vocab != nullptr);
ctx.params.sampling.logit_bias.clear();
const auto & logit_bias = data.at("logit_bias");
const int n_vocab = llama_vocab_n_tokens(ctx.vocab);
auto parse_bias = [](const json & v, float & bias) -> bool {
if (v.is_number()) { bias = v.get<float>(); return true; }
if (v.is_boolean() && !v.get<bool>()) { bias = -INFINITY; return true; }
return false;
};
if (logit_bias.is_array()) {
for (const auto & el : logit_bias) {
if (!el.is_array() || el.size() != 2) continue;
float bias;
if (!parse_bias(el[1], bias)) continue;
if (el[0].is_number_integer()) {
llama_token tok = el[0].get<llama_token>();
if (tok >= 0 && tok < n_vocab) ctx.params.sampling.logit_bias.push_back({tok, bias});
} else if (el[0].is_string()) {
for (auto tok : common_tokenize(ctx.vocab, el[0].get<std::string>(), false))
ctx.params.sampling.logit_bias.push_back({tok, bias});
}
}
} else if (logit_bias.is_object()) {
for (const auto & el : logit_bias.items()) {
float bias;
if (!parse_bias(el.value(), bias)) continue;
char * end;
llama_token tok = strtol(el.key().c_str(), &end, 10);
if (*end == 0) {
if (tok >= 0 && tok < n_vocab) ctx.params.sampling.logit_bias.push_back({tok, bias});
} else {
for (auto t : common_tokenize(ctx.vocab, el.key(), false))
ctx.params.sampling.logit_bias.push_back({t, bias});
}
}
}
}));
add((new field_bool("ignore_eos", params.sampling.ignore_eos))
->set_desc("Ignore the end-of-sequence token and continue generating")
->set_handler([&](field_eval_context & ctx, const json & data) {
GGML_ASSERT(ctx.logit_bias_eog != nullptr);
ctx.params.sampling.ignore_eos = data.at("ignore_eos").get<bool>();
if (ctx.params.sampling.ignore_eos && ctx.logit_bias_eog) {
ctx.params.sampling.logit_bias.insert(
ctx.params.sampling.logit_bias.end(),
ctx.logit_bias_eog->begin(), ctx.logit_bias_eog->end());
}
}));
add((new field_json("stop"))
->set_desc("Specify stopping strings. Generation stops when one is produced, and the string is not included in the output")
->set_handler([&](field_eval_context & ctx, const json & data) {
ctx.params.antiprompt.clear();
const auto & stop = data.at("stop");
if (stop.is_array()) {
for (const auto & word : stop) {
if (!word.empty()) ctx.params.antiprompt.push_back(word);
}
} else if (stop.is_string()) {
ctx.params.antiprompt.push_back(stop.get<std::string>());
}
// fall back to CLI defaults if the request provided no effective stop strings
if (ctx.params.antiprompt.empty()) {
ctx.params.antiprompt = params_base.antiprompt;
}
}));
add((new field_json("samplers"))
->set_desc("The order in which samplers are applied. An array of sampler type names, or a single string of sampler chars")
->set_handler([&](field_eval_context & ctx, const json & data) {
const auto & samplers = data.at("samplers");
if (samplers.is_array()) {
ctx.params.sampling.samplers = common_sampler_types_from_names(samplers);
} else if (samplers.is_string()) {
ctx.params.sampling.samplers = common_sampler_types_from_chars(samplers.get<std::string>());
}
}));
return fields;
}
task_params eval_llama_cmpl_schema(
const llama_vocab * vocab,
const common_params & params_base,
const int n_ctx_slot,
const std::vector<llama_logit_bias> & logit_bias_eog,
const json & data) {
task_params params;
// Sampling parameter defaults are loaded from the global server context (but individual requests can still them)
params.sampling = params_base.sampling;
params.speculative = params_base.speculative;
params.n_keep = params_base.n_keep;
params.n_predict = params_base.n_predict;
params.n_cache_reuse = params_base.n_cache_reuse;
params.cache_prompt = params_base.cache_prompt;
params.antiprompt = params_base.antiprompt;
// enabling this will output extra debug information in the HTTP responses from the server
params.verbose = params_base.verbosity > 9;
params.chat_parser_params.reasoning_format = params_base.reasoning_format;
// create context and schema
field_eval_context ctx(params);
ctx.vocab = vocab;
ctx.logit_bias_eog = &logit_bias_eog;
auto schema = make_llama_cmpl_schema(params_base, params);
// eval all fields in the schema
for (const auto & f : schema) {
f->eval(ctx, data);
}
// post-processing
{
if (params.sampling.penalty_last_n == -1) {
// note: should be the slot's context and not the full context, but it's ok
params.sampling.penalty_last_n = n_ctx_slot;
}
if (params.sampling.dry_penalty_last_n == -1) {
params.sampling.dry_penalty_last_n = n_ctx_slot;
}
// if "reasoning_format" is not provided, its handler will not be called, we will need to handle it here
auto reasoning_format = params.chat_parser_params.reasoning_format;
params.chat_parser_params.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
}
// debugging
{
auto budget = params.sampling.reasoning_budget_tokens;
SRV_DBG("reasoning budget: tokens=%d, generation_prompt='%s', start=%zu toks, end=%zu toks, forced=%zu toks\n",
budget, params.sampling.generation_prompt.c_str(),
params.sampling.reasoning_budget_start.size(),
params.sampling.reasoning_budget_end.size(),
params.sampling.reasoning_budget_forced.size());
}
return params;
}
//
// eval() implementations
//
static void handle_with_catch(const char * name, std::function<void()> func) {
try {
func();
} catch (const std::exception & e) {
throw std::invalid_argument(string_format("Field '%s': %s", name, e.what()));
}
}
template <typename T>
void field_num<T>::eval(field_eval_context & ctx, const json & data) {
for (const auto & n : name) {
if (data.contains(n)) {
handle_with_catch(n, [&]() {
if (custom_handler) {
custom_handler(ctx, data);
} else if (!is_hard_limit) {
val = std::max(min, std::min(max, data.at(n).template get<T>()));
} else {
T tmp = data.at(n).template get<T>();
if (tmp < min || tmp > max) {
throw std::invalid_argument(std::string("Value must be between ") + std::to_string(min) + " <= value <= " + std::to_string(max) + ", but got " + std::to_string(tmp));
}
val = tmp;
}
});
return;
}
}
}
void field_str::eval(field_eval_context & ctx, const json & data) {
GGML_ASSERT(custom_handler);
for (const auto & n : name) {
if (data.contains(n)) {
handle_with_catch(n, [&]() {
custom_handler(ctx, data);
});
return;
}
}
}
void field_bool::eval(field_eval_context & ctx, const json & data) {
for (const auto & n : name) {
if (data.contains(n)) {
handle_with_catch(n, [&]() {
if (custom_handler) {
custom_handler(ctx, data);
} else {
val = data.at(n).get<bool>();
}
});
return;
}
}
}
void field_json::eval(field_eval_context & ctx, const json & data) {
GGML_ASSERT(custom_handler);
for (const auto & n : name) {
if (data.contains(n)) {
handle_with_catch(n, [&]() {
custom_handler(ctx, data);
});
return;
}
}
}
void field_nested::eval(field_eval_context & ctx, const json & data) {
for (const auto & n : name) {
if (data.contains(n) && data.at(n).is_object()) {
for (auto & f : subfields) {
f->eval(ctx, data.at(n));
}
return;
}
}
}
} // namespace server_schema
+105
View File
@@ -0,0 +1,105 @@
#pragma once
#include "server-common.h"
#include "server-task.h"
#include "sampling.h"
#include "speculative.h"
#include <climits>
#include <functional>
#include <limits>
#include <memory>
#include <string>
#include <vector>
namespace server_schema {
struct field_eval_context {
task_params & params;
const llama_vocab * vocab = nullptr;
const std::vector<llama_logit_bias> * logit_bias_eog = nullptr;
field_eval_context(task_params & params) : params(params) {}
};
using field_handler = std::function<void(field_eval_context &, const json &)>;
struct field {
std::vector<const char *> name;
const char * desc = "";
field_handler custom_handler;
field() = default;
field(const char * n) : name({n}) {}
virtual ~field() = default;
field * set_desc(const char * s) {
desc = s;
return this;
}
// if 'name' is present, use it, otherwise look for aliases following the order they were added
field * add_alias(const char * n) {
name.push_back(n);
return this;
}
field * set_handler(field_handler h) { this->custom_handler = h; return this; }
virtual void eval(field_eval_context & ctx, const json & data) = 0;
};
template <typename T = int32_t>
struct field_num : public field {
T & val;
T min = std::numeric_limits<T>::lowest();
T max = std::numeric_limits<T>::max();
bool is_hard_limit = false; // if true, throw error if the value is invalid
field_num(const char * n, T & val) : field(n), val(val) {}
// limits are inclusive, min <= value <= max
field_num * set_limits(T min, T max) {
this->min = min;
this->max = max;
return this;
}
field_num * set_hard_limits(T min, T max) {
set_limits(min, max);
is_hard_limit = true;
return this;
}
virtual void eval(field_eval_context & ctx, const json & data) override;
};
struct field_str : public field {
field_str(const char * n) : field(n) {}
virtual void eval(field_eval_context & ctx, const json & data) override;
};
struct field_bool : public field {
bool & val;
field_bool(const char * n, bool & val) : field(n), val(val) {}
virtual void eval(field_eval_context & ctx, const json & data) override;
};
struct field_json : public field {
field_json(const char * n) : field(n) {}
virtual void eval(field_eval_context & ctx, const json & data) override;
};
struct field_nested : public field {
std::vector<std::unique_ptr<field>> subfields;
field_nested(const char * n) : field(n) {}
field_nested * add_subfield(field * f) {
subfields.emplace_back(std::unique_ptr<field>(f));
return this;
}
virtual void eval(field_eval_context & ctx, const json & data) override;
};
std::vector<std::unique_ptr<field>> make_llama_cmpl_schema(
const common_params & params_base,
task_params & params);
task_params eval_llama_cmpl_schema(
const llama_vocab * vocab,
const common_params & params_base,
const int n_ctx_slot,
const std::vector<llama_logit_bias> & logit_bias_eog,
const json & data);
} // namespace server_schema
-388
View File
@@ -232,396 +232,8 @@ common_chat_msg task_result_state::update_chat_msg(
return chat_msg;
}
//
// server_task
//
task_params server_task::params_from_json_cmpl(
const llama_vocab * vocab,
const common_params & params_base,
const int n_ctx_slot,
const std::vector<llama_logit_bias> & logit_bias_eog,
const json & data) {
task_params params;
// Sampling parameter defaults are loaded from the global server context (but individual requests can still them)
task_params defaults;
defaults.sampling = params_base.sampling;
defaults.speculative = params_base.speculative;
defaults.n_keep = params_base.n_keep;
defaults.n_predict = params_base.n_predict;
defaults.n_cache_reuse = params_base.n_cache_reuse;
defaults.cache_prompt = params_base.cache_prompt;
defaults.antiprompt = params_base.antiprompt;
// enabling this will output extra debug information in the HTTP responses from the server
params.verbose = params_base.verbosity > 9;
params.timings_per_token = json_value(data, "timings_per_token", false);
params.stream = json_value(data, "stream", false);
auto stream_opt = json_value(data, "stream_options", json::object());
params.include_usage = json_value(stream_opt, "include_usage", false);
params.cache_prompt = json_value(data, "cache_prompt", defaults.cache_prompt);
params.return_tokens = json_value(data, "return_tokens", false);
params.return_progress = json_value(data, "return_progress", false);
auto max_tokens = json_value(data, "max_tokens", defaults.n_predict);
params.n_predict = json_value(data, "n_predict", json_value(data, "max_completion_tokens", max_tokens));
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
params.n_discard = std::max(0, params.n_discard);
params.n_cmpl = json_value(data, "n_cmpl", json_value(data, "n", 1));
params.n_cache_reuse = json_value(data, "n_cache_reuse", defaults.n_cache_reuse);
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma);
params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
params.sampling.adaptive_target = json_value(data, "adaptive_target", defaults.sampling.adaptive_target);
params.sampling.adaptive_decay = json_value(data, "adaptive_decay", defaults.sampling.adaptive_decay);
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
params.sampling.backend_sampling = json_value(data, "backend_sampling", defaults.sampling.backend_sampling);
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
params.speculative = defaults.speculative;
// TODO: to keep things simple, we disable speculative parameter adjustments for now
#if 0
// TODO: for now, be able to adjust only the draft-model based speculative parameters
params.speculative.draft.n_min = json_value(data, "speculative.n_min", defaults.speculative.draft.n_min);
params.speculative.draft.n_max = json_value(data, "speculative.n_max", defaults.speculative.draft.n_max);
params.speculative.draft.p_min = json_value(data, "speculative.p_min", defaults.speculative.draft.p_min);
params.speculative.draft.n_min = std::min(params.speculative.draft.n_max, params.speculative.draft.n_min);
params.speculative.draft.n_min = std::max(params.speculative.draft.n_min, 0);
params.speculative.draft.n_max = std::max(params.speculative.draft.n_max, 0);
// for debugging and research purposes
params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type)));
params.speculative.ngram_size_n = json_value(data, "speculative.ngram_size_n", defaults.speculative.ngram_size_n);
params.speculative.ngram_size_m = json_value(data, "speculative.ngram_size_m", defaults.speculative.ngram_size_m);
params.speculative.ngram_min_hits = json_value(data, "speculative.ngram_m_hits", defaults.speculative.ngram_min_hits);
params.speculative.ngram_size_n = std::max(std::min(1, (int) params.speculative.ngram_size_n), 1024);
params.speculative.ngram_size_m = std::max(std::min(1, (int) params.speculative.ngram_size_m), 1024);
params.speculative.ngram_min_hits = std::max(std::min(1, (int) params.speculative.ngram_min_hits), 1024);
#endif
// Use OpenAI API logprobs only if n_probs wasn't provided
if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);
}
if (data.contains("lora")) {
if (data.at("lora").is_array()) {
params.lora = parse_lora_request(data.at("lora"));
} else {
throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
}
} else {
params.lora = {};
}
// TODO: add more sanity checks for the input parameters
if (params.sampling.penalty_last_n < -1) {
throw std::runtime_error("Error: repeat_last_n must be >= -1");
}
if (params.sampling.dry_penalty_last_n < -1) {
throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
}
if (params.sampling.penalty_last_n == -1) {
// note: should be the slot's context and not the full context, but it's ok
params.sampling.penalty_last_n = n_ctx_slot;
}
if (params.sampling.dry_penalty_last_n == -1) {
params.sampling.dry_penalty_last_n = n_ctx_slot;
}
if (params.sampling.dry_base < 1.0f) {
params.sampling.dry_base = defaults.sampling.dry_base;
}
// sequence breakers for DRY
{
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
if (data.contains("dry_sequence_breakers")) {
params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
if (params.sampling.dry_sequence_breakers.empty()) {
throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings");
}
}
}
// process "json_schema" and "grammar"
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
auto schema = json_value(data, "json_schema", json::object());
SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str());
std::string grammar_str = json_schema_to_grammar(schema);
SRV_DBG("Converted grammar: %s\n", grammar_str.c_str());
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, std::move(grammar_str)};
} catch (const std::exception & e) {
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
}
} else {
params.sampling.grammar = defaults.sampling.grammar;
std::string grammar_str = json_value(data, "grammar", std::string());
if (!grammar_str.empty()) {
// grammar_type key is set by the server when converting chat template grammars
std::string grammar_type = json_value(data, "grammar_type", std::string());
if (grammar_type == "tool_calls") {
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_TOOL_CALLS, std::move(grammar_str)};
} else {
// explicit grammar from the user (API field "grammar")
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, std::move(grammar_str)};
}
SRV_DBG("Grammar (%s): %s\n", grammar_type.c_str(), common_grammar_value(params.sampling.grammar).c_str());
}
params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
}
{
auto it = data.find("chat_format");
if (it != data.end()) {
params.chat_parser_params.format = static_cast<common_chat_format>(it->get<int>());
SRV_INF("Chat format: %s\n", common_chat_format_name(params.chat_parser_params.format));
} else {
params.chat_parser_params.format = defaults.chat_parser_params.format;
}
common_reasoning_format reasoning_format = params_base.reasoning_format;
if (data.contains("reasoning_format")) {
reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get<std::string>());
}
params.chat_parser_params.reasoning_format = reasoning_format;
params.chat_parser_params.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
params.chat_parser_params.generation_prompt = json_value(data, "generation_prompt", std::string());
params.sampling.generation_prompt = params.chat_parser_params.generation_prompt;
SRV_DBG("Generation prompt: '%s'\n", params.chat_parser_params.generation_prompt.c_str());
params.chat_parser_params.parse_tool_calls = json_value(data, "parse_tool_calls", false);
if (data.contains("chat_parser")) {
params.chat_parser_params.parser.load(data.at("chat_parser").get<std::string>());
}
if (data.contains("continue_final_message")) {
auto continuation = common_chat_continuation_parse(data.at("continue_final_message"));
params.chat_parser_params.is_continuation = continuation != COMMON_CHAT_CONTINUATION_NONE;
}
params.chat_parser_params.echo = json_value(data, "echo", false);
}
{
const auto preserved_tokens = data.find("preserved_tokens");
if (preserved_tokens != data.end()) {
for (const auto & t : *preserved_tokens) {
auto ids = common_tokenize(vocab, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
SRV_DBG("Preserved token: %d\n", ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
} else {
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
SRV_DBG("Not preserved because more than 1 token: %s\n", t.get<std::string>().c_str());
}
}
}
const auto grammar_triggers = data.find("grammar_triggers");
if (grammar_triggers != data.end()) {
for (const auto & t : *grammar_triggers) {
server_grammar_trigger ct(t);
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
const auto & word = ct.value.value;
auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
auto token = ids[0];
if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) {
throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
}
SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
common_grammar_trigger trigger;
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
trigger.value = word;
trigger.token = token;
params.sampling.grammar_triggers.push_back(std::move(trigger));
} else {
SRV_DBG("Grammar trigger word: `%s`\n", word.c_str());
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
}
} else {
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) {
SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str());
} else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) {
SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str());
} else {
throw std::runtime_error("Unknown grammar trigger type");
}
params.sampling.grammar_triggers.emplace_back(std::move(ct.value));
}
}
}
if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) {
throw std::runtime_error("Error: no triggers set for lazy grammar!");
}
}
// Parse reasoning budget sampler parameters
{
const int32_t budget = json_value(data, "reasoning_budget_tokens", (int32_t) -1);
const auto start_tag = json_value(data, "reasoning_budget_start_tag", std::string());
const auto end_tag = json_value(data, "reasoning_budget_end_tag", std::string());
const auto message = json_value(data, "reasoning_budget_message", std::string());
params.sampling.reasoning_budget_tokens = budget;
params.sampling.reasoning_control = json_value(data, "reasoning_control", false);
if (!start_tag.empty()) {
params.sampling.reasoning_budget_start = common_tokenize(vocab, start_tag, false, true);
}
if (!end_tag.empty()) {
params.sampling.reasoning_budget_end = common_tokenize(vocab, end_tag, false, true);
params.sampling.reasoning_budget_forced = common_tokenize(vocab, message + end_tag, false, true);
SRV_DBG("reasoning budget: tokens=%d, generation_prompt='%s', start=%zu toks, end=%zu toks, forced=%zu toks\n",
budget, params.sampling.generation_prompt.c_str(),
params.sampling.reasoning_budget_start.size(),
params.sampling.reasoning_budget_end.size(),
params.sampling.reasoning_budget_forced.size());
}
}
{
params.sampling.logit_bias.clear();
const auto & logit_bias = data.find("logit_bias");
if (logit_bias != data.end() && logit_bias->is_array()) {
const int n_vocab = llama_vocab_n_tokens(vocab);
for (const auto & el : *logit_bias) {
// TODO: we may want to throw errors here, in case "el" is incorrect
if (el.is_array() && el.size() == 2) {
float bias;
if (el[1].is_number()) {
bias = el[1].get<float>();
} else if (el[1].is_boolean() && !el[1].get<bool>()) {
bias = -INFINITY;
} else {
continue;
}
if (el[0].is_number_integer()) {
llama_token tok = el[0].get<llama_token>();
if (tok >= 0 && tok < n_vocab) {
params.sampling.logit_bias.push_back({tok, bias});
}
} else if (el[0].is_string()) {
auto toks = common_tokenize(vocab, el[0].get<std::string>(), false);
for (auto tok : toks) {
params.sampling.logit_bias.push_back({tok, bias});
}
}
}
}
} else if (logit_bias != data.end() && logit_bias->is_object()) {
const int n_vocab = llama_vocab_n_tokens(vocab);
for (const auto & el : logit_bias->items()) {
float bias;
const auto & key = el.key();
const auto & value = el.value();
if (value.is_number()) {
bias = value.get<float>();
} else if (value.is_boolean() && !value.get<bool>()) {
bias = -INFINITY;
} else {
continue;
}
char *end;
llama_token tok = strtol(key.c_str(), &end, 10);
if (*end == 0) {
if (tok >= 0 && tok < n_vocab) {
params.sampling.logit_bias.push_back({tok, bias});
}
} else {
auto toks = common_tokenize(vocab, key, false);
for (auto tok : toks) {
params.sampling.logit_bias.push_back({tok, bias});
}
}
}
}
params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos);
if (params.sampling.ignore_eos) {
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
logit_bias_eog.begin(), logit_bias_eog.end());
}
}
{
params.antiprompt.clear();
const auto & stop = data.find("stop");
if (stop != data.end() && stop->is_array()) {
for (const auto & word : *stop) {
if (!word.empty()) {
params.antiprompt.push_back(word);
}
}
}
// set reverse prompt from cli args if not set in the request
if (params.antiprompt.empty()) {
params.antiprompt = defaults.antiprompt;
}
}
{
const auto samplers = data.find("samplers");
if (samplers != data.end()) {
if (samplers->is_array()) {
params.sampling.samplers = common_sampler_types_from_names(*samplers);
} else if (samplers->is_string()){
params.sampling.samplers = common_sampler_types_from_chars(samplers->get<std::string>());
}
} else {
params.sampling.samplers = defaults.sampling.samplers;
}
}
if (params.n_cmpl > params_base.n_parallel) {
throw std::runtime_error("n_cmpl cannot be greater than the number of slots, please increase -np");
}
return params;
}
//
// result_timings
//
-7
View File
@@ -210,13 +210,6 @@ struct server_task {
}
}
static task_params params_from_json_cmpl(
const llama_vocab * vocab,
const common_params & params_base,
const int n_ctx_slot,
const std::vector<llama_logit_bias> & logit_bias_eog,
const json & data);
// utility function
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
std::unordered_set<int> ids(tasks.size());
+6
View File
@@ -349,6 +349,12 @@ int llama_server(int argc, char ** argv) {
SRV_INF("router server is listening on %s\n", ctx_http.listening_address.c_str());
SRV_WRN("%s", "NOTE: router mode is experimental\n");
SRV_WRN("%s", " it is not recommended to use this mode in untrusted environments\n");
if (!params.models_preset_hf.empty()) {
SRV_WRN( "NOTE: using preset.ini from HF repo '%s'\n", params.models_preset_hf.c_str());
SRV_WRN("%s", " please only use presets that you can trust! Unknown presets may be unsafe\n");
}
if (ctx_http.thread.joinable()) {
ctx_http.thread.join(); // keep the main thread alive
}
@@ -307,6 +307,20 @@ def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re
assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
def test_completion_with_invalid_grammar():
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 8,
"messages": [
{"role": "user", "content": "Does not matter what I say, does it?"},
],
"grammar": "root ::= this is (not valid GBNF",
})
assert res.status_code == 400, res.body
assert "error" in res.body
@pytest.mark.parametrize("messages", [
None,
"string",
@@ -79,7 +79,7 @@
<!-- svelte-ignore a11y_no_static_element_interactions -->
<!-- svelte-ignore a11y_click_events_have_key_events -->
<div
class="pointer-events-none flex items-center justify-center gap-0.75 pl-2 opacity-0 group-hover:pointer-events-auto group-hover:opacity-100"
class="pointer-events-none flex items-center justify-center gap-0.75 pl-2 opacity-0 group-hover:pointer-events-auto group-hover:opacity-100 [@media(pointer:coarse)]:pointer-events-auto [@media(pointer:coarse)]:opacity-100"
onclick={(e) => e.stopPropagation()}
>
{#if isFav}
@@ -113,12 +113,16 @@
</div>
{#if isLoading}
<Loader2 class="h-4 w-4 animate-spin text-muted-foreground" />
<div class="flex w-4 [@media(pointer:coarse)]:w-5 items-center justify-center">
<Loader2 class="h-4 w-4 animate-spin text-muted-foreground" />
</div>
{:else if isFailed}
<div class="flex w-4 items-center justify-center">
<CircleAlert class="h-3.5 w-3.5 text-red-500 group-hover:hidden" />
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
<CircleAlert
class="h-3.5 w-3.5 text-red-500 group-hover:hidden [@media(pointer:coarse)]:hidden"
/>
<div class="hidden group-hover:flex">
<div class="hidden group-hover:flex [@media(pointer:coarse)]:flex">
<ActionIcon
iconSize="h-2.5 w-2.5"
icon={RotateCw}
@@ -130,15 +134,17 @@
</div>
</div>
{:else if isSleeping}
<div class="flex w-4 items-center justify-center">
<span class="h-2 w-2 rounded-full bg-orange-400 group-hover:hidden"></span>
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
<span
class="h-2 w-2 rounded-full bg-orange-400 group-hover:hidden [@media(pointer:coarse)]:hidden"
></span>
<div class="hidden group-hover:flex">
<div class="hidden group-hover:flex [@media(pointer:coarse)]:flex">
<ActionIcon
iconSize="h-2.5 w-2.5"
icon={PowerOff}
tooltip="Unload model"
class="h-3 w-3 text-red-500 hover:text-red-600"
class="h-3 w-3 text-red-500 hover:text-red-600 [@media(pointer:coarse)]:text-amber-500 [@media(pointer:coarse)]:hover:text-amber-600"
onclick={(e) => {
e?.stopPropagation();
modelsStore.unloadModel(option.model);
@@ -147,30 +153,34 @@
</div>
</div>
{:else if isLoaded}
<div class="flex w-4 items-center justify-center">
<span class="h-2 w-2 rounded-full bg-green-500 group-hover:hidden"></span>
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
<span
class="h-2 w-2 rounded-full bg-green-500 group-hover:hidden [@media(pointer:coarse)]:hidden"
></span>
<div class="hidden group-hover:flex">
<div class="hidden group-hover:flex [@media(pointer:coarse)]:flex">
<ActionIcon
iconSize="h-2.5 w-2.5"
icon={PowerOff}
tooltip="Unload model"
class="h-3 w-3 text-red-500 hover:text-red-600"
class="h-3 w-3 text-red-500 hover:text-red-600 [@media(pointer:coarse)]:text-green-500 [@media(pointer:coarse)]:hover:text-green-600"
onclick={() => modelsStore.unloadModel(option.model)}
stopPropagationOnClick
/>
</div>
</div>
{:else}
<div class="flex w-4 items-center justify-center">
<span class="h-2 w-2 rounded-full bg-muted-foreground/50 group-hover:hidden"></span>
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
<span
class="h-2 w-2 rounded-full bg-muted-foreground/50 group-hover:hidden [@media(pointer:coarse)]:hidden"
></span>
<div class="hidden group-hover:flex">
<div class="hidden group-hover:flex [@media(pointer:coarse)]:flex">
<ActionIcon
iconSize="h-2.5 w-2.5"
icon={Power}
tooltip="Load model"
class="h-3 w-3"
class="h-3 w-3 [@media(pointer:coarse)]:text-muted-foreground"
onclick={() => modelsStore.loadModel(option.model)}
stopPropagationOnClick
/>
@@ -66,7 +66,7 @@
<button
type="button"
class={[
`inline-flex cursor-pointer items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
`inline-flex cursor-pointer items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 max-sm:px-3 max-sm:py-2 text-xs max-sm:text-sm shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
!ms.isCurrentModelInCache
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
: forceForegroundText
@@ -0,0 +1,269 @@
<script module lang="ts">
import { defineMeta } from '@storybook/addon-svelte-csf';
import ModelsSelectorList from '$lib/components/app/models/ModelsSelectorList.svelte';
import ModelsSelectorOption from '$lib/components/app/models/ModelsSelectorOption.svelte';
import type { GroupedModelOptions, ModelItem } from '$lib/components/app/models/utils';
import { modelsStore } from '$lib/stores/models.svelte';
import { ServerModelStatus } from '$lib/enums';
const { Story } = defineMeta({
title: 'Components/ModelsSelector',
parameters: {
layout: 'centered'
}
});
const mockModel = (id: string, name: string, orgName?: string, tags?: string[]): ModelOption => ({
id,
name,
model: orgName ? `${orgName}/${name}` : name,
capabilities: [],
parsedId: {
raw: orgName ? `${orgName}/${name}` : name,
orgName: orgName ?? null,
modelName: name,
params: null,
activatedParams: null,
quantization: null,
tags: tags ?? []
},
tags
});
const mockRouterEntry = (modelName: string, status: ServerModelStatus): ApiModelDataEntry => ({
id: modelName,
object: 'model',
owned_by: 'llamacpp',
created: Date.now(),
in_cache: true,
path: `/models/${modelName}`,
status: { value: status }
});
</script>
<script lang="ts">
let selectedModel = $state<string | null>(null);
let activeId = $state<string | null>(null);
function mockModelsStore() {
modelsStore.favoriteModelIds = new Set(['qwen2.5-7b', 'llama3.2-3b']);
// Mock router models with various statuses for ModelLoadedStates story
modelsStore.routerModels = [
mockRouterEntry('meta/Model (loading)', ServerModelStatus.LOADING),
mockRouterEntry('meta/Model (loaded)', ServerModelStatus.LOADED),
mockRouterEntry('meta/Model (sleeping)', ServerModelStatus.SLEEPING),
mockRouterEntry('meta/Model (failed)', ServerModelStatus.FAILED)
];
}
mockModelsStore();
const loadedModels: ModelItem[] = [
{ option: mockModel('llama3.1-8b', 'Llama-3.1-8B-Instruct', 'meta'), flatIndex: 0 },
{ option: mockModel('mistral-7b', 'Mistral-7B-v0.3', 'mistralai'), flatIndex: 1 }
];
const favoriteModels: ModelItem[] = [
{ option: mockModel('qwen2.5-7b', 'Qwen2.5-7B-Instruct', 'Qwen'), flatIndex: 2 },
{ option: mockModel('llama3.2-3b', 'Llama-3.2-3B-Instruct', 'meta'), flatIndex: 3 }
];
const availableModels: ModelItem[] = [
{
option: mockModel('deepseek-coder-6.7b', 'DeepSeek-Coder-6.7B', 'deepseek', ['coding']),
flatIndex: 4
},
{ option: mockModel('gemma-2-9b', 'Gemma-2-9B-IT', 'google'), flatIndex: 5 },
{ option: mockModel('phi-3-mini', 'Phi-3-mini-4k', 'microsoft'), flatIndex: 6 },
{ option: mockModel('codellama-7b', 'CodeLlama-7B', 'codellama', ['coding']), flatIndex: 7 },
{ option: mockModel('neural-chat-7b', 'Neural-Chat-7B-v3-3', 'intel'), flatIndex: 8 }
];
const groupedOptions: GroupedModelOptions = {
loaded: loadedModels,
favorites: favoriteModels,
available: [
{
orgName: 'deepseek',
items: [availableModels[0]]
},
{
orgName: 'google',
items: [availableModels[1]]
},
{
orgName: 'microsoft',
items: [availableModels[2]]
},
{
orgName: 'codellama',
items: [availableModels[3]]
},
{
orgName: 'intel',
items: [availableModels[4]]
}
]
};
function handleSelect(modelId: string) {
const opt = [...loadedModels, ...favoriteModels, ...availableModels].find(
(m) => m.option.id === modelId
);
if (opt) {
selectedModel = opt.option.model;
activeId = modelId;
}
}
</script>
<Story name="List">
<div class="w-80 rounded-lg border border-border bg-popover p-2 shadow-md">
<ModelsSelectorList
groups={groupedOptions}
currentModel={selectedModel}
{activeId}
onSelect={handleSelect}
onInfoClick={(modelName) => console.log('Info clicked:', modelName)}
/>
</div>
</Story>
<Story name="SingleLoaded">
<div class="w-80 rounded-lg border border-border bg-popover p-2 shadow-md">
<ModelsSelectorList
groups={{
loaded: [loadedModels[0]],
favorites: [],
available: []
}}
currentModel={null}
activeId={null}
onSelect={handleSelect}
onInfoClick={(modelName) => console.log('Info clicked:', modelName)}
/>
</div>
</Story>
<Story name="WithFavoritesOnly">
<div class="w-80 rounded-lg border border-border bg-popover p-2 shadow-md">
<ModelsSelectorList
groups={{
loaded: [],
favorites: favoriteModels,
available: []
}}
currentModel={null}
activeId={null}
onSelect={handleSelect}
onInfoClick={(modelName) => console.log('Info clicked:', modelName)}
/>
</div>
</Story>
<Story name="ModelLoadedStates">
<div class="w-80 rounded-lg border border-border bg-popover p-2 shadow-md">
<div class="px-2 py-2 text-[13px] font-semibold text-muted-foreground/70 select-none">
Server model states
</div>
<ModelsSelectorOption
option={mockModel('model-idle', 'Model (idle)', 'meta')}
isSelected={false}
isHighlighted={false}
isFav={false}
hideOrgName={true}
onSelect={() => {}}
onMouseEnter={() => {}}
onKeyDown={() => {}}
/>
<ModelsSelectorOption
option={mockModel('model-loading', 'Model (loading)', 'meta')}
isSelected={false}
isHighlighted={false}
isFav={false}
hideOrgName={true}
onSelect={() => {}}
onMouseEnter={() => {}}
onKeyDown={() => {}}
/>
<ModelsSelectorOption
option={mockModel('model-loaded', 'Model (loaded)', 'meta')}
isSelected={false}
isHighlighted={false}
isFav={false}
hideOrgName={true}
onSelect={() => {}}
onMouseEnter={() => {}}
onKeyDown={() => {}}
/>
<ModelsSelectorOption
option={mockModel('model-sleeping', 'Model (sleeping)', 'meta')}
isSelected={false}
isHighlighted={false}
isFav={false}
hideOrgName={true}
onSelect={() => {}}
onMouseEnter={() => {}}
onKeyDown={() => {}}
/>
<ModelsSelectorOption
option={mockModel('model-failed', 'Model (failed)', 'meta')}
isSelected={false}
isHighlighted={false}
isFav={false}
hideOrgName={true}
onSelect={() => {}}
onMouseEnter={() => {}}
onKeyDown={() => {}}
/>
</div>
</Story>
<Story name="ModelSelectedStates">
<div class="w-80 rounded-lg border border-border bg-popover p-2 shadow-md">
<div class="px-2 py-2 text-[13px] font-semibold text-muted-foreground/70 select-none">
Selection states
</div>
<ModelsSelectorOption
option={mockModel('normal-model', 'Normal Model', 'meta')}
isSelected={false}
isHighlighted={false}
isFav={false}
hideOrgName={true}
onSelect={() => {}}
onMouseEnter={() => {}}
onKeyDown={() => {}}
/>
<ModelsSelectorOption
option={mockModel('selected-model', 'Selected Model', 'meta')}
isSelected={true}
isHighlighted={false}
isFav={false}
hideOrgName={true}
onSelect={() => {}}
onMouseEnter={() => {}}
onKeyDown={() => {}}
/>
<ModelsSelectorOption
option={mockModel('highlighted-model', 'Highlighted Model', 'meta')}
isSelected={false}
isHighlighted={true}
isFav={false}
hideOrgName={true}
onSelect={() => {}}
onMouseEnter={() => {}}
onKeyDown={() => {}}
/>
<ModelsSelectorOption
option={mockModel('fav-model', 'Favorite Model', 'Qwen')}
isSelected={false}
isHighlighted={false}
isFav={true}
hideOrgName={true}
onSelect={() => {}}
onMouseEnter={() => {}}
onKeyDown={() => {}}
/>
</div>
</Story>