mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-18 19:57:46 +02:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 32eddaf2ea | |||
| 060ce1bf72 | |||
| d2c67959b3 | |||
| 7b6c5a2aed | |||
| fe7c8b2414 | |||
| e1efd0991d | |||
| 08023072ef | |||
| 20832179e2 |
@@ -69,6 +69,7 @@ static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE;
|
||||
static int opt_opbatch = 1024; // max number of ops in a batch
|
||||
static int opt_opqueue = 16; // max number of pending batches
|
||||
static int opt_oppoll = 0; // polling for batch completions
|
||||
static int opt_optrace = 0; // trace buffer size per thread (0 means default)
|
||||
|
||||
static std::regex* opt_opfilter = NULL; // regex of ops to not claim
|
||||
|
||||
@@ -118,20 +119,39 @@ static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct
|
||||
ggml_op_desc(op), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, supp ? "yes" : "no");
|
||||
}
|
||||
|
||||
static const char * htp_event_name(uint16_t id) {
|
||||
switch (id) {
|
||||
case HTP_TRACE_EVT_DMA: return "DMA";
|
||||
case HTP_TRACE_EVT_HVX_COMP: return "HVX_COMP";
|
||||
case HTP_TRACE_EVT_HVX_A_QUANT: return "HVX_A_QUANT";
|
||||
case HTP_TRACE_EVT_HVX_A_PREP: return "HVX_A_PREP";
|
||||
case HTP_TRACE_EVT_HVX_W_DEQUANT: return "HVX_W_DEQUANT";
|
||||
case HTP_TRACE_EVT_HVX_W_PREP: return "HVX_W_PREP";
|
||||
case HTP_TRACE_EVT_HVX_O_PROC: return "HVX_O_PROC";
|
||||
case HTP_TRACE_EVT_HMX_COMP: return "HMX_COMP";
|
||||
default: return "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const htp_opnode & node,
|
||||
uint32_t op_usec, uint32_t op_cycles, const uint32_t pmu[]) {
|
||||
const htp_prof_desc & pd) {
|
||||
if (!opt_profile) return;
|
||||
|
||||
uint32_t op_usec = pd.usecs;
|
||||
uint32_t op_cycles = pd.cycles_stop - pd.cycles_start;
|
||||
const uint32_t * pmu = pd.pmu;
|
||||
|
||||
char pmu_str[256] = "";
|
||||
if (opt_profile > 1) {
|
||||
if (opt_profile == 2) {
|
||||
static_assert(HTP_PROF_PMU_NCNT == 8, "current implementation assumes 8 PMU counters");
|
||||
sprintf(pmu_str, " pmu [%u,%u,%u,%u,%u,%u,%u,%u]",
|
||||
pmu[0], pmu[1], pmu[2], pmu[3], pmu[4], pmu[5], pmu[6], pmu[7]);
|
||||
}
|
||||
|
||||
htp_opformat fmt(node);
|
||||
GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u%s\n", sess_name.c_str(),
|
||||
node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, op_usec, op_cycles, pmu_str);
|
||||
float mhz = op_usec > 0 ? (float) op_cycles / op_usec : 0.0f;
|
||||
GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u start %u mhz %.1f%s\n", sess_name.c_str(),
|
||||
node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, op_usec, op_cycles, pd.cycles_start, mhz, pmu_str);
|
||||
}
|
||||
|
||||
// ** backend sessions
|
||||
@@ -1995,10 +2015,16 @@ struct ggml_hexagon_opqueue {
|
||||
size_t n_ops = batch_size;
|
||||
size_t n_tensors = n_ops + n_ops * HTP_OP_MAX_INPUTS;
|
||||
|
||||
size_t tr_size = 0;
|
||||
if (opt_profile == 3) {
|
||||
tr_size = (HTP_MAX_NTHREADS + 1) * opt_optrace * sizeof(htp_trace_desc);
|
||||
}
|
||||
|
||||
shm_blk_size = sizeof(htp_buf_desc) * n_bufs +
|
||||
sizeof(htp_tensor) * n_tensors +
|
||||
sizeof(htp_op_desc) * n_ops +
|
||||
sizeof(htp_prof_desc) * n_ops;
|
||||
sizeof(htp_prof_desc) * n_ops +
|
||||
tr_size;
|
||||
|
||||
shm_buf = new ggml_hexagon_shared_buffer(sess, shm_blk_size * depth, true /* pinned */);
|
||||
|
||||
@@ -2042,11 +2068,19 @@ struct ggml_hexagon_opqueue {
|
||||
const size_t o_size = sizeof(htp_op_desc) * req.n_ops;
|
||||
const size_t p_size = sizeof(htp_prof_desc) * req.n_ops;
|
||||
|
||||
size_t tr_size = 0;
|
||||
if (opt_profile == 3) {
|
||||
req.n_traces = opt_optrace;
|
||||
tr_size = (HTP_MAX_NTHREADS + 1) * req.n_traces * sizeof(htp_trace_desc);
|
||||
} else {
|
||||
req.n_traces = 0;
|
||||
}
|
||||
|
||||
dbuf.ptr = shm_buf->base + (req.id * shm_blk_size);
|
||||
dbuf.fd = shm_buf->fd;
|
||||
dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT;
|
||||
dbuf.offset = (uint8_t*) dbuf.ptr - (uint8_t*) shm_buf->base;
|
||||
dbuf.size = b_size + t_size + o_size + p_size;
|
||||
dbuf.size = b_size + t_size + o_size + p_size + tr_size;
|
||||
|
||||
GGML_ASSERT(dbuf.size <= shm_blk_size);
|
||||
|
||||
@@ -2092,7 +2126,14 @@ struct ggml_hexagon_opqueue {
|
||||
const size_t o_size = sizeof(htp_op_desc) * rsp.n_ops;
|
||||
const size_t p_size = sizeof(htp_prof_desc) * rsp.n_ops;
|
||||
|
||||
const size_t m_size = b_size + t_size + o_size + p_size;
|
||||
size_t tr_size = 0;
|
||||
uint32_t n_traces = 0;
|
||||
if (opt_profile == 3) {
|
||||
n_traces = opt_optrace;
|
||||
tr_size = (HTP_MAX_NTHREADS + 1) * n_traces * sizeof(htp_trace_desc);
|
||||
}
|
||||
|
||||
const size_t m_size = b_size + t_size + o_size + p_size + tr_size;
|
||||
GGML_ASSERT(m_size <= shm_blk_size);
|
||||
|
||||
HEX_VERBOSE("ggml-hex: %s op-queue pop batch #%u : n-bufs %u n-tensors %u n-ops %u : m-size %zu b-size %zu t-size %zu o-size %zu\n",
|
||||
@@ -2111,13 +2152,62 @@ struct ggml_hexagon_opqueue {
|
||||
GGML_ASSERT(rsp.n_ops <= ops.size());
|
||||
|
||||
const htp_prof_desc * pd = (const htp_prof_desc *) p_ptr;
|
||||
for (uint32_t i = 0; i < rsp.n_ops; i++) {
|
||||
htp_usec += pd[i].usecs;
|
||||
ggml_hexagon_dump_op_prof(shm_buf->sess->name, ops[i], pd[i].usecs, pd[i].cycles, pd[i].pmu);
|
||||
|
||||
const htp_trace_desc * trace_events = nullptr;
|
||||
|
||||
if (opt_profile == 3) {
|
||||
trace_events = (const htp_trace_desc *) (p_ptr + p_size);
|
||||
}
|
||||
|
||||
GGML_LOG_DEBUG("ggml-hex: %s profile-batch n-ops %u batch-dur-usec %lld htp-ops-usec %u\n",
|
||||
shm_buf->sess->c_name(), rsp.n_ops, (long long) batch_usec, htp_usec);
|
||||
uint32_t trace_idx[HTP_MAX_NTHREADS + 1] = {0};
|
||||
uint32_t valid_cnt[HTP_MAX_NTHREADS + 1] = {0};
|
||||
|
||||
if (opt_profile == 3) {
|
||||
for (uint32_t t = 0; t <= HTP_MAX_NTHREADS; t++) {
|
||||
uint32_t count = rsp.n_traces[t];
|
||||
valid_cnt[t] = count > n_traces ? n_traces : count;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < rsp.n_ops; i++) {
|
||||
htp_usec += pd[i].usecs;
|
||||
|
||||
ggml_hexagon_dump_op_prof(shm_buf->sess->name, ops[i], pd[i]);
|
||||
|
||||
if (opt_profile == 3) {
|
||||
uint32_t op_duration = pd[i].cycles_stop - pd[i].cycles_start;
|
||||
|
||||
for (uint32_t t = 0; t <= HTP_MAX_NTHREADS; t++) {
|
||||
while (trace_idx[t] < valid_cnt[t]) {
|
||||
const auto & e = trace_events[t * n_traces + trace_idx[t]];
|
||||
uint32_t offset = e.cycles - pd[i].cycles_start;
|
||||
if (offset >= 0x80000000) {
|
||||
trace_idx[t]++;
|
||||
continue;
|
||||
}
|
||||
if (offset > op_duration) {
|
||||
break;
|
||||
}
|
||||
bool is_stop = (e.info & 0x8000) != 0;
|
||||
uint16_t info = e.info & 0x7FFF;
|
||||
GGML_LOG_DEBUG("ggml-hex: %s trace-op %s: thread %u event %s info %u %s %u\n",
|
||||
shm_buf->sess->c_name(), ops[i].op_name().c_str(), t, htp_event_name(e.id), info, is_stop ? "stop" : "start", e.cycles);
|
||||
trace_idx[t]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
char evt_str[256] = "";
|
||||
if (opt_profile == 3) {
|
||||
sprintf(evt_str, " evt [%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u]",
|
||||
rsp.n_traces[0], rsp.n_traces[1], rsp.n_traces[2], rsp.n_traces[3],
|
||||
rsp.n_traces[4], rsp.n_traces[5], rsp.n_traces[6], rsp.n_traces[7],
|
||||
rsp.n_traces[8], rsp.n_traces[9], rsp.n_traces[10]);
|
||||
}
|
||||
|
||||
GGML_LOG_DEBUG("ggml-hex: %s profile-batch n-ops %u batch-dur-usec %lld htp-ops-usec %u%s\n",
|
||||
shm_buf->sess->c_name(), rsp.n_ops, (long long) batch_usec, htp_usec, evt_str);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -3901,6 +3991,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
||||
const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH");
|
||||
const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE");
|
||||
const char * str_oppoll = getenv("GGML_HEXAGON_OPPOLL");
|
||||
const char * str_optrace = getenv("GGML_HEXAGON_OPTRACE");
|
||||
const char * str_opfilter = getenv("GGML_HEXAGON_OPFILTER");
|
||||
const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
|
||||
const char * str_etm = getenv("GGML_HEXAGON_ETM");
|
||||
@@ -3939,6 +4030,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
||||
opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch;
|
||||
opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue;
|
||||
opt_oppoll = str_oppoll ? strtoul(str_oppoll, NULL, 0) : opt_oppoll;
|
||||
opt_optrace = str_optrace ? strtoul(str_optrace, NULL, 0) : (opt_opbatch * 128);
|
||||
opt_profile = str_profile ? atoi(str_profile) : 0;
|
||||
opt_etm = str_etm ? atoi(str_etm) : 0;
|
||||
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
|
||||
|
||||
@@ -37,8 +37,8 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
|
||||
|
||||
if (_hmx_idx GREATER_EQUAL 0)
|
||||
target_sources(${HTP_LIB} PRIVATE
|
||||
hmx-matmul-ops.c
|
||||
hmx-flash-attn-ops.c
|
||||
hmx-matmul-ops.c
|
||||
hmx-queue.c
|
||||
)
|
||||
|
||||
|
||||
@@ -339,6 +339,9 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
|
||||
|
||||
if (ir0 >= ir1) return;
|
||||
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
|
||||
dma_queue * dma = octx->ctx->dma[ith];
|
||||
|
||||
const uint32_t DK = nek0;
|
||||
@@ -615,6 +618,7 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
|
||||
hvx_copy_f16_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
|
||||
}
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
}
|
||||
|
||||
int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hex-profile.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
@@ -88,6 +90,7 @@ typedef struct {
|
||||
uint32_t pop_idx;
|
||||
uint32_t capacity;
|
||||
uint32_t idx_mask;
|
||||
struct htp_thread_trace * trace;
|
||||
} dma_queue;
|
||||
|
||||
dma_queue * dma_queue_create(size_t capacity);
|
||||
@@ -152,6 +155,7 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t
|
||||
q->dptr[q->push_idx] = dptr;
|
||||
|
||||
if (size) {
|
||||
htp_trace_event_start(q->trace, HTP_TRACE_EVT_DMA, q->push_idx);
|
||||
dmlink(q->tail, desc);
|
||||
q->tail = (dma_descriptor_2d *) desc;
|
||||
} else {
|
||||
@@ -202,6 +206,7 @@ static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t
|
||||
q->dptr[q->push_idx] = dptr;
|
||||
|
||||
if (nrows) {
|
||||
htp_trace_event_start(q->trace, HTP_TRACE_EVT_DMA, q->push_idx);
|
||||
dmlink(q->tail, desc);
|
||||
q->tail = desc;
|
||||
} else {
|
||||
@@ -223,10 +228,12 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) {
|
||||
dma_descriptor_2d * desc = &q->desc[q->pop_idx];
|
||||
|
||||
// Wait for desc to complete
|
||||
while (!desc->done) {
|
||||
// FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx);
|
||||
dmpoll();
|
||||
if (!desc->done) {
|
||||
while (!desc->done) {
|
||||
dmpoll();
|
||||
}
|
||||
}
|
||||
htp_trace_event_stop(q->trace, HTP_TRACE_EVT_DMA, q->pop_idx);
|
||||
|
||||
dptr = q->dptr[q->pop_idx];
|
||||
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
#ifndef HEX_PROFILE_H
|
||||
#define HEX_PROFILE_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <qurt.h>
|
||||
|
||||
#include "hex-utils.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
#define HTP_TRACE_EVT_START 0
|
||||
#define HTP_TRACE_EVT_STOP 1
|
||||
|
||||
#ifndef HEX_NUM_PMU_COUNTERS
|
||||
#define HEX_NUM_PMU_COUNTERS 8
|
||||
#endif
|
||||
|
||||
static inline void hex_get_pmu(uint32_t counters[]) {
|
||||
#if __HVX_ARCH__ >= 79
|
||||
asm volatile("%0 = upmucnt0" : "=r"(counters[0]));
|
||||
asm volatile("%0 = upmucnt1" : "=r"(counters[1]));
|
||||
asm volatile("%0 = upmucnt2" : "=r"(counters[2]));
|
||||
asm volatile("%0 = upmucnt3" : "=r"(counters[3]));
|
||||
asm volatile("%0 = upmucnt4" : "=r"(counters[4]));
|
||||
asm volatile("%0 = upmucnt5" : "=r"(counters[5]));
|
||||
asm volatile("%0 = upmucnt6" : "=r"(counters[6]));
|
||||
asm volatile("%0 = upmucnt7" : "=r"(counters[7]));
|
||||
#else
|
||||
counters[0] = qurt_pmu_get(QURT_PMUCNT0);
|
||||
counters[1] = qurt_pmu_get(QURT_PMUCNT1);
|
||||
counters[2] = qurt_pmu_get(QURT_PMUCNT2);
|
||||
counters[3] = qurt_pmu_get(QURT_PMUCNT3);
|
||||
counters[4] = qurt_pmu_get(QURT_PMUCNT4);
|
||||
counters[5] = qurt_pmu_get(QURT_PMUCNT5);
|
||||
counters[6] = qurt_pmu_get(QURT_PMUCNT6);
|
||||
counters[7] = qurt_pmu_get(QURT_PMUCNT7);
|
||||
#endif
|
||||
}
|
||||
|
||||
struct htp_thread_trace {
|
||||
uint32_t count;
|
||||
uint32_t max_events;
|
||||
struct htp_trace_desc * events;
|
||||
};
|
||||
|
||||
static inline void htp_trace_event(struct htp_thread_trace * tr, uint16_t id, uint16_t info, uint32_t type) {
|
||||
if (tr && tr->events && tr->count < tr->max_events) {
|
||||
uint32_t idx = tr->count;
|
||||
tr->events[idx].id = id;
|
||||
tr->events[idx].info = info | (type == HTP_TRACE_EVT_STOP ? 0x8000 : 0);
|
||||
tr->events[idx].cycles = (uint32_t) hex_get_cycles();
|
||||
tr->count++;
|
||||
}
|
||||
}
|
||||
|
||||
static inline void htp_trace_event_start(struct htp_thread_trace * tr, uint16_t id, uint16_t info) {
|
||||
htp_trace_event(tr, id, info, HTP_TRACE_EVT_START);
|
||||
}
|
||||
|
||||
static inline void htp_trace_event_stop(struct htp_thread_trace * tr, uint16_t id, uint16_t info) {
|
||||
htp_trace_event(tr, id, info, HTP_TRACE_EVT_STOP);
|
||||
}
|
||||
|
||||
#endif /* HEX_PROFILE_H */
|
||||
@@ -107,31 +107,4 @@ static inline void hex_pause() {
|
||||
asm volatile(" pause(#255)\n");
|
||||
}
|
||||
|
||||
#ifndef HEX_NUM_PMU_COUNTERS
|
||||
#define HEX_NUM_PMU_COUNTERS 8
|
||||
#endif
|
||||
|
||||
static inline void hex_get_pmu(uint32_t counters[]) {
|
||||
#if __HVX_ARCH__ >= 79
|
||||
asm volatile("%0 = upmucnt0" : "=r"(counters[0]));
|
||||
asm volatile("%0 = upmucnt1" : "=r"(counters[1]));
|
||||
asm volatile("%0 = upmucnt2" : "=r"(counters[2]));
|
||||
asm volatile("%0 = upmucnt3" : "=r"(counters[3]));
|
||||
asm volatile("%0 = upmucnt4" : "=r"(counters[4]));
|
||||
asm volatile("%0 = upmucnt5" : "=r"(counters[5]));
|
||||
asm volatile("%0 = upmucnt6" : "=r"(counters[6]));
|
||||
asm volatile("%0 = upmucnt7" : "=r"(counters[7]));
|
||||
#else
|
||||
counters[0] = qurt_pmu_get(QURT_PMUCNT0);
|
||||
counters[1] = qurt_pmu_get(QURT_PMUCNT1);
|
||||
counters[2] = qurt_pmu_get(QURT_PMUCNT2);
|
||||
counters[3] = qurt_pmu_get(QURT_PMUCNT3);
|
||||
counters[4] = qurt_pmu_get(QURT_PMUCNT4);
|
||||
counters[5] = qurt_pmu_get(QURT_PMUCNT5);
|
||||
counters[6] = qurt_pmu_get(QURT_PMUCNT6);
|
||||
counters[7] = qurt_pmu_get(QURT_PMUCNT7);
|
||||
// qurt_pmu_get_pmucnt(counters);
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif /* HEX_UTILS_H */
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
#include "ggml-common.h"
|
||||
#include "hex-dma.h"
|
||||
#include "hex-fastdiv.h"
|
||||
#include "hmx-profile.h"
|
||||
#include "hex-profile.h"
|
||||
#include "hmx-queue.h"
|
||||
#include "hmx-utils.h"
|
||||
#include "htp-ctx.h"
|
||||
@@ -367,8 +367,11 @@ static void fa_k_interleave_thread(unsigned int n, unsigned int i, void * data)
|
||||
return;
|
||||
}
|
||||
|
||||
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
|
||||
hmx_interleave_rows_to_tiles(factx->vtcm_k_tiles, factx->vtcm_k_fp16[args->buf_idx], total_rows, (int) factx->DK,
|
||||
(int) args->src_stride, start, end);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, start);
|
||||
}
|
||||
|
||||
static void fa_phase_k_interleave(struct hmx_fa_context * factx, int kv_rows, size_t src_stride, size_t buf_idx) {
|
||||
@@ -408,8 +411,11 @@ static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data)
|
||||
|
||||
__fp16 * v_tiles_dest = factx->use_pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0];
|
||||
|
||||
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
|
||||
hmx_interleave_cols_to_tiles(v_tiles_dest, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV,
|
||||
(int) args->src_stride, (int) args->n_col_tiles, start, end);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, start);
|
||||
}
|
||||
|
||||
static void fa_phase_v_interleave(struct hmx_fa_context * factx,
|
||||
@@ -462,6 +468,9 @@ static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) {
|
||||
return;
|
||||
}
|
||||
|
||||
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
|
||||
|
||||
const struct htp_tensor * q = args->q;
|
||||
const uint32_t q_start = args->q_start;
|
||||
const uint32_t kv_head = args->kv_head;
|
||||
@@ -515,6 +524,7 @@ static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) {
|
||||
}
|
||||
}
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, start);
|
||||
}
|
||||
|
||||
static void fa_phase_q_load(struct hmx_fa_context * factx,
|
||||
@@ -566,6 +576,9 @@ static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) {
|
||||
return;
|
||||
}
|
||||
|
||||
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
|
||||
|
||||
const struct htp_tensor * dst = args->dst;
|
||||
const __fp16 * o_tile_src = args->o_tile_src;
|
||||
const uint32_t q_start = args->q_start;
|
||||
@@ -611,6 +624,7 @@ static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) {
|
||||
}
|
||||
}
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, start);
|
||||
}
|
||||
|
||||
static void fa_phase_o_store(struct hmx_fa_context * factx,
|
||||
@@ -680,6 +694,9 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
|
||||
return;
|
||||
}
|
||||
|
||||
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, vec_start);
|
||||
|
||||
// Per-thread row scratch: thread i uses bufs at offset i * 2 * stride
|
||||
const size_t row_buf_stride = factx->row_buf_stride;
|
||||
HVX_Vector * my_row_buf0 = factx->vtcm_row_bufs + i * 2 * row_buf_stride;
|
||||
@@ -950,6 +967,7 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
|
||||
factx->vtcm_s_rowmax[r_vec_idx] = rowmax_acc_v;
|
||||
factx->vtcm_p_rowsum[r_vec_idx] = rowsum_acc_v;
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, vec_start);
|
||||
}
|
||||
|
||||
// Serial m/l update + build_D. Must run after softmax barrier (s_rowmax written by all threads).
|
||||
@@ -1245,6 +1263,7 @@ static __attribute__((noinline)) void fa_compute_slopes(
|
||||
// ============================================================================
|
||||
|
||||
int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[HTP_MAX_NTHREADS] : NULL;
|
||||
const struct htp_tensor * q = octx->src[0];
|
||||
const struct htp_tensor * k = octx->src[1];
|
||||
const struct htp_tensor * v = octx->src[2];
|
||||
@@ -1422,19 +1441,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
// Profiling timers
|
||||
TIMER_DEFINE(total);
|
||||
TIMER_DEFINE(q_load);
|
||||
TIMER_DEFINE(kv_dma);
|
||||
TIMER_DEFINE(k_interleave);
|
||||
TIMER_DEFINE(v_interleave);
|
||||
TIMER_DEFINE(qk_dot);
|
||||
TIMER_DEFINE(softmax);
|
||||
TIMER_DEFINE(o_update);
|
||||
TIMER_DEFINE(o_norm);
|
||||
TIMER_DEFINE(o_store);
|
||||
|
||||
TIMER_START(total);
|
||||
|
||||
// ======== DMA setup ========
|
||||
dma_queue * const dma = ctx->dma[0];
|
||||
@@ -1474,12 +1480,10 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
const size_t n_row_tiles = g_br_actual / HMX_FP16_TILE_N_ROWS;
|
||||
|
||||
// ---- Load Q block [g_br, D] -> tiles, interleaving G heads ----
|
||||
TIMER_START(q_load);
|
||||
if (n_rows_g < g_br) {
|
||||
hvx_splat_u8_a(factx.vtcm_q_tiles, 0, q_tile_bytes);
|
||||
}
|
||||
fa_phase_q_load(&factx, q, q_start, kv_head, ib3, n_rows_g);
|
||||
TIMER_STOP(q_load);
|
||||
|
||||
// ---- Initialize per-block state ----
|
||||
hvx_splat_u8_a(factx.vtcm_l_vec, 0, col_vec_bytes);
|
||||
@@ -1558,10 +1562,8 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS);
|
||||
|
||||
// Wait for current KV DMA
|
||||
TIMER_START(kv_dma);
|
||||
dma_queue_pop(dma); // K
|
||||
dma_queue_pop(dma); // V
|
||||
TIMER_STOP(kv_dma);
|
||||
|
||||
// Push mask DMA for this block (single 2D DMA when broadcast)
|
||||
bool has_mask_dma = false;
|
||||
@@ -1583,10 +1585,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
ou_job.DV = DV;
|
||||
hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job));
|
||||
}
|
||||
|
||||
TIMER_START(k_interleave);
|
||||
fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx);
|
||||
TIMER_STOP(k_interleave);
|
||||
|
||||
// ---- Phase 2: qk_dot(blk) on HMX ‖ V_int(blk) + DMA prefetch on HVX ----
|
||||
qk_job.q_tiles = factx.vtcm_q_tiles;
|
||||
@@ -1597,15 +1596,11 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
qk_job.n_dot_tiles = DK / 32;
|
||||
qk_job.n_tiles_per_bc = n_tiles_per_bc;
|
||||
qk_job.hmx_scales = factx.vtcm_hmx_scales_qk;
|
||||
TIMER_START(qk_dot);
|
||||
hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_qk_dot_worker, &qk_job));
|
||||
|
||||
// DMA push next block (non-blocking, before worker_pool)
|
||||
DMA_PREFETCH_KV(kv_blk + 1);
|
||||
|
||||
TIMER_START(v_interleave);
|
||||
fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc);
|
||||
TIMER_STOP(v_interleave);
|
||||
|
||||
// Pop and swap previous block's output update (deferred HMX pop)
|
||||
if (kv_blk > 0) {
|
||||
@@ -1615,7 +1610,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
|
||||
// Pop current block's dot product job
|
||||
hmx_queue_pop(hmx_q);
|
||||
TIMER_STOP(qk_dot);
|
||||
|
||||
// ---- Phase 3: softmax(blk) + build_D(blk) | HMX idle ----
|
||||
// Pop mask DMA before softmax (ensures VTCM buffer is ready)
|
||||
@@ -1641,10 +1635,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL;
|
||||
sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride;
|
||||
sargs.slopes = factx.vtcm_slopes;
|
||||
|
||||
TIMER_START(softmax);
|
||||
fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br);
|
||||
TIMER_STOP(softmax);
|
||||
|
||||
buf_idx = 1 - buf_idx;
|
||||
} // end KV block loop (pipeline)
|
||||
@@ -1664,11 +1655,8 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
ou_job.n_row_tiles_g_br = n_row_tiles_g_br;
|
||||
ou_job.n_tiles_per_bc = n_tiles_per_bc;
|
||||
ou_job.DV = DV;
|
||||
|
||||
TIMER_START(o_update);
|
||||
hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job));
|
||||
hmx_queue_pop(hmx_q);
|
||||
TIMER_STOP(o_update);
|
||||
|
||||
hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev);
|
||||
}
|
||||
@@ -1683,23 +1671,14 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
const uint32_t kv_start = kv_blk * Bc;
|
||||
const uint32_t kv_rows = hex_smin(Bc, nek1 - kv_start);
|
||||
const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS);
|
||||
|
||||
TIMER_START(kv_dma);
|
||||
dma_queue_pop(dma); // K
|
||||
dma_queue_pop(dma); // V
|
||||
TIMER_STOP(kv_dma);
|
||||
|
||||
bool has_mask_dma = false;
|
||||
MASK_DMA_PUSH(kv_start, kv_rows, has_mask_dma);
|
||||
DMA_PREFETCH_KV(kv_blk + 1);
|
||||
|
||||
// K interleave (multi-thread HVX)
|
||||
TIMER_START(k_interleave);
|
||||
fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx);
|
||||
TIMER_STOP(k_interleave);
|
||||
|
||||
// QK dot (inline HMX on main thread)
|
||||
TIMER_START(qk_dot);
|
||||
{
|
||||
const size_t n_dot_tiles = (size_t) (DK / 32);
|
||||
const __fp16 * restrict q_base = factx.vtcm_q_tiles;
|
||||
@@ -1709,6 +1688,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
__builtin_assume(n_col_tiles > 0);
|
||||
__builtin_assume(n_dot_tiles > 0);
|
||||
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_qk);
|
||||
for (size_t r = 0; r < n_row_tiles; ++r) {
|
||||
for (size_t c = 0; c < n_col_tiles; ++c) {
|
||||
@@ -1724,8 +1704,8 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
Q6_mxmem_AR_after_hf(out_tile, 0);
|
||||
}
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
}
|
||||
TIMER_STOP(qk_dot);
|
||||
|
||||
// Pop mask DMA
|
||||
MASK_DMA_POP(has_mask_dma);
|
||||
@@ -1751,21 +1731,9 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL;
|
||||
sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride;
|
||||
sargs.slopes = factx.vtcm_slopes;
|
||||
|
||||
TIMER_START(softmax);
|
||||
fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br);
|
||||
TIMER_STOP(softmax);
|
||||
|
||||
// V interleave (multi-thread HVX)
|
||||
TIMER_START(v_interleave);
|
||||
// FIX(v-stride): use n_tiles_per_bc (block-invariant) as V tile layout
|
||||
// stride to match o_update's v_tile access. Using per-block n_col_tiles
|
||||
// misplaces DV_tile 1..3 in the last partial KV block.
|
||||
fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc);
|
||||
TIMER_STOP(v_interleave);
|
||||
|
||||
// O update (inline HMX on main thread)
|
||||
TIMER_START(o_update);
|
||||
{
|
||||
const size_t DV_tiles = (size_t) (DV / 32);
|
||||
const __fp16 * restrict d_base = factx.vtcm_d_tiles;
|
||||
@@ -1777,6 +1745,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
__builtin_assume(n_col_tiles > 0);
|
||||
__builtin_assume(DV_tiles > 0);
|
||||
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id);
|
||||
for (size_t r = 0; r < n_row_tiles; ++r) {
|
||||
for (size_t c = 0; c < DV_tiles; ++c) {
|
||||
@@ -1798,16 +1767,15 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
Q6_mxmem_AR_after_hf(o_tile_out, 0);
|
||||
}
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev);
|
||||
}
|
||||
TIMER_STOP(o_update);
|
||||
|
||||
buf_idx = 1 - buf_idx;
|
||||
} // end KV block loop (fallback)
|
||||
}
|
||||
|
||||
// ---- Final normalization: O = diag(1/l) @ O ----
|
||||
TIMER_START(o_norm);
|
||||
{
|
||||
fa_build_d_diag_inv_l(&factx, n_row_tiles, n_row_tiles_g_br);
|
||||
|
||||
@@ -1830,6 +1798,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
__builtin_assume(n_row_tiles > 0);
|
||||
__builtin_assume(DV_tiles > 0);
|
||||
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id);
|
||||
for (size_t r = 0; r < n_row_tiles; ++r) {
|
||||
for (size_t c = 0; c < DV_tiles; ++c) {
|
||||
@@ -1842,14 +1811,12 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
Q6_mxmem_AR_after_hf(o_out, 0);
|
||||
}
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
}
|
||||
}
|
||||
TIMER_STOP(o_norm);
|
||||
|
||||
// ---- Store O block ----
|
||||
TIMER_START(o_store);
|
||||
fa_phase_o_store(&factx, dst, o_tile_curr, q_start, kv_head, ib3, n_rows_g);
|
||||
TIMER_STOP(o_store);
|
||||
|
||||
#undef MASK_DMA_PUSH
|
||||
#undef MASK_DMA_POP
|
||||
@@ -1865,14 +1832,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
|
||||
}
|
||||
|
||||
TIMER_STOP(total);
|
||||
|
||||
#if defined(ENABLE_PROFILE_TIMERS)
|
||||
FARF(HIGH, "hmx-fa: %lld us, q_load=%lld kv_dma=%lld k_interleave=%lld v_interleave=%lld", TIMER_US(total),
|
||||
TIMER_US(q_load), TIMER_US(kv_dma), TIMER_US(k_interleave), TIMER_US(v_interleave));
|
||||
FARF(HIGH, " qk_dot=%lld softmax=%lld o_update=%lld o_norm=%lld o_store=%lld", TIMER_US(qk_dot), TIMER_US(softmax),
|
||||
TIMER_US(o_update), TIMER_US(o_norm), TIMER_US(o_store));
|
||||
#endif
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@
|
||||
#include "hmx-ops.h"
|
||||
#include "hmx-utils.h"
|
||||
#include "hmx-queue.h"
|
||||
#include "hmx-profile.h"
|
||||
#include "hex-profile.h"
|
||||
|
||||
#include "vtcm-utils.h"
|
||||
|
||||
@@ -430,6 +430,7 @@ typedef struct {
|
||||
int n_tasks;
|
||||
int n_k_tiles;
|
||||
struct fastdiv_values n_k_tiles_div;
|
||||
struct htp_thread_trace * traces;
|
||||
} x4x2_dequantize_state_t;
|
||||
|
||||
// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16.
|
||||
@@ -533,11 +534,14 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task_##suffix(
|
||||
\
|
||||
static void dequantize_x4x2_worker_loop_##suffix(unsigned int n, unsigned int i, void *data) { \
|
||||
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; \
|
||||
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; \
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \
|
||||
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { \
|
||||
int start = task_id * state->n_tiles_per_task; \
|
||||
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); \
|
||||
dequantize_x4x2_weight_to_fp16_tiles_task_##suffix(state, start, end); \
|
||||
} \
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \
|
||||
}
|
||||
|
||||
DEFINE_DEQUANTIZE_Q4_TASK(q4_0, q4_0_to_fp16_lut, q4_0, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16))
|
||||
@@ -657,11 +661,14 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(
|
||||
|
||||
static void dequantize_x4x2_worker_loop_mxfp4(unsigned int n, unsigned int i, void *data) {
|
||||
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
|
||||
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
|
||||
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
|
||||
int start = task_id * state->n_tiles_per_task;
|
||||
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
|
||||
dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(state, start, end);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
|
||||
}
|
||||
|
||||
static void dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(
|
||||
@@ -717,11 +724,14 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(
|
||||
|
||||
static void dequantize_x4x2_worker_loop_q8_0(unsigned int n, unsigned int i, void *data) {
|
||||
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
|
||||
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
|
||||
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
|
||||
int start = task_id * state->n_tiles_per_task;
|
||||
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
|
||||
dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(state, start, end);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
|
||||
}
|
||||
|
||||
static void convert_f16_weight_to_fp16_tiles_task(
|
||||
@@ -773,11 +783,14 @@ static void convert_f16_weight_to_fp16_tiles_task(
|
||||
|
||||
static void convert_f16_worker_loop(unsigned int n, unsigned int i, void *data) {
|
||||
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
|
||||
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
|
||||
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
|
||||
int start = task_id * state->n_tiles_per_task;
|
||||
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
|
||||
convert_f16_weight_to_fp16_tiles_task(state, start, end);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
|
||||
}
|
||||
|
||||
static void quantize_f32_weight_to_fp16_tiles_task(
|
||||
@@ -833,11 +846,14 @@ static void quantize_f32_weight_to_fp16_tiles_task(
|
||||
|
||||
static void quantize_f32_worker_loop(unsigned int n, unsigned int i, void *data) {
|
||||
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
|
||||
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
|
||||
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
|
||||
int start = task_id * state->n_tiles_per_task;
|
||||
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
|
||||
quantize_f32_weight_to_fp16_tiles_task(state, start, end);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
|
||||
}
|
||||
|
||||
|
||||
@@ -868,6 +884,7 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
|
||||
state.weight_type = weight_type;
|
||||
state.n_k_tiles = n_k_tiles;
|
||||
state.n_k_tiles_div = n_k_tiles_div;
|
||||
state.traces = ctx ? ctx->trace : NULL;
|
||||
|
||||
if (state.n_tasks == 1 || n_threads == 1) {
|
||||
dequant_worker_fn(1, 0, &state);
|
||||
@@ -985,10 +1002,13 @@ typedef struct {
|
||||
int n_chunks_per_task;
|
||||
int n_cols;
|
||||
int n; // DDR row stride (total output columns)
|
||||
struct htp_thread_trace * traces;
|
||||
} output_transfer_task_state_t;
|
||||
|
||||
static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void *data) {
|
||||
output_transfer_task_state_t *st = (output_transfer_task_state_t *) data;
|
||||
struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, i);
|
||||
|
||||
for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) {
|
||||
int chunk_idx = task_id * st->n_chunks_per_task;
|
||||
@@ -998,6 +1018,7 @@ static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void
|
||||
const __fp16 *vtcm_src = st->vtcm_src + chunk_idx * st->n_cols;
|
||||
transfer_output_chunk_fp16_to_fp32(dst, vtcm_src, chunk_size, st->n_cols, st->n);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, i);
|
||||
}
|
||||
|
||||
static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src,
|
||||
@@ -1015,6 +1036,7 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst,
|
||||
state.vtcm_src = vtcm_src;
|
||||
state.n_cols = n_cols;
|
||||
state.n = n;
|
||||
state.traces = ctx ? ctx->trace : NULL;
|
||||
|
||||
if (state.n_tasks == 1 || n_threads == 1) {
|
||||
transfer_output_chunk_worker_fn(1, 0, &state);
|
||||
@@ -1086,10 +1108,13 @@ typedef struct {
|
||||
int n_chunks_per_task;
|
||||
int k_block;
|
||||
int k_stride;
|
||||
struct htp_thread_trace * traces;
|
||||
} activation_transfer_task_state_t;
|
||||
|
||||
static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) {
|
||||
activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data;
|
||||
struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, i);
|
||||
|
||||
for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) {
|
||||
// one chunk: one row
|
||||
@@ -1100,6 +1125,7 @@ static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i,
|
||||
const float *src = st->src + chunk_idx * st->k_stride;
|
||||
transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, i);
|
||||
}
|
||||
|
||||
static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride, int n_threads) {
|
||||
@@ -1117,6 +1143,7 @@ static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *
|
||||
state.src = src;
|
||||
state.k_block = k_block;
|
||||
state.k_stride = k_stride;
|
||||
state.traces = ctx ? ctx->trace : NULL;
|
||||
|
||||
if (state.n_tasks == 1 || n_threads == 1) {
|
||||
transfer_activation_chunk_worker_fn(1, 0, &state);
|
||||
@@ -1245,13 +1272,7 @@ int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float
|
||||
FARF(HIGH, "hmx-mm-2d: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu",
|
||||
m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget);
|
||||
|
||||
TIMER_DEFINE(activation_load);
|
||||
TIMER_DEFINE(weight_load);
|
||||
TIMER_DEFINE(hmx_core);
|
||||
TIMER_DEFINE(output_store);
|
||||
|
||||
TIMER_DEFINE(total);
|
||||
TIMER_START(total);
|
||||
|
||||
int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols);
|
||||
|
||||
@@ -1370,7 +1391,12 @@ int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads);
|
||||
|
||||
// C: HMX Compute (Synchronous)
|
||||
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS);
|
||||
{
|
||||
struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
}
|
||||
|
||||
// D: Output Store
|
||||
float *output_chunk = dst + (mr * n + nc);
|
||||
@@ -1380,18 +1406,7 @@ int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float
|
||||
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
|
||||
}
|
||||
|
||||
TIMER_STOP(total);
|
||||
|
||||
#if defined(ENABLE_PROFILE_TIMERS)
|
||||
FARF(HIGH, "hex-mm-2d: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n);
|
||||
if (!use_pipeline) {
|
||||
FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us",
|
||||
TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store));
|
||||
size_t weight_size = (size_t)n * row_stride;
|
||||
float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load);
|
||||
FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth);
|
||||
}
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1523,13 +1538,7 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
|
||||
m_chunk_n_rows, n_chunk_n_cols,
|
||||
(size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget);
|
||||
|
||||
TIMER_DEFINE(activation_load);
|
||||
TIMER_DEFINE(weight_load);
|
||||
TIMER_DEFINE(hmx_core);
|
||||
TIMER_DEFINE(output_store);
|
||||
TIMER_DEFINE(total);
|
||||
|
||||
TIMER_START(total);
|
||||
|
||||
const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16);
|
||||
const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16);
|
||||
@@ -1549,7 +1558,6 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
|
||||
// contiguous rows into a VTCM scratch buffer first, then HVX
|
||||
// converts from the contiguous VTCM buffer. This avoids L2 cache
|
||||
// thrashing from HVX loads at large strides.
|
||||
TIMER_START(activation_load);
|
||||
for (int g = 0; g < group_size; ++g) {
|
||||
const float *activation_chunk = hmx_matmul_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride;
|
||||
__fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride;
|
||||
@@ -1569,7 +1577,6 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
|
||||
params->k, params->act_stride, ctx->n_threads);
|
||||
}
|
||||
}
|
||||
TIMER_STOP(activation_load);
|
||||
|
||||
void *buf_curr = vtcm_scratch0;
|
||||
void *buf_next = vtcm_scratch1;
|
||||
@@ -1584,7 +1591,6 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
|
||||
const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols);
|
||||
const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS);
|
||||
|
||||
TIMER_START(weight_load);
|
||||
{
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
|
||||
@@ -1601,24 +1607,22 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
|
||||
0, n_cols);
|
||||
hex_swap_ptr(&buf_curr, &buf_next);
|
||||
}
|
||||
TIMER_STOP(weight_load);
|
||||
|
||||
// Reuse the interleaved weight for every q_head in this GQA group
|
||||
for (int g = 0; g < group_size; ++g) {
|
||||
TIMER_START(hmx_core);
|
||||
{
|
||||
const __fp16 * vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride;
|
||||
struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles,
|
||||
params->k / 32);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
}
|
||||
TIMER_STOP(hmx_core);
|
||||
|
||||
TIMER_START(output_store);
|
||||
{
|
||||
float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc;
|
||||
transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride, ctx->n_threads);
|
||||
}
|
||||
TIMER_STOP(output_store);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1627,14 +1631,7 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
|
||||
|
||||
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
|
||||
|
||||
TIMER_STOP(total);
|
||||
|
||||
#if defined(ENABLE_PROFILE_TIMERS)
|
||||
FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d group=%d", __func__, TIMER_US(total),
|
||||
params->m, params->k, params->n, group_size);
|
||||
FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us",
|
||||
TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store));
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1668,6 +1665,7 @@ typedef struct {
|
||||
size_t nb12;
|
||||
int start_row;
|
||||
int cne1;
|
||||
struct htp_thread_trace *traces;
|
||||
} activation_transfer_gathered_task_state_t;
|
||||
|
||||
typedef struct {
|
||||
@@ -1684,6 +1682,7 @@ typedef struct {
|
||||
size_t dst_nb2;
|
||||
int start_row;
|
||||
int cne1;
|
||||
struct htp_thread_trace *traces;
|
||||
} output_transfer_scattered_task_state_t;
|
||||
|
||||
static void transfer_activation_chunk_fp32_to_fp16_gathered(
|
||||
@@ -1780,6 +1779,9 @@ static void transfer_activation_chunk_fp32_to_fp16_gathered(
|
||||
|
||||
static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigned int i, void *data) {
|
||||
activation_transfer_gathered_task_state_t *st = data;
|
||||
struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, i);
|
||||
|
||||
int chunk_idx = i;
|
||||
int chunk_size = st->n_chunks_per_task;
|
||||
int start_row = st->start_row + chunk_idx * chunk_size;
|
||||
@@ -1791,6 +1793,7 @@ static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigne
|
||||
st->matrix_rows, st->cur_a, st->mapping_stride,
|
||||
st->ne11, &st->ne11_div, st->nb11, st->nb12, st->cne1);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, i);
|
||||
}
|
||||
|
||||
static void transfer_activation_chunk_gathered_threaded(
|
||||
@@ -1830,6 +1833,7 @@ static void transfer_activation_chunk_gathered_threaded(
|
||||
.nb12 = nb12,
|
||||
.start_row = start_row,
|
||||
.cne1 = cne1,
|
||||
.traces = ctx ? ctx->trace : NULL,
|
||||
};
|
||||
|
||||
if (actual_threads <= 1) {
|
||||
@@ -1895,6 +1899,9 @@ static void transfer_output_chunk_fp16_to_fp32_scattered(
|
||||
|
||||
static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned int i, void *data) {
|
||||
output_transfer_scattered_task_state_t *st = data;
|
||||
struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, i);
|
||||
|
||||
int chunk_idx = i;
|
||||
int chunk_size = st->n_chunks_per_task;
|
||||
int start_row = st->start_row + chunk_idx * chunk_size;
|
||||
@@ -1906,6 +1913,7 @@ static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned i
|
||||
st->matrix_rows, st->cur_a, st->mapping_stride,
|
||||
st->dst_nb1, st->dst_nb2, st->cne1);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, i);
|
||||
}
|
||||
|
||||
static void transfer_output_chunk_scattered_threaded(
|
||||
@@ -1942,6 +1950,7 @@ static void transfer_output_chunk_scattered_threaded(
|
||||
.dst_nb2 = dst_nb2,
|
||||
.start_row = start_row,
|
||||
.cne1 = cne1,
|
||||
.traces = ctx ? ctx->trace : NULL,
|
||||
};
|
||||
|
||||
if (actual_threads <= 1) {
|
||||
@@ -2053,7 +2062,12 @@ int hmx_matmul_id_2d_f32(struct htp_context *ctx,
|
||||
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads);
|
||||
|
||||
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS);
|
||||
{
|
||||
struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
}
|
||||
|
||||
transfer_output_chunk_scattered_threaded(
|
||||
ctx, dst, vtcm_output, (int) mr, (int) n_rows, (int) n_cols,
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
// Conditional fine-grained profiling macros for HMX operations.
|
||||
//
|
||||
// Define ENABLE_PROFILE_TIMERS (via compiler flag or before including this
|
||||
// header) to instrument sub-operation latencies with HAP qtimer. When the
|
||||
// macro is not defined the TIMER_* helpers expand to nothing so there is zero
|
||||
// overhead.
|
||||
//
|
||||
// Usage:
|
||||
// TIMER_DEFINE(my_phase); // declare accumulator variable
|
||||
// TIMER_START(my_phase); // snapshot start time
|
||||
// ... work ...
|
||||
// TIMER_STOP(my_phase); // accumulate elapsed ticks
|
||||
// FARF(ALWAYS, "my_phase: %lld us", TIMER_US(my_phase));
|
||||
|
||||
#ifndef HMX_PROFILE_H
|
||||
#define HMX_PROFILE_H
|
||||
|
||||
#include <HAP_perf.h>
|
||||
|
||||
// #define ENABLE_PROFILE_TIMERS
|
||||
|
||||
#if defined(ENABLE_PROFILE_TIMERS)
|
||||
# define TIMER_DEFINE(name) int64_t name##_ticks = 0
|
||||
# define TIMER_START(name) int64_t name##_t0 = HAP_perf_get_qtimer_count()
|
||||
# define TIMER_STOP(name) name##_ticks += HAP_perf_get_qtimer_count() - name##_t0
|
||||
# define TIMER_US(name) HAP_perf_qtimer_count_to_us(name##_ticks)
|
||||
#else
|
||||
# define TIMER_DEFINE(name)
|
||||
# define TIMER_START(name)
|
||||
# define TIMER_STOP(name)
|
||||
# define TIMER_US(name) 0LL
|
||||
#endif
|
||||
|
||||
#endif // HMX_PROFILE_H
|
||||
@@ -44,7 +44,9 @@ static inline void hmx_queue_process(struct hmx_queue *q, bool* killed) {
|
||||
case HMX_QUEUE_SUSPEND: hmx_unlock(q); break;
|
||||
default:
|
||||
hmx_lock(q);
|
||||
htp_trace_event_start(q->trace, HTP_TRACE_EVT_HMX_COMP, ir);
|
||||
d->func(d->data);
|
||||
htp_trace_event_stop(q->trace, HTP_TRACE_EVT_HMX_COMP, ir);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <HAP_farf.h>
|
||||
|
||||
#include "hex-utils.h"
|
||||
#include "hex-profile.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
@@ -47,6 +48,7 @@ struct hmx_queue {
|
||||
void * stack;
|
||||
uint32_t hap_rctx;
|
||||
bool hmx_locked;
|
||||
struct htp_thread_trace * trace;
|
||||
};
|
||||
|
||||
struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx);
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "hex-dma.h"
|
||||
#include "hmx-queue.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hex-profile.h"
|
||||
#include "worker-pool.h"
|
||||
|
||||
#include <assert.h>
|
||||
@@ -70,6 +71,7 @@ struct htp_context {
|
||||
bool hmx_enabled;
|
||||
bool etm;
|
||||
uint32_t profiler;
|
||||
struct htp_thread_trace trace[HTP_MAX_NTHREADS + 1];
|
||||
|
||||
uint8_t * vtcm_base;
|
||||
size_t vtcm_size;
|
||||
|
||||
@@ -146,10 +146,36 @@ struct htp_op_desc {
|
||||
uint16_t dst; // Output tensor index
|
||||
};
|
||||
|
||||
#ifndef HTP_MAX_NTHREADS
|
||||
#define HTP_MAX_NTHREADS 10
|
||||
#endif
|
||||
|
||||
#define HTP_TRACE_MAX_EVENTS 256
|
||||
|
||||
enum htp_profiler_mode {
|
||||
HTP_PROF_DISABLED = 0,
|
||||
HTP_PROF_BASIC = 1,
|
||||
HTP_PROF_PMU = 2,
|
||||
HTP_PROF_TRACE = 3,
|
||||
};
|
||||
|
||||
enum htp_trace_event_id {
|
||||
HTP_TRACE_EVT_DMA = 0,
|
||||
|
||||
HTP_TRACE_EVT_HVX_COMP = 20,
|
||||
HTP_TRACE_EVT_HVX_A_QUANT = 21,
|
||||
HTP_TRACE_EVT_HVX_A_PREP = 22,
|
||||
HTP_TRACE_EVT_HVX_W_DEQUANT = 23,
|
||||
HTP_TRACE_EVT_HVX_W_PREP = 24,
|
||||
HTP_TRACE_EVT_HVX_O_PROC = 25,
|
||||
|
||||
HTP_TRACE_EVT_HMX_COMP = 40,
|
||||
};
|
||||
|
||||
struct htp_trace_desc {
|
||||
uint32_t cycles; // lower 32-bits of cycle counter
|
||||
uint16_t id; // Event ID
|
||||
uint16_t info; // bit 15: is_stop. bits 14-0: tile/chunk index or other metadata.
|
||||
};
|
||||
|
||||
#define HTP_PROF_PMU_NCNT 8
|
||||
@@ -158,8 +184,8 @@ enum htp_profiler_mode {
|
||||
struct htp_prof_desc {
|
||||
uint32_t opcode; // GGML/HTP Op
|
||||
uint32_t usecs; // Number of usec
|
||||
uint32_t cycles; // Number of cycles
|
||||
uint32_t pad; // Unused
|
||||
uint32_t cycles_start; // Start cycle counter
|
||||
uint32_t cycles_stop; // Stop cycle counter
|
||||
uint32_t pmu[HTP_PROF_PMU_NCNT]; // PMU counters
|
||||
};
|
||||
|
||||
@@ -168,7 +194,7 @@ struct htp_opbatch_req {
|
||||
uint32_t n_bufs; // Number of buffers
|
||||
uint32_t n_tensors; // Number of tensors
|
||||
uint32_t n_ops; // Number of ops
|
||||
uint32_t flags; // unused
|
||||
uint32_t n_traces; // Number of trace descriptors per thread
|
||||
uint32_t pad; // unused
|
||||
// struct htp_buf_desc bufs[]; -- dspqueue buf 0
|
||||
// struct htp_tensor tensors[]; -- dspqueue buf 0
|
||||
@@ -181,7 +207,8 @@ struct htp_opbatch_rsp {
|
||||
uint32_t n_bufs; // Number of buffers
|
||||
uint32_t n_tensors; // Number of tensors
|
||||
uint32_t n_ops; // Number of op profile descriptors
|
||||
uint32_t pad; // unused
|
||||
uint32_t n_traces[HTP_MAX_NTHREADS + 1];
|
||||
uint8_t pad[8]; // align to 8 bytes
|
||||
// struct htp_prof_desc profs[]; -- dspqueue buf 0
|
||||
};
|
||||
|
||||
|
||||
@@ -400,7 +400,9 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
|
||||
ctx->hmx_queue = NULL;
|
||||
if (use_hmx) {
|
||||
ctx->hmx_queue = hmx_queue_create(16, ctx->vtcm_rctx);
|
||||
if (!ctx->hmx_queue) {
|
||||
if (ctx->hmx_queue) {
|
||||
ctx->hmx_queue->trace = &ctx->trace[HTP_MAX_NTHREADS];
|
||||
} else {
|
||||
FARF(ERROR, "hmx-queue-create failed");
|
||||
ctx->hmx_enabled = false;
|
||||
}
|
||||
@@ -425,6 +427,9 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
|
||||
ctx->n_threads = n_hvx;
|
||||
for (int i = 0; i < ctx->n_threads; i++) {
|
||||
ctx->dma[i] = dma_queue_create(256); // queue depth
|
||||
if (ctx->dma[i]) {
|
||||
ctx->dma[i]->trace = &ctx->trace[i];
|
||||
}
|
||||
}
|
||||
|
||||
ctx->ddr_spad_size = 512 * 1024; // 512 KB
|
||||
@@ -502,7 +507,8 @@ static void htp_error_callback(dspqueue_t queue, int error, void * context) {
|
||||
|
||||
struct profile_data {
|
||||
uint64_t usecs;
|
||||
uint64_t cycles;
|
||||
uint64_t cycles_start;
|
||||
uint64_t cycles_stop;
|
||||
uint32_t pmu_counters[HEX_NUM_PMU_COUNTERS];
|
||||
};
|
||||
|
||||
@@ -512,8 +518,9 @@ static inline void profile_start(uint32_t mode, struct profile_data * d) {
|
||||
hex_get_pmu(d->pmu_counters);
|
||||
// fallthrough
|
||||
case HTP_PROF_BASIC:
|
||||
case HTP_PROF_TRACE:
|
||||
d->usecs = HAP_perf_get_qtimer_count();
|
||||
d->cycles = hex_get_cycles();
|
||||
d->cycles_start = hex_get_cycles();
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
@@ -530,8 +537,9 @@ static inline void profile_stop(uint32_t mode, struct profile_data * d) {
|
||||
}
|
||||
// fallthrough
|
||||
case HTP_PROF_BASIC:
|
||||
case HTP_PROF_TRACE:
|
||||
d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs);
|
||||
d->cycles = hex_get_cycles() - d->cycles;
|
||||
d->cycles_stop = hex_get_cycles();
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
@@ -845,14 +853,15 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
const uint32_t t_size = sizeof(struct htp_tensor) * n_tens;
|
||||
const uint32_t o_size = sizeof(struct htp_op_desc) * n_ops;
|
||||
const uint32_t p_size = sizeof(struct htp_prof_desc) * n_ops;
|
||||
const uint32_t tr_size = (HTP_MAX_NTHREADS + 1) * req.n_traces * sizeof(struct htp_trace_desc);
|
||||
|
||||
if (dbuf.size < b_size + t_size + o_size + p_size) {
|
||||
FARF(ERROR, "invalid opbatch memory block size %u", dbuf.size);
|
||||
if (dbuf.size < b_size + t_size + o_size + p_size + tr_size) {
|
||||
FARF(ERROR, "invalid opbatch memory block size %u (req %u)", dbuf.size, b_size + t_size + o_size + p_size + tr_size);
|
||||
break;
|
||||
}
|
||||
|
||||
FARF(HIGH, "processing opbatch #%u: n-bufs %u n-tensors %u n-ops %u : m-size %u b-size %u t-size %u o-size %u", req.id,
|
||||
n_bufs, n_tens, n_ops, dbuf.size, b_size, t_size, o_size);
|
||||
FARF(HIGH, "processing opbatch #%u: n-bufs %u n-tensors %u n-ops %u n-traces %u : m-size %u b-size %u t-size %u o-size %u", req.id,
|
||||
n_bufs, n_tens, n_ops, req.n_traces, dbuf.size, b_size, t_size, o_size);
|
||||
|
||||
// Setup descriptor pointers
|
||||
uint8_t * m_ptr = dbuf.ptr;
|
||||
@@ -869,6 +878,20 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
octx->n_threads = ctx->n_threads;
|
||||
octx->ctx = ctx;
|
||||
|
||||
if (ctx->profiler == HTP_PROF_TRACE) {
|
||||
memset(ctx->trace, 0, sizeof(ctx->trace));
|
||||
struct htp_trace_desc * trace_events = (struct htp_trace_desc *) (m_ptr + p_size);
|
||||
for (int t = 0; t <= HTP_MAX_NTHREADS; t++) {
|
||||
ctx->trace[t].events = &trace_events[t * req.n_traces];
|
||||
ctx->trace[t].max_events = req.n_traces;
|
||||
}
|
||||
} else {
|
||||
for (int t = 0; t <= HTP_MAX_NTHREADS; t++) {
|
||||
ctx->trace[t].events = NULL;
|
||||
ctx->trace[t].max_events = 0;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i=0; i < n_ops; i++) {
|
||||
struct profile_data prof;
|
||||
|
||||
@@ -886,7 +909,8 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
if (ctx->profiler) {
|
||||
pds[i].opcode = ops[i].opcode;
|
||||
pds[i].usecs = prof.usecs;
|
||||
pds[i].cycles = prof.cycles;
|
||||
pds[i].cycles_start = prof.cycles_start;
|
||||
pds[i].cycles_stop = prof.cycles_stop;
|
||||
for (int j = 0; j < HEX_NUM_PMU_COUNTERS; j++) {
|
||||
pds[i].pmu[j] = prof.pmu_counters[j];
|
||||
}
|
||||
@@ -899,6 +923,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
rsp.n_bufs = n_bufs;
|
||||
rsp.n_tensors = n_tens;
|
||||
rsp.n_ops = n_ops;
|
||||
memset(rsp.pad, 0, sizeof(rsp.pad));
|
||||
if (ctx->profiler == HTP_PROF_TRACE) {
|
||||
for (int t = 0; t <= HTP_MAX_NTHREADS; t++) {
|
||||
rsp.n_traces[t] = ctx->trace[t].count;
|
||||
}
|
||||
} else {
|
||||
memset(rsp.n_traces, 0, sizeof(rsp.n_traces));
|
||||
}
|
||||
|
||||
dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT;
|
||||
|
||||
|
||||
@@ -3350,6 +3350,7 @@ static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void *
|
||||
|
||||
static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
|
||||
htp_matmul_preamble;
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
@@ -3411,10 +3412,12 @@ static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
|
||||
float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
|
||||
|
||||
const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, iir0);
|
||||
for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
|
||||
const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
|
||||
mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, iir0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3430,6 +3433,7 @@ static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
|
||||
// src1 tensor is already in VTCM spad
|
||||
static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
htp_matmul_preamble;
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows
|
||||
@@ -3477,6 +3481,8 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
|
||||
// Process src1 columns in pairs (2×2 tiling)
|
||||
uint32_t ir1 = 0;
|
||||
for (; ir1 + 1 < src1_nrows; ir1 += 2) {
|
||||
@@ -3494,6 +3500,8 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col);
|
||||
}
|
||||
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
|
||||
// Prefetch next (n + spad_nrows) row
|
||||
const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
|
||||
const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
|
||||
@@ -3511,12 +3519,14 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
src0_stride, src0_row_size, 1);
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
#pragma unroll(2)
|
||||
for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
|
||||
const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
|
||||
float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
|
||||
mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
@@ -3530,6 +3540,7 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
// q8x4x2 src1 tensor is already in VTCM spad
|
||||
static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
htp_matmul_preamble;
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
|
||||
const uint32_t src0_nrows = ne01;
|
||||
|
||||
@@ -3581,7 +3592,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
// Process src0 rows
|
||||
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) {
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
mmctx->vec_dot_4x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, ss0 + 2 * src0_stride, ss0 + 3 * src0_stride, src1_col);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
|
||||
// Prefetch next (n + spad_nrows) row
|
||||
const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
|
||||
@@ -3599,7 +3612,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
|
||||
src0_stride, src0_row_size, 2);
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
ir0 += 2;
|
||||
}
|
||||
if (ir0 < src0_end_row) {
|
||||
@@ -3607,7 +3622,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
|
||||
src0_stride, src0_row_size, 1);
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
ir0 += 1;
|
||||
}
|
||||
} else {
|
||||
@@ -3627,7 +3644,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
// Process src0 rows
|
||||
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
|
||||
// Prefetch next (n + spad_nrows) row
|
||||
const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
|
||||
@@ -3645,7 +3664,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
|
||||
src0_stride, src0_row_size, 1);
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3669,6 +3690,7 @@ struct mmid_row_mapping {
|
||||
// src1 tensor is already in VTCM spad
|
||||
static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
htp_matmul_preamble;
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
|
||||
const struct htp_tensor * restrict ids = octx->src[2];
|
||||
struct htp_spad * restrict src2_spad = &octx->src2_spad;
|
||||
@@ -3735,6 +3757,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
for (uint32_t cid = 0; cid < cne1; ++cid) {
|
||||
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
|
||||
const int rm1 = row_mapping.i1; // expert idx
|
||||
@@ -3746,6 +3769,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
|
||||
mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
|
||||
// Prefetch next (n + spad_nrows) row
|
||||
const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
|
||||
@@ -3764,6 +3788,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
src0_row_size_padded, src0_row_size, 1);
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
for (uint32_t cid = 0; cid < cne1; ++cid) {
|
||||
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
|
||||
const int rm1 = row_mapping.i1; // expert idx
|
||||
@@ -3775,6 +3800,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
|
||||
mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3789,6 +3815,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
// src1 tensor is already in VTCM spad
|
||||
static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
htp_matmul_preamble;
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
|
||||
const struct htp_tensor * restrict ids = octx->src[2];
|
||||
struct htp_spad * restrict src2_spad = &octx->src2_spad;
|
||||
@@ -3847,7 +3874,9 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
// Process src0 rows
|
||||
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
|
||||
// Prefetch next (n + spad_nrows) row
|
||||
const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
|
||||
@@ -3865,7 +3894,9 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
|
||||
src0_row_size_padded, src0_row_size, 1);
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4147,6 +4178,7 @@ static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, ui
|
||||
static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_matmul_context * mmctx = data;
|
||||
struct htp_ops_context * octx = mmctx->octx;
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
|
||||
const struct htp_tensor * src = octx->src[1];
|
||||
uint8_t * restrict dst = octx->src1_spad.data;
|
||||
@@ -4163,6 +4195,7 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data)
|
||||
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
|
||||
|
||||
const uint32_t ir_first = nrows_per_thread * ith; // first row
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
|
||||
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
|
||||
|
||||
const size_t src_row_size = src->nb[1];
|
||||
@@ -4189,6 +4222,7 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data)
|
||||
|
||||
FARF(HIGH, "quantize-f32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
|
||||
ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
|
||||
}
|
||||
|
||||
static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) {
|
||||
@@ -4219,6 +4253,7 @@ static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y,
|
||||
static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_matmul_context * mmctx = data;
|
||||
struct htp_ops_context * octx = mmctx->octx;
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
|
||||
const struct htp_tensor * src = octx->src[1];
|
||||
uint8_t * restrict dst = octx->src1_spad.data;
|
||||
@@ -4235,6 +4270,7 @@ static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * dat
|
||||
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
|
||||
|
||||
const uint32_t ir_first = nrows_per_thread * ith; // first row
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
|
||||
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
|
||||
|
||||
const size_t src_row_size = src->nb[1];
|
||||
@@ -4260,11 +4296,13 @@ static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * dat
|
||||
|
||||
FARF(HIGH, "quantize-f32-q8_1x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
|
||||
ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
|
||||
}
|
||||
|
||||
static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_matmul_context * mmctx = data;
|
||||
struct htp_ops_context * octx = mmctx->octx;
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
|
||||
const struct htp_tensor * src = octx->src[1];
|
||||
uint8_t * restrict dst = octx->src1_spad.data;
|
||||
@@ -4281,6 +4319,7 @@ static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
|
||||
|
||||
const uint32_t ir_first = nrows_per_thread * ith; // first row
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
|
||||
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
|
||||
|
||||
const size_t src_row_size = ne0 * sizeof(float);
|
||||
@@ -4301,11 +4340,13 @@ static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
|
||||
FARF(HIGH, "quantize-f32-f32: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
|
||||
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
|
||||
}
|
||||
|
||||
static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_matmul_context * mmctx = data;
|
||||
struct htp_ops_context * octx = mmctx->octx;
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
|
||||
const struct htp_tensor * src = octx->src[1];
|
||||
uint8_t * restrict dst = octx->src1_spad.data;
|
||||
@@ -4322,6 +4363,7 @@ static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
|
||||
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
|
||||
|
||||
const uint32_t ir_first = nrows_per_thread * ith; // first row
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
|
||||
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
|
||||
|
||||
const size_t src_row_size = ne0 * sizeof(float);
|
||||
@@ -4342,12 +4384,14 @@ static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
|
||||
|
||||
FARF(HIGH, "quantize-f32-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
|
||||
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
|
||||
}
|
||||
|
||||
// TODO just a plain copy that should be done via the DMA during the Op setup
|
||||
static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_matmul_context * mmctx = data;
|
||||
struct htp_ops_context * octx = mmctx->octx;
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
|
||||
const struct htp_tensor * src = octx->src[1];
|
||||
uint8_t * restrict dst = octx->src1_spad.data;
|
||||
@@ -4364,6 +4408,7 @@ static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
|
||||
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
|
||||
|
||||
const uint32_t ir_first = nrows_per_thread * ith; // first row
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
|
||||
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
|
||||
|
||||
const size_t src_row_size = ne0 * sizeof(float);
|
||||
@@ -4384,6 +4429,7 @@ static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
|
||||
|
||||
FARF(HIGH, "quantize-f16-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
|
||||
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import re
|
||||
import argparse
|
||||
import statistics
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
@@ -25,12 +26,47 @@ COL_MAP = {
|
||||
}
|
||||
|
||||
op_pattern = re.compile(
|
||||
r"profile-op\s+(?P<op_name>[A-Z_0-9+]+):\s+.*?\s+:\s+(?P<dims>[\d:x\s\->!]+)\s+:\s+(?P<types>[a-z\d_\s\->x]+)\s+:\s+.*?\s+(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?"
|
||||
r"profile-op\s+(?P<op_name>[A-Z_0-9+]+):\s+.*?\s+:\s+(?P<dims>[\d:x\s\->!]+)\s+:\s+(?P<types>[a-z\d_\s\->x]+)\s+:\s+.*?\s+(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?"
|
||||
)
|
||||
|
||||
trace_pattern = re.compile(
|
||||
r"trace-op\s+(?P<op_name>[A-Z_0-9+]+):\s+thread\s+(?P<thread>\d+)\s+event\s+(?P<event>[A-Z_0-9\-]+)\s+info\s+(?P<info>\d+)\s+(?P<state>start|stop)\s+(?P<cycles>\d+)"
|
||||
)
|
||||
|
||||
logger = logging.getLogger("ggml-hexagon-profile")
|
||||
|
||||
|
||||
def normalize_event_name(evt_type):
|
||||
if evt_type == "HVX_COMP":
|
||||
return "V-COMP"
|
||||
if evt_type == "HMX_COMP":
|
||||
return "M-COMP"
|
||||
|
||||
# Strip HVX_ or HMX_ prefixes
|
||||
name = evt_type
|
||||
if name.startswith("HVX_") or name.startswith("HMX_"):
|
||||
name = name[4:]
|
||||
return name.replace("_", "-")
|
||||
|
||||
|
||||
class CycleUnwrapper:
|
||||
def __init__(self):
|
||||
self.last_raw = None
|
||||
self.high_part = 0
|
||||
|
||||
def unwrap(self, raw):
|
||||
if self.last_raw is None:
|
||||
self.last_raw = raw
|
||||
return raw
|
||||
diff = raw - self.last_raw
|
||||
if diff < -0x80000000:
|
||||
self.high_part += 0x100000000
|
||||
elif diff > 0x80000000:
|
||||
self.high_part -= 0x100000000
|
||||
self.last_raw = raw
|
||||
return raw + self.high_part
|
||||
|
||||
|
||||
def parse_log(file_path, pmu_index=None):
|
||||
try:
|
||||
if file_path != "-":
|
||||
@@ -41,35 +77,211 @@ def parse_log(file_path, pmu_index=None):
|
||||
logger.error(f"file '{file_path}' not found.")
|
||||
sys.exit(1)
|
||||
|
||||
all_ops = []
|
||||
all_ops: List[Dict[str, Any]] = []
|
||||
current_op: Optional[Dict[str, Any]] = None
|
||||
|
||||
timestamp_pattern = re.compile(r"^(?P<min>\d+)\.(?P<sec>\d+)\.(?P<ms>\d+)\.(?P<us>\d+)\s+[A-Z]\s+")
|
||||
unwrapper = CycleUnwrapper()
|
||||
|
||||
for line in f:
|
||||
match = op_pattern.search(line)
|
||||
if not match: continue
|
||||
ts_match = timestamp_pattern.match(line)
|
||||
abs_usec = 0
|
||||
if ts_match:
|
||||
abs_usec = (
|
||||
(int(ts_match.group('min')) * 60 + int(ts_match.group('sec'))) * 1000000
|
||||
+ int(ts_match.group('ms')) * 1000
|
||||
+ int(ts_match.group('us'))
|
||||
)
|
||||
|
||||
pmu_raw = match.group('pmu')
|
||||
pmu_val = None
|
||||
if pmu_raw and pmu_index is not None:
|
||||
try:
|
||||
pmu_list = [int(x.strip()) for x in pmu_raw.split(',')]
|
||||
if len(pmu_list) > pmu_index:
|
||||
pmu_val = pmu_list[pmu_index]
|
||||
except (ValueError, IndexError):
|
||||
pmu_val = None
|
||||
op_match = op_pattern.search(line)
|
||||
if op_match:
|
||||
pmu_raw = op_match.group('pmu')
|
||||
pmu_val = None
|
||||
if pmu_raw and pmu_index is not None:
|
||||
try:
|
||||
pmu_list = [int(x.strip()) for x in pmu_raw.split(',')]
|
||||
if len(pmu_list) > pmu_index:
|
||||
pmu_val = pmu_list[pmu_index]
|
||||
except (ValueError, IndexError):
|
||||
pmu_val = None
|
||||
|
||||
all_ops.append({
|
||||
'name': match.group('op_name'),
|
||||
'dims': match.group('dims').strip(),
|
||||
'types': match.group('types').strip(),
|
||||
'usec': int(match.group('usec')),
|
||||
'cycles': int(match.group('cycles')),
|
||||
'pmu_val': pmu_val
|
||||
})
|
||||
evt_raw = op_match.group('evt')
|
||||
evt_val = None
|
||||
if evt_raw:
|
||||
try:
|
||||
evt_val = [int(x.strip()) for x in evt_raw.split(',')]
|
||||
except ValueError:
|
||||
evt_val = None
|
||||
|
||||
cycles_start_raw = op_match.group('start')
|
||||
unwrapped_cycles_start = None
|
||||
if cycles_start_raw:
|
||||
unwrapped_cycles_start = unwrapper.unwrap(int(cycles_start_raw))
|
||||
|
||||
idx = line.find("profile-op ")
|
||||
op_text = line[idx + 11:].strip() if idx != -1 else line.strip()
|
||||
|
||||
current_op = {
|
||||
'name': op_match.group('op_name'),
|
||||
'dims': op_match.group('dims').strip(),
|
||||
'types': op_match.group('types').strip(),
|
||||
'op_text': op_text,
|
||||
'usec': int(op_match.group('usec')),
|
||||
'cycles': int(op_match.group('cycles')),
|
||||
'cycles_start': int(cycles_start_raw) if cycles_start_raw else None,
|
||||
'unwrapped_cycles_start': unwrapped_cycles_start,
|
||||
'pmu_val': pmu_val,
|
||||
'evt_val': evt_val,
|
||||
'abs_usec': abs_usec,
|
||||
'trace_events': []
|
||||
}
|
||||
all_ops.append(current_op)
|
||||
continue
|
||||
|
||||
trace_match = trace_pattern.search(line)
|
||||
if trace_match and current_op:
|
||||
if trace_match.group('op_name') == current_op['name']:
|
||||
raw_cyc = int(trace_match.group('cycles'))
|
||||
current_op['trace_events'].append({
|
||||
'thread': int(trace_match.group('thread')),
|
||||
'event': trace_match.group('event'),
|
||||
'info': int(trace_match.group('info')),
|
||||
'cycles': raw_cyc,
|
||||
'unwrapped_cycles': unwrapper.unwrap(raw_cyc),
|
||||
'state': trace_match.group('state')
|
||||
})
|
||||
|
||||
f.close()
|
||||
|
||||
return all_ops
|
||||
|
||||
|
||||
def print_ascii_timeline(op_name, dims, types, usec, cycles, events, evt_val=None):
|
||||
evt_str = ""
|
||||
if evt_val:
|
||||
evt_str = " - evt [" + ",".join(str(x) for x in evt_val) + "]"
|
||||
logger.info("=" * 100)
|
||||
logger.info(f"{op_name} ({dims} : {types}) - {usec} usec {cycles} cycles{evt_str}")
|
||||
logger.info("=" * 100)
|
||||
|
||||
events = sorted(events, key=lambda e: e['cycles'])
|
||||
if not events:
|
||||
logger.info(" No trace events recorded.")
|
||||
return
|
||||
|
||||
min_cycles = events[0]['cycles']
|
||||
|
||||
logger.info("Cycles %-30s" % "EventDetails" + " ".join(f"T{i:<2}" for i in range(10)) + " HMX")
|
||||
logger.info("-" * 100)
|
||||
|
||||
thread_stacks = [[] for _ in range(11)]
|
||||
|
||||
for e in events:
|
||||
t = e['thread']
|
||||
if t < 0 or t > 10:
|
||||
continue
|
||||
|
||||
if e['cycles'] >= min_cycles:
|
||||
rel_cycles = e['cycles'] - min_cycles
|
||||
else:
|
||||
rel_cycles = (e['cycles'] + 0x100000000) - min_cycles
|
||||
|
||||
state = e['state']
|
||||
evt_type = e['event']
|
||||
|
||||
# Determine char representing the event
|
||||
norm_evt = normalize_event_name(evt_type)
|
||||
char = '?'
|
||||
if norm_evt == 'V-COMP':
|
||||
char = 'V'
|
||||
elif norm_evt == 'M-COMP':
|
||||
char = 'H'
|
||||
elif norm_evt == 'A-QUANT':
|
||||
char = 'Q'
|
||||
elif norm_evt == 'A-PREP':
|
||||
char = 'A'
|
||||
elif norm_evt == 'W-DEQUANT':
|
||||
char = 'D'
|
||||
elif norm_evt == 'O-PROC':
|
||||
char = 'O'
|
||||
elif norm_evt == 'W-PREP':
|
||||
char = 'P'
|
||||
elif norm_evt == 'DMA':
|
||||
char = 'M'
|
||||
|
||||
if state == 'start':
|
||||
thread_stacks[t].append(char)
|
||||
elif state == 'stop':
|
||||
if thread_stacks[t]:
|
||||
if thread_stacks[t][-1] == char:
|
||||
thread_stacks[t].pop()
|
||||
elif char in thread_stacks[t]:
|
||||
thread_stacks[t].remove(char)
|
||||
else:
|
||||
thread_stacks[t].pop()
|
||||
|
||||
cols = []
|
||||
for i in range(11):
|
||||
if thread_stacks[i]:
|
||||
cols.append(f"[{thread_stacks[i][-1]}]")
|
||||
else:
|
||||
cols.append(" | ")
|
||||
|
||||
evt_desc = f"T{t}: {evt_type} {state} ({e['info']})"
|
||||
logger.info(f"{rel_cycles:10d} %-30s" % evt_desc + " ".join(cols[:10]) + " " + cols[10])
|
||||
logger.info("-" * 100)
|
||||
|
||||
|
||||
def print_ascii_summary(op_name, dims, types, usec, cycles, events, evt_val=None):
|
||||
evt_str = ""
|
||||
if evt_val:
|
||||
evt_str = " - evt [" + ",".join(str(x) for x in evt_val) + "]"
|
||||
logger.info("=" * 100)
|
||||
logger.info(f"{op_name} ({dims} : {types}) - {usec} usec {cycles} cycles{evt_str}")
|
||||
logger.info("=" * 100)
|
||||
|
||||
events = sorted(events, key=lambda e: e['cycles'])
|
||||
if not events:
|
||||
logger.info(" No trace events recorded.")
|
||||
return
|
||||
|
||||
active_starts = {}
|
||||
thread_totals = defaultdict(lambda: defaultdict(int))
|
||||
|
||||
for e in events:
|
||||
t = e['thread']
|
||||
evt = e['event']
|
||||
info = e['info']
|
||||
cyc = e['cycles']
|
||||
state = e['state']
|
||||
|
||||
key = (t, evt, info)
|
||||
if state == 'start':
|
||||
active_starts[key] = cyc
|
||||
elif state == 'stop':
|
||||
if key in active_starts:
|
||||
start_cyc = active_starts[key]
|
||||
del active_starts[key]
|
||||
|
||||
if cyc >= start_cyc:
|
||||
dur = cyc - start_cyc
|
||||
else:
|
||||
dur = (cyc + 0x100000000) - start_cyc
|
||||
|
||||
norm_evt = normalize_event_name(evt)
|
||||
thread_totals[t][norm_evt] += dur
|
||||
|
||||
for t in sorted(thread_totals.keys()):
|
||||
thread_name = f"Thread {t} (HVX)" if t != 10 else "Thread 10 (HMX)"
|
||||
sorted_evts = sorted(thread_totals[t].items(), key=lambda item: item[0])
|
||||
|
||||
evt_strs = []
|
||||
for evt, dur in sorted_evts:
|
||||
pct = (dur / cycles * 100) if cycles > 0 else 0
|
||||
evt_strs.append(f"{evt} {dur} ({pct:.1f}%)")
|
||||
|
||||
logger.info(f" {thread_name:<16}: " + " | ".join(evt_strs))
|
||||
|
||||
|
||||
def generate_report(ops, top_n, width_overrides, sort_col, pmu_name=None):
|
||||
if not ops:
|
||||
logger.info("No valid records found.")
|
||||
@@ -115,7 +327,6 @@ def generate_report(ops, top_n, width_overrides, sort_col, pmu_name=None):
|
||||
|
||||
# Sorting logic
|
||||
actual_sort_key = COL_MAP[sort_col][2]
|
||||
# We sort numeric fields descending, strings (op/dims) ascending
|
||||
is_numeric = actual_sort_key.startswith("_") or actual_sort_key == "count"
|
||||
sorted_groups = sorted(group_stats, key=lambda x: x[actual_sort_key], reverse=is_numeric)[:top_n]
|
||||
|
||||
@@ -132,7 +343,7 @@ def generate_report(ops, top_n, width_overrides, sort_col, pmu_name=None):
|
||||
if "pmu" in col_name and pmu_name:
|
||||
header_text = header_text.replace("PMU", pmu_name)
|
||||
|
||||
natural_width = max([len(row[data_key]) for row in sorted_groups] + [len(header_text)])
|
||||
natural_width = max([len(str(row[data_key])) for row in sorted_groups] + [len(header_text)])
|
||||
target_width = width_overrides.get(col_name, natural_width)
|
||||
|
||||
if target_width == 0:
|
||||
@@ -152,7 +363,7 @@ def generate_report(ops, top_n, width_overrides, sort_col, pmu_name=None):
|
||||
for group in sorted_groups:
|
||||
row_vals = []
|
||||
for i, key in enumerate(final_keys):
|
||||
val = group[key]
|
||||
val = str(group[key])
|
||||
if len(val) > final_widths[i]:
|
||||
val = val[:final_widths[i] - 3] + "..."
|
||||
row_vals.append(f"{val:<{final_widths[i]}}")
|
||||
@@ -167,12 +378,18 @@ def main():
|
||||
parser.add_argument("--pmu-index", type=int)
|
||||
parser.add_argument("--pmu-name", type=str)
|
||||
parser.add_argument("--width", action='append', default=['dims:40'], help="Override column width, e.g. --width dims:50")
|
||||
parser.add_argument("--timeline", type=str, nargs='?', const='summary', choices=["summary", "diagram"],
|
||||
help="Output ASCII art event summary or timing diagram (default: summary)")
|
||||
parser.add_argument("--filter", type=str, help="Regex filter matching against the original profile-op line")
|
||||
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument("--head", type=int, help="Limit to first N ops")
|
||||
group.add_argument("--tail", type=int, help="Limit to last N ops")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
|
||||
# Sort validation: can't sort by PMU if index isn't provided
|
||||
if "pmu" in args.sort and args.pmu_index is None:
|
||||
logger.error(f"Cannot sort by '{args.sort}' without --pmu-index.")
|
||||
sys.exit(1)
|
||||
@@ -188,7 +405,33 @@ def main():
|
||||
|
||||
final_pmu_name = (args.pmu_name or f"#{args.pmu_index}") if args.pmu_index is not None else None
|
||||
ops = parse_log(args.logfile, pmu_index=args.pmu_index)
|
||||
generate_report(ops, args.top, overrides, args.sort, pmu_name=final_pmu_name)
|
||||
|
||||
if args.filter:
|
||||
try:
|
||||
filter_re = re.compile(args.filter)
|
||||
except re.error as e:
|
||||
logger.error(f"Invalid regex filter: {e}")
|
||||
sys.exit(1)
|
||||
ops = [op for op in ops if filter_re.search(op['op_text'])]
|
||||
|
||||
if args.head is not None:
|
||||
ops = ops[:args.head]
|
||||
elif args.tail is not None:
|
||||
ops = ops[-args.tail:]
|
||||
|
||||
if args.timeline:
|
||||
logger.info(f"\n# ASCII Timing {args.timeline.capitalize()}\n")
|
||||
printed_cnt = 0
|
||||
for op in ops:
|
||||
if args.timeline == "summary":
|
||||
print_ascii_summary(op['name'], op['dims'], op['types'], op['usec'], op['cycles'], op['trace_events'], op.get('evt_val'))
|
||||
elif args.timeline == "diagram":
|
||||
print_ascii_timeline(op['name'], op['dims'], op['types'], op['usec'], op['cycles'], op['trace_events'], op.get('evt_val'))
|
||||
printed_cnt += 1
|
||||
if printed_cnt >= args.top:
|
||||
break
|
||||
else:
|
||||
generate_report(ops, args.top, overrides, args.sort, pmu_name=final_pmu_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Executable
+463
@@ -0,0 +1,463 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
import argparse
|
||||
import statistics
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
logger = logging.getLogger("ggml-hexagon-trace")
|
||||
|
||||
op_pattern = re.compile(
|
||||
r"profile-op\s+(?P<op_name>[A-Z_0-9+]+):\s+.*?\s+:\s+(?P<dims>[\d:x\s\->!]+)\s+:\s+(?P<types>[a-z\d_\s\->x]+)\s+:\s+(?P<strides>[\d:x\s\->!]+)\s+:\s+(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?"
|
||||
)
|
||||
|
||||
trace_pattern = re.compile(
|
||||
r"trace-op\s+(?P<op_name>[A-Z_0-9+]+):\s+thread\s+(?P<thread>\d+)\s+event\s+(?P<event>[A-Z_0-9\-]+)\s+info\s+(?P<info>\d+)\s+(?P<state>start|stop)\s+(?P<cycles>\d+)"
|
||||
)
|
||||
|
||||
|
||||
def normalize_event_name(evt_type):
|
||||
if evt_type == "HVX_COMP":
|
||||
return "V-COMP"
|
||||
if evt_type == "HMX_COMP":
|
||||
return "M-COMP"
|
||||
name = evt_type
|
||||
if name.startswith("HVX_") or name.startswith("HMX_"):
|
||||
name = name[4:]
|
||||
return name.replace("_", "-")
|
||||
|
||||
|
||||
class CycleUnwrapper:
|
||||
def __init__(self):
|
||||
self.last_raw = None
|
||||
self.high_part = 0
|
||||
|
||||
def unwrap(self, raw):
|
||||
if self.last_raw is None:
|
||||
self.last_raw = raw
|
||||
return raw
|
||||
diff = raw - self.last_raw
|
||||
if diff < -0x80000000:
|
||||
self.high_part += 0x100000000
|
||||
elif diff > 0x80000000:
|
||||
self.high_part -= 0x100000000
|
||||
self.last_raw = raw
|
||||
return raw + self.high_part
|
||||
|
||||
|
||||
def parse_log(file_path):
|
||||
try:
|
||||
if file_path != "-":
|
||||
f = open(file_path, 'r', encoding='utf-8', errors='ignore')
|
||||
else:
|
||||
f = os.fdopen(0, 'r', encoding='utf-8', errors='ignore')
|
||||
except FileNotFoundError:
|
||||
logger.error(f"file '{file_path}' not found.")
|
||||
sys.exit(1)
|
||||
|
||||
all_ops: List[Dict[str, Any]] = []
|
||||
current_op: Optional[Dict[str, Any]] = None
|
||||
unwrapper = CycleUnwrapper()
|
||||
line_idx = 0
|
||||
|
||||
for line in f:
|
||||
line_idx += 1
|
||||
op_match = op_pattern.search(line)
|
||||
if op_match:
|
||||
cycles_start_raw = op_match.group('start')
|
||||
unwrapped_cycles_start = None
|
||||
if cycles_start_raw:
|
||||
unwrapped_cycles_start = unwrapper.unwrap(int(cycles_start_raw))
|
||||
|
||||
idx = line.find("profile-op ")
|
||||
op_text = line[idx + 11:].strip() if idx != -1 else line.strip()
|
||||
|
||||
current_op = {
|
||||
'name': op_match.group('op_name'),
|
||||
'dims': op_match.group('dims').strip() if op_match.group('dims') else '',
|
||||
'types': op_match.group('types').strip() if op_match.group('types') else '',
|
||||
'strides': op_match.group('strides').strip() if op_match.group('strides') else '',
|
||||
'op_text': op_text,
|
||||
'usec': int(op_match.group('usec')),
|
||||
'cycles': int(op_match.group('cycles')),
|
||||
'cycles_start': int(cycles_start_raw) if cycles_start_raw else None,
|
||||
'unwrapped_cycles_start': unwrapped_cycles_start,
|
||||
'trace_events': [],
|
||||
'line_num': line_idx
|
||||
}
|
||||
all_ops.append(current_op)
|
||||
continue
|
||||
|
||||
trace_match = trace_pattern.search(line)
|
||||
if trace_match and current_op:
|
||||
if trace_match.group('op_name') == current_op['name']:
|
||||
raw_cyc = int(trace_match.group('cycles'))
|
||||
current_op['trace_events'].append({
|
||||
'thread': int(trace_match.group('thread')),
|
||||
'event': trace_match.group('event'),
|
||||
'info': int(trace_match.group('info')),
|
||||
'cycles': raw_cyc,
|
||||
'unwrapped_cycles': unwrapper.unwrap(raw_cyc),
|
||||
'state': trace_match.group('state')
|
||||
})
|
||||
|
||||
f.close()
|
||||
return all_ops
|
||||
|
||||
# --- Simple protobuf encoder ---
|
||||
|
||||
|
||||
def write_varint(val):
|
||||
if val < 0:
|
||||
val = (1 << 64) + val
|
||||
res = bytearray()
|
||||
while True:
|
||||
towrite = val & 0x7f
|
||||
val >>= 7
|
||||
if val > 0:
|
||||
res.append(towrite | 0x80)
|
||||
else:
|
||||
res.append(towrite)
|
||||
break
|
||||
return bytes(res)
|
||||
|
||||
|
||||
def pb_field(num, wire, data):
|
||||
return write_varint((num << 3) | wire) + data
|
||||
|
||||
|
||||
def pb_varint(num, val):
|
||||
return pb_field(num, 0, write_varint(val))
|
||||
|
||||
|
||||
def pb_length_delimited(num, data):
|
||||
return pb_field(num, 2, write_varint(len(data)) + data)
|
||||
|
||||
|
||||
def pb_string(num, text):
|
||||
return pb_length_delimited(num, text.encode('utf-8'))
|
||||
|
||||
|
||||
# Message Encoders
|
||||
def make_process_descriptor(pid, name):
|
||||
return pb_varint(1, pid) + pb_string(6, name)
|
||||
|
||||
|
||||
def make_thread_descriptor(pid, tid, name, sort_index=None):
|
||||
payload = pb_varint(1, pid) + pb_varint(2, tid) + pb_string(5, name)
|
||||
if sort_index is not None:
|
||||
payload += pb_varint(3, sort_index)
|
||||
return payload
|
||||
|
||||
|
||||
def make_track_descriptor(uuid, name=None, parent_uuid=None, thread=None, process=None, sibling_merge_behavior=None, child_ordering=None, sibling_order_rank=None):
|
||||
payload = pb_varint(1, uuid)
|
||||
if name is not None:
|
||||
payload += pb_string(2, name)
|
||||
if parent_uuid is not None:
|
||||
payload += pb_varint(5, parent_uuid)
|
||||
if process is not None:
|
||||
payload += pb_length_delimited(3, process)
|
||||
if thread is not None:
|
||||
payload += pb_length_delimited(4, thread)
|
||||
if sibling_merge_behavior is not None:
|
||||
payload += pb_varint(15, sibling_merge_behavior)
|
||||
if child_ordering is not None:
|
||||
payload += pb_varint(11, child_ordering)
|
||||
if sibling_order_rank is not None:
|
||||
payload += pb_varint(12, sibling_order_rank)
|
||||
return payload
|
||||
|
||||
|
||||
def make_debug_annotation(name, string_val=None, int_val=None):
|
||||
payload = pb_string(10, name)
|
||||
if string_val is not None:
|
||||
payload += pb_string(6, string_val)
|
||||
elif int_val is not None:
|
||||
payload += pb_varint(4, int_val)
|
||||
return payload
|
||||
|
||||
|
||||
def make_track_event(event_type, track_uuid, name=None, category=None, debug_annotations=None):
|
||||
payload = pb_varint(9, event_type)
|
||||
payload += pb_varint(11, track_uuid)
|
||||
if name is not None:
|
||||
payload += pb_string(23, name)
|
||||
if category is not None:
|
||||
payload += pb_string(22, category)
|
||||
if debug_annotations is not None:
|
||||
for da in debug_annotations:
|
||||
payload += pb_length_delimited(4, da)
|
||||
return payload
|
||||
|
||||
|
||||
def make_trace_packet(timestamp, track_event=None, track_descriptor=None, seq_id=1):
|
||||
payload = pb_varint(8, timestamp)
|
||||
payload += pb_varint(10, seq_id)
|
||||
if track_event is not None:
|
||||
payload += pb_length_delimited(11, track_event)
|
||||
if track_descriptor is not None:
|
||||
payload += pb_length_delimited(60, track_descriptor)
|
||||
return payload
|
||||
|
||||
|
||||
def write_trace_packet_to_file(f, packet_bytes):
|
||||
# Write as field 1 of top-level Trace message
|
||||
f.write(pb_length_delimited(1, packet_bytes))
|
||||
|
||||
# --- End Protobuf Encoder ---
|
||||
|
||||
|
||||
def generate_perfetto_trace(filtered_ops, output_path):
|
||||
if not filtered_ops:
|
||||
logger.warning("No operators found after filtering.")
|
||||
return
|
||||
|
||||
# Compute average frequency
|
||||
frequencies = []
|
||||
for op in filtered_ops:
|
||||
if op['usec'] > 0 and op['cycles'] > 0:
|
||||
frequencies.append(op['cycles'] / op['usec'])
|
||||
avg_freq_mhz = statistics.mean(frequencies) if frequencies else 1000.0
|
||||
if avg_freq_mhz <= 0:
|
||||
avg_freq_mhz = 1000.0
|
||||
|
||||
# Assign start and end cycles to each operator
|
||||
for op in filtered_ops:
|
||||
op['start_cycles'] = op['unwrapped_cycles_start']
|
||||
op['end_cycles'] = op['start_cycles'] + op['cycles']
|
||||
|
||||
global_min_cyc = min(op['start_cycles'] for op in filtered_ops if op['start_cycles'] is not None)
|
||||
|
||||
# Process events
|
||||
completed_events = []
|
||||
for op in filtered_ops:
|
||||
events = op['trace_events']
|
||||
if not events:
|
||||
continue
|
||||
events = sorted(events, key=lambda e: e['unwrapped_cycles'])
|
||||
|
||||
active_starts = {}
|
||||
for e in events:
|
||||
t = e['thread']
|
||||
evt = e['event']
|
||||
info = e['info']
|
||||
state = e['state']
|
||||
cyc = e['unwrapped_cycles']
|
||||
|
||||
key = (t, evt, info)
|
||||
if state == 'start':
|
||||
active_starts[key] = cyc
|
||||
elif state == 'stop':
|
||||
if key in active_starts:
|
||||
start_cyc = active_starts[key]
|
||||
del active_starts[key]
|
||||
completed_events.append({
|
||||
'thread': t,
|
||||
'event': evt,
|
||||
'info': info,
|
||||
'start_cyc': start_cyc,
|
||||
'end_cyc': cyc,
|
||||
'op_name': op['name']
|
||||
})
|
||||
|
||||
completed_events.sort(key=lambda e: e['start_cyc'])
|
||||
|
||||
# Convert event times to microseconds and apply clamp rounded to 1ns resolution (3 decimals)
|
||||
for e in completed_events:
|
||||
start_us = (e['start_cyc'] - global_min_cyc) / avg_freq_mhz
|
||||
dur_us = (e['end_cyc'] - e['start_cyc']) / avg_freq_mhz
|
||||
e['ts_ns'] = int(round(start_us * 1000))
|
||||
e['dur_ns'] = int(round(max(dur_us, 0.1) * 1000))
|
||||
|
||||
# Allocate slots (sub-tracks) to prevent overlaps on same virtual track
|
||||
active_slots = defaultdict(list)
|
||||
for e in completed_events:
|
||||
t = e['thread']
|
||||
evt = e['event']
|
||||
ts = e['ts_ns']
|
||||
dur = e['dur_ns']
|
||||
|
||||
norm_evt = normalize_event_name(evt)
|
||||
if norm_evt == "DMA":
|
||||
track_key = (t, "DMA")
|
||||
elif t == 10:
|
||||
track_key = (t, "HMX")
|
||||
else:
|
||||
track_key = (t, "HVX")
|
||||
|
||||
slots = active_slots[track_key]
|
||||
allocated_slot = -1
|
||||
for idx, slot_end_ns in enumerate(slots):
|
||||
if ts >= slot_end_ns:
|
||||
slots[idx] = ts + dur
|
||||
allocated_slot = idx
|
||||
break
|
||||
if allocated_slot == -1:
|
||||
slots.append(ts + dur)
|
||||
allocated_slot = len(slots) - 1
|
||||
e['slot'] = allocated_slot
|
||||
|
||||
# Generate Track IDs and track definitions
|
||||
used_tracks = {}
|
||||
for e in completed_events:
|
||||
t = e['thread']
|
||||
evt = e['event']
|
||||
slot = e['slot']
|
||||
|
||||
norm_evt = normalize_event_name(evt)
|
||||
if norm_evt == "DMA":
|
||||
track_evt = "DMA"
|
||||
evt_id = 1
|
||||
elif t == 10:
|
||||
track_evt = "HMX"
|
||||
evt_id = 3
|
||||
else:
|
||||
track_evt = "HVX"
|
||||
evt_id = 2
|
||||
|
||||
t_sort = 1 if t == 10 else t + 2
|
||||
# Unique UUID for each sub-track
|
||||
if t == 10:
|
||||
uuid = 20 # HMX thread track UUID
|
||||
else:
|
||||
uuid = int(t_sort * 1000000 + evt_id * 1000 + slot)
|
||||
e['uuid'] = uuid
|
||||
used_tracks[uuid] = (t, track_evt, slot)
|
||||
|
||||
with open(output_path, "wb") as f:
|
||||
# Define Process with EXPLICIT child sorting
|
||||
proc_desc = make_process_descriptor(1, "HTP NPU")
|
||||
proc_packet = make_trace_packet(0, track_descriptor=make_track_descriptor(1, process=proc_desc, child_ordering=3))
|
||||
write_trace_packet_to_file(f, proc_packet)
|
||||
|
||||
# Define Operators Track (UUID = 2) as a thread track at rank 1, tid 8
|
||||
op_thread_desc = make_thread_descriptor(1, 8, "Ops", sort_index=1)
|
||||
op_packet = make_trace_packet(0, track_descriptor=make_track_descriptor(2, parent_uuid=1, thread=op_thread_desc))
|
||||
write_trace_packet_to_file(f, op_packet)
|
||||
|
||||
# Define HMX Thread Track (UUID = 20) at rank 2, tid 9
|
||||
hmx_thread_desc = make_thread_descriptor(1, 9, "HMX", sort_index=2)
|
||||
hmx_packet = make_trace_packet(0, track_descriptor=make_track_descriptor(20, parent_uuid=1, thread=hmx_thread_desc))
|
||||
write_trace_packet_to_file(f, hmx_packet)
|
||||
|
||||
# Define Thread Tracks (T0, T1, ..., T9)
|
||||
unique_threads = sorted(list(set(t for (t, _, _) in used_tracks.values() if t != 10)))
|
||||
for t in unique_threads:
|
||||
thread_uuid = 10 + t
|
||||
thread_name = f"T{t}"
|
||||
# Sort order starts from index 3 (T0 -> 3, T1 -> 4, etc.)
|
||||
sort_index = 3 + t
|
||||
tid = 10 + t
|
||||
thread_desc = make_thread_descriptor(1, tid, thread_name, sort_index=sort_index)
|
||||
thread_packet = make_trace_packet(0, track_descriptor=make_track_descriptor(
|
||||
thread_uuid,
|
||||
parent_uuid=1,
|
||||
thread=thread_desc,
|
||||
sibling_order_rank=sort_index,
|
||||
child_ordering=3 # Explicit child sorting for sub-tracks
|
||||
))
|
||||
write_trace_packet_to_file(f, thread_packet)
|
||||
|
||||
# Define Track descriptors for sub-tracks parented to thread tracks
|
||||
for uuid in sorted(used_tracks.keys()):
|
||||
if uuid == 20:
|
||||
continue
|
||||
t, evt, slot = used_tracks[uuid]
|
||||
name = f"T{t} {evt}"
|
||||
rank = 0 if evt == "HVX" else 1
|
||||
parent_thread_uuid = 10 + t
|
||||
# Sibling merge behavior: 1 (SIBLING_MERGE_BEHAVIOR_BY_TRACK_NAME)
|
||||
track_desc = make_track_descriptor(
|
||||
uuid=uuid,
|
||||
name=name,
|
||||
parent_uuid=parent_thread_uuid,
|
||||
sibling_merge_behavior=1,
|
||||
sibling_order_rank=rank
|
||||
)
|
||||
track_packet = make_trace_packet(0, track_descriptor=track_desc)
|
||||
write_trace_packet_to_file(f, track_packet)
|
||||
|
||||
# Emit Operators
|
||||
last_op_end_ns = 0
|
||||
for op in filtered_ops:
|
||||
op_start_ns = int(round(((op['start_cycles'] - global_min_cyc) / avg_freq_mhz) * 1000))
|
||||
op_dur_ns = int(round((op['cycles'] / avg_freq_mhz) * 1000))
|
||||
if op_start_ns < last_op_end_ns:
|
||||
op_start_ns = last_op_end_ns
|
||||
clamped_dur = max(op_dur_ns, 100) # Clamp to 100ns (0.1us)
|
||||
|
||||
# Debug annotations for Ops
|
||||
debug_annots = []
|
||||
if 'line_num' in op:
|
||||
debug_annots.append(make_debug_annotation("line", int_val=op['line_num']))
|
||||
if 'strides' in op and op['strides']:
|
||||
debug_annots.append(make_debug_annotation("strides", string_val=op['strides']))
|
||||
|
||||
# Slice Begin
|
||||
evt_begin = make_track_event(1, 2, name=f"{op['name']} ({op['dims']})", category="operator", debug_annotations=debug_annots)
|
||||
packet_begin = make_trace_packet(op_start_ns, track_event=evt_begin)
|
||||
write_trace_packet_to_file(f, packet_begin)
|
||||
|
||||
# Slice End
|
||||
evt_end = make_track_event(2, 2)
|
||||
packet_end = make_trace_packet(op_start_ns + clamped_dur, track_event=evt_end)
|
||||
write_trace_packet_to_file(f, packet_end)
|
||||
|
||||
last_op_end_ns = op_start_ns + clamped_dur
|
||||
|
||||
# Emit Thread Trace Events
|
||||
for e in completed_events:
|
||||
norm_name = normalize_event_name(e['event'])
|
||||
name = f"DMA {e['info']}" if norm_name == "DMA" else norm_name
|
||||
|
||||
# Slice Begin
|
||||
evt_begin = make_track_event(1, e['uuid'], name=name, category="trace")
|
||||
packet_begin = make_trace_packet(e['ts_ns'], track_event=evt_begin)
|
||||
write_trace_packet_to_file(f, packet_begin)
|
||||
|
||||
# Slice End
|
||||
evt_end = make_track_event(2, e['uuid'])
|
||||
packet_end = make_trace_packet(e['ts_ns'] + e['dur_ns'], track_event=evt_end)
|
||||
write_trace_packet_to_file(f, packet_end)
|
||||
|
||||
logger.info(f"Successfully generated Perfetto trace at {output_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Convert Hexagon Op profile logs to native Perfetto Protobuf traces.")
|
||||
parser.add_argument("logfile", help="Path to hex-log profile file")
|
||||
parser.add_argument("-o", "--output", default="optrace.perfetto-trace", help="Output trace file path (default: optrace.perfetto-trace)")
|
||||
parser.add_argument("--filter", type=str, help="Regex filter matching against the original profile-op line")
|
||||
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument("--head", type=int, help="Limit to first N ops")
|
||||
group.add_argument("--tail", type=int, help="Limit to last N ops")
|
||||
|
||||
args = parser.parse_args()
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
|
||||
ops = parse_log(args.logfile)
|
||||
|
||||
if args.filter:
|
||||
try:
|
||||
filter_re = re.compile(args.filter)
|
||||
except re.error as e:
|
||||
logger.error(f"Invalid regex filter: {e}")
|
||||
sys.exit(1)
|
||||
ops = [op for op in ops if filter_re.search(op['op_text'])]
|
||||
|
||||
if args.head is not None:
|
||||
ops = ops[:args.head]
|
||||
elif args.tail is not None:
|
||||
ops = ops[-args.tail:]
|
||||
|
||||
generate_perfetto_trace(ops, args.output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+23
-4
@@ -20,6 +20,7 @@ set(LLAMA_UI_GZIP "" CACHE STRING "Apply gzip compress to assets to save ban
|
||||
|
||||
set(DIST_DIR "${UI_BINARY_DIR}/dist")
|
||||
set(SRC_DIST_DIR "${UI_SOURCE_DIR}/dist")
|
||||
set(WORK_DIR "${UI_BINARY_DIR}/ui-src")
|
||||
set(STAMP_FILE "${UI_BINARY_DIR}/.ui-stamp")
|
||||
set(UI_CPP "${UI_BINARY_DIR}/ui.cpp")
|
||||
set(UI_H "${UI_BINARY_DIR}/ui.h")
|
||||
@@ -64,6 +65,22 @@ function(npm_build_should_skip out_var)
|
||||
set(${out_var} TRUE PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
function(stage_sources)
|
||||
if(EXISTS "${WORK_DIR}")
|
||||
file(GLOB staged RELATIVE "${WORK_DIR}" "${WORK_DIR}/*")
|
||||
list(REMOVE_ITEM staged "node_modules")
|
||||
foreach(entry ${staged})
|
||||
file(REMOVE_RECURSE "${WORK_DIR}/${entry}")
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
file(COPY "${UI_SOURCE_DIR}/"
|
||||
DESTINATION "${WORK_DIR}"
|
||||
NO_SOURCE_PERMISSIONS
|
||||
PATTERN "node_modules" EXCLUDE
|
||||
)
|
||||
endfunction()
|
||||
|
||||
function(npm_build out_var)
|
||||
set(${out_var} FALSE PARENT_SCOPE)
|
||||
|
||||
@@ -89,14 +106,16 @@ function(npm_build out_var)
|
||||
return()
|
||||
endif()
|
||||
|
||||
stage_sources()
|
||||
|
||||
# npm writes node_modules/.package-lock.json on every successful install,
|
||||
# so a package-lock.json newer than this marker means node_modules is stale
|
||||
set(NPM_MARKER "${UI_SOURCE_DIR}/node_modules/.package-lock.json")
|
||||
set(NPM_MARKER "${WORK_DIR}/node_modules/.package-lock.json")
|
||||
set(need_install FALSE)
|
||||
if(NOT EXISTS "${NPM_MARKER}")
|
||||
set(need_install TRUE)
|
||||
else()
|
||||
file(TIMESTAMP "${UI_SOURCE_DIR}/package-lock.json" lock_ts)
|
||||
file(TIMESTAMP "${WORK_DIR}/package-lock.json" lock_ts)
|
||||
file(TIMESTAMP "${NPM_MARKER}" marker_ts)
|
||||
if(lock_ts STRGREATER marker_ts)
|
||||
set(need_install TRUE)
|
||||
@@ -107,7 +126,7 @@ function(npm_build out_var)
|
||||
message(STATUS "UI: running npm install")
|
||||
execute_process(
|
||||
COMMAND ${NPM_EXECUTABLE} install
|
||||
WORKING_DIRECTORY "${UI_SOURCE_DIR}"
|
||||
WORKING_DIRECTORY "${WORK_DIR}"
|
||||
RESULT_VARIABLE rc
|
||||
ERROR_VARIABLE err
|
||||
)
|
||||
@@ -124,7 +143,7 @@ function(npm_build out_var)
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env "LLAMA_UI_OUT_DIR=${DIST_DIR}" "LLAMA_UI_VERSION=${HF_VERSION}" "LLAMA_BUILD_NUMBER=${LLAMA_BUILD_NUMBER}"
|
||||
${NPM_EXECUTABLE} run build
|
||||
WORKING_DIRECTORY "${UI_SOURCE_DIR}"
|
||||
WORKING_DIRECTORY "${WORK_DIR}"
|
||||
RESULT_VARIABLE rc
|
||||
ERROR_VARIABLE err
|
||||
)
|
||||
|
||||
@@ -6,11 +6,10 @@ Apply LORA adapters to base model and export the resulting model.
|
||||
usage: llama-export-lora [options]
|
||||
|
||||
options:
|
||||
-m, --model model path from which to load base model (default '')
|
||||
--lora FNAME path to LoRA adapter (can be repeated to use multiple adapters)
|
||||
--lora-scaled FNAME S path to LoRA adapter with user defined scaling S (can be repeated to use multiple adapters)
|
||||
-t, --threads N number of threads to use during computation (default: 4)
|
||||
-o, --output FNAME output file (default: 'ggml-lora-merged-f16.gguf')
|
||||
-m, --model FNAME model path from which to load base model
|
||||
--lora FNAME path to LoRA adapter (use comma-separated values to load multiple adapters)
|
||||
--lora-scaled FNAME:SCALE,... path to LoRA adapter with user defined scaling (format: FNAME:SCALE,...)
|
||||
-o, --output, --output-file FNAME output file (default: 'ggml-lora-merged-f16.gguf')
|
||||
```
|
||||
|
||||
For example:
|
||||
@@ -22,12 +21,11 @@ For example:
|
||||
--lora lora-open-llama-3b-v2-english2tokipona-chat-LATEST.gguf
|
||||
```
|
||||
|
||||
Multiple LORA adapters can be applied by passing multiple `--lora FNAME` or `--lora-scaled FNAME S` command line parameters:
|
||||
Multiple LORA adapters can be applied by passing comma-separated values to `--lora FNAME` or `--lora-scaled FNAME:SCALE,...`:
|
||||
|
||||
```bash
|
||||
./bin/llama-export-lora \
|
||||
-m your_base_model.gguf \
|
||||
-o your_merged_model.gguf \
|
||||
--lora-scaled lora_task_A.gguf 0.5 \
|
||||
--lora-scaled lora_task_B.gguf 0.5
|
||||
--lora-scaled lora_task_A.gguf:0.5,lora_task_B.gguf:0.5
|
||||
```
|
||||
|
||||
+34
-36
@@ -26,6 +26,13 @@ void mtmd_image_preproc_out::append(const clip_hparams & hparams, clip_image_f32
|
||||
entries.push_back(std::move(img));
|
||||
}
|
||||
|
||||
void mtmd_image_preproc_out::append_overview(const clip_hparams & hparams, const clip_image_u8 & img, bool normalized) {
|
||||
overview.from_u8(img);
|
||||
if (normalized) {
|
||||
overview.normalize(hparams.image_mean, hparams.image_std);
|
||||
}
|
||||
}
|
||||
|
||||
// set of tools to manipulate images
|
||||
// in the future, we can have HW acceleration by allowing this struct to access 3rd party lib like imagick or opencv
|
||||
struct img_tool {
|
||||
@@ -607,10 +614,11 @@ private:
|
||||
mtmd_image_preproc_out mtmd_image_preprocessor_llava_uhd::preprocess(const clip_image_u8 & img) {
|
||||
const clip_image_size original_size = img.get_size();
|
||||
auto const inst = get_slice_instructions(original_size);
|
||||
std::vector<clip_image_u8> imgs = slice_image(img, inst);
|
||||
auto sliced = slice_image(img, inst);
|
||||
|
||||
mtmd_image_preproc_out output;
|
||||
output.append(hparams, imgs, true);
|
||||
output.append_overview(hparams, sliced.overview, true);
|
||||
output.append(hparams, sliced.slices, true);
|
||||
output.grid_x = inst.grid_size.width;
|
||||
output.grid_y = inst.grid_size.height;
|
||||
|
||||
@@ -722,22 +730,15 @@ mtmd_image_preprocessor_llava_uhd::slice_instructions mtmd_image_preprocessor_ll
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<clip_image_u8> mtmd_image_preprocessor_llava_uhd::slice_image(const clip_image_u8 & img, const mtmd_image_preprocessor_llava_uhd::slice_instructions & inst, bool overview_first) {
|
||||
std::vector<clip_image_u8> output;
|
||||
mtmd_image_preprocessor_llava_uhd::slice_output mtmd_image_preprocessor_llava_uhd::slice_image(const clip_image_u8 & img, const mtmd_image_preprocessor_llava_uhd::slice_instructions & inst) {
|
||||
slice_output output;
|
||||
|
||||
// resize to overview size
|
||||
clip_image_u8 resized_img;
|
||||
img_tool::resize(img, resized_img, inst.overview_size, hparams.image_resize_algo_ov,
|
||||
img_tool::resize(img, output.overview, inst.overview_size, hparams.image_resize_algo_ov,
|
||||
hparams.image_pad_ov, hparams.image_pad_color_ov);
|
||||
if (overview_first) {
|
||||
output.push_back(resized_img);
|
||||
}
|
||||
|
||||
if (inst.slices.empty()) {
|
||||
// no slices, just return the resized image
|
||||
if (!overview_first) {
|
||||
output.push_back(resized_img);
|
||||
}
|
||||
// no slices, just return the overview image
|
||||
return output;
|
||||
}
|
||||
|
||||
@@ -755,11 +756,7 @@ std::vector<clip_image_u8> mtmd_image_preprocessor_llava_uhd::slice_image(const
|
||||
|
||||
clip_image_u8 img_slice;
|
||||
img_tool::crop(refined_img, img_slice, x, y, w, h);
|
||||
output.push_back(std::move(img_slice));
|
||||
}
|
||||
|
||||
if (!overview_first) {
|
||||
output.push_back(resized_img);
|
||||
output.slices.push_back(std::move(img_slice));
|
||||
}
|
||||
|
||||
return output;
|
||||
@@ -1077,10 +1074,11 @@ mtmd_image_preproc_out mtmd_image_preprocessor_idefics3::preprocess(const clip_i
|
||||
});
|
||||
}
|
||||
}
|
||||
auto imgs = slice_image(img, instructions);
|
||||
auto sliced = slice_image(img, instructions);
|
||||
|
||||
mtmd_image_preproc_out output;
|
||||
output.append(hparams, imgs, true);
|
||||
output.append_overview(hparams, sliced.overview, true);
|
||||
output.append(hparams, sliced.slices, true);
|
||||
output.grid_x = instructions.grid_size.width;
|
||||
output.grid_y = instructions.grid_size.height;
|
||||
return output;
|
||||
@@ -1094,10 +1092,12 @@ mtmd_image_preproc_out mtmd_image_preprocessor_internvl::preprocess(const clip_i
|
||||
GGML_ASSERT(!hparams.image_res_candidates.empty());
|
||||
const clip_image_size original_size = img.get_size();
|
||||
auto const inst = get_slice_instructions(original_size);
|
||||
std::vector<clip_image_u8> imgs = slice_image(img, inst, false);
|
||||
auto sliced = slice_image(img, inst);
|
||||
|
||||
mtmd_image_preproc_out output;
|
||||
output.append(hparams, imgs, true);
|
||||
// InternVL: slices first, then overview
|
||||
output.append(hparams, sliced.slices, true);
|
||||
output.append_overview(hparams, sliced.overview, true);
|
||||
output.grid_x = inst.grid_size.width;
|
||||
output.grid_y = inst.grid_size.height;
|
||||
return output;
|
||||
@@ -1131,9 +1131,10 @@ mtmd_image_preproc_out mtmd_image_preprocessor_deepseekocr::preprocess(const cli
|
||||
img_tool::resize(img, padded, {image_size, image_size}, RESIZE_ALGO_BICUBIC_PILLOW,
|
||||
PAD_NEAREST, hparams.image_pad_color);
|
||||
mtmd_image_preproc_out output;
|
||||
output.append(hparams, padded, true);
|
||||
output.grid_x = 1;
|
||||
output.grid_y = 1;
|
||||
output.append_overview(hparams, padded, true);
|
||||
output.grid_x = 0;
|
||||
output.grid_y = 0;
|
||||
// TODO @ngxson : support slicing for DeepSeek-OCR, to do in another PR
|
||||
return output;
|
||||
}
|
||||
|
||||
@@ -1226,10 +1227,8 @@ mtmd_image_preproc_out mtmd_image_preprocessor_deepseekocr2::preprocess(const cl
|
||||
clip_image_u8 padded;
|
||||
img_tool::resize(img, padded, { base_size, base_size }, RESIZE_ALGO_BICUBIC_PILLOW,
|
||||
PAD_NEAREST, hparams.image_pad_color);
|
||||
output.append(hparams, padded, true);
|
||||
output.entries.back().add_viewsep = true;
|
||||
output.grid_x = 1;
|
||||
output.grid_y = 1;
|
||||
output.append_overview(hparams, padded, true);
|
||||
output.overview.add_viewsep = true;
|
||||
return output;
|
||||
}
|
||||
|
||||
@@ -1447,15 +1446,14 @@ mtmd_image_preproc_out mtmd_image_preprocessor_step3vl::preprocess(const clip_im
|
||||
const auto instructions = build_slice_instructions(hparams, prepared.get_size());
|
||||
|
||||
mtmd_image_preproc_out output;
|
||||
clip_image_f32 overview_f32;
|
||||
// overview (normalized f32, already includes mean/std)
|
||||
img_u8_resize_bilinear_to_f32(
|
||||
prepared,
|
||||
overview_f32,
|
||||
output.overview,
|
||||
hparams.image_size,
|
||||
hparams.image_size,
|
||||
hparams.image_mean,
|
||||
hparams.image_std);
|
||||
output.append(hparams, overview_f32, false);
|
||||
|
||||
if (instructions.slices.empty()) {
|
||||
output.grid_x = 0;
|
||||
@@ -1548,13 +1546,13 @@ mtmd_image_preproc_out mtmd_image_preprocessor_youtuvl::preprocess(const clip_im
|
||||
|
||||
mtmd_image_preproc_out mtmd_image_preprocessor_granite::preprocess(const clip_image_u8 & img) {
|
||||
auto output = mtmd_image_preprocessor_llava_uhd::preprocess(img);
|
||||
if (output.entries.size() == 1) {
|
||||
if (output.entries.size() == 0) {
|
||||
// Single-tile (overview only): append one newline row.
|
||||
output.entries[0].add_newline = true;
|
||||
output.overview.add_newline = true;
|
||||
} else {
|
||||
// Multi-tile: overview gets no newline, grid tiles get one.
|
||||
output.entries[0].add_newline = false;
|
||||
for (size_t i = 1; i < output.entries.size(); ++i) {
|
||||
output.overview.add_newline = false;
|
||||
for (size_t i = 0; i < output.entries.size(); ++i) {
|
||||
output.entries[i].add_newline = true;
|
||||
}
|
||||
}
|
||||
|
||||
+15
-1
@@ -11,11 +11,19 @@
|
||||
struct mtmd_image_preproc_out {
|
||||
std::vector<clip_image_f32> entries;
|
||||
// grid size is required for llava-uhd style models
|
||||
|
||||
clip_image_f32 overview; // overview image (downscaled image)
|
||||
int grid_x = 0;
|
||||
int grid_y = 0;
|
||||
|
||||
void append(const clip_hparams & hparams, const clip_image_u8 & img, bool normalized = true);
|
||||
void append(const clip_hparams & hparams, const std::vector<clip_image_u8> & imgs, bool normalized = true);
|
||||
void append(const clip_hparams & hparams, clip_image_f32 & img, bool normalized = true);
|
||||
|
||||
void append_overview(const clip_hparams & hparams, const clip_image_u8 & img, bool normalized = true);
|
||||
bool has_overview() const {
|
||||
return overview.nx() > 0 || overview.ny() > 0;
|
||||
}
|
||||
};
|
||||
|
||||
// base class, models must inherit from this class
|
||||
@@ -46,6 +54,8 @@ struct mtmd_image_preprocessor {
|
||||
* [overview] --> [slice 1] --> [slice 2]
|
||||
* | |
|
||||
* +--> [slice 3] --> [slice 4]
|
||||
*
|
||||
* NOTE: for the ordering of overview, set "ov_img_first" on the mtmd_context
|
||||
*/
|
||||
struct mtmd_image_preprocessor_llava_uhd : mtmd_image_preprocessor {
|
||||
mtmd_image_preprocessor_llava_uhd(const clip_ctx * ctx) : mtmd_image_preprocessor(ctx) {}
|
||||
@@ -67,7 +77,11 @@ struct mtmd_image_preprocessor_llava_uhd : mtmd_image_preprocessor {
|
||||
// LFM2 override this function to implement its custom slicing logic
|
||||
virtual slice_instructions get_slice_instructions(const clip_image_size & original_size);
|
||||
|
||||
std::vector<clip_image_u8> slice_image(const clip_image_u8 & img, const slice_instructions & inst, bool overview_first = true);
|
||||
struct slice_output {
|
||||
clip_image_u8 overview;
|
||||
std::vector<clip_image_u8> slices;
|
||||
};
|
||||
slice_output slice_image(const clip_image_u8 & img, const slice_instructions & inst);
|
||||
|
||||
private:
|
||||
clip_image_size get_best_resize(const clip_image_size & original_size, int scale_resolution, int patch_size, bool allow_upscale = false);
|
||||
|
||||
+64
-51
@@ -516,6 +516,7 @@ struct mtmd_context {
|
||||
LOG_WRN("%s: llama 4 vision is known to have degraded quality:\n"
|
||||
" https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_llava_uhd>(ctx_v);
|
||||
ov_img_first = false;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_STEP3VL:
|
||||
{
|
||||
@@ -539,6 +540,7 @@ struct mtmd_context {
|
||||
img_beg = "<img>";
|
||||
img_end = "</img>";
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_internvl>(ctx_v);
|
||||
ov_img_first = false;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
{
|
||||
@@ -615,11 +617,13 @@ struct mtmd_context {
|
||||
{
|
||||
img_end = "\n"; // prevent empty batch on llama-server
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_deepseekocr>(ctx_v);
|
||||
ov_img_first = false;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_DEEPSEEKOCR2:
|
||||
{
|
||||
img_end = "\n"; // prevent empty batch on llama-server
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_deepseekocr2>(ctx_v);
|
||||
ov_img_first = false;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_HUNYUANVL:
|
||||
{
|
||||
@@ -640,6 +644,7 @@ struct mtmd_context {
|
||||
img_beg = "<image>";
|
||||
img_end = "";
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_granite>(ctx_v);
|
||||
ov_img_first = true;
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error(string_format("%s: unexpected vision projector type %d\n", __func__, proj));
|
||||
@@ -1079,26 +1084,38 @@ struct mtmd_tokenizer {
|
||||
|
||||
// for llava-uhd style, we need to handle grid too
|
||||
// we don't care about overwriting these values for now because the case where bitmaps.size() > 1 is only for frame merging (qwen-vl), not supported by llava-uhd
|
||||
if (tmp_preproc_out.grid_x > 0 && tmp_preproc_out.grid_y > 0) {
|
||||
if ((tmp_preproc_out.grid_x > 0 && tmp_preproc_out.grid_y > 0)
|
||||
|| tmp_preproc_out.has_overview()) {
|
||||
GGML_ASSERT(bitmaps.size() == 1);
|
||||
preproc_out.grid_x = tmp_preproc_out.grid_x;
|
||||
preproc_out.grid_y = tmp_preproc_out.grid_y;
|
||||
preproc_out.overview = std::move(tmp_preproc_out.overview);
|
||||
}
|
||||
}
|
||||
|
||||
LOG_DBG("%s: preproc_out has %zu entries, grid_x = %d, grid_y = %d, has_overview = %d\n",
|
||||
__func__, preproc_out.entries.size(), preproc_out.grid_x, preproc_out.grid_y,
|
||||
preproc_out.has_overview() ? 1 : 0);
|
||||
|
||||
// handle llava-uhd style preprocessing
|
||||
const bool has_tiling_grid = preproc_out.grid_x > 0 && preproc_out.grid_y > 0;
|
||||
// (output either a grid, or overview-only)
|
||||
const bool has_tiling_grid = (preproc_out.grid_x > 0 && preproc_out.grid_y > 0)
|
||||
|| preproc_out.has_overview();
|
||||
|
||||
if (has_tiling_grid) {
|
||||
// [QWEN_VIDEO] we do not support "frame merging" for llama-uhd style, so no batching for now
|
||||
GGML_ASSERT(bitmaps.size() == 1);
|
||||
|
||||
const int n_col = preproc_out.grid_x;
|
||||
const int n_row = preproc_out.grid_y;
|
||||
|
||||
// split batch into chunks of single images
|
||||
// NOTE: preproc_out will be invalidated after this call
|
||||
auto chunks = split_batch_to_chunk(std::move(preproc_out), bitmaps[0]->id);
|
||||
GGML_ASSERT(chunks.size() > 0);
|
||||
|
||||
// NOTE: preproc_out is invalidated after this point, do not use it anymore
|
||||
|
||||
// split_batch_to_chunk must always put the overview image first
|
||||
auto ov_chunk = std::move(chunks.front());
|
||||
chunks.erase(chunks.begin());
|
||||
|
||||
@@ -1125,7 +1142,16 @@ struct mtmd_tokenizer {
|
||||
std::snprintf(buf.get(), sz, ctx->sli_img_start_tmpl.c_str(), y+1, x+1);
|
||||
add_text(std::string(buf.get(), buf.get() + sz - 1), true);
|
||||
}
|
||||
cur.entries.emplace_back(std::move(chunks[y * n_col + x]));
|
||||
|
||||
auto & curr_chunk = chunks[y * n_col + x];
|
||||
auto & curr_batch = curr_chunk.tokens_image->batch_f32;
|
||||
if (curr_batch.entries.size() != 1) {
|
||||
throw std::runtime_error(string_format("%s: expect 1 image in batch_f32", __func__));
|
||||
}
|
||||
|
||||
LOG_DBG("%s: adding slice image at row %d col %d\n", __func__, y, x);
|
||||
cur.entries.emplace_back(std::move(curr_chunk));
|
||||
|
||||
add_text(ctx->tok_sli_img_end);
|
||||
if (!is_last_in_row) {
|
||||
add_text(ctx->tok_sli_img_mid);
|
||||
@@ -1147,6 +1173,11 @@ struct mtmd_tokenizer {
|
||||
|
||||
} else {
|
||||
|
||||
if (preproc_out.entries.size() == 0) {
|
||||
LOG_ERR("%s: no image tokens produced by preprocessor (ref: https://github.com/ggml-org/llama.cpp/pull/24769)\n", __func__);
|
||||
return 2;
|
||||
}
|
||||
|
||||
size_t n_tokens = 0;
|
||||
for (auto & e : preproc_out.entries) {
|
||||
n_tokens += clip_n_output_tokens(ctx->ctx_v, &e);
|
||||
@@ -1303,13 +1334,15 @@ struct mtmd_tokenizer {
|
||||
std::vector<mtmd_input_chunk> split_batch_to_chunk(mtmd_image_preproc_out && preproc_out, const std::string & id) {
|
||||
std::vector<mtmd_input_chunk> chunks;
|
||||
|
||||
for (auto & entry : preproc_out.entries) {
|
||||
auto process_chunk = [&](clip_image_f32 && img) {
|
||||
mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
|
||||
image_tokens->nx = clip_n_output_tokens(ctx->ctx_v, &entry);
|
||||
image_tokens->nx = clip_n_output_tokens(ctx->ctx_v, &img);
|
||||
image_tokens->ny = 1;
|
||||
image_tokens->batch_f32.entries.push_back(std::move(entry));
|
||||
image_tokens->batch_f32.entries.push_back(std::move(img));
|
||||
image_tokens->id = id;
|
||||
|
||||
GGML_ASSERT(image_tokens->nx > 0);
|
||||
|
||||
mtmd_input_chunk chunk{
|
||||
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
||||
{}, // text tokens
|
||||
@@ -1317,6 +1350,21 @@ struct mtmd_tokenizer {
|
||||
nullptr, // audio tokens
|
||||
};
|
||||
chunks.emplace_back(std::move(chunk));
|
||||
};
|
||||
|
||||
// overview image first
|
||||
auto & overview = preproc_out.overview;
|
||||
if (overview.nx() == 0 || overview.ny() == 0) {
|
||||
throw std::runtime_error(string_format("%s: invalid overview image for llava-uhd style preprocessing\n", __func__));
|
||||
}
|
||||
process_chunk(std::move(preproc_out.overview));
|
||||
|
||||
// then, process slices
|
||||
for (auto & entry : preproc_out.entries) {
|
||||
if (entry.nx() == 0 || entry.ny() == 0) {
|
||||
throw std::runtime_error(string_format("%s: invalid image slice for llava-uhd style preprocessing\n", __func__));
|
||||
}
|
||||
process_chunk(std::move(entry));
|
||||
}
|
||||
|
||||
return chunks;
|
||||
@@ -1390,57 +1438,22 @@ static int32_t mtmd_encode_impl(mtmd_context * ctx, const mtmd_image_tokens * im
|
||||
LOG_ERR("%s: this API does not support non-vision input, please use mtmd_encode_chunk instead\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
auto proj_type = clip_get_projector_type(ctx_clip);
|
||||
|
||||
int n_embd_out = ctx->n_embd_out();
|
||||
auto n_tokens_out = image_tokens->n_tokens();
|
||||
out_embd.resize((size_t)n_embd_out * n_tokens_out);
|
||||
|
||||
bool ok = false;
|
||||
|
||||
if (clip_is_llava(ctx_clip)
|
||||
|| proj_type == PROJECTOR_TYPE_MINICPMV
|
||||
|| proj_type == PROJECTOR_TYPE_GLM_EDGE
|
||||
|| proj_type == PROJECTOR_TYPE_INTERNVL
|
||||
|| proj_type == PROJECTOR_TYPE_DEEPSEEKOCR2
|
||||
|| proj_type == PROJECTOR_TYPE_GRANITE4_VISION) {
|
||||
// TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
|
||||
const auto & entries = image_tokens->batch_f32.entries;
|
||||
// entries may have different token counts
|
||||
// e.g., DeepSeek-OCR-2: 144 per tile views, 257 for the global view
|
||||
size_t offset = 0;
|
||||
for (size_t i = 0; i < entries.size(); i++) {
|
||||
if (entries[i].is_placeholder()) {
|
||||
LOG_ERR("%s: image tokens batch entry %zu is placeholder\n", __func__, i);
|
||||
return 1;
|
||||
}
|
||||
int n_tokens_per_image = clip_n_output_tokens(ctx_clip, &entries[i]);
|
||||
std::vector<float> tmp_embd((size_t)n_tokens_per_image * n_embd_out);
|
||||
bool ok_i = clip_image_encode(
|
||||
ctx_clip,
|
||||
ctx->n_threads,
|
||||
&entries[i],
|
||||
tmp_embd);
|
||||
if (!ok_i) {
|
||||
LOG_ERR("%s: failed to encode image %zu\n", __func__, i);
|
||||
return 1;
|
||||
}
|
||||
ok = true;
|
||||
std::copy(tmp_embd.begin(), tmp_embd.end(), out_embd.begin() + offset);
|
||||
offset += static_cast<size_t>(n_embd_out) * n_tokens_per_image;
|
||||
}
|
||||
} else {
|
||||
if (image_tokens->is_placeholder()) {
|
||||
LOG_ERR("%s: image tokens batch is placeholder\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
ok = clip_image_batch_encode(
|
||||
ctx_clip,
|
||||
ctx->n_threads,
|
||||
&image_tokens->batch_f32,
|
||||
out_embd);
|
||||
if (image_tokens->is_placeholder()) {
|
||||
LOG_ERR("%s: image tokens batch is placeholder\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
bool ok = clip_image_batch_encode(
|
||||
ctx_clip,
|
||||
ctx->n_threads,
|
||||
&image_tokens->batch_f32,
|
||||
out_embd);
|
||||
|
||||
return ok ? 0 : 1;
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,8 @@ add_library(${TARGET} STATIC
|
||||
server-context.h
|
||||
server-tools.cpp
|
||||
server-tools.h
|
||||
server-schema.cpp
|
||||
server-schema.h
|
||||
)
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "server-http.h"
|
||||
#include "server-task.h"
|
||||
#include "server-queue.h"
|
||||
#include "server-schema.h"
|
||||
|
||||
#include "build-info.h"
|
||||
#include "common.h"
|
||||
@@ -189,9 +190,10 @@ struct server_slot {
|
||||
// stats
|
||||
size_t n_sent_text = 0; // number of sent text character
|
||||
|
||||
int64_t t_print_last = 0;
|
||||
int64_t t_start_process_prompt;
|
||||
int64_t t_start_generation;
|
||||
int64_t t_print_last = 0;
|
||||
int32_t n_decoded_last = 0;
|
||||
|
||||
double t_prompt_processing = 0.0; // ms
|
||||
double t_token_generation = 0.0; // ms
|
||||
@@ -470,11 +472,13 @@ struct server_slot {
|
||||
return;
|
||||
}
|
||||
|
||||
const double n_gen_second = 1e3 / (t_token_generation) * (n_decoded);
|
||||
const double n_gen_second_win = 1e6 / (t_now - t_print_last) * (n_decoded - n_decoded_last);
|
||||
|
||||
t_print_last = t_now;
|
||||
n_decoded_last = n_decoded;
|
||||
|
||||
const double n_gen_second = 1e3 / t_token_generation * n_decoded;
|
||||
|
||||
SLT_INF(*this, "n_decoded = %6d, tg = %6.2f t/s\n", n_decoded, n_gen_second);
|
||||
SLT_INF(*this, "n_decoded = %6d, tg = %6.2f t/s, tg_3s = %6.2f t/s\n", n_decoded, n_gen_second, n_gen_second_win);
|
||||
}
|
||||
|
||||
void print_timings_pp() const {
|
||||
@@ -3038,8 +3042,8 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t t_current = ggml_time_us();
|
||||
slot.t_prompt_processing = (t_current - slot.t_start_process_prompt) / 1e3;
|
||||
const int64_t t_now = ggml_time_us();
|
||||
slot.t_prompt_processing = (t_now - slot.t_start_process_prompt) / 1e3;
|
||||
slot.print_timings_pp();
|
||||
|
||||
// truncate any tokens that are beyond n_past for this slot
|
||||
@@ -3447,17 +3451,19 @@ private:
|
||||
common_sampler_accept(slot.smpl.get(), id, true);
|
||||
|
||||
// here we have synchronized the llama_context (due to the sampling above), so we can do time measurement
|
||||
const int64_t t_current = ggml_time_us();
|
||||
const int64_t t_now = ggml_time_us();
|
||||
|
||||
slot.n_decoded += 1;
|
||||
|
||||
if (slot.n_decoded == 1) {
|
||||
slot.t_start_generation = t_current;
|
||||
slot.t_start_generation = t_now;
|
||||
slot.t_print_last = t_now;
|
||||
slot.n_decoded_last = 0;
|
||||
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
||||
metrics.on_prompt_eval(slot);
|
||||
}
|
||||
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_now - slot.t_start_generation) / 1e3;
|
||||
|
||||
completion_token_output result;
|
||||
result.tok = id;
|
||||
@@ -3551,11 +3557,11 @@ private:
|
||||
slot.spec_draft = std::move(accepted);
|
||||
}
|
||||
|
||||
const int64_t t_current = ggml_time_us();
|
||||
const int64_t t_now = ggml_time_us();
|
||||
|
||||
const auto ids = std::move(slot.spec_draft);
|
||||
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_now - slot.t_start_generation) / 1e3;
|
||||
|
||||
// update how many tokens out of those tested were accepted
|
||||
slot.n_draft_accepted += ids.size() - 1;
|
||||
@@ -3820,7 +3826,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
task.id = rd.get_new_id();
|
||||
|
||||
task.tokens = std::move(inputs[i]);
|
||||
task.params = server_task::params_from_json_cmpl(
|
||||
task.params = server_schema::eval_llama_cmpl_schema(
|
||||
ctx_server.vocab,
|
||||
params,
|
||||
meta->slot_n_ctx,
|
||||
|
||||
@@ -54,7 +54,7 @@ extern char **environ;
|
||||
|
||||
struct server_subproc {
|
||||
std::optional<subprocess_s> sproc; // empty while in DOWNLOADING state
|
||||
std::atomic<bool> stop_download{false}; // flag to signal download cancellation
|
||||
std::atomic<bool> stopped{false}; // set to cancel a download or signal child process exit
|
||||
|
||||
subprocess_s & get() {
|
||||
GGML_ASSERT(sproc.has_value() && "subprocess not initialized");
|
||||
@@ -64,6 +64,22 @@ struct server_subproc {
|
||||
bool is_alive() {
|
||||
return sproc.has_value() && subprocess_alive(&sproc.value());
|
||||
}
|
||||
|
||||
void terminate() {
|
||||
if (!sproc.has_value()) {
|
||||
return;
|
||||
}
|
||||
#if defined(_WIN32)
|
||||
if (sproc->hProcess == NULL) {
|
||||
return;
|
||||
}
|
||||
#else
|
||||
if (sproc->child <= 0) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
subprocess_terminate(&sproc.value());
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -902,50 +918,49 @@ void server_models::load(const std::string & name) {
|
||||
});
|
||||
|
||||
std::thread stopping_thread([&]() {
|
||||
// thread to monitor stopping signal OR child crash
|
||||
// thread to monitor explicit stop requests; child crash is signalled via child_proc->stopped
|
||||
auto is_stopping = [this, &name]() {
|
||||
return this->stopping_models.find(name) != this->stopping_models.end();
|
||||
};
|
||||
auto should_wake = [&]() {
|
||||
return is_stopping() || !child_proc->is_alive();
|
||||
};
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(this->mutex);
|
||||
this->cv_stop.wait(lk, should_wake);
|
||||
this->cv_stop.wait(lk, [&]() {
|
||||
return is_stopping() || child_proc->stopped.load(std::memory_order_acquire);
|
||||
});
|
||||
}
|
||||
// child may have already exited (e.g. crashed) — skip shutdown sequence
|
||||
if (!child_proc->is_alive()) {
|
||||
// child crashed or finished on its own — skip graceful shutdown sequence
|
||||
if (child_proc->stopped.load(std::memory_order_acquire)) {
|
||||
return;
|
||||
}
|
||||
SRV_INF("stopping model instance name=%s\n", name.c_str());
|
||||
// send interrupt to child process
|
||||
fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
|
||||
fflush(stdin_file);
|
||||
// wait to stop gracefully or timeout
|
||||
int64_t start_time = ggml_time_ms();
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lk(this->mutex);
|
||||
if (!is_stopping()) {
|
||||
return; // already stopped
|
||||
if (!is_stopping() || child_proc->stopped.load(std::memory_order_acquire)) {
|
||||
return;
|
||||
}
|
||||
int64_t elapsed = ggml_time_ms() - start_time;
|
||||
if (elapsed >= stop_timeout * 1000) {
|
||||
// timeout, force kill
|
||||
lk.unlock();
|
||||
SRV_WRN("force-killing model instance name=%s after %d seconds timeout\n", name.c_str(), stop_timeout);
|
||||
subprocess_terminate(&child_proc->get());
|
||||
child_proc->terminate();
|
||||
return;
|
||||
}
|
||||
this->cv_stop.wait_for(lk, std::chrono::seconds(1));
|
||||
this->cv_stop.wait_for(lk, std::chrono::seconds(1), [&]() {
|
||||
return !is_stopping() || child_proc->stopped.load(std::memory_order_acquire);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// we reach here when the child process exits
|
||||
// we reach here when the child process exits (stdout EOF)
|
||||
// note: we cannot join() prior to this point because it will close stdin_file
|
||||
if (log_thread.joinable()) {
|
||||
log_thread.join();
|
||||
}
|
||||
|
||||
// stop the timeout monitoring thread
|
||||
child_proc->stopped.store(true, std::memory_order_release);
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(this->mutex);
|
||||
stopping_models.erase(name);
|
||||
@@ -971,7 +986,7 @@ void server_models::load(const std::string & name) {
|
||||
// old process should have exited already, but just in case, we clean it up here
|
||||
if (old_instance.subproc->is_alive()) {
|
||||
SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
|
||||
subprocess_terminate(&old_instance.subproc->get()); // force kill
|
||||
old_instance.subproc->terminate(); // force kill
|
||||
}
|
||||
if (old_instance.th.joinable()) {
|
||||
old_instance.th.join();
|
||||
@@ -1039,7 +1054,7 @@ void server_models::download(common_params_model && model, common_download_opts
|
||||
dl->opts = opts; // copy
|
||||
|
||||
dl->should_stop = [sp = inst.subproc]() {
|
||||
return sp->stop_download.load(std::memory_order_relaxed);
|
||||
return sp->stopped.load(std::memory_order_relaxed);
|
||||
};
|
||||
|
||||
dl->on_progress = [this, name](const common_download_progress & p) {
|
||||
@@ -1069,7 +1084,7 @@ void server_models::unload(const std::string & name) {
|
||||
if (it != mapping.end()) {
|
||||
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
|
||||
SRV_INF("cancelling download for model name=%s\n", name.c_str());
|
||||
it->second.subproc->stop_download.store(true, std::memory_order_relaxed);
|
||||
it->second.subproc->stopped.store(true, std::memory_order_relaxed);
|
||||
// for convenience, we wait the status change here
|
||||
wait(lk, name, [](const server_model_meta & new_meta) {
|
||||
return new_meta.status != SERVER_MODEL_STATUS_DOWNLOADING;
|
||||
@@ -1080,7 +1095,7 @@ void server_models::unload(const std::string & name) {
|
||||
if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) {
|
||||
// special case: if model is in loading state, unloading means force-killing it
|
||||
SRV_WRN("model name=%s is still loading, force-killing\n", name.c_str());
|
||||
subprocess_terminate(&it->second.subproc->get());
|
||||
it->second.subproc->terminate();
|
||||
}
|
||||
cv_stop.notify_all();
|
||||
// status change will be handled by the managing thread
|
||||
@@ -1097,7 +1112,7 @@ void server_models::unload_all() {
|
||||
for (auto & [name, inst] : mapping) {
|
||||
if (inst.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
|
||||
SRV_INF("cancelling download for model name=%s\n", name.c_str());
|
||||
inst.subproc->stop_download.store(true, std::memory_order_relaxed);
|
||||
inst.subproc->stopped.store(true, std::memory_order_relaxed);
|
||||
} else if (inst.meta.is_running()) {
|
||||
SRV_INF("stopping model instance name=%s\n", name.c_str());
|
||||
stopping_models.insert(name);
|
||||
|
||||
@@ -0,0 +1,635 @@
|
||||
#include "server-schema.h"
|
||||
|
||||
#include "json-schema-to-grammar.h"
|
||||
|
||||
namespace server_schema {
|
||||
|
||||
//
|
||||
// llama.cpp-specific completion schema
|
||||
//
|
||||
|
||||
std::vector<std::unique_ptr<field>> make_llama_cmpl_schema(const common_params & params_base, task_params & params) {
|
||||
std::vector<std::unique_ptr<field>> fields;
|
||||
auto add = [&](field * f) {
|
||||
fields.emplace_back(f);
|
||||
};
|
||||
|
||||
add((new field_bool("timings_per_token", params.timings_per_token))
|
||||
->set_desc("Include prompt processing and text generation speed information in each response"));
|
||||
|
||||
add((new field_bool("stream", params.stream))
|
||||
->set_desc("Allows receiving each predicted token in real-time instead of waiting for the completion to finish"));
|
||||
|
||||
add((new field_nested("stream_options"))
|
||||
->add_subfield((new field_bool("include_usage", params.include_usage))
|
||||
->set_desc("Whether to include usage information in the stream"))
|
||||
->set_desc("Additional options for streaming responses"));
|
||||
|
||||
add((new field_bool("cache_prompt", params.cache_prompt))
|
||||
->set_desc("Re-use KV cache from a previous request if possible. This way the common prefix does not have to be re-processed, only the suffix that differs between the requests"));
|
||||
|
||||
add((new field_bool("return_tokens", params.return_tokens))
|
||||
->set_desc("Return the raw generated token ids in the `tokens` field"));
|
||||
|
||||
add((new field_bool("return_progress", params.return_progress))
|
||||
->set_desc("Include prompt processing progress events in stream mode"));
|
||||
|
||||
add((new field_num("n_predict", params.n_predict))
|
||||
->set_hard_limits(-1, INT32_MAX)
|
||||
->add_alias("max_completion_tokens")
|
||||
->add_alias("max_tokens")
|
||||
->set_desc("Set the maximum number of tokens to predict. When 0, no tokens will be generated but the prompt is evaluated into the cache"));
|
||||
|
||||
add((new field_num("n_indent", params.n_indent))
|
||||
->set_hard_limits(0, INT32_MAX)
|
||||
->set_desc("Specify the minimum line indentation for the generated text in number of whitespace characters. Useful for code completion tasks"));
|
||||
|
||||
add((new field_num("n_keep", params.n_keep))
|
||||
->set_hard_limits(-1, INT32_MAX)
|
||||
->set_desc("Specify the number of tokens from the initial prompt to retain when context size is exceeded. Use -1 to retain all tokens from the prompt"));
|
||||
|
||||
add((new field_num("n_discard", params.n_discard))
|
||||
->set_hard_limits(0, INT32_MAX)
|
||||
->set_desc("Number of tokens after n_keep that may be discarded when shifting context (0 = half context)"));
|
||||
|
||||
add((new field_num("n_cmpl", params.n_cmpl))
|
||||
->set_hard_limits(1, params_base.n_parallel)
|
||||
->add_alias("n") // alias "n" as fallback (OpenAI completions API)
|
||||
->set_desc("Number of completions to generate. If the input has multiple prompts, total outputs will be N prompts times n_cmpl"));
|
||||
|
||||
add((new field_num("n_cache_reuse", params.n_cache_reuse))
|
||||
->set_hard_limits(0, INT32_MAX)
|
||||
->set_desc("Min chunk size to attempt reusing from the cache via KV shifting. See --cache-reuse arg"));
|
||||
|
||||
// TODO: implement t_max_prompt_ms
|
||||
// add((new field_num("t_max_prompt_ms", params.t_max_prompt_ms))
|
||||
|
||||
add((new field_num("t_max_predict_ms", params.t_max_predict_ms))
|
||||
->set_hard_limits(-1, std::numeric_limits<int64_t>::max())
|
||||
->set_desc("Set a time limit in milliseconds for the prediction phase. The timeout triggers if generation exceeds this time (measured since the first token) and a newline has been generated. Useful for FIM applications"));
|
||||
|
||||
add((new field_json("response_fields"))
|
||||
->set_desc("A list of response fields to return. Missing fields are omitted without error. Fields with a slash are unnested (e.g. generation_settings/n_predict moves n_predict to the root)")
|
||||
->set_handler([&](field_eval_context & ctx, const json & data) {
|
||||
ctx.params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
|
||||
}));
|
||||
|
||||
|
||||
//
|
||||
// Sampling params
|
||||
//
|
||||
|
||||
add((new field_num("top_k", params.sampling.top_k))
|
||||
->set_limits(0, INT32_MAX)
|
||||
->set_desc("Limit the next token selection to the K most probable tokens (0 = disabled)"));
|
||||
|
||||
add((new field_num("top_p", params.sampling.top_p))
|
||||
->set_limits(0.0f, 1.0f)
|
||||
->set_desc("Limit the next token selection to a subset of tokens with cumulative probability above threshold P (1.0 = disabled)"));
|
||||
|
||||
add((new field_num("min_p", params.sampling.min_p))
|
||||
->set_limits(0.0f, 1.0f)
|
||||
->set_desc("The minimum probability for a token to be considered, relative to the probability of the most likely token (0 = disabled)"));
|
||||
|
||||
add((new field_num("top_n_sigma", params.sampling.top_n_sigma))
|
||||
->set_desc("Keep tokens within n standard deviations of the top token logit (< 0 = disabled)"));
|
||||
|
||||
add((new field_num("xtc_probability", params.sampling.xtc_probability))
|
||||
->set_limits(0.0f, 1.0f)
|
||||
->set_desc("Set the chance for token removal via XTC sampler (0 = disabled)"));
|
||||
|
||||
add((new field_num("xtc_threshold", params.sampling.xtc_threshold))
|
||||
->set_limits(0.0f, 1.0f)
|
||||
->set_desc("Set a minimum probability threshold for tokens to be removed via XTC sampler (> 0.5 disables XTC)"));
|
||||
|
||||
add((new field_num("typical_p", params.sampling.typ_p))
|
||||
// ->set_limits(0.0f, 1.0f) // what's the valid range?
|
||||
->set_desc("Enable locally typical sampling with parameter p (1.0 = disabled)"));
|
||||
|
||||
add((new field_num("temperature", params.sampling.temp))
|
||||
->set_limits(0.0f, std::numeric_limits<float>::infinity())
|
||||
->set_desc("Adjust the randomness of the generated text (0 = greedy)"));
|
||||
|
||||
add((new field_num("dynatemp_range", params.sampling.dynatemp_range))
|
||||
->set_desc("Dynamic temperature range. The final temperature will be in [temperature - range, temperature + range] (0 = disabled)"));
|
||||
|
||||
add((new field_num("dynatemp_exponent", params.sampling.dynatemp_exponent))
|
||||
->set_desc("Dynamic temperature exponent, controls how entropy maps to temperature"));
|
||||
|
||||
add((new field_num("repeat_last_n", params.sampling.penalty_last_n))
|
||||
->set_hard_limits(-1, INT32_MAX)
|
||||
->set_desc("Last n tokens to consider for penalizing repetition (0 = disabled, -1 = ctx-size)"));
|
||||
|
||||
add((new field_num("repeat_penalty", params.sampling.penalty_repeat))
|
||||
->set_desc("Control the repetition of token sequences in the generated text (1.0 = disabled)"));
|
||||
|
||||
add((new field_num("frequency_penalty", params.sampling.penalty_freq))
|
||||
->set_desc("Repeat alpha frequency penalty (0 = disabled)"));
|
||||
|
||||
add((new field_num("presence_penalty", params.sampling.penalty_present))
|
||||
->set_desc("Repeat alpha presence penalty (0 = disabled)"));
|
||||
|
||||
add((new field_num("dry_multiplier", params.sampling.dry_multiplier))
|
||||
->set_desc("Set the DRY (Don't Repeat Yourself) repetition penalty multiplier (0 = disabled)"));
|
||||
|
||||
add((new field_num("dry_base", params.sampling.dry_base))
|
||||
->set_desc("Set the DRY repetition penalty base value (must be >= 1.0, any values < 1.0 will be replaced with the default value)")
|
||||
->set_handler([&](field_eval_context & ctx, const json & data) {
|
||||
float v = data.at("dry_base").get<float>();
|
||||
ctx.params.sampling.dry_base = (v < 1.0f) ? params_base.sampling.dry_base : v;
|
||||
}));
|
||||
|
||||
add((new field_num("dry_allowed_length", params.sampling.dry_allowed_length))
|
||||
->set_hard_limits(0, INT32_MAX)
|
||||
->set_desc("Tokens that extend repetition beyond this length receive exponentially increasing penalty: multiplier * base ^ (sequence_length - allowed_length)"));
|
||||
|
||||
add((new field_num("dry_penalty_last_n", params.sampling.dry_penalty_last_n))
|
||||
->set_hard_limits(-1, INT32_MAX)
|
||||
->set_desc("How many tokens to scan for repetitions (0 = disabled, -1 = context size)"));
|
||||
|
||||
add((new field_num("mirostat", params.sampling.mirostat))
|
||||
->set_limits(0, 2)
|
||||
->set_desc("Enable Mirostat sampling, controlling perplexity during text generation (0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"));
|
||||
|
||||
add((new field_num("mirostat_tau", params.sampling.mirostat_tau))
|
||||
->set_desc("Set the Mirostat target entropy, parameter tau"));
|
||||
|
||||
add((new field_num("mirostat_eta", params.sampling.mirostat_eta))
|
||||
->set_desc("Set the Mirostat learning rate, parameter eta"));
|
||||
|
||||
add((new field_num("adaptive_target", params.sampling.adaptive_target))
|
||||
->set_limits(-std::numeric_limits<float>::max(), 1.0f)
|
||||
->set_desc("Adaptive sampling target entropy (valid range 0.0 to 1.0; negative = disabled)"));
|
||||
|
||||
add((new field_num("adaptive_decay", params.sampling.adaptive_decay))
|
||||
->set_hard_limits(0.0f, 0.99f)
|
||||
->set_desc("EMA decay for adaptive sampling; history approximates 1/(1-decay) tokens"));
|
||||
|
||||
// seed is uint32_t; field_num uses int32_t so use a handler
|
||||
add((new field_num("seed", params.sampling.seed))
|
||||
->set_desc("Set the random number generator (RNG) seed (-1 = random)"));
|
||||
|
||||
add((new field_num("n_probs", params.sampling.n_probs))
|
||||
->add_alias("logprobs") // use "logprobs" if "n_probs" wasn't provided
|
||||
->set_desc("If greater than 0, output the probabilities of top N tokens for each generated token"));
|
||||
|
||||
add((new field_num("min_keep", params.sampling.min_keep))
|
||||
->set_hard_limits(0, INT32_MAX)
|
||||
->set_desc("If greater than 0, force samplers to return at least N possible tokens"));
|
||||
|
||||
add((new field_bool("backend_sampling", params.sampling.backend_sampling))
|
||||
->set_desc("Use backend sampling instead of llama.cpp sampling"));
|
||||
|
||||
add((new field_bool("post_sampling_probs", params.post_sampling_probs))
|
||||
->set_desc("Return probabilities of top n_probs tokens after applying the sampling chain"));
|
||||
|
||||
//
|
||||
// Speculative decoding params
|
||||
//
|
||||
|
||||
// TODO: to keep things simple, we disable speculative parameter adjustments for now
|
||||
#if 0
|
||||
// TODO: for now, be able to adjust only the draft-model based speculative parameters
|
||||
add((new field_num("speculative.n_max", params.speculative.draft.n_max))
|
||||
->set_hard_limits(0, INT32_MAX)
|
||||
->set_desc("Maximum number of tokens to draft during speculative decoding"));
|
||||
|
||||
add((new field_num("speculative.n_min", params.speculative.draft.n_min))
|
||||
->set_hard_limits(0, INT32_MAX)
|
||||
->set_desc("Minimum number of draft tokens to use for speculative decoding");
|
||||
|
||||
add((new field_num("speculative.p_min", params.speculative.draft.p_min))
|
||||
->set_hard_limits(0.0f, 1.0f)
|
||||
->set_desc("Minimum speculative decoding probability for draft tokens (0 = greedy)"));
|
||||
|
||||
add((new field_str("speculative.type"))
|
||||
->set_desc("Speculative decoding method (for debugging and research purposes)")
|
||||
->set_handler([&](field_eval_context & ctx, const json & data) {
|
||||
ctx.params.speculative.types = { common_speculative_type_from_name(data.at("speculative.type").get<std::string>()) };
|
||||
}));
|
||||
|
||||
add((new field_num("speculative.ngram_size_n", params.speculative.ngram_simple.size_n))
|
||||
->set_desc("Ngram size for lookup in ngram-based speculative decoding"));
|
||||
|
||||
add((new field_num("speculative.ngram_size_m", params.speculative.ngram_simple.size_m))
|
||||
->set_desc("Mgram size for speculative tokens in ngram-based speculative decoding"));
|
||||
|
||||
add((new field_num("speculative.ngram_min_hits", params.speculative.ngram_simple.min_hits))
|
||||
->set_desc("Minimum hits at ngram lookup for mgram to be proposed"));
|
||||
#endif
|
||||
|
||||
add((new field_json("lora"))
|
||||
->set_desc("A list of LoRA adapters to apply to this request. Each entry must have `id` and `scale` fields. Adapters not listed default to scale 0.0")
|
||||
->set_handler([&](field_eval_context & ctx, const json & data) {
|
||||
const auto & lora = data.at("lora");
|
||||
if (!lora.is_array()) {
|
||||
throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
|
||||
}
|
||||
ctx.params.lora = parse_lora_request(lora);
|
||||
}));
|
||||
|
||||
// sequence breakers for DRY
|
||||
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
|
||||
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
|
||||
add((new field_json("dry_sequence_breakers"))
|
||||
->set_desc("Specify an array of sequence breakers for DRY sampling. Only a JSON array of strings is accepted")
|
||||
->set_handler([&](field_eval_context & ctx, const json & data) {
|
||||
ctx.params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
|
||||
if (ctx.params.sampling.dry_sequence_breakers.empty()) {
|
||||
throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings");
|
||||
}
|
||||
}));
|
||||
|
||||
// handle both "json_schema" and "grammar"
|
||||
add((new field_json("json_schema"))
|
||||
->add_alias("grammar")
|
||||
->set_desc("Set a JSON schema (json_schema) or GBNF grammar string (grammar) for constrained generation. json_schema takes precedence if both are provided")
|
||||
->set_handler([&](field_eval_context & ctx, const json & data) {
|
||||
auto & params = ctx.params;
|
||||
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||
try {
|
||||
auto schema = json_value(data, "json_schema", json::object());
|
||||
SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str());
|
||||
std::string grammar_str = json_schema_to_grammar(schema);
|
||||
SRV_DBG("Converted grammar: %s\n", grammar_str.c_str());
|
||||
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, std::move(grammar_str)};
|
||||
} catch (const std::exception & e) {
|
||||
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
||||
}
|
||||
} else {
|
||||
std::string grammar_str = json_value(data, "grammar", std::string());
|
||||
if (!grammar_str.empty()) {
|
||||
// grammar_type key is set by the server when converting chat template grammars
|
||||
std::string grammar_type = json_value(data, "grammar_type", std::string());
|
||||
if (grammar_type == "tool_calls") {
|
||||
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_TOOL_CALLS, std::move(grammar_str)};
|
||||
} else {
|
||||
// explicit grammar from the user (API field "grammar")
|
||||
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, std::move(grammar_str)};
|
||||
}
|
||||
SRV_DBG("Grammar (%s): %s\n", grammar_type.c_str(), common_grammar_value(params.sampling.grammar).c_str());
|
||||
}
|
||||
}
|
||||
}));
|
||||
|
||||
add((new field_bool("grammar_lazy", params.sampling.grammar_lazy))
|
||||
->set_desc("Whether to apply grammar constraints lazily, only when triggered (instead of at every step)"));
|
||||
|
||||
//
|
||||
// Chat parser params
|
||||
//
|
||||
|
||||
// TODO: change this to string field instead
|
||||
add((new field_json("chat_format"))
|
||||
->set_desc("Chat format used internally by the server")
|
||||
->set_handler([&](field_eval_context & ctx, const json & data) {
|
||||
ctx.params.chat_parser_params.format = static_cast<common_chat_format>(data.at("chat_format").get<int>());
|
||||
SRV_INF("Chat format: %s\n", common_chat_format_name(ctx.params.chat_parser_params.format));
|
||||
}));
|
||||
|
||||
add((new field_str("reasoning_format"))
|
||||
->set_desc("Reasoning format for chain-of-thought models")
|
||||
->set_handler([&](field_eval_context & ctx, const json & data) {
|
||||
auto reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get<std::string>());
|
||||
ctx.params.chat_parser_params.reasoning_format = reasoning_format;
|
||||
ctx.params.chat_parser_params.reasoning_in_content = ctx.params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
|
||||
}));
|
||||
|
||||
add((new field_str("generation_prompt"))
|
||||
->set_desc("Generation prompt appended to the chat template output")
|
||||
->set_handler([&](field_eval_context & ctx, const json & data) {
|
||||
std::string s = data.at("generation_prompt").get<std::string>();
|
||||
ctx.params.chat_parser_params.generation_prompt = s;
|
||||
ctx.params.sampling.generation_prompt = s;
|
||||
}));
|
||||
|
||||
add((new field_bool("parse_tool_calls", params.chat_parser_params.parse_tool_calls))
|
||||
->set_desc("Whether to parse tool calls from the generated output"));
|
||||
|
||||
add((new field_str("chat_parser"))
|
||||
->set_desc("Chat parser configuration string")
|
||||
->set_handler([&](field_eval_context & ctx, const json & data) {
|
||||
ctx.params.chat_parser_params.parser.load(data.at("chat_parser").get<std::string>());
|
||||
}));
|
||||
|
||||
add((new field_json("continue_final_message"))
|
||||
->set_desc("Whether to continue the final message of the chat template")
|
||||
->set_handler([&](field_eval_context & ctx, const json & data) {
|
||||
auto continuation = common_chat_continuation_parse(data.at("continue_final_message"));
|
||||
ctx.params.chat_parser_params.is_continuation = continuation != COMMON_CHAT_CONTINUATION_NONE;
|
||||
}));
|
||||
|
||||
add((new field_bool("echo", params.chat_parser_params.echo))
|
||||
->set_desc("Whether to echo the input tokens in the output"));
|
||||
|
||||
//
|
||||
// Token-level fields (require vocab)
|
||||
//
|
||||
|
||||
add((new field_json("preserved_tokens"))
|
||||
->set_desc("List of token strings that must not be split during tokenization")
|
||||
->set_handler([&](field_eval_context & ctx, const json & data) {
|
||||
GGML_ASSERT(ctx.vocab != nullptr);
|
||||
for (const auto & t : data.at("preserved_tokens")) {
|
||||
auto ids = common_tokenize(ctx.vocab, t.get<std::string>(), false, true);
|
||||
if (ids.size() == 1) {
|
||||
ctx.params.sampling.preserved_tokens.insert(ids[0]);
|
||||
}
|
||||
}
|
||||
}));
|
||||
|
||||
add((new field_json("grammar_triggers"))
|
||||
->set_desc("List of strings or patterns that trigger grammar-constrained generation")
|
||||
->set_handler([&](field_eval_context & ctx, const json & data) {
|
||||
GGML_ASSERT(ctx.vocab != nullptr);
|
||||
for (const auto & t : data.at("grammar_triggers")) {
|
||||
server_grammar_trigger ct(t);
|
||||
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
|
||||
const auto & word = ct.value.value;
|
||||
auto ids = common_tokenize(ctx.vocab, word, false, true);
|
||||
if (ids.size() == 1) {
|
||||
auto token = ids[0];
|
||||
if (std::find(ctx.params.sampling.preserved_tokens.begin(), ctx.params.sampling.preserved_tokens.end(), (llama_token) token) == ctx.params.sampling.preserved_tokens.end()) {
|
||||
throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
|
||||
}
|
||||
common_grammar_trigger trigger;
|
||||
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
|
||||
trigger.value = word;
|
||||
trigger.token = token;
|
||||
ctx.params.sampling.grammar_triggers.push_back(std::move(trigger));
|
||||
} else {
|
||||
ctx.params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
|
||||
}
|
||||
} else {
|
||||
ctx.params.sampling.grammar_triggers.emplace_back(std::move(ct.value));
|
||||
}
|
||||
}
|
||||
if (ctx.params.sampling.grammar_lazy && ctx.params.sampling.grammar_triggers.empty()) {
|
||||
throw std::runtime_error("Error: no triggers set for lazy grammar!");
|
||||
}
|
||||
}));
|
||||
|
||||
add((new field_bool("reasoning_control", params.sampling.reasoning_control))
|
||||
->set_desc("Create the budget sampler on demand so reasoning can be ended at runtime"));
|
||||
|
||||
add((new field_num("reasoning_budget_tokens", params.sampling.reasoning_budget_tokens))
|
||||
->set_hard_limits(-1, INT32_MAX)
|
||||
->set_desc("Number of tokens in the reasoning budget (-1 = disabled)"));
|
||||
|
||||
add((new field_str("reasoning_budget_start_tag"))
|
||||
->set_desc("Token string marking the start of the reasoning budget section")
|
||||
->set_handler([&](field_eval_context & ctx, const json & data) {
|
||||
GGML_ASSERT(ctx.vocab != nullptr);
|
||||
ctx.params.sampling.reasoning_budget_start = common_tokenize(ctx.vocab, data.at("reasoning_budget_start_tag").get<std::string>(), false, true);
|
||||
}));
|
||||
|
||||
add((new field_str("reasoning_budget_end_tag"))
|
||||
->set_desc("Token string marking the end of the reasoning budget section")
|
||||
->set_handler([&](field_eval_context & ctx, const json & data) {
|
||||
GGML_ASSERT(ctx.vocab != nullptr);
|
||||
std::string end_tag = data.at("reasoning_budget_end_tag").get<std::string>();
|
||||
ctx.params.sampling.reasoning_budget_end = common_tokenize(ctx.vocab, end_tag, false, true);
|
||||
}));
|
||||
|
||||
add((new field_str("reasoning_budget_message"))
|
||||
->set_desc("Message to prepend to the reasoning budget end tag when forcing it")
|
||||
->set_handler([&](field_eval_context & ctx, const json & data) {
|
||||
GGML_ASSERT(ctx.vocab != nullptr);
|
||||
std::string end_tag = json_value(data, "reasoning_budget_end_tag", std::string());
|
||||
std::string message = data.at("reasoning_budget_message").get<std::string>();
|
||||
ctx.params.sampling.reasoning_budget_forced = common_tokenize(ctx.vocab, message + end_tag, false, true);
|
||||
}));
|
||||
|
||||
add((new field_json("logit_bias"))
|
||||
->set_desc("Modify the likelihood of specific tokens. Accepts an array of [token, bias] pairs or an object mapping token to bias. Use false as bias to ban a token")
|
||||
->set_handler([&](field_eval_context & ctx, const json & data) {
|
||||
GGML_ASSERT(ctx.vocab != nullptr);
|
||||
ctx.params.sampling.logit_bias.clear();
|
||||
const auto & logit_bias = data.at("logit_bias");
|
||||
const int n_vocab = llama_vocab_n_tokens(ctx.vocab);
|
||||
auto parse_bias = [](const json & v, float & bias) -> bool {
|
||||
if (v.is_number()) { bias = v.get<float>(); return true; }
|
||||
if (v.is_boolean() && !v.get<bool>()) { bias = -INFINITY; return true; }
|
||||
return false;
|
||||
};
|
||||
if (logit_bias.is_array()) {
|
||||
for (const auto & el : logit_bias) {
|
||||
if (!el.is_array() || el.size() != 2) continue;
|
||||
float bias;
|
||||
if (!parse_bias(el[1], bias)) continue;
|
||||
if (el[0].is_number_integer()) {
|
||||
llama_token tok = el[0].get<llama_token>();
|
||||
if (tok >= 0 && tok < n_vocab) ctx.params.sampling.logit_bias.push_back({tok, bias});
|
||||
} else if (el[0].is_string()) {
|
||||
for (auto tok : common_tokenize(ctx.vocab, el[0].get<std::string>(), false))
|
||||
ctx.params.sampling.logit_bias.push_back({tok, bias});
|
||||
}
|
||||
}
|
||||
} else if (logit_bias.is_object()) {
|
||||
for (const auto & el : logit_bias.items()) {
|
||||
float bias;
|
||||
if (!parse_bias(el.value(), bias)) continue;
|
||||
char * end;
|
||||
llama_token tok = strtol(el.key().c_str(), &end, 10);
|
||||
if (*end == 0) {
|
||||
if (tok >= 0 && tok < n_vocab) ctx.params.sampling.logit_bias.push_back({tok, bias});
|
||||
} else {
|
||||
for (auto t : common_tokenize(ctx.vocab, el.key(), false))
|
||||
ctx.params.sampling.logit_bias.push_back({t, bias});
|
||||
}
|
||||
}
|
||||
}
|
||||
}));
|
||||
|
||||
add((new field_bool("ignore_eos", params.sampling.ignore_eos))
|
||||
->set_desc("Ignore the end-of-sequence token and continue generating")
|
||||
->set_handler([&](field_eval_context & ctx, const json & data) {
|
||||
GGML_ASSERT(ctx.logit_bias_eog != nullptr);
|
||||
ctx.params.sampling.ignore_eos = data.at("ignore_eos").get<bool>();
|
||||
if (ctx.params.sampling.ignore_eos && ctx.logit_bias_eog) {
|
||||
ctx.params.sampling.logit_bias.insert(
|
||||
ctx.params.sampling.logit_bias.end(),
|
||||
ctx.logit_bias_eog->begin(), ctx.logit_bias_eog->end());
|
||||
}
|
||||
}));
|
||||
|
||||
add((new field_json("stop"))
|
||||
->set_desc("Specify stopping strings. Generation stops when one is produced, and the string is not included in the output")
|
||||
->set_handler([&](field_eval_context & ctx, const json & data) {
|
||||
ctx.params.antiprompt.clear();
|
||||
const auto & stop = data.at("stop");
|
||||
if (stop.is_array()) {
|
||||
for (const auto & word : stop) {
|
||||
if (!word.empty()) ctx.params.antiprompt.push_back(word);
|
||||
}
|
||||
} else if (stop.is_string()) {
|
||||
ctx.params.antiprompt.push_back(stop.get<std::string>());
|
||||
}
|
||||
// fall back to CLI defaults if the request provided no effective stop strings
|
||||
if (ctx.params.antiprompt.empty()) {
|
||||
ctx.params.antiprompt = params_base.antiprompt;
|
||||
}
|
||||
}));
|
||||
|
||||
add((new field_json("samplers"))
|
||||
->set_desc("The order in which samplers are applied. An array of sampler type names, or a single string of sampler chars")
|
||||
->set_handler([&](field_eval_context & ctx, const json & data) {
|
||||
const auto & samplers = data.at("samplers");
|
||||
if (samplers.is_array()) {
|
||||
ctx.params.sampling.samplers = common_sampler_types_from_names(samplers);
|
||||
} else if (samplers.is_string()) {
|
||||
ctx.params.sampling.samplers = common_sampler_types_from_chars(samplers.get<std::string>());
|
||||
}
|
||||
}));
|
||||
|
||||
return fields;
|
||||
}
|
||||
|
||||
task_params eval_llama_cmpl_schema(
|
||||
const llama_vocab * vocab,
|
||||
const common_params & params_base,
|
||||
const int n_ctx_slot,
|
||||
const std::vector<llama_logit_bias> & logit_bias_eog,
|
||||
const json & data) {
|
||||
task_params params;
|
||||
|
||||
// Sampling parameter defaults are loaded from the global server context (but individual requests can still them)
|
||||
params.sampling = params_base.sampling;
|
||||
params.speculative = params_base.speculative;
|
||||
params.n_keep = params_base.n_keep;
|
||||
params.n_predict = params_base.n_predict;
|
||||
params.n_cache_reuse = params_base.n_cache_reuse;
|
||||
params.cache_prompt = params_base.cache_prompt;
|
||||
params.antiprompt = params_base.antiprompt;
|
||||
|
||||
// enabling this will output extra debug information in the HTTP responses from the server
|
||||
params.verbose = params_base.verbosity > 9;
|
||||
|
||||
params.chat_parser_params.reasoning_format = params_base.reasoning_format;
|
||||
|
||||
// create context and schema
|
||||
field_eval_context ctx(params);
|
||||
ctx.vocab = vocab;
|
||||
ctx.logit_bias_eog = &logit_bias_eog;
|
||||
|
||||
auto schema = make_llama_cmpl_schema(params_base, params);
|
||||
|
||||
// eval all fields in the schema
|
||||
for (const auto & f : schema) {
|
||||
f->eval(ctx, data);
|
||||
}
|
||||
|
||||
// post-processing
|
||||
{
|
||||
if (params.sampling.penalty_last_n == -1) {
|
||||
// note: should be the slot's context and not the full context, but it's ok
|
||||
params.sampling.penalty_last_n = n_ctx_slot;
|
||||
}
|
||||
|
||||
if (params.sampling.dry_penalty_last_n == -1) {
|
||||
params.sampling.dry_penalty_last_n = n_ctx_slot;
|
||||
}
|
||||
|
||||
// if "reasoning_format" is not provided, its handler will not be called, we will need to handle it here
|
||||
auto reasoning_format = params.chat_parser_params.reasoning_format;
|
||||
params.chat_parser_params.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
|
||||
}
|
||||
|
||||
// debugging
|
||||
{
|
||||
auto budget = params.sampling.reasoning_budget_tokens;
|
||||
SRV_DBG("reasoning budget: tokens=%d, generation_prompt='%s', start=%zu toks, end=%zu toks, forced=%zu toks\n",
|
||||
budget, params.sampling.generation_prompt.c_str(),
|
||||
params.sampling.reasoning_budget_start.size(),
|
||||
params.sampling.reasoning_budget_end.size(),
|
||||
params.sampling.reasoning_budget_forced.size());
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
//
|
||||
// eval() implementations
|
||||
//
|
||||
|
||||
static void handle_with_catch(const char * name, std::function<void()> func) {
|
||||
try {
|
||||
func();
|
||||
} catch (const std::exception & e) {
|
||||
throw std::invalid_argument(string_format("Field '%s': %s", name, e.what()));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void field_num<T>::eval(field_eval_context & ctx, const json & data) {
|
||||
for (const auto & n : name) {
|
||||
if (data.contains(n)) {
|
||||
handle_with_catch(n, [&]() {
|
||||
if (custom_handler) {
|
||||
custom_handler(ctx, data);
|
||||
} else if (!is_hard_limit) {
|
||||
val = std::max(min, std::min(max, data.at(n).template get<T>()));
|
||||
} else {
|
||||
T tmp = data.at(n).template get<T>();
|
||||
if (tmp < min || tmp > max) {
|
||||
throw std::invalid_argument(std::string("Value must be between ") + std::to_string(min) + " <= value <= " + std::to_string(max) + ", but got " + std::to_string(tmp));
|
||||
}
|
||||
val = tmp;
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void field_str::eval(field_eval_context & ctx, const json & data) {
|
||||
GGML_ASSERT(custom_handler);
|
||||
for (const auto & n : name) {
|
||||
if (data.contains(n)) {
|
||||
handle_with_catch(n, [&]() {
|
||||
custom_handler(ctx, data);
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void field_bool::eval(field_eval_context & ctx, const json & data) {
|
||||
for (const auto & n : name) {
|
||||
if (data.contains(n)) {
|
||||
handle_with_catch(n, [&]() {
|
||||
if (custom_handler) {
|
||||
custom_handler(ctx, data);
|
||||
} else {
|
||||
val = data.at(n).get<bool>();
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void field_json::eval(field_eval_context & ctx, const json & data) {
|
||||
GGML_ASSERT(custom_handler);
|
||||
for (const auto & n : name) {
|
||||
if (data.contains(n)) {
|
||||
handle_with_catch(n, [&]() {
|
||||
custom_handler(ctx, data);
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void field_nested::eval(field_eval_context & ctx, const json & data) {
|
||||
for (const auto & n : name) {
|
||||
if (data.contains(n) && data.at(n).is_object()) {
|
||||
for (auto & f : subfields) {
|
||||
f->eval(ctx, data.at(n));
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace server_schema
|
||||
@@ -0,0 +1,105 @@
|
||||
#pragma once
|
||||
|
||||
#include "server-common.h"
|
||||
#include "server-task.h"
|
||||
|
||||
#include "sampling.h"
|
||||
#include "speculative.h"
|
||||
|
||||
#include <climits>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace server_schema {
|
||||
|
||||
struct field_eval_context {
|
||||
task_params & params;
|
||||
const llama_vocab * vocab = nullptr;
|
||||
const std::vector<llama_logit_bias> * logit_bias_eog = nullptr;
|
||||
field_eval_context(task_params & params) : params(params) {}
|
||||
};
|
||||
|
||||
using field_handler = std::function<void(field_eval_context &, const json &)>;
|
||||
|
||||
struct field {
|
||||
std::vector<const char *> name;
|
||||
const char * desc = "";
|
||||
field_handler custom_handler;
|
||||
field() = default;
|
||||
field(const char * n) : name({n}) {}
|
||||
virtual ~field() = default;
|
||||
field * set_desc(const char * s) {
|
||||
desc = s;
|
||||
return this;
|
||||
}
|
||||
// if 'name' is present, use it, otherwise look for aliases following the order they were added
|
||||
field * add_alias(const char * n) {
|
||||
name.push_back(n);
|
||||
return this;
|
||||
}
|
||||
field * set_handler(field_handler h) { this->custom_handler = h; return this; }
|
||||
virtual void eval(field_eval_context & ctx, const json & data) = 0;
|
||||
};
|
||||
|
||||
template <typename T = int32_t>
|
||||
struct field_num : public field {
|
||||
T & val;
|
||||
T min = std::numeric_limits<T>::lowest();
|
||||
T max = std::numeric_limits<T>::max();
|
||||
bool is_hard_limit = false; // if true, throw error if the value is invalid
|
||||
field_num(const char * n, T & val) : field(n), val(val) {}
|
||||
// limits are inclusive, min <= value <= max
|
||||
field_num * set_limits(T min, T max) {
|
||||
this->min = min;
|
||||
this->max = max;
|
||||
return this;
|
||||
}
|
||||
field_num * set_hard_limits(T min, T max) {
|
||||
set_limits(min, max);
|
||||
is_hard_limit = true;
|
||||
return this;
|
||||
}
|
||||
virtual void eval(field_eval_context & ctx, const json & data) override;
|
||||
};
|
||||
|
||||
struct field_str : public field {
|
||||
field_str(const char * n) : field(n) {}
|
||||
virtual void eval(field_eval_context & ctx, const json & data) override;
|
||||
};
|
||||
|
||||
struct field_bool : public field {
|
||||
bool & val;
|
||||
field_bool(const char * n, bool & val) : field(n), val(val) {}
|
||||
virtual void eval(field_eval_context & ctx, const json & data) override;
|
||||
};
|
||||
|
||||
struct field_json : public field {
|
||||
field_json(const char * n) : field(n) {}
|
||||
virtual void eval(field_eval_context & ctx, const json & data) override;
|
||||
};
|
||||
|
||||
struct field_nested : public field {
|
||||
std::vector<std::unique_ptr<field>> subfields;
|
||||
field_nested(const char * n) : field(n) {}
|
||||
field_nested * add_subfield(field * f) {
|
||||
subfields.emplace_back(std::unique_ptr<field>(f));
|
||||
return this;
|
||||
}
|
||||
virtual void eval(field_eval_context & ctx, const json & data) override;
|
||||
};
|
||||
|
||||
std::vector<std::unique_ptr<field>> make_llama_cmpl_schema(
|
||||
const common_params & params_base,
|
||||
task_params & params);
|
||||
|
||||
task_params eval_llama_cmpl_schema(
|
||||
const llama_vocab * vocab,
|
||||
const common_params & params_base,
|
||||
const int n_ctx_slot,
|
||||
const std::vector<llama_logit_bias> & logit_bias_eog,
|
||||
const json & data);
|
||||
|
||||
} // namespace server_schema
|
||||
@@ -232,396 +232,8 @@ common_chat_msg task_result_state::update_chat_msg(
|
||||
return chat_msg;
|
||||
}
|
||||
|
||||
//
|
||||
// server_task
|
||||
//
|
||||
|
||||
task_params server_task::params_from_json_cmpl(
|
||||
const llama_vocab * vocab,
|
||||
const common_params & params_base,
|
||||
const int n_ctx_slot,
|
||||
const std::vector<llama_logit_bias> & logit_bias_eog,
|
||||
const json & data) {
|
||||
task_params params;
|
||||
|
||||
// Sampling parameter defaults are loaded from the global server context (but individual requests can still them)
|
||||
task_params defaults;
|
||||
defaults.sampling = params_base.sampling;
|
||||
defaults.speculative = params_base.speculative;
|
||||
defaults.n_keep = params_base.n_keep;
|
||||
defaults.n_predict = params_base.n_predict;
|
||||
defaults.n_cache_reuse = params_base.n_cache_reuse;
|
||||
defaults.cache_prompt = params_base.cache_prompt;
|
||||
defaults.antiprompt = params_base.antiprompt;
|
||||
|
||||
// enabling this will output extra debug information in the HTTP responses from the server
|
||||
params.verbose = params_base.verbosity > 9;
|
||||
params.timings_per_token = json_value(data, "timings_per_token", false);
|
||||
|
||||
params.stream = json_value(data, "stream", false);
|
||||
auto stream_opt = json_value(data, "stream_options", json::object());
|
||||
params.include_usage = json_value(stream_opt, "include_usage", false);
|
||||
params.cache_prompt = json_value(data, "cache_prompt", defaults.cache_prompt);
|
||||
params.return_tokens = json_value(data, "return_tokens", false);
|
||||
params.return_progress = json_value(data, "return_progress", false);
|
||||
auto max_tokens = json_value(data, "max_tokens", defaults.n_predict);
|
||||
params.n_predict = json_value(data, "n_predict", json_value(data, "max_completion_tokens", max_tokens));
|
||||
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
|
||||
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
|
||||
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
||||
params.n_discard = std::max(0, params.n_discard);
|
||||
params.n_cmpl = json_value(data, "n_cmpl", json_value(data, "n", 1));
|
||||
params.n_cache_reuse = json_value(data, "n_cache_reuse", defaults.n_cache_reuse);
|
||||
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
|
||||
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
||||
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
|
||||
|
||||
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
|
||||
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
|
||||
params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
|
||||
params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma);
|
||||
params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
|
||||
params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
|
||||
params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
|
||||
params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
|
||||
params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
|
||||
params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
|
||||
params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
|
||||
params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
|
||||
params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
|
||||
params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
|
||||
params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
|
||||
params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
|
||||
params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
|
||||
params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
|
||||
params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
|
||||
params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
|
||||
params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
|
||||
params.sampling.adaptive_target = json_value(data, "adaptive_target", defaults.sampling.adaptive_target);
|
||||
params.sampling.adaptive_decay = json_value(data, "adaptive_decay", defaults.sampling.adaptive_decay);
|
||||
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
|
||||
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
|
||||
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
|
||||
params.sampling.backend_sampling = json_value(data, "backend_sampling", defaults.sampling.backend_sampling);
|
||||
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
|
||||
|
||||
params.speculative = defaults.speculative;
|
||||
|
||||
// TODO: to keep things simple, we disable speculative parameter adjustments for now
|
||||
#if 0
|
||||
// TODO: for now, be able to adjust only the draft-model based speculative parameters
|
||||
params.speculative.draft.n_min = json_value(data, "speculative.n_min", defaults.speculative.draft.n_min);
|
||||
params.speculative.draft.n_max = json_value(data, "speculative.n_max", defaults.speculative.draft.n_max);
|
||||
params.speculative.draft.p_min = json_value(data, "speculative.p_min", defaults.speculative.draft.p_min);
|
||||
|
||||
params.speculative.draft.n_min = std::min(params.speculative.draft.n_max, params.speculative.draft.n_min);
|
||||
params.speculative.draft.n_min = std::max(params.speculative.draft.n_min, 0);
|
||||
params.speculative.draft.n_max = std::max(params.speculative.draft.n_max, 0);
|
||||
|
||||
// for debugging and research purposes
|
||||
params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type)));
|
||||
|
||||
params.speculative.ngram_size_n = json_value(data, "speculative.ngram_size_n", defaults.speculative.ngram_size_n);
|
||||
params.speculative.ngram_size_m = json_value(data, "speculative.ngram_size_m", defaults.speculative.ngram_size_m);
|
||||
params.speculative.ngram_min_hits = json_value(data, "speculative.ngram_m_hits", defaults.speculative.ngram_min_hits);
|
||||
|
||||
params.speculative.ngram_size_n = std::max(std::min(1, (int) params.speculative.ngram_size_n), 1024);
|
||||
params.speculative.ngram_size_m = std::max(std::min(1, (int) params.speculative.ngram_size_m), 1024);
|
||||
params.speculative.ngram_min_hits = std::max(std::min(1, (int) params.speculative.ngram_min_hits), 1024);
|
||||
#endif
|
||||
|
||||
// Use OpenAI API logprobs only if n_probs wasn't provided
|
||||
if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
|
||||
params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);
|
||||
}
|
||||
|
||||
if (data.contains("lora")) {
|
||||
if (data.at("lora").is_array()) {
|
||||
params.lora = parse_lora_request(data.at("lora"));
|
||||
} else {
|
||||
throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
|
||||
}
|
||||
} else {
|
||||
params.lora = {};
|
||||
}
|
||||
|
||||
// TODO: add more sanity checks for the input parameters
|
||||
|
||||
if (params.sampling.penalty_last_n < -1) {
|
||||
throw std::runtime_error("Error: repeat_last_n must be >= -1");
|
||||
}
|
||||
|
||||
if (params.sampling.dry_penalty_last_n < -1) {
|
||||
throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
|
||||
}
|
||||
|
||||
if (params.sampling.penalty_last_n == -1) {
|
||||
// note: should be the slot's context and not the full context, but it's ok
|
||||
params.sampling.penalty_last_n = n_ctx_slot;
|
||||
}
|
||||
|
||||
if (params.sampling.dry_penalty_last_n == -1) {
|
||||
params.sampling.dry_penalty_last_n = n_ctx_slot;
|
||||
}
|
||||
|
||||
if (params.sampling.dry_base < 1.0f) {
|
||||
params.sampling.dry_base = defaults.sampling.dry_base;
|
||||
}
|
||||
|
||||
// sequence breakers for DRY
|
||||
{
|
||||
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
|
||||
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
|
||||
|
||||
if (data.contains("dry_sequence_breakers")) {
|
||||
params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
|
||||
if (params.sampling.dry_sequence_breakers.empty()) {
|
||||
throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// process "json_schema" and "grammar"
|
||||
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||
try {
|
||||
auto schema = json_value(data, "json_schema", json::object());
|
||||
SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str());
|
||||
std::string grammar_str = json_schema_to_grammar(schema);
|
||||
SRV_DBG("Converted grammar: %s\n", grammar_str.c_str());
|
||||
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, std::move(grammar_str)};
|
||||
} catch (const std::exception & e) {
|
||||
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
||||
}
|
||||
} else {
|
||||
params.sampling.grammar = defaults.sampling.grammar;
|
||||
|
||||
std::string grammar_str = json_value(data, "grammar", std::string());
|
||||
if (!grammar_str.empty()) {
|
||||
// grammar_type key is set by the server when converting chat template grammars
|
||||
std::string grammar_type = json_value(data, "grammar_type", std::string());
|
||||
if (grammar_type == "tool_calls") {
|
||||
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_TOOL_CALLS, std::move(grammar_str)};
|
||||
} else {
|
||||
// explicit grammar from the user (API field "grammar")
|
||||
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, std::move(grammar_str)};
|
||||
}
|
||||
SRV_DBG("Grammar (%s): %s\n", grammar_type.c_str(), common_grammar_value(params.sampling.grammar).c_str());
|
||||
}
|
||||
params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
|
||||
SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
|
||||
}
|
||||
|
||||
{
|
||||
auto it = data.find("chat_format");
|
||||
if (it != data.end()) {
|
||||
params.chat_parser_params.format = static_cast<common_chat_format>(it->get<int>());
|
||||
SRV_INF("Chat format: %s\n", common_chat_format_name(params.chat_parser_params.format));
|
||||
} else {
|
||||
params.chat_parser_params.format = defaults.chat_parser_params.format;
|
||||
}
|
||||
common_reasoning_format reasoning_format = params_base.reasoning_format;
|
||||
if (data.contains("reasoning_format")) {
|
||||
reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get<std::string>());
|
||||
}
|
||||
params.chat_parser_params.reasoning_format = reasoning_format;
|
||||
params.chat_parser_params.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
|
||||
params.chat_parser_params.generation_prompt = json_value(data, "generation_prompt", std::string());
|
||||
params.sampling.generation_prompt = params.chat_parser_params.generation_prompt;
|
||||
SRV_DBG("Generation prompt: '%s'\n", params.chat_parser_params.generation_prompt.c_str());
|
||||
params.chat_parser_params.parse_tool_calls = json_value(data, "parse_tool_calls", false);
|
||||
if (data.contains("chat_parser")) {
|
||||
params.chat_parser_params.parser.load(data.at("chat_parser").get<std::string>());
|
||||
}
|
||||
if (data.contains("continue_final_message")) {
|
||||
auto continuation = common_chat_continuation_parse(data.at("continue_final_message"));
|
||||
params.chat_parser_params.is_continuation = continuation != COMMON_CHAT_CONTINUATION_NONE;
|
||||
}
|
||||
params.chat_parser_params.echo = json_value(data, "echo", false);
|
||||
}
|
||||
|
||||
{
|
||||
const auto preserved_tokens = data.find("preserved_tokens");
|
||||
if (preserved_tokens != data.end()) {
|
||||
for (const auto & t : *preserved_tokens) {
|
||||
auto ids = common_tokenize(vocab, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
SRV_DBG("Preserved token: %d\n", ids[0]);
|
||||
params.sampling.preserved_tokens.insert(ids[0]);
|
||||
} else {
|
||||
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
|
||||
SRV_DBG("Not preserved because more than 1 token: %s\n", t.get<std::string>().c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
const auto grammar_triggers = data.find("grammar_triggers");
|
||||
if (grammar_triggers != data.end()) {
|
||||
for (const auto & t : *grammar_triggers) {
|
||||
server_grammar_trigger ct(t);
|
||||
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
|
||||
const auto & word = ct.value.value;
|
||||
auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
auto token = ids[0];
|
||||
if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) {
|
||||
throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
|
||||
}
|
||||
SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
|
||||
common_grammar_trigger trigger;
|
||||
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
|
||||
trigger.value = word;
|
||||
trigger.token = token;
|
||||
params.sampling.grammar_triggers.push_back(std::move(trigger));
|
||||
} else {
|
||||
SRV_DBG("Grammar trigger word: `%s`\n", word.c_str());
|
||||
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
|
||||
}
|
||||
} else {
|
||||
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) {
|
||||
SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str());
|
||||
} else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) {
|
||||
SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str());
|
||||
} else {
|
||||
throw std::runtime_error("Unknown grammar trigger type");
|
||||
}
|
||||
params.sampling.grammar_triggers.emplace_back(std::move(ct.value));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) {
|
||||
throw std::runtime_error("Error: no triggers set for lazy grammar!");
|
||||
}
|
||||
}
|
||||
|
||||
// Parse reasoning budget sampler parameters
|
||||
{
|
||||
const int32_t budget = json_value(data, "reasoning_budget_tokens", (int32_t) -1);
|
||||
const auto start_tag = json_value(data, "reasoning_budget_start_tag", std::string());
|
||||
const auto end_tag = json_value(data, "reasoning_budget_end_tag", std::string());
|
||||
const auto message = json_value(data, "reasoning_budget_message", std::string());
|
||||
params.sampling.reasoning_budget_tokens = budget;
|
||||
params.sampling.reasoning_control = json_value(data, "reasoning_control", false);
|
||||
|
||||
if (!start_tag.empty()) {
|
||||
params.sampling.reasoning_budget_start = common_tokenize(vocab, start_tag, false, true);
|
||||
}
|
||||
if (!end_tag.empty()) {
|
||||
params.sampling.reasoning_budget_end = common_tokenize(vocab, end_tag, false, true);
|
||||
params.sampling.reasoning_budget_forced = common_tokenize(vocab, message + end_tag, false, true);
|
||||
|
||||
SRV_DBG("reasoning budget: tokens=%d, generation_prompt='%s', start=%zu toks, end=%zu toks, forced=%zu toks\n",
|
||||
budget, params.sampling.generation_prompt.c_str(),
|
||||
params.sampling.reasoning_budget_start.size(),
|
||||
params.sampling.reasoning_budget_end.size(),
|
||||
params.sampling.reasoning_budget_forced.size());
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
params.sampling.logit_bias.clear();
|
||||
|
||||
const auto & logit_bias = data.find("logit_bias");
|
||||
if (logit_bias != data.end() && logit_bias->is_array()) {
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
for (const auto & el : *logit_bias) {
|
||||
// TODO: we may want to throw errors here, in case "el" is incorrect
|
||||
if (el.is_array() && el.size() == 2) {
|
||||
float bias;
|
||||
if (el[1].is_number()) {
|
||||
bias = el[1].get<float>();
|
||||
} else if (el[1].is_boolean() && !el[1].get<bool>()) {
|
||||
bias = -INFINITY;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (el[0].is_number_integer()) {
|
||||
llama_token tok = el[0].get<llama_token>();
|
||||
if (tok >= 0 && tok < n_vocab) {
|
||||
params.sampling.logit_bias.push_back({tok, bias});
|
||||
}
|
||||
} else if (el[0].is_string()) {
|
||||
auto toks = common_tokenize(vocab, el[0].get<std::string>(), false);
|
||||
for (auto tok : toks) {
|
||||
params.sampling.logit_bias.push_back({tok, bias});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (logit_bias != data.end() && logit_bias->is_object()) {
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
for (const auto & el : logit_bias->items()) {
|
||||
float bias;
|
||||
const auto & key = el.key();
|
||||
const auto & value = el.value();
|
||||
if (value.is_number()) {
|
||||
bias = value.get<float>();
|
||||
} else if (value.is_boolean() && !value.get<bool>()) {
|
||||
bias = -INFINITY;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
char *end;
|
||||
llama_token tok = strtol(key.c_str(), &end, 10);
|
||||
if (*end == 0) {
|
||||
if (tok >= 0 && tok < n_vocab) {
|
||||
params.sampling.logit_bias.push_back({tok, bias});
|
||||
}
|
||||
} else {
|
||||
auto toks = common_tokenize(vocab, key, false);
|
||||
for (auto tok : toks) {
|
||||
params.sampling.logit_bias.push_back({tok, bias});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos);
|
||||
if (params.sampling.ignore_eos) {
|
||||
params.sampling.logit_bias.insert(
|
||||
params.sampling.logit_bias.end(),
|
||||
logit_bias_eog.begin(), logit_bias_eog.end());
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
params.antiprompt.clear();
|
||||
|
||||
const auto & stop = data.find("stop");
|
||||
if (stop != data.end() && stop->is_array()) {
|
||||
for (const auto & word : *stop) {
|
||||
if (!word.empty()) {
|
||||
params.antiprompt.push_back(word);
|
||||
}
|
||||
}
|
||||
}
|
||||
// set reverse prompt from cli args if not set in the request
|
||||
if (params.antiprompt.empty()) {
|
||||
params.antiprompt = defaults.antiprompt;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
const auto samplers = data.find("samplers");
|
||||
if (samplers != data.end()) {
|
||||
if (samplers->is_array()) {
|
||||
params.sampling.samplers = common_sampler_types_from_names(*samplers);
|
||||
} else if (samplers->is_string()){
|
||||
params.sampling.samplers = common_sampler_types_from_chars(samplers->get<std::string>());
|
||||
}
|
||||
} else {
|
||||
params.sampling.samplers = defaults.sampling.samplers;
|
||||
}
|
||||
}
|
||||
|
||||
if (params.n_cmpl > params_base.n_parallel) {
|
||||
throw std::runtime_error("n_cmpl cannot be greater than the number of slots, please increase -np");
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
//
|
||||
// result_timings
|
||||
//
|
||||
|
||||
|
||||
@@ -210,13 +210,6 @@ struct server_task {
|
||||
}
|
||||
}
|
||||
|
||||
static task_params params_from_json_cmpl(
|
||||
const llama_vocab * vocab,
|
||||
const common_params & params_base,
|
||||
const int n_ctx_slot,
|
||||
const std::vector<llama_logit_bias> & logit_bias_eog,
|
||||
const json & data);
|
||||
|
||||
// utility function
|
||||
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
|
||||
std::unordered_set<int> ids(tasks.size());
|
||||
|
||||
@@ -79,7 +79,7 @@
|
||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||
<!-- svelte-ignore a11y_click_events_have_key_events -->
|
||||
<div
|
||||
class="pointer-events-none flex items-center justify-center gap-0.75 pl-2 opacity-0 group-hover:pointer-events-auto group-hover:opacity-100"
|
||||
class="pointer-events-none flex items-center justify-center gap-0.75 pl-2 opacity-0 group-hover:pointer-events-auto group-hover:opacity-100 [@media(pointer:coarse)]:pointer-events-auto [@media(pointer:coarse)]:opacity-100"
|
||||
onclick={(e) => e.stopPropagation()}
|
||||
>
|
||||
{#if isFav}
|
||||
@@ -113,12 +113,16 @@
|
||||
</div>
|
||||
|
||||
{#if isLoading}
|
||||
<Loader2 class="h-4 w-4 animate-spin text-muted-foreground" />
|
||||
<div class="flex w-4 [@media(pointer:coarse)]:w-5 items-center justify-center">
|
||||
<Loader2 class="h-4 w-4 animate-spin text-muted-foreground" />
|
||||
</div>
|
||||
{:else if isFailed}
|
||||
<div class="flex w-4 items-center justify-center">
|
||||
<CircleAlert class="h-3.5 w-3.5 text-red-500 group-hover:hidden" />
|
||||
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
|
||||
<CircleAlert
|
||||
class="h-3.5 w-3.5 text-red-500 group-hover:hidden [@media(pointer:coarse)]:hidden"
|
||||
/>
|
||||
|
||||
<div class="hidden group-hover:flex">
|
||||
<div class="hidden group-hover:flex [@media(pointer:coarse)]:flex">
|
||||
<ActionIcon
|
||||
iconSize="h-2.5 w-2.5"
|
||||
icon={RotateCw}
|
||||
@@ -130,15 +134,17 @@
|
||||
</div>
|
||||
</div>
|
||||
{:else if isSleeping}
|
||||
<div class="flex w-4 items-center justify-center">
|
||||
<span class="h-2 w-2 rounded-full bg-orange-400 group-hover:hidden"></span>
|
||||
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
|
||||
<span
|
||||
class="h-2 w-2 rounded-full bg-orange-400 group-hover:hidden [@media(pointer:coarse)]:hidden"
|
||||
></span>
|
||||
|
||||
<div class="hidden group-hover:flex">
|
||||
<div class="hidden group-hover:flex [@media(pointer:coarse)]:flex">
|
||||
<ActionIcon
|
||||
iconSize="h-2.5 w-2.5"
|
||||
icon={PowerOff}
|
||||
tooltip="Unload model"
|
||||
class="h-3 w-3 text-red-500 hover:text-red-600"
|
||||
class="h-3 w-3 text-red-500 hover:text-red-600 [@media(pointer:coarse)]:text-amber-500 [@media(pointer:coarse)]:hover:text-amber-600"
|
||||
onclick={(e) => {
|
||||
e?.stopPropagation();
|
||||
modelsStore.unloadModel(option.model);
|
||||
@@ -147,30 +153,34 @@
|
||||
</div>
|
||||
</div>
|
||||
{:else if isLoaded}
|
||||
<div class="flex w-4 items-center justify-center">
|
||||
<span class="h-2 w-2 rounded-full bg-green-500 group-hover:hidden"></span>
|
||||
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
|
||||
<span
|
||||
class="h-2 w-2 rounded-full bg-green-500 group-hover:hidden [@media(pointer:coarse)]:hidden"
|
||||
></span>
|
||||
|
||||
<div class="hidden group-hover:flex">
|
||||
<div class="hidden group-hover:flex [@media(pointer:coarse)]:flex">
|
||||
<ActionIcon
|
||||
iconSize="h-2.5 w-2.5"
|
||||
icon={PowerOff}
|
||||
tooltip="Unload model"
|
||||
class="h-3 w-3 text-red-500 hover:text-red-600"
|
||||
class="h-3 w-3 text-red-500 hover:text-red-600 [@media(pointer:coarse)]:text-green-500 [@media(pointer:coarse)]:hover:text-green-600"
|
||||
onclick={() => modelsStore.unloadModel(option.model)}
|
||||
stopPropagationOnClick
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="flex w-4 items-center justify-center">
|
||||
<span class="h-2 w-2 rounded-full bg-muted-foreground/50 group-hover:hidden"></span>
|
||||
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
|
||||
<span
|
||||
class="h-2 w-2 rounded-full bg-muted-foreground/50 group-hover:hidden [@media(pointer:coarse)]:hidden"
|
||||
></span>
|
||||
|
||||
<div class="hidden group-hover:flex">
|
||||
<div class="hidden group-hover:flex [@media(pointer:coarse)]:flex">
|
||||
<ActionIcon
|
||||
iconSize="h-2.5 w-2.5"
|
||||
icon={Power}
|
||||
tooltip="Load model"
|
||||
class="h-3 w-3"
|
||||
class="h-3 w-3 [@media(pointer:coarse)]:text-muted-foreground"
|
||||
onclick={() => modelsStore.loadModel(option.model)}
|
||||
stopPropagationOnClick
|
||||
/>
|
||||
|
||||
@@ -66,7 +66,7 @@
|
||||
<button
|
||||
type="button"
|
||||
class={[
|
||||
`inline-flex cursor-pointer items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
|
||||
`inline-flex cursor-pointer items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 max-sm:px-3 max-sm:py-2 text-xs max-sm:text-sm shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
|
||||
!ms.isCurrentModelInCache
|
||||
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
|
||||
: forceForegroundText
|
||||
|
||||
@@ -0,0 +1,269 @@
|
||||
<script module lang="ts">
|
||||
import { defineMeta } from '@storybook/addon-svelte-csf';
|
||||
import ModelsSelectorList from '$lib/components/app/models/ModelsSelectorList.svelte';
|
||||
import ModelsSelectorOption from '$lib/components/app/models/ModelsSelectorOption.svelte';
|
||||
import type { GroupedModelOptions, ModelItem } from '$lib/components/app/models/utils';
|
||||
import { modelsStore } from '$lib/stores/models.svelte';
|
||||
import { ServerModelStatus } from '$lib/enums';
|
||||
|
||||
const { Story } = defineMeta({
|
||||
title: 'Components/ModelsSelector',
|
||||
parameters: {
|
||||
layout: 'centered'
|
||||
}
|
||||
});
|
||||
|
||||
const mockModel = (id: string, name: string, orgName?: string, tags?: string[]): ModelOption => ({
|
||||
id,
|
||||
name,
|
||||
model: orgName ? `${orgName}/${name}` : name,
|
||||
capabilities: [],
|
||||
parsedId: {
|
||||
raw: orgName ? `${orgName}/${name}` : name,
|
||||
orgName: orgName ?? null,
|
||||
modelName: name,
|
||||
params: null,
|
||||
activatedParams: null,
|
||||
quantization: null,
|
||||
tags: tags ?? []
|
||||
},
|
||||
tags
|
||||
});
|
||||
|
||||
const mockRouterEntry = (modelName: string, status: ServerModelStatus): ApiModelDataEntry => ({
|
||||
id: modelName,
|
||||
object: 'model',
|
||||
owned_by: 'llamacpp',
|
||||
created: Date.now(),
|
||||
in_cache: true,
|
||||
path: `/models/${modelName}`,
|
||||
status: { value: status }
|
||||
});
|
||||
</script>
|
||||
|
||||
<script lang="ts">
|
||||
let selectedModel = $state<string | null>(null);
|
||||
let activeId = $state<string | null>(null);
|
||||
|
||||
function mockModelsStore() {
|
||||
modelsStore.favoriteModelIds = new Set(['qwen2.5-7b', 'llama3.2-3b']);
|
||||
|
||||
// Mock router models with various statuses for ModelLoadedStates story
|
||||
modelsStore.routerModels = [
|
||||
mockRouterEntry('meta/Model (loading)', ServerModelStatus.LOADING),
|
||||
mockRouterEntry('meta/Model (loaded)', ServerModelStatus.LOADED),
|
||||
mockRouterEntry('meta/Model (sleeping)', ServerModelStatus.SLEEPING),
|
||||
mockRouterEntry('meta/Model (failed)', ServerModelStatus.FAILED)
|
||||
];
|
||||
}
|
||||
|
||||
mockModelsStore();
|
||||
|
||||
const loadedModels: ModelItem[] = [
|
||||
{ option: mockModel('llama3.1-8b', 'Llama-3.1-8B-Instruct', 'meta'), flatIndex: 0 },
|
||||
{ option: mockModel('mistral-7b', 'Mistral-7B-v0.3', 'mistralai'), flatIndex: 1 }
|
||||
];
|
||||
|
||||
const favoriteModels: ModelItem[] = [
|
||||
{ option: mockModel('qwen2.5-7b', 'Qwen2.5-7B-Instruct', 'Qwen'), flatIndex: 2 },
|
||||
{ option: mockModel('llama3.2-3b', 'Llama-3.2-3B-Instruct', 'meta'), flatIndex: 3 }
|
||||
];
|
||||
|
||||
const availableModels: ModelItem[] = [
|
||||
{
|
||||
option: mockModel('deepseek-coder-6.7b', 'DeepSeek-Coder-6.7B', 'deepseek', ['coding']),
|
||||
flatIndex: 4
|
||||
},
|
||||
{ option: mockModel('gemma-2-9b', 'Gemma-2-9B-IT', 'google'), flatIndex: 5 },
|
||||
{ option: mockModel('phi-3-mini', 'Phi-3-mini-4k', 'microsoft'), flatIndex: 6 },
|
||||
{ option: mockModel('codellama-7b', 'CodeLlama-7B', 'codellama', ['coding']), flatIndex: 7 },
|
||||
{ option: mockModel('neural-chat-7b', 'Neural-Chat-7B-v3-3', 'intel'), flatIndex: 8 }
|
||||
];
|
||||
|
||||
const groupedOptions: GroupedModelOptions = {
|
||||
loaded: loadedModels,
|
||||
favorites: favoriteModels,
|
||||
available: [
|
||||
{
|
||||
orgName: 'deepseek',
|
||||
items: [availableModels[0]]
|
||||
},
|
||||
{
|
||||
orgName: 'google',
|
||||
items: [availableModels[1]]
|
||||
},
|
||||
{
|
||||
orgName: 'microsoft',
|
||||
items: [availableModels[2]]
|
||||
},
|
||||
{
|
||||
orgName: 'codellama',
|
||||
items: [availableModels[3]]
|
||||
},
|
||||
{
|
||||
orgName: 'intel',
|
||||
items: [availableModels[4]]
|
||||
}
|
||||
]
|
||||
};
|
||||
|
||||
function handleSelect(modelId: string) {
|
||||
const opt = [...loadedModels, ...favoriteModels, ...availableModels].find(
|
||||
(m) => m.option.id === modelId
|
||||
);
|
||||
if (opt) {
|
||||
selectedModel = opt.option.model;
|
||||
activeId = modelId;
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<Story name="List">
|
||||
<div class="w-80 rounded-lg border border-border bg-popover p-2 shadow-md">
|
||||
<ModelsSelectorList
|
||||
groups={groupedOptions}
|
||||
currentModel={selectedModel}
|
||||
{activeId}
|
||||
onSelect={handleSelect}
|
||||
onInfoClick={(modelName) => console.log('Info clicked:', modelName)}
|
||||
/>
|
||||
</div>
|
||||
</Story>
|
||||
|
||||
<Story name="SingleLoaded">
|
||||
<div class="w-80 rounded-lg border border-border bg-popover p-2 shadow-md">
|
||||
<ModelsSelectorList
|
||||
groups={{
|
||||
loaded: [loadedModels[0]],
|
||||
favorites: [],
|
||||
available: []
|
||||
}}
|
||||
currentModel={null}
|
||||
activeId={null}
|
||||
onSelect={handleSelect}
|
||||
onInfoClick={(modelName) => console.log('Info clicked:', modelName)}
|
||||
/>
|
||||
</div>
|
||||
</Story>
|
||||
|
||||
<Story name="WithFavoritesOnly">
|
||||
<div class="w-80 rounded-lg border border-border bg-popover p-2 shadow-md">
|
||||
<ModelsSelectorList
|
||||
groups={{
|
||||
loaded: [],
|
||||
favorites: favoriteModels,
|
||||
available: []
|
||||
}}
|
||||
currentModel={null}
|
||||
activeId={null}
|
||||
onSelect={handleSelect}
|
||||
onInfoClick={(modelName) => console.log('Info clicked:', modelName)}
|
||||
/>
|
||||
</div>
|
||||
</Story>
|
||||
|
||||
<Story name="ModelLoadedStates">
|
||||
<div class="w-80 rounded-lg border border-border bg-popover p-2 shadow-md">
|
||||
<div class="px-2 py-2 text-[13px] font-semibold text-muted-foreground/70 select-none">
|
||||
Server model states
|
||||
</div>
|
||||
<ModelsSelectorOption
|
||||
option={mockModel('model-idle', 'Model (idle)', 'meta')}
|
||||
isSelected={false}
|
||||
isHighlighted={false}
|
||||
isFav={false}
|
||||
hideOrgName={true}
|
||||
onSelect={() => {}}
|
||||
onMouseEnter={() => {}}
|
||||
onKeyDown={() => {}}
|
||||
/>
|
||||
<ModelsSelectorOption
|
||||
option={mockModel('model-loading', 'Model (loading)', 'meta')}
|
||||
isSelected={false}
|
||||
isHighlighted={false}
|
||||
isFav={false}
|
||||
hideOrgName={true}
|
||||
onSelect={() => {}}
|
||||
onMouseEnter={() => {}}
|
||||
onKeyDown={() => {}}
|
||||
/>
|
||||
<ModelsSelectorOption
|
||||
option={mockModel('model-loaded', 'Model (loaded)', 'meta')}
|
||||
isSelected={false}
|
||||
isHighlighted={false}
|
||||
isFav={false}
|
||||
hideOrgName={true}
|
||||
onSelect={() => {}}
|
||||
onMouseEnter={() => {}}
|
||||
onKeyDown={() => {}}
|
||||
/>
|
||||
<ModelsSelectorOption
|
||||
option={mockModel('model-sleeping', 'Model (sleeping)', 'meta')}
|
||||
isSelected={false}
|
||||
isHighlighted={false}
|
||||
isFav={false}
|
||||
hideOrgName={true}
|
||||
onSelect={() => {}}
|
||||
onMouseEnter={() => {}}
|
||||
onKeyDown={() => {}}
|
||||
/>
|
||||
<ModelsSelectorOption
|
||||
option={mockModel('model-failed', 'Model (failed)', 'meta')}
|
||||
isSelected={false}
|
||||
isHighlighted={false}
|
||||
isFav={false}
|
||||
hideOrgName={true}
|
||||
onSelect={() => {}}
|
||||
onMouseEnter={() => {}}
|
||||
onKeyDown={() => {}}
|
||||
/>
|
||||
</div>
|
||||
</Story>
|
||||
|
||||
<Story name="ModelSelectedStates">
|
||||
<div class="w-80 rounded-lg border border-border bg-popover p-2 shadow-md">
|
||||
<div class="px-2 py-2 text-[13px] font-semibold text-muted-foreground/70 select-none">
|
||||
Selection states
|
||||
</div>
|
||||
<ModelsSelectorOption
|
||||
option={mockModel('normal-model', 'Normal Model', 'meta')}
|
||||
isSelected={false}
|
||||
isHighlighted={false}
|
||||
isFav={false}
|
||||
hideOrgName={true}
|
||||
onSelect={() => {}}
|
||||
onMouseEnter={() => {}}
|
||||
onKeyDown={() => {}}
|
||||
/>
|
||||
<ModelsSelectorOption
|
||||
option={mockModel('selected-model', 'Selected Model', 'meta')}
|
||||
isSelected={true}
|
||||
isHighlighted={false}
|
||||
isFav={false}
|
||||
hideOrgName={true}
|
||||
onSelect={() => {}}
|
||||
onMouseEnter={() => {}}
|
||||
onKeyDown={() => {}}
|
||||
/>
|
||||
<ModelsSelectorOption
|
||||
option={mockModel('highlighted-model', 'Highlighted Model', 'meta')}
|
||||
isSelected={false}
|
||||
isHighlighted={true}
|
||||
isFav={false}
|
||||
hideOrgName={true}
|
||||
onSelect={() => {}}
|
||||
onMouseEnter={() => {}}
|
||||
onKeyDown={() => {}}
|
||||
/>
|
||||
<ModelsSelectorOption
|
||||
option={mockModel('fav-model', 'Favorite Model', 'Qwen')}
|
||||
isSelected={false}
|
||||
isHighlighted={false}
|
||||
isFav={true}
|
||||
hideOrgName={true}
|
||||
onSelect={() => {}}
|
||||
onMouseEnter={() => {}}
|
||||
onKeyDown={() => {}}
|
||||
/>
|
||||
</div>
|
||||
</Story>
|
||||
Reference in New Issue
Block a user