mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-30 09:37:42 +02:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 312cf03328 | |||
| f4049ad735 | |||
| 5e8910a0db | |||
| fe00a84b4b | |||
| 7ab321d40d | |||
| 7533a7d509 | |||
| a69d54f990 | |||
| cf23ee2447 | |||
| 892e3c333a |
+42
-67
@@ -933,17 +933,12 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
|
||||
// Copy reasoning to the "thinking" field as expected by the gpt-oss template
|
||||
auto adjusted_messages = json::array();
|
||||
for (const auto & msg : inputs.messages) {
|
||||
auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string();
|
||||
auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array();
|
||||
|
||||
if (has_reasoning_content && has_tool_calls) {
|
||||
auto adjusted_message = msg;
|
||||
adjusted_message["thinking"] = msg.at("reasoning_content");
|
||||
adjusted_messages.push_back(adjusted_message);
|
||||
} else {
|
||||
adjusted_messages.push_back(msg);
|
||||
for (auto msg : inputs.messages) {
|
||||
if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) {
|
||||
msg["thinking"] = msg.at("reasoning_content");
|
||||
msg.erase("content");
|
||||
}
|
||||
adjusted_messages.push_back(msg);
|
||||
}
|
||||
|
||||
auto prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
@@ -969,45 +964,31 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
"<|channel|>", "<|constrain|>", "<|message|>", "<|start|>", "<|end|>",
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && has_tools;
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object();
|
||||
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
const std::string END = "<|end|>";
|
||||
const std::string START = "<|start|>";
|
||||
const std::string MESSAGE = "<|message|>";
|
||||
const std::string CHANNEL = "<|channel|>";
|
||||
const std::string CONSTRAIN = "<|constrain|>";
|
||||
const std::string START_ASSISTANT = START + "assistant";
|
||||
const std::string CHANNEL_ANALYSIS = CHANNEL + "analysis";
|
||||
const std::string CHANNEL_COMMENTARY = CHANNEL + "commentary";
|
||||
const std::string CHANNEL_FINAL = CHANNEL + "final";
|
||||
auto start = p.rule("start", p.literal("<|start|>assistant"));
|
||||
auto end = p.rule("end", p.literal("<|end|>"));
|
||||
auto content = p.rule("message-content", p.until("<|end|>"));
|
||||
auto channel = p.literal("<|channel|>") + (p.literal("commentary") | p.literal("analysis"));
|
||||
auto constrain_type = p.chars("[A-Za-z0-9_-]", 1, -1);
|
||||
|
||||
auto the_end = END | p.end();
|
||||
auto analysis = p.rule("analysis", p.literal("<|channel|>analysis<|message|>") + p.reasoning(content) + end);
|
||||
auto preamble = p.rule("preamble", p.literal("<|channel|>commentary<|message|>") + p.content(content) + end);
|
||||
auto final_msg = p.rule("final", p.literal("<|channel|>final<|message|>") + p.content(content));
|
||||
auto any = p.rule("any", preamble | analysis);
|
||||
|
||||
const std::string analysis_header = CHANNEL_ANALYSIS + MESSAGE;
|
||||
auto segment_content = p.until(END);
|
||||
auto analysis_segment = extract_reasoning ?
|
||||
p.literal(analysis_header) + p.reasoning(segment_content) + p.until(END) + the_end :
|
||||
p.content(analysis_header + p.until(END) + the_end);
|
||||
if (has_response_format) {
|
||||
auto constraint = p.optional(p.space() + p.literal("<|constrain|>") + constrain_type);
|
||||
auto response_format = p.rule("response-format",
|
||||
p.literal("<|channel|>final") + constraint + p.literal("<|message|>") +
|
||||
p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
|
||||
|
||||
auto channel_header_content = p.until_one_of({ " to=functions.", MESSAGE });
|
||||
auto content_header = p.choice({ p.literal(CHANNEL_COMMENTARY), p.literal(CHANNEL_FINAL) });
|
||||
auto content_segment = p.rule("content-segment", content_header + channel_header_content + MESSAGE +
|
||||
p.content(segment_content) + the_end);
|
||||
|
||||
if (!inputs.json_schema.is_null()) {
|
||||
auto final_header = p.literal(CHANNEL_FINAL);
|
||||
auto constraint = p.optional(p.space() + p.literal(CONSTRAIN) + channel_header_content);
|
||||
return p.optional(analysis_segment) + final_header + constraint + MESSAGE +
|
||||
p.content(p.schema(p.json(), "response-format", inputs.json_schema));
|
||||
return response_format | (analysis + p.zero_or_more(start + analysis) + start + response_format);
|
||||
}
|
||||
|
||||
auto segment = p.optional(START_ASSISTANT + p.space()) + p.choice({ content_segment, analysis_segment });
|
||||
auto contents = p.optional(segment + p.repeat(p.optional(p.space()) + segment, 0, -1)) + p.end();
|
||||
|
||||
// Tool call parser
|
||||
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
auto tool_choice = p.choice();
|
||||
|
||||
@@ -1016,42 +997,37 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
std::string name = function.at("name");
|
||||
const auto & params = function.at("parameters");
|
||||
|
||||
// Tool call can appear as:
|
||||
// 1. In role header: " to=functions.NAME<|channel|>..."
|
||||
// 2. In channel: "<|channel|>(analysis|commentary) to=functions.NAME..."
|
||||
auto func_name = p.literal(" to=functions.") + p.tool_name(p.literal(name));
|
||||
|
||||
auto channel = p.literal(CHANNEL_COMMENTARY) | p.literal(CHANNEL_ANALYSIS);
|
||||
auto constraint = p.space() + p.optional(p.literal(CONSTRAIN) + channel_header_content);
|
||||
auto func_name = p.literal(" to=functions.") + p.tool_name(p.literal(name));
|
||||
auto constraint = p.optional(p.space() + p.literal("<|constrain|>") + constrain_type);
|
||||
auto args = p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", params));
|
||||
|
||||
// Pattern 1: recipient in role header
|
||||
// " to=functions.NAME<|channel|>(analysis|commentary)[constraint]<|message|>ARGS"
|
||||
auto tool_in_role = p.tool(p.tool_open(func_name + channel) + constraint + MESSAGE + args);
|
||||
// recipient in role header
|
||||
// <|start|>assistant to=functions.NAME<|channel|>(commentary|analysis)[constraint]<|message|>ARGS
|
||||
auto tool_in_role = p.tool(p.tool_open(func_name + channel + constraint + p.literal("<|message|>")) + args);
|
||||
|
||||
// Pattern 2: recipient in channel header
|
||||
// "<|channel|>(analysis|commentary) to=functions.NAME[constraint]<|message|>ARGS"
|
||||
auto tool_in_channel = p.tool(channel + p.tool_open(func_name + constraint + MESSAGE) + args);
|
||||
// recipient in channel header
|
||||
// <|channel|>(commentary|analysis) to=functions.NAME[constraint]<|message|>ARGS
|
||||
auto tool_in_channel = p.tool(p.tool_open(channel + func_name + constraint + p.literal("<|message|>")) + args);
|
||||
|
||||
tool_choice |= tool_in_role | tool_in_channel;
|
||||
tool_choice |= p.rule("tool-" + name, tool_in_role | tool_in_channel);
|
||||
});
|
||||
|
||||
auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0;
|
||||
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
|
||||
auto tool_call = p.trigger_rule("tool-call", tool_choice);
|
||||
|
||||
auto role_start = p.optional(p.space() + p.literal(START_ASSISTANT));
|
||||
auto tool_call = p.rule("tool-call", p.repeat(role_start + tool_choice, min_calls, max_calls) + p.end());
|
||||
if (inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED) {
|
||||
return tool_call | ( any + p.zero_or_more(start + any) + start + tool_call);
|
||||
}
|
||||
|
||||
return p.choice({ p.trigger_rule("single-tool", tool_call), p.trigger_rule("tools", p.one_or_more(segment) + tool_call) });
|
||||
return tool_call | final_msg | (any + p.zero_or_more(start + any) + start + (tool_call | final_msg));
|
||||
}
|
||||
|
||||
return contents;
|
||||
return final_msg | (any + p.zero_or_more(start + any) + start + final_msg);
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
@@ -1062,10 +1038,9 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
});
|
||||
|
||||
data.grammar_triggers = {
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "^(?:<\\|start\\|>assistant\\s*)?(\\s+to=functions)" },
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "(?:<\\|end\\|>)(?:<\\|start\\|>assistant\\s*)?(\\s+to=functions)" },
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
||||
"(?:<\\|start\\|>assistant\\s*)?(<\\|channel\\|>(?:commentary|analysis)\\s+to=functions)" }
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "^\\s+to$" },
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "<\\|start\\|>assistant(\\s+to)" },
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "<\\|start\\|>assistant(<\\|channel\\|>(?:commentary|analysis)\\s+to)" }
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
+1
-1
@@ -1067,7 +1067,7 @@ common_init_result::common_init_result(common_params & params) :
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
// load and optionally apply lora adapters (must be loaded before context creation)
|
||||
// load and optionally apply lora adapters
|
||||
for (auto & la : params.lora_adapters) {
|
||||
llama_adapter_lora_ptr lora;
|
||||
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
|
||||
|
||||
@@ -126,7 +126,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
|
||||
if (err == hipSuccess) {
|
||||
// hipMemAdviseSetCoarseGrain is an optional performance hint;
|
||||
// ignore errors (e.g. hipErrorInvalidValue on some APU/iGPU configs).
|
||||
cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device);
|
||||
(void)cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device);
|
||||
(void)hipGetLastError(); // clear any error
|
||||
}
|
||||
|
||||
|
||||
@@ -2362,6 +2362,27 @@ static inline size_t init_cpy_req(htp_general_req * req, dspqueue_buffer * bufs,
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_cont_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
// CONT is just a contiguous copy — reuse CPY op
|
||||
req->op = HTP_OP_CPY;
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_repeat_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
req->op = HTP_OP_REPEAT;
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
req->op = HTP_OP_GET_ROWS;
|
||||
|
||||
@@ -2449,12 +2470,33 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
|
||||
break;
|
||||
|
||||
case GGML_OP_UNARY:
|
||||
if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
|
||||
switch (ggml_get_unary_op(t)) {
|
||||
case GGML_UNARY_OP_SILU:
|
||||
req->op = HTP_OP_UNARY_SILU;
|
||||
supported = true;
|
||||
} else if (ggml_get_unary_op(t) == GGML_UNARY_OP_GELU) {
|
||||
break;
|
||||
case GGML_UNARY_OP_GELU:
|
||||
req->op = HTP_OP_UNARY_GELU;
|
||||
supported = true;
|
||||
break;
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
req->op = HTP_OP_UNARY_SIGMOID;
|
||||
supported = true;
|
||||
break;
|
||||
case GGML_UNARY_OP_NEG:
|
||||
req->op = HTP_OP_UNARY_NEG;
|
||||
supported = true;
|
||||
break;
|
||||
case GGML_UNARY_OP_EXP:
|
||||
req->op = HTP_OP_UNARY_EXP;
|
||||
supported = true;
|
||||
break;
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
req->op = HTP_OP_UNARY_SOFTPLUS;
|
||||
supported = true;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
|
||||
@@ -2640,16 +2682,28 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
ggml_hexagon_dispatch_op<init_sum_rows_req>(sess, node, flags);
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) ||
|
||||
(ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) {
|
||||
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
||||
switch (ggml_get_unary_op(node)) {
|
||||
case GGML_UNARY_OP_NEG:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_GELU:
|
||||
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_GLU:
|
||||
if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) ||
|
||||
(ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) ||
|
||||
(ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) {
|
||||
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
||||
switch (ggml_get_glu_op(node)) {
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_SWIGLU_OAI:
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
@@ -2676,6 +2730,14 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
ggml_hexagon_dispatch_op<init_cpy_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_CONT:
|
||||
ggml_hexagon_dispatch_op<init_cont_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_REPEAT:
|
||||
ggml_hexagon_dispatch_op<init_repeat_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_ARGSORT:
|
||||
ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
|
||||
break;
|
||||
@@ -3006,6 +3068,39 @@ static bool ggml_hexagon_supported_cpy(const struct ggml_hexagon_session * sess,
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_cont(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
GGML_UNUSED(sess);
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
|
||||
// CONT is same-type only, supports f32 and f16
|
||||
if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
GGML_UNUSED(sess);
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
// Support f32 and f16
|
||||
if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false;
|
||||
|
||||
// src and dst must be the same type
|
||||
if (src0->type != dst->type) return false;
|
||||
|
||||
// dst dims must be multiples of src dims
|
||||
if (dst->ne[0] % src0->ne[0] != 0) return false;
|
||||
if (dst->ne[1] % src0->ne[1] != 0) return false;
|
||||
if (dst->ne[2] % src0->ne[2] != 0) return false;
|
||||
if (dst->ne[3] % src0->ne[3] != 0) return false;
|
||||
|
||||
// require contiguous tensors (no transposition)
|
||||
if (ggml_is_transposed(src0) || ggml_is_transposed(dst)) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
||||
auto sess = static_cast<ggml_hexagon_session *>(dev->context);
|
||||
|
||||
@@ -3063,21 +3158,32 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
break;
|
||||
|
||||
case GGML_OP_UNARY:
|
||||
{
|
||||
const auto unary_op = ggml_get_unary_op(op);
|
||||
if (unary_op == GGML_UNARY_OP_SILU || unary_op == GGML_UNARY_OP_GELU) {
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
case GGML_UNARY_OP_NEG:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
supp = ggml_hexagon_supported_unary(sess, op);
|
||||
break;
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_GELU:
|
||||
supp = ggml_hexagon_supported_activations(sess, op);
|
||||
}
|
||||
break;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_GLU:
|
||||
{
|
||||
const auto glu_op = ggml_get_glu_op(op);
|
||||
if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) {
|
||||
switch (ggml_get_glu_op(op)) {
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_SWIGLU_OAI:
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
supp = ggml_hexagon_supported_activations(sess, op);
|
||||
}
|
||||
break;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_ROPE:
|
||||
supp = ggml_hexagon_supported_rope(sess, op);
|
||||
break;
|
||||
@@ -3098,6 +3204,14 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
supp = ggml_hexagon_supported_cpy(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_CONT:
|
||||
supp = ggml_hexagon_supported_cont(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_REPEAT:
|
||||
supp = ggml_hexagon_supported_repeat(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_ARGSORT:
|
||||
supp = ggml_hexagon_supported_argsort(sess, op);
|
||||
break;
|
||||
|
||||
@@ -30,6 +30,7 @@ add_library(${HTP_LIB} SHARED
|
||||
set-rows-ops.c
|
||||
get-rows-ops.c
|
||||
cpy-ops.c
|
||||
repeat-ops.c
|
||||
argsort-ops.c
|
||||
ssm-conv.c
|
||||
)
|
||||
|
||||
@@ -53,6 +53,10 @@ enum htp_op {
|
||||
HTP_OP_RMS_NORM,
|
||||
HTP_OP_UNARY_SILU,
|
||||
HTP_OP_UNARY_GELU,
|
||||
HTP_OP_UNARY_SIGMOID,
|
||||
HTP_OP_UNARY_EXP,
|
||||
HTP_OP_UNARY_NEG,
|
||||
HTP_OP_UNARY_SOFTPLUS,
|
||||
HTP_OP_GLU_SWIGLU,
|
||||
HTP_OP_GLU_SWIGLU_OAI,
|
||||
HTP_OP_GLU_GEGLU,
|
||||
@@ -69,6 +73,7 @@ enum htp_op {
|
||||
HTP_OP_SQRT,
|
||||
HTP_OP_SUM_ROWS,
|
||||
HTP_OP_SSM_CONV,
|
||||
HTP_OP_REPEAT,
|
||||
INVALID
|
||||
};
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx);
|
||||
int op_set_rows(struct htp_ops_context * octx);
|
||||
int op_get_rows(struct htp_ops_context * octx);
|
||||
int op_cpy(struct htp_ops_context * octx);
|
||||
int op_repeat(struct htp_ops_context * octx);
|
||||
int op_argsort(struct htp_ops_context * octx);
|
||||
int op_ssm_conv(struct htp_ops_context * octx);
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <math.h>
|
||||
#include <assert.h>
|
||||
|
||||
#include "hex-utils.h"
|
||||
#include "hvx-types.h"
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <math.h>
|
||||
|
||||
#include "hvx-base.h"
|
||||
#include "hvx-floor.h"
|
||||
@@ -16,8 +17,8 @@
|
||||
#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
|
||||
#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
|
||||
#define EXP_ONE (0x3f800000) // 1.0
|
||||
#define EXP_RANGE_R (0x41a00000) // 20.0
|
||||
#define EXP_RANGE_L (0xc1a00000) // -20.0
|
||||
#define EXP_RANGE_R (0x42B16666) // 88.7
|
||||
#define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN))
|
||||
|
||||
static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
|
||||
HVX_Vector z_qf32_v;
|
||||
@@ -47,12 +48,12 @@ static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
|
||||
|
||||
HVX_Vector temp_v = in_vec;
|
||||
|
||||
// Clamp inputs to (-20.0, 20.0)
|
||||
// Clamp inputs to (-88.0, 88.0) to avoid overflow/underflow
|
||||
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R));
|
||||
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec);
|
||||
|
||||
in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v);
|
||||
in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v);
|
||||
in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), in_vec);
|
||||
|
||||
epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec);
|
||||
epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v);
|
||||
@@ -69,12 +70,12 @@ static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
|
||||
// normalize before every QFloat's vmpy
|
||||
x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v);
|
||||
|
||||
x_v = Q6_Vsf_equals_Vqf32(x_qf32_v);
|
||||
|
||||
// z = x * x;
|
||||
z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v);
|
||||
z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v);
|
||||
|
||||
x_v = Q6_Vsf_equals_Vqf32(x_qf32_v);
|
||||
|
||||
// y = E4 + E5 * x;
|
||||
E_const = Q6_V_vsplat_R(EXP_COEFF_5);
|
||||
y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v);
|
||||
@@ -145,7 +146,7 @@ static inline HVX_Vector hvx_vec_exp_f32_guard(HVX_Vector in_vec, HVX_Vector max
|
||||
return Q6_V_vmux_QVV(pred0, inf, out);
|
||||
}
|
||||
|
||||
static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) {
|
||||
static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems, bool negate) {
|
||||
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||
int num_elems_whole = num_elems - left_over;
|
||||
|
||||
@@ -162,7 +163,7 @@ static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict
|
||||
HVX_Vector vec_out = Q6_V_vzero();
|
||||
|
||||
static const float kInf = INFINITY;
|
||||
static const float kMaxExp = 88.02f; // log(INF)
|
||||
static const float kMaxExp = 88.7f;
|
||||
|
||||
const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp);
|
||||
const HVX_Vector inf = hvx_vec_splat_f32(kInf);
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define HVX_SIGMOID_H
|
||||
|
||||
#include "hvx-base.h"
|
||||
#include "hvx-inverse.h"
|
||||
|
||||
#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022
|
||||
#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777
|
||||
|
||||
@@ -516,6 +516,39 @@ static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req,
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_repeat_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[0].fd = bufs[1].fd;
|
||||
rsp_bufs[0].ptr = bufs[1].ptr;
|
||||
rsp_bufs[0].offset = bufs[1].offset;
|
||||
rsp_bufs[0].size = bufs[1].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[1].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = op_repeat(&octx);
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
@@ -1090,6 +1123,10 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
|
||||
case HTP_OP_SQR:
|
||||
case HTP_OP_SQRT:
|
||||
case HTP_OP_UNARY_NEG:
|
||||
case HTP_OP_UNARY_EXP:
|
||||
case HTP_OP_UNARY_SIGMOID:
|
||||
case HTP_OP_UNARY_SOFTPLUS:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad unary-req buffer list");
|
||||
continue;
|
||||
@@ -1175,6 +1212,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
proc_cpy_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_REPEAT:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad repeat-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_repeat_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_ARGSORT:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad argsort-req buffer list");
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_perf.h>
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
struct htp_repeat_context {
|
||||
struct htp_ops_context * octx;
|
||||
|
||||
uint32_t nr0;
|
||||
uint32_t nr1;
|
||||
uint32_t nr2;
|
||||
uint32_t nr3;
|
||||
|
||||
uint32_t nrows_per_thread;
|
||||
uint32_t total_dst_rows; // ne1 * ne2 * ne3
|
||||
|
||||
size_t type_size;
|
||||
};
|
||||
|
||||
static void repeat_job_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
const struct htp_repeat_context * rctx = (const struct htp_repeat_context *) data;
|
||||
struct htp_ops_context * octx = rctx->octx;
|
||||
const struct htp_tensor * src = &octx->src0;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
const uint32_t ne00 = src->ne[0];
|
||||
const uint32_t ne01 = src->ne[1];
|
||||
const uint32_t ne02 = src->ne[2];
|
||||
const uint32_t ne03 = src->ne[3];
|
||||
|
||||
const uint32_t nb00 = src->nb[0];
|
||||
const uint32_t nb01 = src->nb[1];
|
||||
const uint32_t nb02 = src->nb[2];
|
||||
const uint32_t nb03 = src->nb[3];
|
||||
|
||||
const uint32_t ne0 = dst->ne[0];
|
||||
const uint32_t ne1 = dst->ne[1];
|
||||
const uint32_t ne2 = dst->ne[2];
|
||||
const uint32_t ne3 = dst->ne[3];
|
||||
|
||||
const uint32_t nb0 = dst->nb[0];
|
||||
const uint32_t nb1 = dst->nb[1];
|
||||
const uint32_t nb2 = dst->nb[2];
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
const uint32_t nr0 = rctx->nr0;
|
||||
const uint32_t nr1 = rctx->nr1;
|
||||
const uint32_t nr2 = rctx->nr2;
|
||||
const uint32_t nr3 = rctx->nr3;
|
||||
|
||||
const size_t row_bytes = ne00 * rctx->type_size;
|
||||
|
||||
const uint32_t row_start = rctx->nrows_per_thread * ith;
|
||||
const uint32_t row_end = MIN(row_start + rctx->nrows_per_thread, rctx->total_dst_rows);
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) {
|
||||
// Decompose flat dst row index into (i1, i2, i3)
|
||||
const uint32_t i1 = dst_row % ne1;
|
||||
const uint32_t i2 = (dst_row / ne1) % ne2;
|
||||
const uint32_t i3 = dst_row / (ne1 * ne2);
|
||||
|
||||
// Map to source indices (tiling)
|
||||
const uint32_t k1 = i1 % ne01;
|
||||
const uint32_t k2 = i2 % ne02;
|
||||
const uint32_t k3 = i3 % ne03;
|
||||
|
||||
const uint8_t * src_row = (const uint8_t *) src->data + k1 * nb01 + k2 * nb02 + k3 * nb03;
|
||||
uint8_t * dst_base = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3;
|
||||
|
||||
// Tile along dimension 0
|
||||
for (uint32_t i0 = 0; i0 < nr0; i0++) {
|
||||
uint8_t * dst_ptr = dst_base + i0 * ne00 * nb0;
|
||||
memcpy(dst_ptr, src_row, row_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "repeat %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n",
|
||||
ith, nth, src->ne[0], src->ne[1], src->ne[2], src->ne[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
row_start, row_end, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
int op_repeat(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
// Validate that dst dims are multiples of src dims
|
||||
if (dst->ne[0] % src0->ne[0] != 0 ||
|
||||
dst->ne[1] % src0->ne[1] != 0 ||
|
||||
dst->ne[2] % src0->ne[2] != 0 ||
|
||||
dst->ne[3] % src0->ne[3] != 0) {
|
||||
FARF(ERROR, "repeat: dst dims must be multiples of src dims\n");
|
||||
return HTP_STATUS_INVAL_PARAMS;
|
||||
}
|
||||
|
||||
size_t type_size;
|
||||
switch (src0->type) {
|
||||
case HTP_TYPE_F32: type_size = 4; break;
|
||||
case HTP_TYPE_F16: type_size = 2; break;
|
||||
default:
|
||||
FARF(ERROR, "repeat: unsupported type %u\n", src0->type);
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
const uint32_t total_dst_rows = dst->ne[1] * dst->ne[2] * dst->ne[3];
|
||||
const uint32_t n_threads = MIN(octx->n_threads, total_dst_rows);
|
||||
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
struct htp_repeat_context rctx = {
|
||||
.octx = octx,
|
||||
.nr0 = dst->ne[0] / src0->ne[0],
|
||||
.nr1 = dst->ne[1] / src0->ne[1],
|
||||
.nr2 = dst->ne[2] / src0->ne[2],
|
||||
.nr3 = dst->ne[3] / src0->ne[3],
|
||||
.nrows_per_thread = (total_dst_rows + n_threads - 1) / n_threads,
|
||||
.total_dst_rows = total_dst_rows,
|
||||
.type_size = type_size,
|
||||
};
|
||||
|
||||
FARF(HIGH, "repeat: (%ux%ux%ux%u) -> (%ux%ux%ux%u) nr=(%u,%u,%u,%u)\n",
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
rctx.nr0, rctx.nr1, rctx.nr2, rctx.nr3);
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, repeat_job_per_thread, &rctx, n_threads);
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
@@ -195,7 +195,7 @@ static float hvx_softmax_f32(const uint8_t * restrict src,
|
||||
const float max) {
|
||||
hvx_sub_scalar_f32(spad, src, max, num_elems);
|
||||
|
||||
hvx_exp_f32(spad, dst, num_elems, false);
|
||||
hvx_exp_f32(dst, spad, num_elems, false);
|
||||
|
||||
float sum = hvx_reduce_sum_f32(dst, num_elems);
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-exp.h"
|
||||
#include "hvx-sigmoid.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
@@ -166,6 +168,75 @@ static void sqrt_f32(const float * restrict src,
|
||||
}
|
||||
}
|
||||
|
||||
static void neg_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
||||
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
||||
|
||||
hvx_scale_f32_aa(dst_local, src_local, row_elems, -1.0f);
|
||||
}
|
||||
}
|
||||
|
||||
static void exp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
||||
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
||||
|
||||
hvx_exp_f32(dst_local, src_local, row_elems, false);
|
||||
}
|
||||
}
|
||||
|
||||
static void sigmoid_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
||||
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
||||
|
||||
hvx_sigmoid_f32_aa(dst_local, src_local, row_elems);
|
||||
}
|
||||
}
|
||||
|
||||
static void softplus_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
// softplus(x) = log(1 + exp(x))
|
||||
// Match CPU reference: ggml_compute_softplus_f32() in ggml-impl.h
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size));
|
||||
float * restrict dst_f = (float *)((uint8_t *)dst + (ir * row_size));
|
||||
|
||||
for (uint32_t i = 0; i < row_elems; i++) {
|
||||
float x = src_f[i];
|
||||
// For x > 20: softplus(x) ≈ x (avoids exp overflow)
|
||||
dst_f[i] = (x > 20.0f) ? x : logf(1.0f + expf(x));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
|
||||
struct htp_ops_context * octx = uctx->octx;
|
||||
@@ -247,6 +318,18 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
case HTP_OP_SQRT:
|
||||
sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_UNARY_NEG:
|
||||
neg_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_UNARY_EXP:
|
||||
exp_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_UNARY_SIGMOID:
|
||||
sigmoid_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_UNARY_SOFTPLUS:
|
||||
softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -295,6 +378,18 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
case HTP_OP_SQRT:
|
||||
op_type = "sqrt-f32";
|
||||
break;
|
||||
case HTP_OP_UNARY_NEG:
|
||||
op_type = "neg-f32";
|
||||
break;
|
||||
case HTP_OP_UNARY_EXP:
|
||||
op_type = "exp-f32";
|
||||
break;
|
||||
case HTP_OP_UNARY_SIGMOID:
|
||||
op_type = "sigmoid-f32";
|
||||
break;
|
||||
case HTP_OP_UNARY_SOFTPLUS:
|
||||
op_type = "softplus-f32";
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
|
||||
|
||||
@@ -7646,20 +7646,14 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
|
||||
return true;
|
||||
}
|
||||
case VK_VENDOR_ID_INTEL:
|
||||
if (k < 2048) {
|
||||
if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) {
|
||||
// Intel Windows proprietary driver MMVQ performance is worse than fp16, see
|
||||
// https://github.com/ggml-org/llama.cpp/issues/17628
|
||||
return false;
|
||||
}
|
||||
|
||||
if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) {
|
||||
// Intel Windows proprietary driver tuning
|
||||
switch (src0_type) {
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
return false;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
if (k < 2048) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (src0_type) {
|
||||
|
||||
+1
-3
@@ -21,9 +21,7 @@ struct llama_sampler_deleter {
|
||||
};
|
||||
|
||||
struct llama_adapter_lora_deleter {
|
||||
void operator()(llama_adapter_lora *) {
|
||||
// llama_adapter_lora_free is deprecated
|
||||
}
|
||||
void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); }
|
||||
};
|
||||
|
||||
typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr;
|
||||
|
||||
+2
-4
@@ -636,7 +636,6 @@ extern "C" {
|
||||
|
||||
// Load a LoRA adapter from file
|
||||
// The adapter is valid as long as the associated model is not freed
|
||||
// All adapters must be loaded before context creation
|
||||
LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init(
|
||||
struct llama_model * model,
|
||||
const char * path_lora);
|
||||
@@ -660,9 +659,8 @@ extern "C" {
|
||||
LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
|
||||
|
||||
// Manually free a LoRA adapter
|
||||
// NOTE: loaded adapters will be free when the associated model is deleted
|
||||
LLAMA_API DEPRECATED(void llama_adapter_lora_free(struct llama_adapter_lora * adapter),
|
||||
"adapters are now freed together with the associated model");
|
||||
// NOTE: loaded adapters that are not manually freed will be freed when the associated model is deleted
|
||||
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
|
||||
|
||||
// Get the invocation tokens if the current lora is an alora
|
||||
LLAMA_API uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter);
|
||||
|
||||
+12
-3
@@ -418,7 +418,7 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
|
||||
}
|
||||
|
||||
llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) {
|
||||
llama_adapter_lora * adapter = new llama_adapter_lora();
|
||||
llama_adapter_lora * adapter = new llama_adapter_lora(model);
|
||||
|
||||
try {
|
||||
llama_adapter_lora_init_impl(*model, path_lora, *adapter);
|
||||
@@ -471,8 +471,17 @@ int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter,
|
||||
return snprintf(buf, buf_size, "%s", it->second.c_str());
|
||||
}
|
||||
|
||||
void llama_adapter_lora_free(llama_adapter_lora *) {
|
||||
// deprecated: adapters are freed by llama_model's destructor
|
||||
void llama_adapter_lora_free(llama_adapter_lora * adapter) {
|
||||
if (adapter == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (adapter->model != nullptr) {
|
||||
adapter->model->loras.erase(adapter);
|
||||
adapter->model = nullptr;
|
||||
}
|
||||
|
||||
delete adapter;
|
||||
}
|
||||
|
||||
uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter) {
|
||||
|
||||
+3
-1
@@ -61,6 +61,8 @@ struct llama_adapter_lora_weight {
|
||||
};
|
||||
|
||||
struct llama_adapter_lora {
|
||||
llama_model * model = nullptr;
|
||||
|
||||
// map tensor name to lora_a_b
|
||||
std::unordered_map<std::string, llama_adapter_lora_weight> ab_map;
|
||||
|
||||
@@ -75,7 +77,7 @@ struct llama_adapter_lora {
|
||||
// activated lora (aLoRA)
|
||||
std::vector<llama_token> alora_invocation_tokens;
|
||||
|
||||
llama_adapter_lora() = default;
|
||||
explicit llama_adapter_lora(llama_model * model) : model(model) {}
|
||||
~llama_adapter_lora() = default;
|
||||
|
||||
llama_adapter_lora_weight * get_weight(ggml_tensor * w);
|
||||
|
||||
@@ -1165,9 +1165,11 @@ bool llama_context::set_adapter_cvec(
|
||||
int32_t il_end) {
|
||||
LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end);
|
||||
|
||||
// TODO: should we reserve?
|
||||
bool res = cvec->apply(model, data, len, n_embd, il_start, il_end);
|
||||
|
||||
return cvec->apply(model, data, len, n_embd, il_start, il_end);
|
||||
sched_need_reserve = true;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
|
||||
|
||||
@@ -89,6 +89,7 @@ struct test_context {
|
||||
cparams.n_batch = 512;
|
||||
cparams.samplers = configs.data();
|
||||
cparams.n_samplers = configs.size();
|
||||
cparams.kv_unified = true;
|
||||
|
||||
// If n_seq_max is not specified, calculate it from configs
|
||||
if (n_seq_max < 0) {
|
||||
|
||||
+14
-46
@@ -2448,7 +2448,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
|
||||
// Analysis channel (reasoning) with final channel (content)
|
||||
tst.test(
|
||||
"<|channel|>analysis<|message|>I'm\nthinking<|end|>\n<|channel|>final<|message|>Hello, world!\nWhat's "
|
||||
"<|channel|>analysis<|message|>I'm\nthinking<|end|><|start|>assistant<|channel|>final<|message|>Hello, world!\nWhat's "
|
||||
"up?")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect(message_assist_thoughts)
|
||||
@@ -2461,15 +2461,6 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_reasoning("I'm\nthinking")
|
||||
.run();
|
||||
|
||||
// Reasoning format none - reasoning stays in content
|
||||
tst.test(
|
||||
"<|channel|>analysis<|message|>I'm\nthinking<|end|>\n<|channel|>final<|message|>Hello, world!\nWhat's "
|
||||
"up?")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_NONE)
|
||||
.expect_content(
|
||||
"<|channel|>analysis<|message|>I'm\nthinking<|end|>Hello, world!\nWhat's up?")
|
||||
.run();
|
||||
|
||||
// Tool call with recipient in role header: " to=functions.NAME<|channel|>analysis<|message|>JSON"
|
||||
tst.test(" to=functions.special_function<|channel|>analysis<|message|>{\"arg1\": 1}")
|
||||
.tools({ special_function_tool })
|
||||
@@ -2496,37 +2487,16 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
|
||||
// Tool call with reasoning + content (analysis first, then tool call)
|
||||
tst.test(
|
||||
"<|channel|>analysis<|message|>I'm\nthinking<|end|>\n"
|
||||
"<|channel|>analysis<|message|>I'm\nthinking<|end|>"
|
||||
"<|start|>assistant to=functions.special_function<|channel|>analysis<|message|>{\"arg1\": 1}")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call_thoughts)
|
||||
.run();
|
||||
|
||||
// Tool calling with extra channel before
|
||||
// Complex tool calling
|
||||
tst.test(
|
||||
"<|channel|>analysis<|message|>I'm\nthinking<|end|><|start|>assistant<|channel|>commentary"
|
||||
" to=functions.special_function <|message|>{\"arg1\": 1}")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call_thoughts)
|
||||
.run();
|
||||
|
||||
// Reasoning after final channel
|
||||
// Tool calling after final channel
|
||||
tst.test(
|
||||
"<|channel|>final<|message|><|end|>"
|
||||
"<|start|>assistant<|channel|>analysis<|message|>Thinking about edit..."
|
||||
)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect_reasoning("Thinking about edit...")
|
||||
.expect_content("")
|
||||
.run();
|
||||
|
||||
// Tool calling after final channel
|
||||
tst.test(
|
||||
"<|channel|>final<|message|><|end|>"
|
||||
"<|start|>assistant<|channel|>analysis<|message|>Thinking about edit...<|end|>"
|
||||
"<|channel|>analysis<|message|>Thinking about edit...<|end|>"
|
||||
"<|start|>assistant<|channel|>commentary to=functions.edit <|constrain|>json"
|
||||
"<|message|>{\"oldString\": \"if (part < railCount - 1) {\", \"newString\": \"if (part < 4) {\", \"replaceAll\": false}"
|
||||
)
|
||||
@@ -2561,19 +2531,17 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
})
|
||||
.run();
|
||||
|
||||
// Parallel tool calls
|
||||
// Structured output
|
||||
tst.test(
|
||||
" to=functions.special_function<|channel|>analysis<|message|>{\"arg1\": 1}\n"
|
||||
"<|start|>assistant to=functions.special_function_with_opt<|channel|>analysis<|message|>{\"arg1\": 1, "
|
||||
"\"arg2\": 2}")
|
||||
.parallel_tool_calls(true)
|
||||
.tools({
|
||||
special_function_tool, special_function_tool_with_optional_param
|
||||
})
|
||||
.expect_tool_calls({
|
||||
{ "special_function", R"({"arg1": 1})", {} },
|
||||
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
|
||||
})
|
||||
"<|channel|>analysis<|message|>I need to output the invoice details in JSON<|end|>"
|
||||
"<|start|>assistant<|channel|>final <|constrain|>json"
|
||||
"<|message|>"
|
||||
R"({"amount": 123.45, "date": "2025-12-03"})"
|
||||
)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.json_schema(invoice_schema)
|
||||
.expect_reasoning("I need to output the invoice details in JSON")
|
||||
.expect_content(R"({"amount": 123.45, "date": "2025-12-03"})")
|
||||
.run();
|
||||
}
|
||||
|
||||
|
||||
+21
-5
@@ -1897,8 +1897,9 @@ import sys
|
||||
from datetime import datetime
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
tmpl = json.loads(sys.argv[1])
|
||||
vars_json = json.loads(sys.argv[2])
|
||||
merged_input = json.loads(sys.stdin.buffer.read().decode("utf-8"))
|
||||
tmpl = merged_input["tmpl"]
|
||||
vars_json = merged_input["vars"]
|
||||
|
||||
env = SandboxedEnvironment(
|
||||
trim_blocks=True,
|
||||
@@ -1921,8 +1922,9 @@ sys.stdout.buffer.write(result.encode())
|
||||
static void test_template_py(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
|
||||
t.test(name, [&tmpl, &vars, &expect](testing & t) {
|
||||
// Prepare arguments
|
||||
std::string tmpl_json = json(tmpl).dump();
|
||||
std::string vars_json = vars.dump();
|
||||
json merged;
|
||||
merged["tmpl"] = json(tmpl);
|
||||
merged["vars"] = vars;
|
||||
|
||||
#ifdef _WIN32
|
||||
const char * python_executable = "python.exe";
|
||||
@@ -1930,7 +1932,7 @@ static void test_template_py(testing & t, const std::string & name, const std::s
|
||||
const char * python_executable = "python3";
|
||||
#endif
|
||||
|
||||
const char * command_line[] = {python_executable, "-c", py_script.c_str(), tmpl_json.c_str(), vars_json.c_str(), NULL};
|
||||
const char * command_line[] = {python_executable, "-c", py_script.c_str(), NULL};
|
||||
|
||||
struct subprocess_s subprocess;
|
||||
int options = subprocess_option_combined_stdout_stderr
|
||||
@@ -1944,6 +1946,20 @@ static void test_template_py(testing & t, const std::string & name, const std::s
|
||||
t.assert_true("subprocess creation", false);
|
||||
return;
|
||||
}
|
||||
FILE * p_stdin = subprocess_stdin(&subprocess);
|
||||
|
||||
// Write input
|
||||
std::string input = merged.dump();
|
||||
auto written = fwrite(input.c_str(), 1, input.size(), p_stdin);
|
||||
if (written != input.size()) {
|
||||
t.log("Failed to write complete input to subprocess stdin");
|
||||
t.assert_true("subprocess stdin write", false);
|
||||
subprocess_destroy(&subprocess);
|
||||
return;
|
||||
}
|
||||
fflush(p_stdin);
|
||||
fclose(p_stdin); // Close stdin to signal EOF to the Python process
|
||||
subprocess.stdin_file = nullptr;
|
||||
|
||||
// Read output
|
||||
std::string output;
|
||||
|
||||
Binary file not shown.
@@ -57,7 +57,6 @@
|
||||
// Handle ?q= parameter - create new conversation and send message
|
||||
if (qParam !== null) {
|
||||
await conversationsStore.createConversation();
|
||||
await chatStore.sendMessage(qParam);
|
||||
clearUrlParams();
|
||||
} else if (modelParam || newChatParam === 'true') {
|
||||
clearUrlParams();
|
||||
|
||||
Reference in New Issue
Block a user