Compare commits

...

10 Commits

Author SHA1 Message Date
Piotr Wilkin (ilintar) 09cedfd699 chat: harden caps check (#24973) 2026-06-25 02:49:22 +02:00
Max Krasnyansky 8be759e6f7 hexagon: MUL_MAT and MUL_MAT_ID rework : 32x32 tiled weight repack, kernel-params, cached graphs (#24954)
* hex-mm: new weight layout and fusion updates

* hvx-mm: unroll the new tiled vec_dots to optimize hvx register util

* hex-mm: optimize dyn.quant format for q8_0 and q8_1 to reduce overhead in vec_dots.

* hvx-mm: parallel quantizer per block for large rows

* hvx-mm: simplify and futher optimize dyn.quant and vec_dots

* hvx-mm: keep intermediate per tile accumulators in fp16

* hmx-mm: optimize weight dequant by aligning the repacked tiles with the DMA

* hmx-mm: remove qweight scratch and just use vtcm_weight

* hmx-mm: remove all unused and obsolete code

* hmx-mm: the new tiled repack format is here to stay -- rename all x4x2 to _tiled

* hmx-mm: improve activation processing with dma prefetch

* hex-mm: fix hmx/hvx fallback logic and MUL_MAT_ID allocation (unbreaks OLMoE)

* hex-mm: align the weight tiles with dma just like we did in hmx-mm

* hex-mm: factor out common mm bits into htp/matmul-ops.h

* hex-mm: start moving mm kernel selection to the host

* hex-mm: move all of the matmul param compute into the host

* hmx-mm: restore pipelined mode

* hmx-mm: unroll the dequant functions to optimize register usage

* hmx-mm: further improve activation process

* hex-mm: use vtcm_seq_alloc for all vtcm allocations and define more common functions

* hex-mm: improve mm optimizer to acount for number of activation threads

* hex-mm: fix matmul-id kernel params selection (unbreaks OLMoE and LFM)

* hexagon: remove support for arch < v73 since HMX is now required for most use-cases

* hex-mm: cleanup naming for consistency

* hex-mm: make sure matmul fusion accounts for vtcm allocation

* hex-mm: minor cleanup for kernel_params definition

* hex-mm: replace hardcoded limits with proper checks for vtcm requirements

* hex-mm: add support for non-tiled mm as a fallback option and factor out hvx kernels into separate header

* hex-mm: remove unused functions

* hex-mm: add shorthand for MM_SELECT in run-tool script

* hvx-mm: factor out hvx/hmx microkernels and unify matmul entry and dispatch

* hex-mm: further cleanup matmul fallback path

* hex-mm: refactor matmul entry point and dispatch a bit further

* hexagon: update cmake build to enable hmx for everything

* hex-ops: optimize kernel_param updates and include summary in the logs

* hex-mm: add support for GGML_HEXAGON_MM_SELECT

* hex-mm: add hex-common header

* hex-mm: pass correct number of tasks to workpool

* hex-mm: add proper checks for no-work in dyn.quant tasks

* hex-mm: convert all quantizers into a macro

* hex-mm: fix hvx-flat fallback to pass all MUL_MAT tests

* hex-mm: vectorize q8_1 quantizer

* hex-mm: improve fused ffn mm stride handling

* hex-mm: consistent use of n_threads and pipeline in kernel_params

* hexagon: minor formatting

* hex-mm: update MUL_MAT_ID kernel_param handling to make sure host/npu are in sync

* hvx-mm: go back to accumulating in fp32 in tiled hvx kernels, more accurate and same perf

* hvx-mm: unroll the loops and remove masking that is not needed for tiled accums

* hmx-mm: optimize activation processing (slit loops, some unrolling, etc)

* hmx-mm: minor optimization for output processing

* hex-mm: consistent use of uint32_t and size_t in mm kernels

* hex-mm: remove legacy restrictions for rows to be multiple of 256

* hexagon: replace sprintf with snprintf

* hex-mm: relax hardcoded nrows checks and rely on VTCM size requirements

* hexagon: minor alignment fix

* hexagon: fix trailing spaces

* hex-mm: relax padding from 256 to 128 (leftovers)

* hex-mm: remove redundant checks for weight align to 128

we always use 2D dma for the weights and align them properly

* hmx-mm: MUL_MAT_ID better work distribution between hvx threads and hmx tracing

* hex-mm: specialize per-token mmid activation handling

* hex-profile: update python scripts to handle kernel-params section in the logging output

* hex-mm: move n_prefetch (aka dma_depth) into kernel params and remove unused fields

* hex-trace: use easier to parse format, simply and fix post-proc scripts

* hmx-mm: relax 32 row limit for output processing which helps utilization

* hmx-mm: use start-chunk idx for tracing info

* hmx-mm: parameterize activation dma pipeline

* hexagon: add support for simple graph caching to avoid recomputing kernel-params

* hex-mm: remove left-over repack functions

* hex-mm: tighten n_prefetch asserts

* hex-mm: remove duplicate round/align_up helper

* hexagon: cleanup common header used in host/npu

* hexagon: update early wakeup threshold

* hmx-mm: define cost constants and update solver to assume that repacked ne[1] is padded to 32

* hmx-mm: make precompute_matmul a bit more readable (split into smaller functions, etc)

* hex-mm: remove n_threads constraint

* hex-mm: minor formatting updates

* hex-mm: remove obsolete profiling logs

* hex-mm: restore hardcode gate to refuse lm-head to avoid repacking that tensor
2026-06-24 12:14:25 -07:00
Saba Fallah 894bb27af3 mtmd: model: unlimited-ocr: converter + parity test (#24969) 2026-06-24 18:20:22 +02:00
Xuan-Son Nguyen fb401045cc common: remove unused json-partial (#24968) 2026-06-24 18:12:16 +02:00
Wagner Bruna 51eae8cfca vulkan: allow reducing the graph submission batches to avoid timeouts (#24872) 2026-06-24 16:29:24 +02:00
liminfei-amd 1191758c5d vulkan: fail the build when a shader fails to compile (#24450)
* vulkan-shaders-gen: fail the build when a shader fails to compile

vulkan-shaders-gen did not detect shader-compile subprocess failures, so a
broken libggml-vulkan could be produced while the build reported success and
the breakage only surfaced at run time. execute_command() discarded the child
exit code (POSIX waitpid passed nullptr for status; the Windows branch never
called GetExitCodeProcess) and string_to_spv decided success only from whether
stderr was empty, so a non-zero exit with empty stderr, or a subprocess that
failed to launch, was treated as success.

Return the child exit code from execute_command() (WEXITSTATUS on POSIX,
GetExitCodeProcess on Windows), treat a non-zero exit or non-empty stderr or a
launch exception as a failure, and record it in an atomic flag. main() checks
the flag after process_shaders() and returns EXIT_FAILURE before writing the
output files, so the build stops instead of emitting a broken backend.

Fixes #24393

Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com>

* vulkan-shaders-gen: simplify compile_failed access and drop unreachable return

Address review feedback on #24450:
- Access the std::atomic<bool> compile_failed directly (= / implicit bool)
  instead of .store()/.load(); the flag stays atomic because the worker
  threads in process_shaders() set it concurrently.
- Remove the unreachable trailing return -1 in execute_command(): on POSIX the
  child _exit()s after execvp and the parent returns (fork()<0 throws); on
  Windows the block returns the exit code.

Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com>

---------

Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com>
2026-06-24 11:42:03 +02:00
Pascal 00139b660b ui: loading bar below the model picker (#24931)
* ui: show model load progress on the selector trigger

Mirror the in-dropdown stage progress as a thin bar on the selector
trigger, so the active model's load percent stays visible when the menu
is closed. Same status gating and composite fraction as the dropdown
row, so both bars track the selected model in sync.

Suggested-by: Julien Chaumond <@julien-c>

* ui: show model load progress bar on the in-conversation model selector

* ui: tune model load indicator to a pulsing highlight (suggested by @ngxson)

Also wire the indicator onto the mobile sheet trigger, which was missing
it since mobile uses the sheet instead of the dropdown.

* ui: thin (@allozaur) pulsating (@ngxson) model load bar
2026-06-24 10:50:44 +02:00
Aleksander Grygier ef9c13d4c2 ui: New Logo + Navigation cleanup & Mobile UI/UX improvements (#24897)
* chore: `npm audit fix --force`

* feat: Update sidebar toggle to use Logo

* refactor: Clean up favicon SVG

* feat: Refactor logo component and implement theme-aware favicon generation

* feat: Add configurable padding to generated PWA assets

* test: Add unit tests for writeThemeFavicons

* refactor: Componentization

* feat: WIP

* feat: WIP

* feat: WIP

* feat: Mobile UI

* feat: add SEARCH route constant

* feat: create SidebarNavigationSearchResults component

* refactor: use SidebarNavigationSearchResults in conversation list

* feat: enable mobile search navigation in sidebar actions

* feat: add mobile search route and page

* fix: prevent sidebar overflow on mobile viewports

* fix: Mobile sidebar

* feat: Mobile Search WIP

* feat: Mobile WIP

* feat: Add PWA standalone detection and refine mobile UI

* feat: Improve mobile layout, sidebar handling, and chat scrolling

* feat: Improve mobile sidebar visibility and iOS Safari chat spacing

* fix: Disable auto-scroll on mobile

* chore: Linting

* fix: Wrong condition

* feat: Mobile chat scroll

* refactor: WIP

* fix: Desktop initial scroll always working again

* fix: Partial fix for mobile auto-scroll / initial scroll

* fix: Desktop auto-scroll on initial load and during streaming

* fix: Mobile scrolling logic

* refactor: Clean up

* feat: Improve start UI

* feat: Add `delay` to `fadeInView`

* feat: Auto-scroll button

* refactor: Cleanup

* refactor: Extract chat dialogs and alerts into dedicated component

* refactor: Reorganize ChatScreen component structure and initialization

* feat: Improve auto-scroll after sending message

* feat: UI improvements

* fix: Settings link

* feat: UI improvements

* fix: better UI spacing

* fix: Remove unneeded logic

* fix: Chat Processing Info UI rendering

* feat: Improve mobile UI

* feat: UI improvement

* fix: Conditional transition delay for Chat Messages based on route from

* fix: Delay mobile sidebar collapse for smoother transitions

* fix: Mobile scroll down button + sidebar pointer events

* fix: Mobile UI

* fix: Auto scrolling

* fix: Implement dynamic height calculations for chat auto-scroll positioning and UI elements

* fix: Retrieve `autofocus` for Chat Form textarea

* fix: Use proper class

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* refactor: extract scroll-to-bottom logic and fix message send flow

* fix: update viewport store usage and remove conflicting autofocus

* feat: add accessibility labels to scroll down button

* fix: correct HTML structure in sidebar empty states

* fix: dynamically toggle processing info visibility

* chore: remove commented exports and fix formatting

* fix

* fix: Mobile Chat Form Add Action Sheet interactions

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-06-24 10:21:33 +02:00
Tarek Dakhran 88636e178f model : Add LFM2.5-ColBERT-350M and LFM2.5-Embedding-350M (#24913)
* model : Add LFM2.5-ColBERT-350M and LFM2.5-Embedding-350M

* Restore LFM2 models in README.md
2026-06-24 09:49:46 +03:00
Jeff Bolz ac4105d68b vulkan: Apply bias before softmax in FA, to avoid overflow (#24909) 2026-06-23 22:34:00 -05:00
140 changed files with 10894 additions and 10674 deletions
+3 -1
View File
@@ -142,7 +142,9 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
- [x] [Liquid LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2)
- [x] [Liquid LFM2.5 models](https://huggingface.co/collections/LiquidAI/lfm25)
- [x] [Liquid Nanos](https://huggingface.co/collections/LiquidAI/liquid-nanos)
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
- [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86)
- [x] [Mellum models](https://huggingface.co/JetBrains/models?search=mellum)
-2
View File
@@ -80,8 +80,6 @@ add_library(${TARGET}
http.h
imatrix-loader.cpp
imatrix-loader.h
json-partial.cpp
json-partial.h
json-schema-to-grammar.cpp
llguidance.cpp
log.cpp
+4
View File
@@ -2758,5 +2758,9 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_templates * chat_templates) {
GGML_ASSERT(chat_templates != nullptr);
GGML_ASSERT(chat_templates->template_default != nullptr);
if (chat_templates->template_tool_use != nullptr) {
// take the more expressive template when available
return chat_templates->template_tool_use->caps.to_map();
}
return chat_templates->template_default->caps.to_map();
}
-324
View File
@@ -1,324 +0,0 @@
#include "json-partial.h"
#include "log.h"
#include <nlohmann/json.hpp>
#include <string>
#include <regex>
using json = nlohmann::ordered_json;
enum common_json_stack_element_type {
COMMON_JSON_STACK_ELEMENT_OBJECT,
COMMON_JSON_STACK_ELEMENT_KEY,
COMMON_JSON_STACK_ELEMENT_ARRAY,
};
struct common_json_stack_element {
common_json_stack_element_type type;
std::string key;
};
bool common_json_parse(
const std::string & input,
const std::string & healing_marker,
common_json & out)
{
std::string::const_iterator it = input.begin();
const auto end = input.end();
return common_json_parse(it, end, healing_marker, out);
}
bool common_json_parse(
std::string::const_iterator & it,
const std::string::const_iterator & end,
const std::string & healing_marker,
common_json & out)
{
// // https://json.nlohmann.me/features/parsing/sax_interface/
struct json_error_locator : public nlohmann::json_sax<json> {
std::size_t position;
bool found_error;
std::string last_token;
std::string exception_message;
std::vector<common_json_stack_element> stack;
json_error_locator() : position(0), found_error(false) {}
bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
this->position = position - 1;
this->found_error = true;
this->last_token = last_token;
this->exception_message = ex.what();
return false;
}
void close_value() {
if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
stack.pop_back();
}
}
bool null() override { // NOLINT
close_value();
return true;
}
bool boolean(bool) override { // NOLINT
close_value();
return true;
}
bool number_integer(number_integer_t) override { // NOLINT
close_value();
return true;
}
bool number_unsigned(number_unsigned_t) override { // NOLINT
close_value();
return true;
}
bool number_float(number_float_t, const string_t &) override { // NOLINT
close_value();
return true;
}
bool string(string_t &) override { // NOLINT
close_value();
return true;
}
bool binary(binary_t &) override { // NOLINT
close_value();
return true;
}
bool start_object(std::size_t) override { // NOLINT
stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
return true;
}
bool end_object() override {
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
stack.pop_back();
close_value();
return true;
}
bool key(string_t & key) override { // NOLINT
stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
return true;
}
bool start_array(std::size_t) override { // NOLINT
stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
return true;
}
bool end_array() override {
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
stack.pop_back();
close_value();
return true;
}
};
json_error_locator err_loc;
auto start = it;
json::sax_parse(it, end, &err_loc);
if (err_loc.found_error) {
it = start;
auto temptative_end = it + err_loc.position;
// LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
auto input = std::string(it, temptative_end);
try {
out.json = json::parse(input);
// out.json = json::parse(it, temptative_end);
it = temptative_end;
return true;
} catch (const std::exception & ex) {
// No, needs healing.
LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
}
auto can_parse = [](const std::string & str) {
try {
auto _ = json::parse(str); // NOLINT
return true;
} catch (const std::exception &) {
return false;
}
};
if (!healing_marker.empty() && !err_loc.stack.empty()) {
std::string str(it, temptative_end);
auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
if (last_non_sp_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
}
auto last_non_sp_char = str[last_non_sp_pos];
// Used to detect stops on a number, which may not be complete.
auto was_maybe_number = [&]() {
if (!str.empty() && std::isspace(str.back())) {
return false;
}
return std::isdigit(last_non_sp_char) ||
last_non_sp_char == '.' ||
last_non_sp_char == 'e' ||
last_non_sp_char == 'E' ||
last_non_sp_char == '-';
};
std::string closing;
for (size_t i = err_loc.stack.size(); i > 0; i--) {
auto & el = err_loc.stack[i - 1];
if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
closing += "}";
} else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
closing += "]";
} else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
throw std::runtime_error("Unexpected stack element type");
}
}
// Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX
static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)");
auto is_high_surrogate = [&](const std::string & s) {
// Check if a partial of a high surrogate (U+D800-U+DBFF)
return s.length() >= 4 &&
s[0] == '\\' && s[1] == 'u' &&
std::tolower(s[2]) == 'd' &&
(s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b');
};
// Initialize the unicode marker to a low surrogate to handle the edge case
// where a high surrogate (U+D800-U+DBFF) is immediately followed by a
// backslash (\)
std::string unicode_marker_padding = "udc00";
std::smatch last_unicode_seq;
if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) {
std::smatch second_last_seq;
std::string prelude = str.substr(0, last_unicode_seq.position());
// Pad the escape sequence with 0s until it forms a complete sequence of 6 characters
unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0');
if (is_high_surrogate(last_unicode_seq.str())) {
// If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF)
unicode_marker_padding += "\\udc00";
} else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) {
if (is_high_surrogate(second_last_seq.str())) {
// If this follows a high surrogate, pad it to be a low surrogate
if (last_unicode_seq.length() == 2) {
unicode_marker_padding = "dc00";
} else if (last_unicode_seq.length() == 3) {
unicode_marker_padding = "c00";
} else {
// The original unicode_marker_padding is already padded with 0s
}
}
}
}
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
// We're inside an object value
if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
// Was about to create an object value
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
} else if (can_parse(str + ": 1" + closing)) {
str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
} else if (last_non_sp_char == '{' && can_parse(str + closing)) {
// Was about to create an object
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
} else if (can_parse(str + "\"" + closing)) {
// Was inside an object value string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
// Was inside an object value string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
// Was inside an object value string after a partial unicode escape
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
} else {
// find last :
auto last_pos = str.find_last_of(':');
if (last_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
}
// Cutting back to opening : for object value
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
}
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
// Was about to create an array value
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
} else if (can_parse(str + "\"" + closing)) {
// Was inside an array value string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
// Was inside an array value string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
// Was inside an array value string after a partial unicode escape
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
// Had just finished a value
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
} else {
auto last_pos = str.find_last_of("[,");
if (last_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
}
// Cutting back to last [ or , for array value
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
}
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
(last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
// Was about to create an object key+value
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
} else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
// Was about to create an object key+value
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
} else if (can_parse(str + "\": 1" + closing)) {
// Was inside an object key string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
// Was inside an object key string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
} else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) {
// Was inside an object key string after a partial unicode escape
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing;
} else {
auto last_pos = str.find_last_of(':');
if (last_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
}
// fprintf(stderr, "Cutting back to last : for object key+value\n");
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
}
} else {
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
}
// fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
out.json = json::parse(str);
it = temptative_end;
return true;
}
// handle unclosed top-level primitive
if (err_loc.position != 0 && !healing_marker.empty() && err_loc.stack.empty()) {
std::string str(it, temptative_end);
const auto & magic_seed = out.healing_marker.marker = healing_marker;
if (can_parse(str + "\"")) {
// Was inside an string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"";
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"")) {
// Was inside an string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"";
} else {
// TODO: handle more unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
// fprintf(stderr, "Closing: TODO\n");
return false;
}
out.json = json::parse(str);
it = temptative_end;
return true;
}
return false;
}
out.json = json::parse(it, end);
it = end;
return true;
}
-39
View File
@@ -1,39 +0,0 @@
#pragma once
// TODO: use json_fwd.hpp when possible
#include <nlohmann/json.hpp>
// Healing marker (empty if the JSON was fully parsed / wasn't healed).
struct common_healing_marker {
// Raw marker.
std::string marker;
// Cutting the `common_json.json.dump()` string at the (only) occurrence of this marker should yield the original partial JSON string (modulo spaces / if it had the same dump format).
std::string json_dump_marker;
};
// Represents a parsed JSON object, with its optional healing marker (a JSON dump fragment that can be used to find the position of healing in the JSON dump string)
struct common_json {
nlohmann::ordered_json json;
common_healing_marker healing_marker;
};
// Parse the JSON string, healing (closing) any partial JSON if `healing_marker` is not empty.
//
// Healing completes partial JSON strings by adding a (possibly modified) healing marker, then whatever is needed to close the JSON.
// This allows to parse the resulting healed JSON string, yet be able to cut it again if needed at the healing marker.
// (this is used when parsing JSON outputs from the models, then crafting partial JSONs for the partial tool calls in OAI format).
//
// For instance, parsing `{` with a healing marker `foo` will produce a healed JSON `{"foo":1}`, w/ json_dump_marker = `"foo"` (which can be used to break the JSON again).
bool common_json_parse(
const std::string & input,
const std::string & healing_marker,
common_json & out);
// Parse the JSON string (see overload above), but advancing an iterator to the end of the input when the (potentially partial) parsing succeeds.
bool common_json_parse(
std::string::const_iterator & it,
const std::string::const_iterator & end,
const std::string & healing_marker,
common_json & out);
+4
View File
@@ -46,6 +46,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
"DbrxForCausalLM": "dbrx",
"DeciLMForCausalLM": "deci",
"DeepseekForCausalLM": "deepseek",
"DeepseekOCRForCausalLM": "deepseek",
"DeepseekV2ForCausalLM": "deepseek",
"DeepseekV3ForCausalLM": "deepseek",
"DeepseekV32ForCausalLM": "deepseek",
@@ -124,6 +125,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
"LLaDAModelLM": "llada",
"LLaMAForCausalLM": "llama",
"Lfm25AudioTokenizer": "lfm2",
"Lfm2BidirectionalModel": "lfm2",
"Lfm2ForCausalLM": "lfm2",
"Lfm2Model": "lfm2",
"Lfm2MoeForCausalLM": "lfm2",
@@ -232,6 +234,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
"UMT5ForConditionalGeneration": "t5",
"UMT5Model": "t5",
"UltravoxModel": "ultravox",
"UnlimitedOCRForCausalLM": "deepseek",
"VLlama3ForCausalLM": "llama",
"VoxtralForConditionalGeneration": "llama",
"WavTokenizerDec": "wavtokenizer",
@@ -298,6 +301,7 @@ MMPROJ_MODEL_MAP: dict[str, str] = {
"StepVLForConditionalGeneration": "step3",
"Step3p7ForConditionalGeneration": "step3",
"UltravoxModel": "ultravox",
"UnlimitedOCRForCausalLM": "deepseek",
"VoxtralForConditionalGeneration": "ultravox",
"YoutuVLForConditionalGeneration": "youtuvl",
}
+10 -2
View File
@@ -14,7 +14,7 @@ from .base import MmprojModel, ModelBase, TextModel, gguf, logger
from .qwen import QwenModel
@ModelBase.register("DeepseekOCRForCausalLM")
@ModelBase.register("DeepseekOCRForCausalLM", "UnlimitedOCRForCausalLM")
class DeepseekOCRVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -205,6 +205,8 @@ class DeepseekModel(TextModel):
@ModelBase.register(
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekOCRForCausalLM",
"UnlimitedOCRForCausalLM",
"KimiVLForConditionalGeneration",
"KimiK25ForConditionalGeneration",
"YoutuForCausalLM",
@@ -224,7 +226,7 @@ class DeepseekV2Model(TextModel):
self.origin_hf_arch = hparams.get('architectures', [None])[0]
# special handling for Deepseek OCR
if self.origin_hf_arch in ("DeepseekOCRForCausalLM", "DeepseekOCR2ForCausalLM"):
if self.origin_hf_arch in ("DeepseekOCRForCausalLM", "DeepseekOCR2ForCausalLM", "UnlimitedOCRForCausalLM"):
self.model_arch = gguf.MODEL_ARCH.DEEPSEEK2OCR
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
self.gguf_writer.add_architecture()
@@ -350,6 +352,12 @@ class DeepseekV2Model(TextModel):
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
# Unlimited-OCR sliding window; written for metadata, the decoder ignores it (full MHA)
if is_ocr:
sliding_window = hparams.get("sliding_window_size") or hparams.get("sliding_window")
if sliding_window:
self.gguf_writer.add_sliding_window(sliding_window)
if (rope_mscale_all := self.rope_parameters.get("mscale_all_dim")) is not None:
# [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
# note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul
+10 -3
View File
@@ -64,11 +64,17 @@ class LFM2Model(TextModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Lfm2Model")
@ModelBase.register("Lfm2Model", "Lfm2BidirectionalModel")
class LFM2ColBertModel(LFM2Model):
model_arch = gguf.MODEL_ARCH.LFM2
dense_tensor_name = "dense_2"
def set_gguf_parameters(self):
super().set_gguf_parameters()
if self.hf_arch == "Lfm2BidirectionalModel":
self.gguf_writer.add_causal_attention(False)
self._try_set_pooling_type()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if not name.startswith(self.dense_tensor_name):
name = "model." + name
@@ -76,10 +82,11 @@ class LFM2ColBertModel(LFM2Model):
yield from super().modify_tensors(data_torch, name, bid)
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
# dense tensor is stored in a separate safetensors file
# optional dense tensor is stored in a separate safetensors file
from safetensors.torch import load_file
tensors_file = self.dir_model / "1_Dense" / "model.safetensors"
assert tensors_file.is_file()
if not tensors_file.is_file():
return
tensor = load_file(tensors_file)["linear.weight"]
self.gguf_writer.add_embedding_length_out(tensor.shape[0])
yield f"{self.dense_tensor_name}.weight", tensor.clone()
@@ -24,7 +24,6 @@
"GGML_LLAMAFILE": "OFF",
"GGML_OPENCL": "ON",
"GGML_HEXAGON": "ON",
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
"LLAMA_OPENSSL": "OFF"
}
},
@@ -47,7 +46,6 @@
"GGML_LLAMAFILE": "OFF",
"GGML_OPENCL": "ON",
"GGML_HEXAGON": "ON",
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
"LLAMA_OPENSSL": "OFF"
}
},
@@ -73,7 +71,6 @@
"GGML_LLAMAFILE": "OFF",
"GGML_OPENCL": "OFF",
"GGML_HEXAGON": "ON",
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
"LLAMA_OPENSSL": "OFF"
}
},
-1
View File
@@ -266,7 +266,6 @@ set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING
"ggml: OpenCL API version to target")
option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF)
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml: quantize group size (32, 64, or 128)")
# toolchain for vulkan-shaders-gen
set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")
-4
View File
@@ -25,7 +25,6 @@ include(ExternalProject)
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF)
set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)")
add_library(htp_iface OBJECT
${CMAKE_CURRENT_BINARY_DIR}/htp_iface_stub.c)
@@ -72,15 +71,12 @@ function(build_htp_skel V)
-DHEXAGON_SDK_ROOT=${HEXAGON_SDK_ROOT}
-DHEXAGON_TOOLS_ROOT=${HEXAGON_TOOLS_ROOT}
-DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}
-DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}
-DDSP_VERSION=${V}
-DPREBUILT_LIB_DIR="toolv19_${V}")
list(APPEND HTP_SKELS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so)
set(HTP_SKELS ${HTP_SKELS} PARENT_SCOPE)
endfunction()
build_htp_skel(v68)
build_htp_skel(v69)
build_htp_skel(v73)
build_htp_skel(v75)
build_htp_skel(v79)
File diff suppressed because it is too large Load Diff
+162 -56
View File
@@ -5,10 +5,12 @@
#include "ggml-backend-impl.h"
#include "ggml-common.h"
#include <algorithm>
#include <string>
#include <vector>
#include <stdio.h>
#include "htp-ops.h"
#include "htp/matmul-ops.h"
struct htp_opnode {
ggml_tensor * node = nullptr;
@@ -17,6 +19,13 @@ struct htp_opnode {
htp_op_code opcode = HTP_OP_INVALID;
std::vector<ggml_tensor *> extra_dsts;
int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS] = {0};
htp_opnode(ggml_tensor * node = nullptr, std::vector<ggml_tensor *> fused = {}, htp_op_code opcode = HTP_OP_INVALID, std::vector<ggml_tensor *> extra_dsts = {})
: node(node), fused(std::move(fused)), opcode(opcode), extra_dsts(std::move(extra_dsts)) {}
ggml_op op() const {
return node->op;
}
@@ -25,6 +34,26 @@ struct htp_opnode {
return fused.empty() ? node : fused.back();
}
void add_fused(ggml_tensor * t, bool extra_dst = false) {
fused.push_back(t);
if (extra_dst) {
extra_dsts.push_back(t);
}
}
std::vector<const ggml_tensor *> get_outputs() const {
std::vector<const ggml_tensor *> res;
if (extra_dsts.empty()) {
res.push_back(dst());
} else {
res.push_back(node);
for (const auto * x : extra_dsts) {
res.push_back(x);
}
}
return res;
}
const ggml_tensor * src0() const {
return node->src[0];
}
@@ -37,10 +66,6 @@ struct htp_opnode {
return ggml_op_is_empty(node->op);
}
void add_fused(ggml_tensor * t) {
fused.push_back(t);
}
bool stackable() const {
switch (this->op()) {
case GGML_OP_MUL_MAT:
@@ -131,87 +156,117 @@ struct htp_opformat {
char types[16 * GGML_MAX_SRC];
char buffs[64 * GGML_MAX_SRC];
char names[64 * GGML_MAX_SRC];
char kparams[128];
int format_tensor_dims(char * str, const struct ggml_tensor * t) {
int format_tensor_dims(char * str, size_t max_size, const struct ggml_tensor * t) {
if (!t) {
return sprintf(str, "NONE");
return snprintf(str, max_size, "NONE");
}
if (t->ne[2] == 1 && t->ne[3] == 1) {
return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]);
return snprintf(str, max_size, "%d:%d", (int) t->ne[0], (int) t->ne[1]);
} else {
return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
return snprintf(str, max_size, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
}
}
void format_op_dims(char * str, const htp_opnode & node) {
void format_op_dims(char * str, size_t max_size, const htp_opnode & node) {
char * p = str;
char * p_end = str + max_size;
auto inputs = node.get_inputs();
if (!inputs.empty()) {
p += format_tensor_dims(p, inputs[0]);
p += std::min((size_t)format_tensor_dims(p, p_end - p, inputs[0]), (size_t)(p_end - p));
for (size_t i = 1; i < inputs.size(); i++) {
p += sprintf(p, " x ");
p += format_tensor_dims(p, inputs[i]);
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
}
if (p < p_end) {
p += std::min((size_t)format_tensor_dims(p, p_end - p, inputs[i]), (size_t)(p_end - p));
}
}
p += sprintf(p, " -> ");
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
}
}
char self[64];
format_tensor_dims(self, node.dst());
p += sprintf(p, "%s", self);
format_tensor_dims(self, sizeof(self), node.dst());
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", self), (size_t)(p_end - p));
}
}
int format_tensor_strides(char * str, const struct ggml_tensor * t) {
int format_tensor_strides(char * str, size_t max_size, const struct ggml_tensor * t) {
if (!t) {
return sprintf(str, "NONE");
return snprintf(str, max_size, "NONE");
}
const char * c = ggml_is_contiguous(t) ? "" : "!";
if (t->ne[2] == 1 && t->ne[3] == 1) {
return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c);
return snprintf(str, max_size, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c);
} else {
return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c);
return snprintf(str, max_size, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c);
}
}
void format_op_strides(char * str, const htp_opnode & node) {
void format_op_strides(char * str, size_t max_size, const htp_opnode & node) {
char * p = str;
char * p_end = str + max_size;
auto inputs = node.get_inputs();
if (!inputs.empty()) {
p += format_tensor_strides(p, inputs[0]);
p += std::min((size_t)format_tensor_strides(p, p_end - p, inputs[0]), (size_t)(p_end - p));
for (size_t i = 1; i < inputs.size(); i++) {
p += sprintf(p, " x ");
p += format_tensor_strides(p, inputs[i]);
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
}
if (p < p_end) {
p += std::min((size_t)format_tensor_strides(p, p_end - p, inputs[i]), (size_t)(p_end - p));
}
}
p += sprintf(p, " -> ");
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
}
}
char self[64];
format_tensor_strides(self, node.dst());
p += sprintf(p, "%s", self);
format_tensor_strides(self, sizeof(self), node.dst());
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", self), (size_t)(p_end - p));
}
}
void format_op_types(char * str, const htp_opnode & node) {
void format_op_types(char * str, size_t max_size, const htp_opnode & node) {
char * p = str;
char * p_end = str + max_size;
auto inputs = node.get_inputs();
if (!inputs.empty()) {
p += sprintf(p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE");
for (size_t i = 1; i < inputs.size(); i++) {
p += sprintf(p, " x ");
p += sprintf(p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE");
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE"), (size_t)(p_end - p));
}
p += sprintf(p, " -> ");
for (size_t i = 1; i < inputs.size(); i++) {
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE"), (size_t)(p_end - p));
}
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
}
}
p += sprintf(p, "%s", ggml_type_name(node.dst()->type));
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", ggml_type_name(node.dst()->type)), (size_t)(p_end - p));
}
}
const char * tensor_buff_name(const struct ggml_tensor * t) {
@@ -221,51 +276,102 @@ struct htp_opformat {
return "NONE";
}
void format_op_buffs(char * str, const htp_opnode & node) {
void format_op_buffs(char * str, size_t max_size, const htp_opnode & node) {
char * p = str;
char * p_end = str + max_size;
auto inputs = node.get_inputs();
if (!inputs.empty()) {
p += sprintf(p, "%s", tensor_buff_name(inputs[0]));
for (size_t i = 1; i < inputs.size(); i++) {
p += sprintf(p, " x ");
p += sprintf(p, "%s", tensor_buff_name(inputs[i]));
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(inputs[0])), (size_t)(p_end - p));
}
p += sprintf(p, " -> ");
for (size_t i = 1; i < inputs.size(); i++) {
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(inputs[i])), (size_t)(p_end - p));
}
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
}
}
p += sprintf(p, "%s", tensor_buff_name(node.dst()));
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(node.dst())), (size_t)(p_end - p));
}
}
void format_op_names(char * str, const htp_opnode & node) {
void format_op_names(char * str, size_t max_size, const htp_opnode & node) {
char * p = str;
char * p_end = str + max_size;
auto inputs = node.get_inputs();
if (!inputs.empty()) {
p += sprintf(p, "%s", inputs[0] ? inputs[0]->name : "NONE");
for (size_t i = 1; i < inputs.size(); i++) {
p += sprintf(p, " x ");
p += sprintf(p, "%s", inputs[i] ? inputs[i]->name : "NONE");
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[0] ? inputs[0]->name : "NONE"), (size_t)(p_end - p));
}
p += sprintf(p, " -> ");
for (size_t i = 1; i < inputs.size(); i++) {
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[i] ? inputs[i]->name : "NONE"), (size_t)(p_end - p));
}
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
}
}
p += sprintf(p, "%s", node.dst()->name);
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", node.dst()->name), (size_t)(p_end - p));
}
}
void format_kernel_params(char * str, size_t max_size, const htp_opnode & node) {
if (node.opcode == HTP_OP_MUL_MAT || node.opcode == HTP_OP_MUL_MAT_ID ||
node.opcode == HTP_OP_MUL_MAT_QKV || node.opcode == HTP_OP_MUL_MAT_FFN) {
const auto * kparams = (const struct htp_mm_kernel_params *) node.kernel_params;
const char * path = "unknown";
int32_t type = kparams->kernel_type;
if (type == HTP_MM_KERNEL_HMX_2D || type == HTP_MM_KERNEL_HMX_F16_BATCHED) {
path = "hmx-tiled";
} else if (type == HTP_MM_KERNEL_HVX_F16_F16_VTCM || type == HTP_MM_KERNEL_HVX_F32_F32_VTCM ||
type == HTP_MM_KERNEL_HVX_QUANT_ROW || type == HTP_MM_KERNEL_HVX_QUANT_BLOCK) {
path = "hvx-tiled";
} else if (type == HTP_MM_KERNEL_HVX_F16_F16_DDR || type == HTP_MM_KERNEL_HVX_F16_F32_DDR ||
type == HTP_MM_KERNEL_HVX_F32_F32_DDR || type == HTP_MM_KERNEL_HVX_F32_F16_DDR ||
type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) {
path = "hvx-flat";
}
snprintf(str, max_size, "%s vtcm %d", path, (int) kparams->vtcm_size);
} else {
snprintf(str, max_size, "----");
}
}
void format(const htp_opnode & node) {
format_op_dims(dims, node);
format_op_strides(strides, node);
format_op_types(types, node);
format_op_buffs(buffs, node);
format_op_names(names, node);
format_op_dims(dims, sizeof(dims), node);
format_op_strides(strides, sizeof(strides), node);
format_op_types(types, sizeof(types), node);
format_op_buffs(buffs, sizeof(buffs), node);
format_op_names(names, sizeof(names), node);
format_kernel_params(kparams, sizeof(kparams), node);
}
htp_opformat() {}
htp_opformat() {
strides[0] = '\0';
dims[0] = '\0';
types[0] = '\0';
buffs[0] = '\0';
names[0] = '\0';
kparams[0] = '\0';
}
htp_opformat(const htp_opnode & node) { format(node); }
};
+14 -38
View File
@@ -19,43 +19,9 @@ add_library(${HTP_LIB} SHARED
htp_iface_skel.c
worker-pool.c
hex-dma.c
)
target_compile_definitions(${HTP_LIB} PRIVATE
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
if (GGML_HEXAGON_FA_EXP2_HF)
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
endif()
# HMX acceleration: available on v73+ architectures
set(HTP_HMX_VERSIONS v73 v75 v79 v81)
list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
if (_hmx_idx GREATER_EQUAL 0)
target_sources(${HTP_LIB} PRIVATE
hmx-flash-attn-ops.c
hmx-matmul-ops.c
hmx-queue.c
)
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
set_source_files_properties(
hmx-flash-attn-ops.c
hmx-matmul-ops.c
hmx-queue.c
PROPERTIES COMPILE_OPTIONS "-mhmx"
)
target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1)
endif()
build_idl(htp_iface.idl ${HTP_LIB})
target_sources(${HTP_LIB} PRIVATE
hmx-queue.c
flash-attn-ops.c
hmx-flash-attn-ops.c
matmul-ops.c
binary-ops.c
unary-ops.c
@@ -63,7 +29,6 @@ target_sources(${HTP_LIB} PRIVATE
softmax-ops.c
act-ops.c
rope-ops.c
flash-attn-ops.c
set-rows-ops.c
get-rows-ops.c
cpy-ops.c
@@ -79,6 +44,17 @@ target_sources(${HTP_LIB} PRIVATE
pad-ops.c
)
target_compile_definitions(${HTP_LIB} PRIVATE
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>)
if (GGML_HEXAGON_FA_EXP2_HF)
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
endif()
build_idl(htp_iface.idl ${HTP_LIB})
set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON)
install(TARGETS ${HTP_LIB})
+13 -15
View File
@@ -3,7 +3,7 @@ if (HEXAGON_TOOLCHAIN_INCLUDED)
endif()
set(HEXAGON_TOOLCHAIN_INCLUDED true)
#Cross Compiling for Hexagon
# Cross Compiling for Hexagon
set(HEXAGON TRUE)
set(CMAKE_SYSTEM_NAME QURT)
set(CMAKE_SYSTEM_PROCESSOR Hexagon)
@@ -14,7 +14,6 @@ set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
set(CUSTOM_RUNELF_PATH "")
#To fix backward compatibility with EAI addon.
if (NOT HEXAGON_SDK_ROOT)
set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT})
endif()
@@ -31,7 +30,6 @@ endif()
file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT)
#Get the Binary extension of the Hexagon Toolchain
if(CMAKE_HOST_SYSTEM_NAME STREQUAL Windows)
set(HEXAGON_TOOLCHAIN_SUFFIX .exe)
endif()
@@ -48,12 +46,12 @@ set(CMAKE_TRY_COMPILE_PLATFORM_VARIABLES
HEXAGON_TOOLS_ROOT
)
#QURT Related includes and linker flags
# QURT Related includes and linker flags
set(V_ARCH ${HEXAGON_ARCH})
set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/ADSP${V_ARCH}MP${V_ARCH_EXTN}")
set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/compute${V_ARCH}${V_ARCH_EXTN}")
if( ${TREE} MATCHES PAKMAN )
if (${TREE} MATCHES PAKMAN)
set(_QURT_INSTALL_DIR "${QURT_IMAGE_DIR}/compute${V_ARCH}${V_ARCH_EXTN}")
endif()
message(DEBUG "_QURT_INSTALL_DIR:${_QURT_INSTALL_DIR}")
@@ -83,11 +81,9 @@ set(QURT_START_LINK_LIBS
)
STRING(REPLACE ";" " " QURT_START_LINK_LIBS "${QURT_START_LINK_LIBS}")
set(QURT_END_LINK_LIBS
${TARGET_DIR}/fini.o
)
set(QURT_END_LINK_LIBS ${TARGET_DIR}/fini.o)
#Non QURT related includes and linker flags
# Non QURT related includes and linker flags
set(TARGET_DIR_NOOS "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/${HEXAGON_ARCH}")
@@ -99,8 +95,10 @@ if (NOT NO_WRAP_MEM_API)
set(WRAP_MEMALIGN -Wl,--wrap=memalign)
endif()
set(ARCH_FLAGS "-mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -mhmx")
set(PIC_SHARED_LD_FLAGS
-mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH}
${ARCH_FLAGS}
-G0
-fpic
-Wl,-Bsymbolic
@@ -120,13 +118,13 @@ STRING(REPLACE ";" " " PIC_SHARED_LD_FLAGS "${PIC_SHARED_LD_FLAGS}")
set(HEXAGON_PIC_SHARED_LINK_OPTIONS "${PIC_SHARED_LD_FLAGS}")
#System include paths
# System include paths
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs)
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs/stddef)
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/ipc/fastrpc/incs)
#LLVM toolchain setup
#Compiler paths, options and architecture
# LLVM toolchain setup
# Compiler paths, options and architecture
set(CMAKE_C_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang${HEXAGON_TOOLCHAIN_SUFFIX})
set(CMAKE_CXX_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX})
set(CMAKE_AR ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-ar${HEXAGON_TOOLCHAIN_SUFFIX})
@@ -137,8 +135,8 @@ set(CMAKE_PREFIX_PATH ${HEXAGON_TOOLCHAIN}/Tools/target/hexagon)
set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,")
set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,")
#Compiler Options
set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
# Compiler Options
set(COMMON_FLAGS "${ARCH_FLAGS} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g")
+2 -3
View File
@@ -18,7 +18,8 @@
#include "htp-ctx.h"
#include "htp-ops.h"
#include "htp-ops.h"
#include "hmx-ops.h"
int hmx_flash_attn_ext(struct htp_ops_context * octx);
// Must be multiple of 32
#define FLASH_ATTN_BLOCK_SIZE (32 * 2)
@@ -633,7 +634,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
return HTP_STATUS_NO_SUPPORT;
}
#ifdef HTP_HAS_HMX
// HMX path: head_dim multiple of 64, F16 KV, and no sinks
if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 64 == 0 && v->ne[0] % 64 == 0 && octx->src[4] == NULL) {
int ret = hmx_flash_attn_ext(octx);
@@ -642,7 +642,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
}
// VTCM too small or other failure -> fall through to HVX path
}
#endif
struct htp_fa_context factx;
factx.octx = octx;
+80
View File
@@ -0,0 +1,80 @@
#ifndef HEX_COMMON_H
#define HEX_COMMON_H
#include <stdint.h>
#include <stddef.h>
#include <stdbool.h>
#ifndef SIZE_MAX
#define SIZE_MAX ((size_t)-1)
#endif
#ifndef MAX
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#endif
#ifndef MIN
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#endif
static inline uint32_t hex_ceil_pow2(uint32_t x) {
if (x <= 1) { return 1; }
int p = 2;
x--;
while (x >>= 1) { p <<= 1; }
return p;
}
static inline size_t hmx_ceil_div(size_t num, size_t den) {
return (num + den - 1) / den;
}
static inline int32_t hex_is_aligned(const void * addr, uint32_t align) {
return ((size_t) addr & (align - 1)) == 0;
}
static inline size_t hex_align_up(size_t v, size_t align) {
return hmx_ceil_div(v, align) * align;
}
static inline size_t hex_align_down(size_t v, size_t align) {
return (v / align) * align;
}
static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
uint32_t left_off = (size_t) addr & (chunk_size - 1);
uint32_t right_off = left_off + n;
return right_off <= chunk_size;
}
static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
return m * ((n + m - 1) / m);
}
static inline size_t hex_smin(size_t a, size_t b) {
return a < b ? a : b;
}
static inline size_t hex_smax(size_t a, size_t b) {
return a > b ? a : b;
}
static inline void hex_swap_ptr(void ** p1, void ** p2) {
void * t = *p1;
*p1 = *p2;
*p2 = t;
}
static inline bool hex_mul_overflow(size_t a, size_t b, size_t *out) {
if (a != 0 && b > SIZE_MAX / a) return true;
*out = a * b;
return false;
}
static inline bool hex_add_overflow(size_t a, size_t b, size_t *out) {
if (a > SIZE_MAX - b) return true;
*out = a + b;
return false;
}
#endif // HEX_COMMON_H
+1 -5
View File
@@ -5,6 +5,7 @@
#include <hexagon_types.h>
#include <stdbool.h>
#include <stdint.h>
#include "hex-utils.h"
#include "hex-profile.h"
@@ -127,13 +128,8 @@ static inline dma_ptr dma_make_ptr(void *dst, const void *src)
return p;
}
#if __HVX_ARCH__ < 73
static const uint32_t dma_src_l2_bypass_on = 1;
static const uint32_t dma_dst_l2_bypass_on = 0;
#else
static const uint32_t dma_src_l2_bypass_on = 1;
static const uint32_t dma_dst_l2_bypass_on = 1;
#endif
static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t size) {
if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) {
+1 -56
View File
@@ -11,14 +11,7 @@
#include "hex-fastdiv.h"
#include "hex-dump.h"
#ifndef MAX
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#endif
#ifndef MIN
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#endif
#include "hex-common.h"
static inline uint64_t hex_get_cycles() {
uint64_t cycles = 0;
@@ -32,54 +25,6 @@ static inline uint64_t hex_get_pktcnt() {
return pktcnt;
}
static inline uint32_t hex_ceil_pow2(uint32_t x) {
if (x <= 1) { return 1; }
int p = 2;
x--;
while (x >>= 1) { p <<= 1; }
return p;
}
static inline size_t hmx_ceil_div(size_t num, size_t den) {
return (num + den - 1) / den;
}
static inline int32_t hex_is_aligned(const void * addr, uint32_t align) {
return ((size_t) addr & (align - 1)) == 0;
}
static inline size_t hex_align_up(size_t v, size_t align) {
return hmx_ceil_div(v, align) * align;
}
static inline size_t hex_align_down(size_t v, size_t align) {
return (v / align) * align;
}
static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
uint32_t left_off = (size_t) addr & (chunk_size - 1);
uint32_t right_off = left_off + n;
return right_off <= chunk_size;
}
static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
return m * ((n + m - 1) / m);
}
static inline size_t hex_smin(size_t a, size_t b) {
return a < b ? a : b;
}
static inline size_t hex_smax(size_t a, size_t b) {
return a > b ? a : b;
}
static inline void hex_swap_ptr(void ** p1, void ** p2) {
void * t = *p1;
*p1 = *p2;
*p2 = t;
}
static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
Q6_l2fetch_AP((void *) p, control);
+13 -13
View File
@@ -49,7 +49,7 @@
// g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions.
// Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales
// Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax.
static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool use_pipeline) {
static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool pipeline) {
const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS);
const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK]
const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong
@@ -70,7 +70,7 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV,
+ k_dma_size * 2 // K DMA x2
+ v_dma_size * 2 // V DMA x2
+ k_tile_size * 1 // K tiles
+ v_tile_size * (use_pipeline ? 2 : 1) // V tiles (double-buffered if pipelining)
+ v_tile_size * (pipeline ? 2 : 1) // V tiles (double-buffered if pipelining)
+ s_tile_size * 2 // S + P
+ d_tile_size * 1 // D (diagonal matrix)
+ col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum
@@ -290,7 +290,7 @@ static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) =
struct hmx_fa_context {
const struct htp_ops_context * octx;
bool use_pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2
bool pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2
uint32_t n_threads;
// Op parameters
@@ -409,7 +409,7 @@ static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data)
return;
}
__fp16 * v_tiles_dest = factx->use_pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0];
__fp16 * v_tiles_dest = factx->pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0];
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
@@ -1312,13 +1312,13 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS);
const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc;
const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2);
const bool pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2);
// Bypass thread pool dispatch for small prompts/non-pipelined prefill by setting n_threads = 1
const uint32_t n_threads = use_pipeline ? n_threads_init : 1;
const uint32_t n_threads = pipeline ? n_threads_init : 1;
FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu",
neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget);
neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, pipeline, vtcm_budget);
// ======== Build context ========
struct hmx_fa_context factx;
@@ -1339,7 +1339,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
factx.n_kv_blocks = n_kv_blocks;
factx.is_q_fp32 = (q->type == HTP_TYPE_F32);
factx.is_dst_fp32 = (dst->type == HTP_TYPE_F32);
factx.use_pipeline = use_pipeline;
factx.pipeline = pipeline;
factx.mask_broadcast = (mask != NULL && mask->ne[2] == 1);
// Extract op parameters (mutable during softcap adjustment, then stored as const in factx)
@@ -1405,7 +1405,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes);
factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes);
factx.vtcm_v_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes);
if (use_pipeline) {
if (pipeline) {
factx.vtcm_v_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes);
} else {
factx.vtcm_v_tiles[1] = NULL;
@@ -1456,7 +1456,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
// ======== HMX lock strategy ========
// Pipeline: queue thread auto-acquires HMX lock on first push; released by suspend.
// Fallback: main thread holds the lock (original behavior).
if (!factx.use_pipeline) {
if (!factx.pipeline) {
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
}
@@ -1550,7 +1550,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
const size_t k_src_stride = size_k_row_padded / sizeof(__fp16);
const size_t v_src_stride = size_v_row_padded / sizeof(__fp16);
if (factx.use_pipeline) {
if (factx.pipeline) {
// ==================================================================
// Pipeline path: HVX phases ‖ HMX queue worker
// ==================================================================
@@ -1780,7 +1780,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
fa_build_d_diag_inv_l(&factx, n_row_tiles, n_row_tiles_g_br);
// HMX: O_final = diag(1/l) @ O_prev
if (factx.use_pipeline) {
if (factx.pipeline) {
on_job.o_curr = o_tile_curr;
on_job.o_prev = o_tile_prev;
on_job.d_tiles = factx.vtcm_d_tiles;
@@ -1826,7 +1826,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
} // end KV head loop
} // end batch loop
if (factx.use_pipeline) {
if (factx.pipeline) {
hmx_queue_suspend(ctx->hmx_queue);
} else {
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
-6
View File
@@ -1,6 +0,0 @@
// HMX operations compiled as a single translation unit.
// This allows interprocedural optimizations within HMX ops without requiring global HTP LTO.
#include "hmx-queue.c"
#include "hmx-matmul-ops.c"
#include "hmx-flash-attn-ops.c"
-88
View File
@@ -1,88 +0,0 @@
// HMX operation entry-point declarations.
// Ported from htp-ops-lib/include/dsp/ops.h (renamed, benchmark kernels removed). (https://github.com/haozixu/htp-ops-lib)
#ifndef HMX_OPS_H
#define HMX_OPS_H
#include <stddef.h>
#include <stdint.h>
#include "htp-ops.h"
#ifdef __cplusplus
extern "C" {
#endif
typedef struct {
float *dst;
const float *activation;
const __fp16 *permuted_weight;
int m;
int k;
int n;
int act_stride;
int weight_stride;
int dst_stride;
int ne02;
int ne03;
int ne12;
int ne13;
size_t src0_nb2;
size_t src0_nb3;
size_t src1_nb2;
size_t src1_nb3;
size_t dst_nb2;
size_t dst_nb3;
} hmx_matmul_f16_f32_batched_params_t;
// HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output
// act_stride: activation row stride in elements (= k for contiguous, or
// nb[1]/sizeof(float) for permuted tensors like attention Q).
// weight_stride: weight row stride in elements (= k for compact weights, or
// nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK).
int hmx_matmul_f16_f32(struct htp_context *ctx,
float *restrict dst,
const float *activation,
const __fp16 *permuted_weight,
int m, int k, int n,
int act_stride,
int weight_stride);
// Batched F16 wrapper over hmx_mat_mul_f16_f32.
// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3.
int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params);
// HMX matrix multiplication — all supported weight types (F16/F32/Q4_0/Q4_1/Q8_0/IQ4_NL/MXFP4)
int hmx_matmul_2d_f32(struct htp_context *ctx,
float *restrict dst,
const float *activation,
const uint8_t *permuted_weight,
int m, int k, int n,
int act_stride,
int weight_stride,
int weight_type);
struct mmid_row_mapping;
int hmx_matmul_id_2d_f32(struct htp_context *ctx,
float *restrict dst,
const float *activation,
const uint8_t *permuted_weight,
int m, int k, int n,
int ne11,
size_t act_nb1, size_t act_nb2,
size_t dst_nb1, size_t dst_nb2,
int weight_stride,
int weight_type,
const struct mmid_row_mapping *matrix_rows,
int cur_a,
int mapping_stride);
// HMX flash attention
int hmx_flash_attn_ext(struct htp_ops_context * octx);
#ifdef __cplusplus
}
#endif
#endif // HMX_OPS_H
+9 -3
View File
@@ -13,7 +13,9 @@
#include <stdint.h>
#include <stdbool.h>
#ifndef HTP_MAX_NTHREADS
#define HTP_MAX_NTHREADS 10
#endif
#define HTP_MAX_MMAPS 16
// Memory mapping
@@ -42,9 +44,13 @@ struct htp_ops_context {
enum htp_op_code op; // FIXME: rename to opcode
int32_t op_params[HTP_OP_MAX_PARAMS];
int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS];
const struct htp_tensor * src[HTP_OP_MAX_INPUTS];
const struct htp_tensor * dst;
union {
const struct htp_tensor * dst;
const struct htp_tensor * dsts[HTP_OP_MAX_OUTPUTS];
};
// TODO convert these to an array
struct htp_spad src0_spad;
@@ -87,13 +93,13 @@ struct htp_context {
struct htp_ops_context octx;
#ifdef HTP_HAS_HMX
struct hmx_queue * hmx_queue; // Async HMX queue for pipeline overlap
#endif
};
int op_matmul(struct htp_ops_context * octx);
int op_matmul_id(struct htp_ops_context * octx);
int op_matmul_qkv(struct htp_ops_context * octx);
int op_matmul_ffn(struct htp_ops_context * octx);
int op_binary(struct htp_ops_context * octx);
int op_unary(struct htp_ops_context * octx);
int op_sum_rows(struct htp_ops_context * octx);
+15 -8
View File
@@ -28,18 +28,19 @@ enum htp_data_type {
HTP_TYPE_MXFP4 = 39,
// types used internally for repack, dyn.quant, etc
HTP_TYPE_Q4_0x4x2 = 200,
HTP_TYPE_Q4_1x4x2,
HTP_TYPE_Q8_0x4x2,
HTP_TYPE_MXFP4x4x2,
HTP_TYPE_Q4_0_TILED = 200,
HTP_TYPE_Q4_1_TILED,
HTP_TYPE_Q8_0_TILED,
HTP_TYPE_MXFP4_TILED,
HTP_TYPE_INVALID
};
// Constats for internal types
#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks
#define QK_Q4_0_TILED 256 // 32x32 Q4_0 tiled layout
#define QK_Q8_0_TILED 128 // 32x32 Q8_0 tiled layout
#define QK_MXFP4_TILED 256 // 32x32 MXFP4 tiled layout
// Mask to enable various stages of the Ops.
@@ -57,6 +58,8 @@ enum htp_op_code {
HTP_OP_DIV = 3,
HTP_OP_MUL_MAT,
HTP_OP_MUL_MAT_ID,
HTP_OP_MUL_MAT_QKV,
HTP_OP_MUL_MAT_FFN,
HTP_OP_RMS_NORM,
HTP_OP_RMS_NORM_MUL,
HTP_OP_UNARY_SILU,
@@ -99,7 +102,9 @@ enum htp_op_code {
#define HTP_OP_MAX_DIMS 4 // aka GGML_MAX_DIMS
#define HTP_OP_MAX_INPUTS 6 // aka GGML_MAX_SRCS
#define HTP_OP_MAX_OUTPUTS 4
#define HTP_OP_MAX_PARAMS 16 // aka GGML_MAX_OP_PARAMS
#define HTP_OP_MAX_KERN_PARAMS 32
#define HTP_OP_MAX_BUFS 16
#define HTP_OP_MAX_REQS 256
@@ -142,8 +147,10 @@ struct htp_op_desc {
uint32_t opcode; // GGML/HTP Op
uint32_t flags; // Op flags
int32_t params[HTP_OP_MAX_PARAMS]; // Params for the op, e.g. epsilon of RMS norm
int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS]; // generic blob for host-precomputed parameters
uint16_t src[HTP_OP_MAX_INPUTS]; // Input tensors indices
uint16_t dst; // Output tensor index
uint16_t dst[HTP_OP_MAX_OUTPUTS]; // Output tensor indices
uint16_t pad[2]; // padding to align to 64 bits
};
#ifndef HTP_MAX_NTHREADS
+2 -1
View File
@@ -11,12 +11,13 @@ struct htp_iface_pmu_conf {
};
interface htp_iface : remote_handle64 {
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx, in uint64 max_vmem);
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 n_hmx, in uint64 max_vmem);
AEEResult stop();
AEEResult mmap(in uint32 fd, in uint32 size);
AEEResult munmap(in uint32 fd);
AEEResult profiler(in uint32 mode, in htp_iface_pmu_conf pmu);
AEEResult etm(in uint32 enable);
AEEResult hwinfo(rout uint32 n_threads, rout uint32 n_hvx, rout uint32 n_hmx, rout uint64 vtcm_size);
};
#endif /* HTP_IDL */
+13 -18
View File
@@ -170,25 +170,7 @@ static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
}
#endif
/* Q6_Vsf_equals_Vw is only available on v73+.*/
#if __HVX_ARCH__ < 73
static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in)
{
HVX_Vector const vzero = Q6_V_vzero();
HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero);
HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in);
HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift);
HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift);
HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized);
HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp));
return ret;
}
static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in)
{
return Q6_Vsf_equals_Vqf32(hvx_vec_i32_to_qf32(in));
}
#endif
static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) {
// This looks complicated.
@@ -305,4 +287,17 @@ static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) {
#endif // __HVX_ARCH__ < 79
static inline HVX_Vector hvx_vec_load_act_tile(const uint8_t * y_q, uint32_t kt, HVX_Vector * v_act_all) {
if (kt % 4 == 0) {
*v_act_all = hvx_vmem(y_q + kt * 32);
return *v_act_all;
} else if (kt % 4 == 1) {
return Q6_V_vror_VR(*v_act_all, 32);
} else if (kt % 4 == 2) {
return Q6_V_vror_VR(*v_act_all, 64);
} else {
return Q6_V_vror_VR(*v_act_all, 96);
}
}
#endif /* HVX_BASE_H */
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+81 -23
View File
@@ -361,7 +361,7 @@ static void vtcm_free(struct htp_context * ctx) {
static void htp_packet_callback(dspqueue_t queue, int error, void * context);
static void htp_error_callback(dspqueue_t queue, int error, void * context);
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx, uint64_t max_vmem) {
AEEResult htp_iface_start(remote_handle64 handle, uint32_t sess_id, uint64_t dsp_queue_id, uint32_t n_hvx, uint32_t n_hmx, uint64_t max_vmem) {
struct htp_context * ctx = (struct htp_context *) handle;
if (!ctx) {
@@ -395,10 +395,9 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
return AEE_ENOMEMORY;
}
#ifdef HTP_HAS_HMX
ctx->hmx_enabled = use_hmx;
ctx->hmx_enabled = n_hmx;
ctx->hmx_queue = NULL;
if (use_hmx) {
if (n_hmx) {
ctx->hmx_queue = hmx_queue_create(16, ctx->vtcm_rctx);
if (ctx->hmx_queue) {
ctx->hmx_queue->trace = &ctx->trace[HTP_MAX_NTHREADS];
@@ -407,8 +406,7 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
ctx->hmx_enabled = false;
}
}
FARF(HIGH, "HMX %s (use_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", use_hmx);
#endif
FARF(HIGH, "HMX %s (n_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", n_hmx);
qurt_sysenv_max_hthreads_t hw_threads;
qurt_sysenv_get_max_hw_threads(&hw_threads);
@@ -481,13 +479,11 @@ AEEResult htp_iface_stop(remote_handle64 handle) {
dma_queue_delete(ctx->dma[i]);
}
#ifdef HTP_HAS_HMX
if (ctx->hmx_queue) {
hmx_queue_delete(ctx->hmx_queue);
ctx->hmx_queue = NULL;
}
ctx->hmx_enabled = false;
#endif
vtcm_free(ctx);
@@ -500,6 +496,36 @@ AEEResult htp_iface_stop(remote_handle64 handle) {
return AEE_SUCCESS;
}
AEEResult htp_iface_hwinfo(remote_handle64 handle, uint32_t * n_threads, uint32_t * n_hvx, uint32_t * n_hmx, uint64_t * vtcm_size) {
(void)handle;
if (!n_threads || !n_hvx || !n_hmx || !vtcm_size) {
return AEE_EBADPARM;
}
qurt_sysenv_max_hthreads_t hw_threads;
qurt_sysenv_get_max_hw_threads(&hw_threads);
uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF;
uint32_t n_hvx_val = hw_nhvx;
if (n_hvx_val > hw_threads.max_hthreads) {
n_hvx_val = hw_threads.max_hthreads;
}
if (n_hvx_val > HTP_MAX_NTHREADS) {
n_hvx_val = HTP_MAX_NTHREADS;
}
// for now we force n_threads == n_hvx
*n_threads = n_hvx_val;
*n_hvx = n_hvx_val;
*n_hmx = 1;
uint32_t vtcm_sz = 8 * 1024 * 1024; // 8MB default fallback
HAP_compute_res_query_VTCM(0, (unsigned int *)&vtcm_sz, NULL, NULL, NULL);
*vtcm_size = vtcm_sz;
return AEE_SUCCESS;
}
static void htp_error_callback(dspqueue_t queue, int error, void * context) {
// No errors expected on the DSP.
FARF(ERROR, "Error callback: 0x%08x", (unsigned) error);
@@ -554,6 +580,12 @@ static int execute_op(struct htp_ops_context * octx) {
case HTP_OP_MUL_MAT_ID:
return op_matmul_id(octx);
case HTP_OP_MUL_MAT_QKV:
return op_matmul_qkv(octx);
case HTP_OP_MUL_MAT_FFN:
return op_matmul_ffn(octx);
case HTP_OP_MUL:
case HTP_OP_ADD:
case HTP_OP_SUB:
@@ -762,8 +794,9 @@ static void prep_tensors(struct htp_context *ctx, struct htp_buf_desc *bufs, str
}
}
static void proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) {
static int proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) {
memcpy(octx->op_params, op->params, sizeof(octx->op_params));
memcpy(octx->kernel_params, op->kernel_params, sizeof(octx->kernel_params));
octx->flags = op->flags;
octx->op = op->opcode;
@@ -785,22 +818,41 @@ static void proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens,
src->ne[0], src->ne[1], src->ne[3], src->ne[3]);
}
// Prep output tensor
struct htp_tensor *dst = tens + op->dst;
// Prep output tensors
for (uint32_t i = 0; i < HTP_OP_MAX_OUTPUTS; i++) {
uint16_t dst_idx = op->dst[i];
if (dst_idx == 0xffff) {
octx->dsts[i] = NULL;
continue;
}
struct htp_tensor *dst = tens + dst_idx;
octx->dsts[i] = dst;
octx->dst = dst;
FARF(HIGH, "prep-dst[%u] #%u: data %p size %u : %u:%u:%u:%u", i, dst_idx, (void*) dst->data, dst->size,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]);
}
FARF(HIGH, "prep-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size,
dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]);
int status = execute_op(octx);
(void) execute_op(octx);
octx->src0_spad.src = NULL;
octx->src1_spad.src = NULL;
octx->src2_spad.src = NULL;
octx->src3_spad.src = NULL;
octx->dst_spad.src = NULL;
// flush buffers on output
hex_l2flush((void *) dst->data, dst->size);
dst->flags |= HTP_TENSOR_FLUSHED;
for (uint32_t i = 0; i < HTP_OP_MAX_OUTPUTS; i++) {
if (octx->dsts[i]) {
struct htp_tensor *dst = (struct htp_tensor *)octx->dsts[i];
hex_l2flush((void *) dst->data, dst->size);
dst->flags |= HTP_TENSOR_FLUSHED;
FARF(HIGH, "post-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size,
dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]);
FARF(HIGH, "post-dst[%u] #%u: data %p size %u : %u:%u:%u:%u", i, op->dst[i], (void*) dst->data, dst->size,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]);
}
}
return status;
}
#define DSPQUEUE_POLL_TIMEOUT_USEC 100
@@ -892,20 +944,26 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
}
}
int op_status = HTP_STATUS_OK;
uint32_t op_wakeup = n_ops / 2; // half-way throgh the batch
for (uint32_t i=0; i < n_ops; i++) {
struct profile_data prof;
if (i == (n_ops-1)) {
// wake up the host before starting the last op
if (i == op_wakeup) {
dspqueue_write_early_wakeup_noblock(queue, 0, 0);
}
profile_start(ctx->profiler, &prof);
proc_op_req(octx, tens, i, &ops[i]);
op_status = proc_op_req(octx, tens, i, &ops[i]);
profile_stop(ctx->profiler, &prof);
if (op_status != HTP_STATUS_OK) {
break;
}
if (ctx->profiler) {
pds[i].opcode = ops[i].opcode;
pds[i].usecs = prof.usecs;
@@ -919,7 +977,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
struct htp_opbatch_rsp rsp;
rsp.id = req.id;
rsp.status = HTP_STATUS_OK;
rsp.status = op_status;
rsp.n_bufs = n_bufs;
rsp.n_tensors = n_tens;
rsp.n_ops = n_ops;
File diff suppressed because it is too large Load Diff
+508
View File
@@ -0,0 +1,508 @@
#ifndef HTP_MATMUL_OPS_H
#define HTP_MATMUL_OPS_H
#include <stdint.h>
#include <stddef.h>
#include "htp-ops.h"
#include "hex-fastdiv.h"
#include "hex-common.h"
#ifdef __cplusplus
extern "C" {
#endif
// --- HMX Tile Constraints ---
#define HTP_MM_HMX_TILE_N_COLS 32
#define HTP_MM_HMX_TILE_N_ROWS 32
#define HTP_MM_HMX_TILE_SIZE (32 * 32 * sizeof(__fp16)) // 2048 bytes
#define HTP_MM_HMX_TILE_N_ELMS 1024
#define HTP_MM_HMX_MIN_NROWS 4
// --- Weight Repacked Tile Sizes ---
#define HTP_MM_WEIGHT_TILE_SIZE_Q4_0 576
#define HTP_MM_WEIGHT_TILE_SIZE_Q4_1 640
#define HTP_MM_WEIGHT_TILE_SIZE_Q8_0 1088
#define HTP_MM_WEIGHT_TILE_SIZE_IQ4_NL 576
#define HTP_MM_WEIGHT_TILE_SIZE_MXFP4 544
// --- Weight Repacked Aligned Tile Sizes ---
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_0 640
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_1 640
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q8_0 1152
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_IQ4_NL 640
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_MXFP4 640
// --- Activation Tiled Block Sizes (including padding) ---
#define HTP_MM_ACT_TILE_SIZE_Q8_0 1152
#define HTP_MM_ACT_TILE_SIZE_Q8_1 1280
#define HTP_MM_MAX_PREFETCH 16
// --- Solver Cost Model Penalty Weights (HMX-specific) ---
#define HTP_MM_HMX_COST_W_DEQUANT 3 // cost penalty for quantized weight loading/dequantization
#define HTP_MM_HMX_COST_A_CONVERT 2 // cost penalty for activation loading/conversion
// --- DMA Activation Transfer Configuration ---
#define HTP_MM_DMA_ACT_ROWS_PER_STEP 2
#define HTP_MM_DMA_ACT_MULTIPLIER 4
enum htp_mm_kernel_type {
HTP_MM_KERNEL_UNSUPPORTED = 0,
// HMX paths
HTP_MM_KERNEL_HMX_2D,
HTP_MM_KERNEL_HMX_F16_BATCHED,
// HVX floating-point paths
HTP_MM_KERNEL_HVX_F16_F16_VTCM,
HTP_MM_KERNEL_HVX_F16_F16_DDR,
HTP_MM_KERNEL_HVX_F16_F32_DDR,
HTP_MM_KERNEL_HVX_F32_F32_VTCM,
HTP_MM_KERNEL_HVX_F32_F32_DDR,
HTP_MM_KERNEL_HVX_F32_F16_DDR,
// HVX quantized paths
HTP_MM_KERNEL_HVX_QUANT_ROW, // standard row-wise parallel quantization
HTP_MM_KERNEL_HVX_QUANT_BLOCK, // parallel block-wise quantization
HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT, // row-wise fallback flat quantization
};
// Op-specific struct for precomputed matmul params
struct htp_mm_kernel_params {
int32_t kernel_type; // enum htp_mm_kernel_type
int32_t pipeline; // 1 = pipelined execution, 0 = standard
int32_t m_chunk; // Row chunk size (M chunk)
int32_t n_chunk; // Col chunk size (N chunk)
int32_t n_threads; // Number of threads to spawn
int32_t n_act_threads; // Number of threads for activation preparation
int32_t n_hmx; // 1 = use HMX, 0 = use HVX
int32_t n_prefetch; // Prefetch lookahead buffers/rows in VTCM
int32_t tile_size; // Weight tile size
int32_t aligned_tile_size; // Aligned weight tile size (padded to 128)
int32_t src1_row_size; // Row size for quantized activation
int32_t vtcm_size; // Total required scratchpad size in VTCM
int32_t vtcm_src0_size; // src0 scratchpad size in VTCM
int32_t vtcm_src1_size; // src1 scratchpad size in VTCM
int32_t vtcm_src2_size; // src2 scratchpad size in VTCM (fused only)
int32_t vtcm_src3_size; // src3 scratchpad size in VTCM (fused only)
int32_t vtcm_dst_size; // dst scratchpad size in VTCM
// Precomputed division values
struct fastdiv_values div_ne12_ne1;
struct fastdiv_values div_ne1;
struct fastdiv_values div_r2;
struct fastdiv_values div_r3;
struct fastdiv_values div_ne11;
};
#if defined(__cplusplus)
static_assert(sizeof(struct htp_mm_kernel_params) <= 128, "htp_matmul_kernel_params is too large for kernel_params blob");
#else
_Static_assert(sizeof(struct htp_mm_kernel_params) <= 128, "htp_matmul_kernel_params is too large for kernel_params blob");
#endif
struct mmid_row_mapping {
uint32_t i1;
uint32_t i2;
};
// Search for optimal (mc, nc) chunk sizes within VTCM budget.
static inline int htp_mm_hmx_compute_chunks(size_t vtcm_total,
size_t overhead,
size_t per_n_cost,
size_t per_m_cost,
size_t per_mn_cost,
size_t m,
size_t n,
size_t m_block_cost,
size_t n_block_cost,
size_t * m_chunk_out,
size_t * n_chunk_out,
size_t * total_out) {
if (m == 0 || n == 0) return -1;
if (vtcm_total <= overhead) return -1;
if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1;
const size_t usable = vtcm_total - overhead;
size_t best_cost = SIZE_MAX;
size_t best_mn = 0;
size_t best_m = 0, best_n = 0;
const size_t n_max = hex_align_down((size_t)n, HTP_MM_HMX_TILE_N_COLS);
for (size_t nc = n_max; nc >= HTP_MM_HMX_TILE_N_COLS; nc -= HTP_MM_HMX_TILE_N_COLS) {
size_t n_fixed = 0, ncmn = 0, mc_denom = 0;
if (hex_mul_overflow(nc, per_n_cost, &n_fixed)) continue;
if (n_fixed >= usable) goto next_nc;
if (hex_mul_overflow(nc, per_mn_cost, &ncmn)) goto next_nc;
if (hex_add_overflow(per_m_cost, ncmn, &mc_denom) || mc_denom == 0) goto next_nc;
{
size_t remain = usable - n_fixed;
size_t mc = remain / mc_denom;
mc = hex_align_down(mc, HTP_MM_HMX_TILE_N_ROWS);
mc = hex_smin(mc, m);
if (mc == 0) {
goto next_nc;
}
size_t mblocks = ((size_t) m + mc - 1) / mc;
size_t nblocks = ((size_t) n + nc - 1) / nc;
size_t cost = mblocks * m_block_cost + nblocks * n_block_cost;
size_t mn = mc * nc;
if (cost < best_cost || (cost == best_cost && mn > best_mn)) {
best_cost = cost;
best_mn = mn;
best_m = mc;
best_n = nc;
}
}
next_nc:
if (nc == HTP_MM_HMX_TILE_N_COLS) break; // avoid size_t underflow
}
if (best_m == 0 || best_n == 0) return -1;
// Compute exact total (with overflow checks)
size_t t0 = 0, t1 = 0, t2 = 0, mn = 0, total = 0;
if (hex_mul_overflow(best_n, per_n_cost, &t0)) return -1;
if (hex_mul_overflow(best_m, per_m_cost, &t1)) return -1;
if (hex_mul_overflow(best_m, best_n, &mn)) return -1;
if (hex_mul_overflow(mn, per_mn_cost, &t2)) return -1;
if (hex_add_overflow(t0, t1, &total)) return -1;
if (hex_add_overflow(total, t2, &total)) return -1;
if (hex_add_overflow(total, overhead, &total)) return -1;
*m_chunk_out = best_m;
*n_chunk_out = best_n;
*total_out = total;
return 0;
}
// --- Tile Size Helpers ---
static inline uint32_t htp_mm_get_weight_tile_size(int weight_type) {
switch (weight_type) {
case HTP_TYPE_Q4_0:
case HTP_TYPE_IQ4_NL:
return HTP_MM_WEIGHT_TILE_SIZE_Q4_0;
case HTP_TYPE_Q4_1:
return HTP_MM_WEIGHT_TILE_SIZE_Q4_1;
case HTP_TYPE_Q8_0:
return HTP_MM_WEIGHT_TILE_SIZE_Q8_0;
case HTP_TYPE_MXFP4:
return HTP_MM_WEIGHT_TILE_SIZE_MXFP4;
default:
return 0;
}
}
static inline uint32_t htp_mm_get_weight_aligned_tile_size(int weight_type) {
switch (weight_type) {
case HTP_TYPE_Q4_0:
case HTP_TYPE_IQ4_NL:
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_0;
case HTP_TYPE_Q4_1:
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_1;
case HTP_TYPE_Q8_0:
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q8_0;
case HTP_TYPE_MXFP4:
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_MXFP4;
default:
return 0;
}
}
// --- Activation/Row Size Helpers ---
static inline size_t htp_mm_q8_0_tiled_row_size(uint32_t ne) {
const uint32_t ne_padded = ((ne + 127) / 128) * 128;
const uint32_t nb_32 = ne_padded / 32;
return nb_32 * HTP_MM_ACT_TILE_SIZE_Q8_0;
}
static inline size_t htp_mm_q8_1_tiled_row_size(uint32_t ne) {
const uint32_t ne_padded = ((ne + 127) / 128) * 128;
const uint32_t nb_32 = ne_padded / 32;
return nb_32 * HTP_MM_ACT_TILE_SIZE_Q8_1;
}
static inline size_t htp_mm_q8_0_flat_row_size(uint32_t ne) {
const uint32_t quants_size = hex_align_up(ne, 128);
const uint32_t num_scales = (ne + 31) / 32;
const uint32_t scales_size = hex_align_up(num_scales * 2, 128);
return quants_size + scales_size;
}
static inline size_t htp_mm_q8_1_flat_row_size(uint32_t ne) {
const uint32_t quants_size = hex_align_up(ne, 128);
const uint32_t num_scales = (ne + 31) / 32;
const uint32_t scales_size = hex_align_up(num_scales * 4, 128);
return quants_size + scales_size;
}
static inline size_t htp_mm_get_tiled_row_stride(int weight_type, uint32_t k) {
uint32_t nb = (k + QK_Q4_0_TILED - 1) / QK_Q4_0_TILED;
switch (weight_type) {
case HTP_TYPE_Q4_0:
case HTP_TYPE_IQ4_NL:
case HTP_TYPE_Q4_1:
case HTP_TYPE_Q8_0:
case HTP_TYPE_MXFP4:
return (size_t) nb * htp_mm_get_weight_tile_size(weight_type);
case HTP_TYPE_F16:
return (size_t) k * sizeof(__fp16);
case HTP_TYPE_F32:
return (size_t) k * sizeof(float);
default:
return 0;
}
}
static inline size_t htp_mm_round_up(size_t n, size_t m) {
return ((n + m - 1) / m) * m;
}
static inline bool htp_mm_hmx_pipeline(uint32_t m) {
return m > 32;
}
static inline void htp_mm_hmx_get_2d_chunk_costs(
int wtype, uint32_t k, bool pipeline, uint32_t aligned_tile_size,
size_t * size_per_n_out, size_t * size_per_m_out, size_t * size_per_mn_out
) {
const bool is_quant = (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32);
const size_t row_stride = htp_mm_get_tiled_row_stride(wtype, k);
const size_t vec_dot_size = k * sizeof(uint16_t);
const uint32_t n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS;
const size_t qweight_row_stride = is_quant ? (size_t)(n_k_tiles * aligned_tile_size) / 32 : 0;
*size_per_n_out = (pipeline ? 2 : 1) * (is_quant ? qweight_row_stride : row_stride) +
(pipeline ? 2 * vec_dot_size : vec_dot_size);
*size_per_m_out = vec_dot_size;
*size_per_mn_out = (pipeline ? 2 : 1) * sizeof(uint16_t);
}
static inline void htp_mm_hmx_get_batched_chunk_costs(
uint32_t k, uint32_t group_size,
size_t * size_per_n_out, size_t * size_per_m_out, size_t * size_per_mn_out
) {
const size_t vec_dot_size = k * sizeof(uint16_t);
*size_per_n_out = 3 * vec_dot_size;
*size_per_m_out = group_size * vec_dot_size;
*size_per_mn_out = sizeof(uint16_t);
}
static inline size_t htp_mm_hmx_get_2d_vtcm_size(
int wtype, uint32_t k, size_t mc, size_t nc, bool pipeline, uint32_t act_threads, uint32_t aligned_tile_size
) {
const uint32_t n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS;
const bool is_quant = (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32);
const size_t row_stride = htp_mm_get_tiled_row_stride(wtype, k);
const size_t vec_dot_size = k * sizeof(uint16_t);
const size_t act_f32_size = htp_mm_round_up(act_threads * 4 * k * sizeof(float), HTP_MM_HMX_TILE_SIZE);
size_t weight_area_size = is_quant
? htp_mm_round_up((nc / 32) * n_k_tiles * aligned_tile_size, HTP_MM_HMX_TILE_SIZE)
: htp_mm_round_up(nc * row_stride, HTP_MM_HMX_TILE_SIZE);
if (pipeline) {
weight_area_size *= 2;
}
const size_t act_area_size = htp_mm_round_up(mc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
const size_t output_area_size = htp_mm_round_up(mc * nc * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE);
size_t scratch0_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
size_t scratch1_size = pipeline ? scratch0_size : 0;
size_t scratch2_size = pipeline ? output_area_size : 0;
return weight_area_size + act_area_size + act_f32_size + output_area_size +
scratch0_size + scratch1_size + scratch2_size + 256;
}
static inline size_t htp_mm_hmx_get_batched_vtcm_size(
int wtype, uint32_t k, size_t mc, size_t nc, uint32_t group_size, bool use_dma_activation, bool pipeline, uint32_t act_threads) {
(void)wtype;
(void)pipeline;
const size_t vec_dot_size = k * sizeof(uint16_t);
const size_t f32_scratch_size = use_dma_activation
? htp_mm_round_up(act_threads * 4 * k * sizeof(float), HTP_MM_HMX_TILE_SIZE) : 0;
const size_t act_head_stride = mc * k;
const size_t weight_area_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
const size_t act_area_size = htp_mm_round_up(group_size * act_head_stride * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE);
const size_t output_area_size = htp_mm_round_up(group_size * mc * nc * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE);
const size_t scratch_area_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
return weight_area_size + act_area_size + output_area_size +
2 * scratch_area_size + 256 + f32_scratch_size;
}
static inline size_t htp_mm_hvx_get_vtcm_sizes(
int kernel_type,
int wtype,
uint32_t ne10, // k
uint32_t src1_nrows, // m_total (or act_nrows)
uint32_t n_threads,
size_t dst_row_size,
size_t src0_row_size,
size_t src1_row_size,
uint32_t n_prefetch,
size_t * vtcm_src0_size_out,
size_t * vtcm_src1_size_out,
size_t * vtcm_dst_size_out
) {
size_t vtcm_src0_size = 0;
size_t vtcm_src1_size = 0;
size_t vtcm_dst_size = 0;
const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 ||
wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL ||
wtype == HTP_TYPE_MXFP4);
const size_t src0_row_size_padded = htp_mm_round_up(src0_row_size, 128);
const size_t dst_nrows = (src1_nrows > 1) ? 0 : 1;
switch (kernel_type) {
case HTP_MM_KERNEL_HVX_F16_F16_VTCM: {
size_t f16_src1_row_size = htp_mm_round_up(ne10 * 2, 128);
vtcm_src1_size = htp_mm_round_up(f16_src1_row_size * src1_nrows, 256);
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256) * n_threads;
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0;
break;
}
case HTP_MM_KERNEL_HVX_F16_F32_DDR:
case HTP_MM_KERNEL_HVX_F16_F16_DDR:
case HTP_MM_KERNEL_HVX_F32_F32_DDR:
case HTP_MM_KERNEL_HVX_F32_F16_DDR: {
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size, 256) * n_threads;
vtcm_src1_size = htp_mm_round_up(n_prefetch * src1_row_size, 256) * n_threads;
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0;
break;
}
case HTP_MM_KERNEL_HVX_F32_F32_VTCM: {
size_t f32_src1_row_size = htp_mm_round_up(ne10 * 4, 128);
vtcm_src1_size = htp_mm_round_up(f32_src1_row_size * src1_nrows, 256);
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256) * n_threads;
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0;
break;
}
case HTP_MM_KERNEL_HVX_QUANT_BLOCK:
case HTP_MM_KERNEL_HVX_QUANT_ROW: {
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10);
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
// src0 spad is also used in dynamic quantizer to store padded src1 rows
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, QK_Q8_0_TILED * sizeof(float));
if (vtcm_src0_size < src1_row_size_padded) {
vtcm_src0_size = src1_row_size_padded;
}
vtcm_src0_size = vtcm_src0_size * n_threads;
vtcm_dst_size = vtcm_dst_size * n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = ne10 / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
vtcm_src0_size = repacked_vtcm_size * n_threads;
}
break;
}
case HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT: {
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10);
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, 256);
if (vtcm_src0_size < src1_row_size_padded) {
vtcm_src0_size = src1_row_size_padded;
}
vtcm_src0_size = vtcm_src0_size * n_threads;
vtcm_dst_size = vtcm_dst_size * n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = ne10 / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
vtcm_src0_size = repacked_vtcm_size * n_threads;
}
break;
}
default:
break;
}
*vtcm_src0_size_out = vtcm_src0_size;
*vtcm_src1_size_out = vtcm_src1_size;
*vtcm_dst_size_out = vtcm_dst_size;
return vtcm_src0_size + vtcm_src1_size + vtcm_dst_size;
}
static inline size_t htp_mm_hvx_id_get_vtcm_sizes(
int wtype,
uint32_t ne10, // k
uint32_t src1_nrows,
uint32_t n_threads,
size_t src0_row_size, // nb01
uint32_t n_prefetch,
size_t * vtcm_src0_size_out,
size_t * vtcm_src1_size_out
) {
const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 ||
wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL ||
wtype == HTP_TYPE_MXFP4);
const size_t src0_row_size_padded = htp_mm_round_up(src0_row_size, 128);
const size_t src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10)
: htp_mm_q8_0_tiled_row_size(ne10);
size_t src0_sz_per_thread = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
size_t src1_sz = htp_mm_round_up(src1_row_size * src1_nrows, 256);
// src0 spad also holds temporary transposed src1 columns during dynamic quantization.
const size_t src1_row_size_padded = htp_mm_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
if (src0_sz_per_thread < src1_row_size_padded) {
src0_sz_per_thread = src1_row_size_padded;
}
if (is_repack) {
const uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
const uint32_t n_k_tiles = ne10 / 32;
const uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
src0_sz_per_thread = repacked_vtcm_size;
}
const size_t vtcm_src0_size = src0_sz_per_thread * n_threads;
*vtcm_src0_size_out = vtcm_src0_size;
*vtcm_src1_size_out = src1_sz;
return vtcm_src0_size + src1_sz;
}
#ifdef __cplusplus
}
#endif
#endif // HTP_MATMUL_OPS_H
-4
View File
@@ -14,8 +14,6 @@ Drivers_Dir = 13
1 = %DiskId%
[SourceDisksFiles]
libggml-htp-v68.so = 1
libggml-htp-v69.so = 1
libggml-htp-v73.so = 1
libggml-htp-v75.so = 1
libggml-htp-v79.so = 1
@@ -28,8 +26,6 @@ ExcludeFromSelect = *
CopyFiles=Drivers_Dir
[Drivers_Dir]
libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE
libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE
libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE
libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE
libggml-htp-v79.so,,,0x10 ;COPYFLG_NO_OVERWRITE
+10 -3
View File
@@ -699,6 +699,7 @@ struct vk_device_struct {
bool add_rms_fusion;
uint32_t partials_binding_alignment;
uint32_t max_nodes_per_submit;
bool shader_64b_indexing;
@@ -5878,6 +5879,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->subgroup_vote = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eVote);
// Submit at least every 100 nodes, in case there are workloads without as much matmul.
device->max_nodes_per_submit = 100;
const char* GGML_VK_MAX_NODES_PER_SUBMIT = getenv("GGML_VK_MAX_NODES_PER_SUBMIT");
if (GGML_VK_MAX_NODES_PER_SUBMIT != nullptr) {
uint32_t max_nodes_per_submit = std::stoul(GGML_VK_MAX_NODES_PER_SUBMIT);
device->max_nodes_per_submit = std::max(max_nodes_per_submit, 1u);
}
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
@@ -16173,8 +16182,6 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
// (and scaled down based on model size, so smaller models submit earlier).
// Also submit at least every 100 nodes, in case there are workloads without as much matmul.
int nodes_per_submit = 100;
int submitted_nodes = 0;
int submit_count = 0;
uint64_t mul_mat_bytes = 0;
@@ -16400,7 +16407,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
bool submit = (submitted_nodes >= nodes_per_submit) ||
bool submit = ((uint32_t)submitted_nodes >= ctx->device->max_nodes_per_submit) ||
(mul_mat_bytes_per_submit != 0 && mul_mat_bytes >= mul_mat_bytes_per_submit) ||
(i + ctx->num_additional_fused_ops >= last_node) ||
(almost_ready && !ctx->almost_ready_fence_pending);
@@ -463,6 +463,7 @@ void main() {
}
rowmaxf = max(rowmaxf, float(Sf[r][c]));
}
rowmaxf += FATTN_KQ_MAX_OFFSET;
float Moldf = Mf[r];
// M = max(rowmax, Mold)
@@ -352,6 +352,7 @@ void main() {
}
rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp]));
}
rowmaxf += FATTN_KQ_MAX_OFFSET;
float Moldf = Mf[r];
// Compute max across the row
@@ -11,6 +11,7 @@
#include <future>
#include <queue>
#include <condition_variable>
#include <atomic>
#include <cstdio>
#include <cstring>
#include <cstdlib>
@@ -34,6 +35,9 @@
std::mutex lock;
std::vector<std::pair<std::string, std::string>> shader_fnames;
// Set when any shader subprocess fails (non-zero exit / stderr / launch failure) so the
// build is stopped instead of silently producing a broken libggml-vulkan. (issue #24393)
static std::atomic<bool> compile_failed{false};
std::locale c_locale("C");
std::string GLSLC = "glslc";
@@ -78,7 +82,7 @@ enum MatMulIdType {
namespace {
void execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
int execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
#ifdef _WIN32
HANDLE stdout_read, stdout_write;
HANDLE stderr_read, stderr_write;
@@ -127,8 +131,11 @@ void execute_command(std::vector<std::string>& command, std::string& stdout_str,
CloseHandle(stdout_read);
CloseHandle(stderr_read);
WaitForSingleObject(pi.hProcess, INFINITE);
DWORD exit_code = 1;
GetExitCodeProcess(pi.hProcess, &exit_code);
CloseHandle(pi.hProcess);
CloseHandle(pi.hThread);
return (int)exit_code;
#else
int stdout_pipe[2];
int stderr_pipe[2];
@@ -175,7 +182,9 @@ void execute_command(std::vector<std::string>& command, std::string& stdout_str,
close(stdout_pipe[0]);
close(stderr_pipe[0]);
waitpid(pid, nullptr, 0);
int status = 0;
waitpid(pid, &status, 0);
return WIFEXITED(status) ? WEXITSTATUS(status) : -1;
}
#endif
}
@@ -372,13 +381,14 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
// }
// std::cout << std::endl;
execute_command(cmd, stdout_str, stderr_str);
if (!stderr_str.empty()) {
std::cerr << "cannot compile " << name << "\n\n";
int exit_code = execute_command(cmd, stdout_str, stderr_str);
if (exit_code != 0 || !stderr_str.empty()) {
std::cerr << "cannot compile " << name << " (exit code " << exit_code << ")\n\n";
for (const auto& part : cmd) {
std::cerr << part << " ";
}
std::cerr << "\n\n" << stderr_str << std::endl;
compile_failed = true;
return;
}
@@ -398,6 +408,7 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
shader_fnames.push_back(std::make_pair(name, out_path));
} catch (const std::exception& e) {
std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
compile_failed = true;
}
}
@@ -1271,6 +1282,11 @@ int main(int argc, char** argv) {
process_shaders();
if (compile_failed) {
std::cerr << "vulkan-shaders-gen: one or more shaders failed to compile" << std::endl;
return EXIT_FAILURE;
}
write_output_files();
return EXIT_SUCCESS;
+7 -1
View File
@@ -57,19 +57,25 @@ oppoll=
opflt=
[ "$OF" != "" ] && opflt="GGML_HEXAGON_OPFILTER=$OF"
opfuse=
[ "$OC" != "" ] && opfuse="GGML_HEXAGON_OPFUSION=$OC"
vmem=
[ "$VM" != "" ] && vmem="GGML_HEXAGON_VMEM=$VM"
mbuf=
[ "$MB" != "" ] && mbuf="GGML_HEXAGON_MBUF=$MB"
mmsel=
[ "$MM" != "" ] && mmsel="GGML_HEXAGON_MM_SELECT=$MM"
set -x
adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $vmem $mbuf \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $opfuse $vmem $mbuf $mmsel \
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--ctx-size 8192 --ubatch-size 1024 -fa on \
+7 -1
View File
@@ -51,6 +51,12 @@ opqueue=
oppoll=
[ "$OP" != "" ] && oppoll="GGML_HEXAGON_OPPOLL=$OP"
opfuse=
[ "$OC" != "" ] && opfuse="GGML_HEXAGON_OPFUSION=$OC"
mmsel=
[ "$MM" != "" ] && mmsel="GGML_HEXAGON_MM_SELECT=$MM"
set -x
tool=$1; shift
@@ -59,5 +65,5 @@ adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll ./$branch/bin/$tool $@ \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opfuse $mmsel ./$branch/bin/$tool $@ \
"
+38 -7
View File
@@ -26,7 +26,7 @@ COL_MAP = {
}
op_pattern = re.compile(
r"profile-op\s+(?P<op_name>[A-Z_0-9+]+):\s+.*?\s+:\s+(?P<dims>[\d:x\s\->!]+)\s+:\s+(?P<types>[a-z\d_\s\->x]+)\s+:\s+.*?\s+(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?"
r"profile-op\s+(?P<op_name>[A-Z_0-9+]+):\s+.*?\s+:\s+(?P<dims>[\d:x\s\->!]+)\s+:\s+(?P<types>[a-z\d_\s\->x]+)\s+:\s+.*?\s+:\s+(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?"
)
trace_pattern = re.compile(
@@ -93,9 +93,40 @@ def parse_log(file_path, pmu_index=None):
+ int(ts_match.group('us'))
)
op_match = op_pattern.search(line)
if "|" in line and "profile-op" in line:
parts = [p.strip() for p in line.split("|")]
prefix = parts[0]
prefix_match = re.search(r"profile-op\s+(?P<op_name>[A-Z_0-9+]+)", prefix)
if not prefix_match:
continue
if len(parts) == 7:
dims, types, timings = parts[2], parts[3], parts[6]
elif len(parts) == 6:
dims, types, timings = parts[2], parts[3], parts[5]
else:
continue
timing_match = re.search(
r"(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?",
timings
)
if not timing_match:
continue
op_match = timing_match
op_name = prefix_match.group("op_name")
else:
op_match = op_pattern.search(line)
if op_match:
op_name = op_match.group('op_name')
dims = op_match.group('dims').strip()
types = op_match.group('types').strip()
else:
op_match = None
if op_match:
pmu_raw = op_match.group('pmu')
pmu_raw = op_match.group('pmu') if 'pmu' in op_match.groupdict() else None
pmu_val = None
if pmu_raw and pmu_index is not None:
try:
@@ -105,7 +136,7 @@ def parse_log(file_path, pmu_index=None):
except (ValueError, IndexError):
pmu_val = None
evt_raw = op_match.group('evt')
evt_raw = op_match.group('evt') if 'evt' in op_match.groupdict() else None
evt_val = None
if evt_raw:
try:
@@ -122,9 +153,9 @@ def parse_log(file_path, pmu_index=None):
op_text = line[idx + 11:].strip() if idx != -1 else line.strip()
current_op = {
'name': op_match.group('op_name'),
'dims': op_match.group('dims').strip(),
'types': op_match.group('types').strip(),
'name': op_name,
'dims': dims,
'types': types,
'op_text': op_text,
'usec': int(op_match.group('usec')),
'cycles': int(op_match.group('cycles')),
+42 -6
View File
@@ -12,7 +12,7 @@ from collections import defaultdict
logger = logging.getLogger("ggml-hexagon-trace")
op_pattern = re.compile(
r"profile-op\s+(?P<op_name>[A-Z_0-9+]+):\s+.*?\s+:\s+(?P<dims>[\d:x\s\->!]+)\s+:\s+(?P<types>[a-z\d_\s\->x]+)\s+:\s+(?P<strides>[\d:x\s\->!]+)\s+:\s+(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?"
r"profile-op\s+(?P<op_name>[A-Z_0-9+]+):\s+.*?\s+:\s+(?P<dims>[\d:x\s\->!]+)\s+:\s+(?P<types>[a-z\d_\s\->x]+)\s+:\s+(?P<strides>[\d:x\s\->!]+?)\s+:\s+(?:(?P<params>.*?)\s+:\s+)?(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?"
)
trace_pattern = re.compile(
@@ -66,7 +66,40 @@ def parse_log(file_path):
for line in f:
line_idx += 1
op_match = op_pattern.search(line)
if "|" in line and "profile-op" in line:
parts = [p.strip() for p in line.split("|")]
prefix = parts[0]
prefix_match = re.search(r"profile-op\s+(?P<op_name>[A-Z_0-9+]+)", prefix)
if not prefix_match:
continue
if len(parts) == 7:
dims, types, strides, params, timings = parts[2], parts[3], parts[4], parts[5], parts[6]
elif len(parts) == 6:
dims, types, strides, params, timings = parts[2], parts[3], parts[4], "", parts[5]
else:
continue
timing_match = re.search(
r"(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?",
timings
)
if not timing_match:
continue
op_match = timing_match
op_name = prefix_match.group("op_name")
else:
op_match = op_pattern.search(line)
if op_match:
op_name = op_match.group('op_name')
dims = op_match.group('dims').strip() if op_match.group('dims') else ''
types = op_match.group('types').strip() if op_match.group('types') else ''
strides = op_match.group('strides').strip() if op_match.group('strides') else ''
params = op_match.group('params').strip() if ('params' in op_match.groupdict() and op_match.group('params')) else ''
else:
op_match = None
if op_match:
cycles_start_raw = op_match.group('start')
unwrapped_cycles_start = None
@@ -77,10 +110,11 @@ def parse_log(file_path):
op_text = line[idx + 11:].strip() if idx != -1 else line.strip()
current_op = {
'name': op_match.group('op_name'),
'dims': op_match.group('dims').strip() if op_match.group('dims') else '',
'types': op_match.group('types').strip() if op_match.group('types') else '',
'strides': op_match.group('strides').strip() if op_match.group('strides') else '',
'name': op_name,
'dims': dims,
'types': types,
'strides': strides,
'params': params,
'op_text': op_text,
'usec': int(op_match.group('usec')),
'cycles': int(op_match.group('cycles')),
@@ -397,6 +431,8 @@ def generate_perfetto_trace(filtered_ops, output_path):
debug_annots.append(make_debug_annotation("line", int_val=op['line_num']))
if 'strides' in op and op['strides']:
debug_annots.append(make_debug_annotation("strides", string_val=op['strides']))
if 'params' in op and op['params'] and op['params'] != '----':
debug_annots.append(make_debug_annotation("params", string_val=op['params']))
# Slice Begin
evt_begin = make_track_event(1, 2, name=f"{op['name']} ({op['dims']})", category="operator", debug_annotations=debug_annots)
+14 -4
View File
@@ -190,7 +190,15 @@ llama_model_lfm2::graph<iswa>::graph(const llama_model & model, const llm_graph_
auto * conv_rs = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs);
auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);
bx = ggml_concat(ctx0, conv, bx, 0);
// causal prepends the state, non-causal pads symmetrically for a centered window
if (hparams.causal_attn) {
bx = ggml_concat(ctx0, conv, bx, 0);
} else {
const int64_t pad = (hparams.n_shortconv_l_cache - 1) / 2;
auto * left = ggml_cont(ctx0,
ggml_view_3d(ctx0, conv, pad, hparams.n_embd, n_seqs, conv->nb[1], conv->nb[2], (d_conv - pad) * conv->nb[0]));
bx = ggml_pad_ext(ctx0, ggml_concat(ctx0, left, bx, 0), 0, pad, 0, 0, 0, 0, 0, 0);
}
GGML_ASSERT(bx->ne[0] > conv->ne[0]);
// last d_conv columns is a new conv state
@@ -266,10 +274,12 @@ llama_model_lfm2::graph<iswa>::graph(const llama_model & model, const llm_graph_
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
if (!cparams.embeddings) {
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;
res->t_logits = cur;
}
ggml_build_forward_expand(gf, cur);
}
-1
View File
@@ -199,7 +199,6 @@ llama_build_and_test(test-jinja.cpp)
llama_test(test-jinja NAME test-jinja-py ARGS -py LABEL python)
llama_build_and_test(test-chat-auto-parser.cpp WORKING_DIRECTORY ${PROJECT_SOURCE_DIR})
llama_build_and_test(test-chat-template.cpp)
llama_build_and_test(test-json-partial.cpp)
llama_build_and_test(test-log.cpp)
llama_build_and_test(
test-peg-parser.cpp
+6
View File
@@ -8420,6 +8420,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 2880, 32, 2880, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 2880, 32, 2880, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_MXFP4, GGML_TYPE_F32, 2880, 32, 2880, {1, 1}, {1, 1}));
#if 0
{
// Test paths in OpenCL
@@ -8594,6 +8599,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
// gpt-oss issue with Vulkan mmq_id
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_Q4_0, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
for (ggml_type type_a : all_types) {
test_cases.emplace_back(new test_mul_mat_id(type_a, GGML_TYPE_F32, 4, 2, false, 64, 16, 3*ggml_blck_size(type_a)));
-287
View File
@@ -1,287 +0,0 @@
#include "common.h"
#include "json-partial.h"
#include <exception>
#include <iostream>
#include <stdexcept>
template <class T> static void assert_equals(const T & expected, const T & actual) {
if (expected != actual) {
std::cerr << "Expected: " << expected << std::endl;
std::cerr << "Actual: " << actual << std::endl;
std::cerr << std::flush;
throw std::runtime_error("Test failed");
}
}
static void test_json_healing() {
auto parse = [](const std::string & str) {
std::cerr << "# Parsing: " << str << '\n';
std::string::const_iterator it = str.begin();
const auto end = str.end();
common_json out;
std::string healing_marker = "$llama.cpp.json$";
if (common_json_parse(it, end, healing_marker, out)) {
auto dump = out.json.dump();
std::cerr << "Parsed: " << dump << '\n';
std::cerr << "Magic: " << out.healing_marker.json_dump_marker << '\n';
std::string result;
if (!out.healing_marker.json_dump_marker.empty()) {
auto i = dump.find(out.healing_marker.json_dump_marker);
if (i == std::string::npos) {
throw std::runtime_error("Failed to find magic in dump " + dump + " (magic: " + out.healing_marker.json_dump_marker + ")");
}
result = dump.substr(0, i);
} else {
result = dump;
}
std::cerr << "Result: " << result << '\n';
if (string_starts_with(str, result)) {
std::cerr << "Failure!\n";
}
// return dump;
} else {
throw std::runtime_error("Failed to parse: " + str);
}
};
auto parse_all = [&](const std::string & str) {
for (size_t i = 1; i < str.size(); i++) {
parse(str.substr(0, i));
}
};
parse_all("{\"a\": \"b\"}");
parse_all("{\"hey\": 1, \"ho\\\"ha\": [1]}");
parse_all("[{\"a\": \"b\"}]");
auto test = [&](const std::vector<std::string> & inputs, const std::string & expected, const std::string & expected_marker) {
for (const auto & input : inputs) {
common_json out;
assert_equals(true, common_json_parse(input, "$foo", out));
assert_equals<std::string>(expected, out.json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true));
assert_equals<std::string>(expected_marker, out.healing_marker.json_dump_marker);
}
};
// No healing needed:
test(
{
R"([{"a":"b"}, "y"])",
},
R"([{"a":"b"},"y"])",
""
);
// Partial literals can't be healed:
test(
{
R"([1)",
R"([tru)",
R"([n)",
R"([nul)",
R"([23.2)",
},
R"(["$foo"])",
R"("$foo)"
);
test(
{
R"({"a": 1)",
R"({"a": tru)",
R"({"a": n)",
R"({"a": nul)",
R"({"a": 23.2)",
},
R"({"a":"$foo"})",
R"("$foo)"
);
test(
{
R"({)",
},
R"({"$foo":1})",
R"("$foo)"
);
test(
{
R"([)",
},
R"(["$foo"])",
R"("$foo)"
);
// Healing right after a full literal
test(
{
R"(1 )",
},
R"(1)",
""
);
test(
{
R"(true)",
R"(true )",
},
R"(true)",
""
);
test(
{
R"(null)",
R"(null )",
},
R"(null)",
""
);
test(
{
R"([1 )",
},
R"([1,"$foo"])",
R"(,"$foo)"
);
test(
{
R"([{})",
R"([{} )",
},
R"([{},"$foo"])",
R"(,"$foo)"
);
test(
{
R"([true)",
},
// TODO: detect the true/false/null literal was complete
R"(["$foo"])",
R"("$foo)"
);
test(
{
R"([true )",
},
R"([true,"$foo"])",
R"(,"$foo)"
);
test(
{
R"([true,)",
},
R"([true,"$foo"])",
R"("$foo)"
);
// Test nesting
test(
{
R"([{"a": [{"b": [{)",
},
R"([{"a":[{"b":[{"$foo":1}]}]}])",
R"("$foo)"
);
test(
{
R"([{"a": [{"b": [)",
},
R"([{"a":[{"b":["$foo"]}]}])",
R"("$foo)"
);
test(
{
R"([{"a": "b"})",
R"([{"a": "b"} )",
},
R"([{"a":"b"},"$foo"])",
R"(,"$foo)"
);
test(
{
R"([{"a": "b"},)",
R"([{"a": "b"}, )",
},
R"([{"a":"b"},"$foo"])",
R"("$foo)"
);
test(
{
R"({ "code)",
},
R"({"code$foo":1})",
R"($foo)"
);
test(
{
R"({ "code\)",
},
R"({"code\\$foo":1})",
R"(\$foo)"
);
test(
{
R"({ "code")",
},
R"({"code":"$foo"})",
R"(:"$foo)"
);
test(
{
R"({ "key")",
},
R"({"key":"$foo"})",
R"(:"$foo)"
);
// Test unicode escape sequences
test(
{
R"({"a":"\u)",
},
R"({"a":"\u0000$foo"})",
R"(0000$foo)"
);
test(
{
R"({"a":"\u00)",
},
R"({"a":"\u0000$foo"})",
R"(00$foo)"
);
test(
{
R"({"a":"\ud300)",
},
R"({"a":"\ud300$foo"})",
R"($foo)"
);
test(
{
R"({"a":"\ud800)",
},
R"({"a":"\ud800\udc00$foo"})",
R"(\udc00$foo)"
);
test(
{
R"({"a":"\ud800\)",
},
R"({"a":"\ud800\udc00$foo"})",
R"(udc00$foo)"
);
test(
{
R"({"a":"\ud800\u)",
},
R"({"a":"\ud800\udc00$foo"})",
R"(dc00$foo)"
);
test(
{
R"({"a":"\ud800\udc00)",
},
R"({"a":"\ud800\udc00$foo"})",
R"($foo)"
);
}
int main() {
test_json_healing();
std::cerr << "All tests passed.\n";
return 0;
}
+48 -6
View File
@@ -9,6 +9,7 @@ its output, and holds them against the HF model's scores.
import argparse
import logging
import re
import subprocess
import sys
import unicodedata
@@ -28,6 +29,12 @@ class ModelSpec:
mmproj_arg: str
model_default: str
mmproj_default: str
prompt: str = "Free OCR. "
n_predict: int = 512
n_ctx: int | None = None
# Unlimited-OCR's "document parsing" prompt emits <|det|> grounding markup that
# the HF reference strips in result.md; drop it before scoring to match.
strip_grounding: bool = False
@dataclass
@@ -63,6 +70,20 @@ MODELS = {
model_default="gguf_models/deepseek-ai/deepseek-ocr-2-bf16.gguf",
mmproj_default="gguf_models/deepseek-ai/mmproj-deepseek-ocr-2-bf16.gguf",
),
"unlimited": ModelSpec(
key="unlimited", label="Unlimited-OCR",
model_arg="--llama-model-unlimited", mmproj_arg="--mmproj-unlimited",
model_default="gguf_models/baidu/unlimited-ocr-bf16.gguf",
mmproj_default="gguf_models/baidu/mmproj-unlimited-ocr-bf16.gguf",
# "Free OCR." immediately emits EOS on this checkpoint; the HF reference
# (demo/unlimited_ocr_scores.py) uses "document parsing.", which grounds.
prompt="document parsing.",
# Grounding emits ~3x the tokens of plain OCR, so it needs a larger budget
# and context to reach the article body the ground truth covers.
n_predict=4096,
n_ctx=16384,
strip_grounding=True,
),
}
CASES = [
@@ -82,9 +103,26 @@ CASES = [
# is one pixel off and lands at ~0.69 instead.
hf_cer=0.7761, hf_chrf=28.70, cer_tol=0.12, chrf_tol=8.0,
),
TestCase(
model_key="unlimited", label="single-view scan",
image="tools/mtmd/test-1.jpeg",
ground_truth="tools/mtmd/tests/test-1-ground-truth.txt",
# HF reference: Unlimited-OCR scoring (gundam, bf16) on this image/ground-truth.
# Decoder runs full MHA, not R-SWA; the band absorbs that gap + bf16 variance.
hf_cer=0.1869, hf_chrf=75.23, cer_tol=0.06, chrf_tol=6.0,
),
]
GROUNDING_TAG_RE = re.compile(r"<\|(ref|det)\|>.*?<\|/\1\|>", re.DOTALL)
def strip_grounding(text: str) -> str:
"""Drop <|ref|>..<|/ref|> / <|det|>..<|/det|> grounding markup, matching the
cleaned result.md the HF reference scores against."""
return GROUNDING_TAG_RE.sub("", text)
def arg_dest(flag: str) -> str:
return flag.lstrip("-").replace("-", "_")
@@ -129,19 +167,19 @@ def compute_chrf(expected: str, ocr_out: str) -> float:
return CHRF().sentence_score(ocr_out, [expected]).score
def run_mtmd_cli(model_path, mmproj_path, image_path, bin_path) -> str:
def run_mtmd_cli(spec: "ModelSpec", model_path, mmproj_path, image_path, bin_path) -> str:
"""Run mtmd-cli on the image and return its output."""
cmd = [
str(bin_path),
"-m", str(model_path),
"--mmproj", str(mmproj_path),
"--image", str(image_path),
"-p", "Free OCR. ",
"-p", spec.prompt,
"--chat-template", "deepseek-ocr",
"--temp", "0",
"--flash-attn", "off", # match the HF "eager" attention reference
"--no-warmup",
"-n", "512", # cap loops on hard images (KV would otherwise fill)
"-n", str(spec.n_predict), # cap loops on hard images (KV would otherwise fill)
# HF decodes with no_repeat_ngram_size; llama.cpp's analog is DRY.
# Default DRY breakers include "\n", so they are cleared below.
"--dry-multiplier", "0.8",
@@ -150,6 +188,8 @@ def run_mtmd_cli(model_path, mmproj_path, image_path, bin_path) -> str:
"--dry-penalty-last-n", "-1",
"--dry-sequence-breaker", "none",
]
if spec.n_ctx is not None:
cmd += ["-c", str(spec.n_ctx)]
logger.debug(f" command: {' '.join(cmd)}")
try:
@@ -164,6 +204,8 @@ def run_mtmd_cli(model_path, mmproj_path, image_path, bin_path) -> str:
raise RuntimeError(f"llama-mtmd-cli failed with code {result.returncode}")
output = result.stdout.decode("utf-8", errors="replace").strip()
if spec.strip_grounding:
output = strip_grounding(output)
if not output:
raise RuntimeError("llama-mtmd-cli produced no output on stdout")
logger.info(f" output: {len(output)} chars")
@@ -193,7 +235,7 @@ def evaluate(case: "TestCase", expected: str, ocr_out: str) -> bool:
logger.info("")
logger.info("=" * 60)
logger.info("Free OCR evaluation:")
logger.info("OCR evaluation:")
logger.info("=" * 60)
logger.info(f" CER {cer:>7.4f} (HF {case.hf_cer:.4f}, <= {case.cer_max:>7.4f} -> {verdict(cer_pass)})")
logger.info(f" chrF (0-100) {chrf:>7.2f} (HF {case.hf_chrf:.2f}, >= {case.chrf_min:>7.2f} -> {verdict(chrf_pass)})")
@@ -269,9 +311,9 @@ def main() -> int:
expected = read_expected_text(ground_truth)
logger.info(f" Image: {case.image}")
logger.info(f" Expected text: {len(expected)} chars")
logger.info(" Running llama.cpp 'Free OCR'")
logger.info(f" Running llama.cpp prompt {model_spec.prompt!r}")
try:
ocr_out = run_mtmd_cli(model, mmproj, image, binary)
ocr_out = run_mtmd_cli(model_spec, model, mmproj, image, binary)
except RuntimeError as e:
logger.error(f" Error: {e}")
results[title] = False
+5 -1
View File
@@ -40,6 +40,7 @@ struct debug_options {
bool enable_reasoning = true;
bool debug_jinja = false;
bool force_tool_call = false;
bool parallel_tool_calls = true;
output_mode mode = output_mode::BOTH;
input_message_type input_message = input_message_type::NONE;
};
@@ -87,6 +88,7 @@ static void print_usage(const char * program_name) {
LOG_ERR("\nOptions:\n");
LOG_ERR(" --no-tools Disable tool definitions\n");
LOG_ERR(" --force-tool-call Set tool calls to forced\n");
LOG_ERR(" --parallel-tool-calls=0|1 Set parallel_tool_calls (default: 1)\n");
LOG_ERR(" --generation-prompt=0|1 Set add_generation_prompt (default: 1)\n");
LOG_ERR(" --enable-reasoning=0|1 Enable reasoning parsing (default: 1)\n");
LOG_ERR(" --output=MODE Output mode: analysis, template, both (default: both)\n");
@@ -121,6 +123,8 @@ static bool parse_options(int argc, char ** argv, debug_options & opts) {
opts.debug_jinja = true;
} else if (arg == "--no-tools") {
opts.with_tools = false;
} else if (arg.rfind("--parallel-tool-calls=", 0) == 0) {
opts.parallel_tool_calls = parse_bool_option(arg.substr(22));
} else if (arg.rfind("--generation-prompt=", 0) == 0) {
opts.generation_prompt = parse_bool_option(arg.substr(20));
} else if (arg.rfind("--enable-reasoning=", 0) == 0) {
@@ -349,7 +353,7 @@ static autoparser::generation_params prepare_params(const debug_options & opts,
params.tools = json();
params.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
}
params.parallel_tool_calls = false;
params.parallel_tool_calls = opts.parallel_tool_calls;
return params;
}
+1 -2
View File
@@ -28,10 +28,9 @@ vite.config.ts.timestamp-*
# PWA Artifacts
apple-splash-*.png
apple-touch-icon-*.png
favicon.ico
favicon-dark.ico
maskable-icon-*.png
pwa-*.png
static/favicon*
# Storybook
*storybook.log
+7 -7
View File
@@ -35,7 +35,7 @@
"bits-ui": "2.18.1",
"clsx": "2.1.1",
"dexie": "4.4.3",
"dompurify": "3.4.5",
"dompurify": "3.4.11",
"eslint": "9.39.4",
"eslint-config-prettier": "10.1.8",
"eslint-plugin-storybook": "10.4.2",
@@ -8653,9 +8653,9 @@
"peer": true
},
"node_modules/dompurify": {
"version": "3.4.5",
"resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.4.5.tgz",
"integrity": "sha512-OrwIBKsdNSVEeubdJ1HBv/wNENRM9ytAVCv7YXt//A3vPdVMNuACRqK9mXCGCBW2ln7BT/A4X0jXHo2Gu89miA==",
"version": "3.4.11",
"resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.4.11.tgz",
"integrity": "sha512-zhlUV12GsaRzMsf9q5M254YhA4+VuF0fG+QFqu6aYpoGlKtz+w8//jBcGVYBgQkR5GHjUomejY84AV+/uPbWdw==",
"dev": true,
"license": "(MPL-2.0 OR Apache-2.0)",
"optionalDependencies": {
@@ -10226,9 +10226,9 @@
}
},
"node_modules/hono": {
"version": "4.12.23",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.23.tgz",
"integrity": "sha512-eIaZ9qDgu7XV0pxOCrg7/WhnQ6Ivm22UcxhXx/A3dcbqbbYgBEkc6e/J/s7j2tS96zoB0S9VBdLwQNCWwUo4LA==",
"version": "4.12.26",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.26.tgz",
"integrity": "sha512-uyZtpnYxM9CmQ7QsQknM4zN8EftNqhON1qYeIKM0Se67CCEe2c44xyGURwB0axX2fBDu1dqHrHAc1hmNT8ITkw==",
"dev": true,
"license": "MIT",
"engines": {
+1 -1
View File
@@ -54,7 +54,7 @@
"bits-ui": "2.18.1",
"clsx": "2.1.1",
"dexie": "4.4.3",
"dompurify": "3.4.5",
"dompurify": "3.4.11",
"eslint": "9.39.4",
"eslint-config-prettier": "10.1.8",
"eslint-plugin-storybook": "10.4.2",
+8 -1
View File
@@ -1,4 +1,10 @@
import { defineConfig } from '@vite-pwa/assets-generator/config';
import { FAVICON_COLORS, PWA_ASSET_GENERATOR } from './src/lib/constants/pwa';
import { writeThemeFavicons } from './scripts/favicon-colorize';
writeThemeFavicons(FAVICON_COLORS.LIGHT, FAVICON_COLORS.DARK, {
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
});
export default defineConfig({
headLinkOptions: {
@@ -7,7 +13,8 @@ export default defineConfig({
preset: {
transparent: {
sizes: [],
favicons: [[48, 'favicon-dark.ico']]
favicons: [[48, 'favicon-dark.ico']],
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
},
maskable: {
sizes: []
+19 -2
View File
@@ -5,15 +5,32 @@ import {
} from '@vite-pwa/assets-generator/config';
import { readFileSync } from 'node:fs';
import { resolve } from 'node:path';
import { THEME_COLORS, PWA_GENERATOR_DEVICES, PWA_ASSET_GENERATOR } from './src/lib/constants/pwa';
import {
THEME_COLORS,
PWA_GENERATOR_DEVICES,
PWA_ASSET_GENERATOR,
FAVICON_COLORS
} from './src/lib/constants/pwa';
import { SplashOrientation } from './src/lib/enums/splash.enums';
import { writeThemeFavicons } from './scripts/favicon-colorize';
writeThemeFavicons(FAVICON_COLORS.LIGHT, FAVICON_COLORS.DARK, {
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
});
export default defineConfig({
headLinkOptions: {
preset: PWA_ASSET_GENERATOR.LINK_PRESET
},
preset: combinePresetAndAppleSplashScreens(
minimal2023Preset,
{
...minimal2023Preset,
// tiny margin so favicon.ico / pwa-*.png breathe inside the canvas
transparent: {
...minimal2023Preset.transparent,
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
}
},
{
padding: PWA_ASSET_GENERATOR.SPLASH_PADDING,
resizeOptions: {
+107
View File
@@ -0,0 +1,107 @@
import { mkdirSync, readFileSync, writeFileSync } from 'node:fs';
import { dirname, resolve } from 'node:path';
import { fileURLToPath } from 'node:url';
const HERE = dirname(fileURLToPath(import.meta.url));
const PROJECT_ROOT = resolve(HERE, '..');
const DEFAULT_LOGO = resolve(PROJECT_ROOT, 'src/lib/assets/logo.svg');
const DEFAULT_OUT_DIR = resolve(PROJECT_ROOT, 'static');
const DEFAULT_OUT_LIGHT = resolve(DEFAULT_OUT_DIR, 'favicon.svg');
const DEFAULT_OUT_DARK = resolve(DEFAULT_OUT_DIR, 'favicon-dark.svg');
const CURRENT_COLOR = 'currentColor';
export interface ColorizedFavicon {
light: string;
dark: string;
}
export interface WriteThemeFaviconsOptions {
sourcePath?: string;
lightOutPath?: string;
darkOutPath?: string;
/**
* Fraction of the icon (0..1) to leave as an even margin on each side.
* Applied by wrapping the inner content in a `<g transform="...">` so the
* source `src/lib/assets/logo.svg` is not modified. Pass 0 to disable.
*/
padding?: number;
}
/**
* Replace every `currentColor` occurrence in the SVG with the given color.
* Pure: no filesystem access, so it is straightforward to unit-test.
*/
export function colorizeFaviconSvg(
svg: string,
lightColor: string,
darkColor: string
): ColorizedFavicon {
return {
light: svg.replaceAll(CURRENT_COLOR, lightColor),
dark: svg.replaceAll(CURRENT_COLOR, darkColor)
};
}
/**
* Shrink the inner SVG content uniformly and re-center it so `padding` (a
* 0..1 fraction) is reserved as equal margin on each side. Returns the input
* unchanged for non-positive padding, missing/invalid `viewBox`, or unexpected
* markup so the caller always gets a renderable SVG.
*/
export function padFaviconSvg(svg: string, padding: number): string {
if (!(padding > 0) || padding >= 1) return svg;
const viewBoxMatch = svg.match(/viewBox\s*=\s*["']([^"']+)["']/i);
if (!viewBoxMatch) return svg;
const parts = viewBoxMatch[1]
.trim()
.split(/[\s,]+/)
.map(Number);
if (parts.length !== 4 || parts.some((n) => !Number.isFinite(n))) return svg;
const [, , width, height] = parts;
if (width <= 0 || height <= 0) return svg;
const scale = 1 - padding;
const translateX = (padding * width) / 2;
const translateY = (padding * height) / 2;
const openTagStart = svg.search(/<svg\b/i);
if (openTagStart === -1) return svg;
const openTagEnd = svg.indexOf('>', openTagStart);
if (openTagEnd === -1) return svg;
const closeStart = svg.lastIndexOf('</svg');
if (closeStart === -1 || closeStart <= openTagEnd) return svg;
const openTag = svg.slice(0, openTagEnd + 1);
const inner = svg.slice(openTagEnd + 1, closeStart);
const closeTag = svg.slice(closeStart);
const group = `<g transform="translate(${translateX} ${translateY}) scale(${scale})">`;
return `${openTag}${group}${inner}</g>${closeTag}`;
}
/**
* Read `src/lib/assets/logo.svg`, colorize it for both themes, and write
* the results to the static directory so the PWA asset generator can consume
* them. Paths can be overridden for tests.
*/
export function writeThemeFavicons(
lightColor: string,
darkColor: string,
{
sourcePath = DEFAULT_LOGO,
lightOutPath = DEFAULT_OUT_LIGHT,
darkOutPath = DEFAULT_OUT_DARK,
padding = 0
}: WriteThemeFaviconsOptions = {}
): void {
const source = readFileSync(sourcePath, 'utf-8');
const { light, dark } = colorizeFaviconSvg(source, lightColor, darkColor);
mkdirSync(dirname(lightOutPath), { recursive: true });
writeFileSync(lightOutPath, padFaviconSvg(light, padding));
writeFileSync(darkOutPath, padFaviconSvg(dark, padding));
}
+6 -1
View File
@@ -48,6 +48,7 @@
--chat-form-area-height: 8rem;
--chat-form-area-offset: 2rem;
--chat-form-padding-top: 6rem;
--max-message-height: max(24rem, min(80dvh, calc(100dvh - var(--chat-form-area-height) - 12rem)));
}
@@ -55,6 +56,7 @@
:root {
--chat-form-area-height: 24rem;
--chat-form-area-offset: 12rem;
--chat-form-padding-top: 6rem;
}
}
@@ -141,7 +143,6 @@
@apply bg-background text-foreground;
scrollbar-width: thin;
scrollbar-gutter: stable;
overflow: hidden; /* Added due to Mermaid rendering somehow causing the double scrollbar */
}
/* Global scrollbar styling - visible only on hover */
@@ -193,3 +194,7 @@
scrollbar-width: none;
}
}
.mermaidTooltip {
display: none !important;
}
@@ -10,9 +10,9 @@ import { isElementInViewport } from '$lib/utils/viewport';
*/
export function fadeInView(
node: HTMLElement,
options: { duration?: number; y?: number; skipIfVisible?: boolean } = {}
options: { duration?: number; y?: number; delay?: number; skipIfVisible?: boolean } = {}
) {
const { duration = 300, y = 0, skipIfVisible = false } = options;
const { duration = 300, y = 0, delay = 0, skipIfVisible = false } = options;
if (skipIfVisible && isElementInViewport(node)) {
return;
@@ -27,10 +27,12 @@ export function fadeInView(
(entries) => {
for (const entry of entries) {
if (entry.isIntersecting) {
requestAnimationFrame(() => {
node.style.opacity = '1';
node.style.transform = 'translateY(0)';
});
setTimeout(() => {
requestAnimationFrame(() => {
node.style.opacity = '1';
node.style.transform = 'translateY(0)';
});
}, delay);
observer.disconnect();
}
}
+7
View File
@@ -0,0 +1,7 @@
<svg width="512" height="512" viewBox="0 0 512 512" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M244.95 8C215.233 8 187.774 23.8591 172.923 49.5999L95.6009 183.625C60.2162 244.959 104.481 321.6 175.29 321.6H208L316.977 132.708C348.959 77.2719 308.95 8 244.95 8ZM208 321.6H351.947C415.982 321.6 456.013 390.91 424.013 446.377C409.155 472.132 381.681 488 351.947 488H271.29C200.481 488 156.216 411.359 191.601 350.026L208 321.6Z" fill="currentColor"/>
<path d="M208 321.6H16L106.462 164.8L208 321.6Z" fill="currentColor"/>
<path d="M388.923 8L208 321.6L253.6 8H388.923Z" fill="currentColor"/>
<path d="M304 488H112L202.462 331.2L304 488Z" fill="currentColor"/>
<path d="M496 321.6H208L419.399 454.4L496 321.6Z" fill="currentColor"/>
</svg>

After

Width:  |  Height:  |  Size: 771 B

@@ -8,12 +8,13 @@
ariaLabel?: string;
class?: string;
disabled?: boolean;
href?: string;
icon: Component;
iconSize?: string;
onclick: (e?: MouseEvent) => void;
onclick?: (e?: MouseEvent) => void;
size?: ButtonSize;
stopPropagationOnClick?: boolean;
tooltip: string;
tooltip?: string;
variant?: ButtonVariant;
tooltipSide?: TooltipSide;
}
@@ -22,6 +23,7 @@
icon,
tooltip,
variant = 'ghost',
href = '',
size = 'sm',
class: className = '',
disabled = false,
@@ -31,34 +33,49 @@
onclick,
ariaLabel
}: Props = $props();
let innerWidth = $state(0);
const showTooltip = $derived(!!tooltip && innerWidth > 768);
</script>
<Tooltip.Root>
<Tooltip.Trigger>
<!-- prevent another nested button element -->
{#snippet child({ props })}
<Button
{...props}
{variant}
{size}
{disabled}
onclick={(e: MouseEvent) => {
if (stopPropagationOnClick) e.stopPropagation();
{#snippet button(props = {})}
<Button
{...props}
{href}
{variant}
{size}
{disabled}
onclick={(e: MouseEvent) => {
if (stopPropagationOnClick) e.stopPropagation();
onclick?.(e);
}}
class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!"
aria-label={ariaLabel || tooltip}
>
{#if icon}
{@const IconComponent = icon}
<IconComponent class={iconSize} />
{/if}
</Button>
{/snippet}
</Tooltip.Trigger>
onclick?.(e);
}}
class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!"
aria-label={ariaLabel || tooltip}
>
{#if icon}
{@const IconComponent = icon}
<Tooltip.Content side={tooltipSide}>
<p>{tooltip}</p>
</Tooltip.Content>
</Tooltip.Root>
<IconComponent class={iconSize} />
{/if}
</Button>
{/snippet}
{#if showTooltip}
<Tooltip.Root>
<Tooltip.Trigger>
<!-- prevent another nested button element -->
{#snippet child({ props })}
{@render button(props)}
{/snippet}
</Tooltip.Trigger>
<Tooltip.Content side={tooltipSide}>
<p>{tooltip}</p>
</Tooltip.Content>
</Tooltip.Root>
{:else}
{@render button({ href })}
{/if}
<svelte:window bind:innerWidth />
@@ -494,7 +494,7 @@
/>
<div
class="{INPUT_CLASSES} overflow-hidden rounded-3xl backdrop-blur-md {disabled
class="{INPUT_CLASSES} overflow-hidden rounded-4xl md:rounded-3xl backdrop-blur-md {disabled
? 'cursor-not-allowed opacity-60'
: ''}"
data-slot="input-area"
@@ -510,7 +510,7 @@
/>
<div
class="flex-column relative min-h-[48px] items-center rounded-3xl py-2 pb-2.25 shadow-sm transition-all focus-within:shadow-md md:!py-3"
class="flex-column relative min-h-12 items-center rounded-4xl md:rounded-3xl py-2 pb-2.25 shadow-sm transition-all focus-within:shadow-md md:py-3!"
onpaste={handlePaste}
>
<ChatFormTextarea
@@ -15,7 +15,7 @@
<Tooltip.Root>
<Tooltip.Trigger class="w-full">
<Button
class="file-upload-button h-8 w-8 rounded-full p-0"
class="file-upload-button md:h-8 md:w-8 h-9 w-9 rounded-full p-0"
{disabled}
{onclick}
variant="secondary"
@@ -15,6 +15,7 @@
import { McpLogo } from '$lib/components/app';
import { PencilRuler, ChevronDown, ChevronRight } from '@lucide/svelte';
import { HealthCheckStatus } from '$lib/enums';
import { AttachmentAction } from '$lib/enums/attachment.enums';
interface Props {
class?: string;
@@ -270,14 +271,22 @@
</Collapsible.Root>
{/if}
<button type="button" class={sheetItemClass} onclick={onSystemPromptClick}>
<button
type="button"
class={sheetItemClass}
onclick={() => attachmentMenu.callbacks[AttachmentAction.SYSTEM_PROMPT_CLICK]()}
>
<MessageSquare class="h-4 w-4 shrink-0" />
<span>System Message</span>
</button>
{#if hasMcpPromptsSupport}
<button type="button" class={sheetItemClass} onclick={onMcpPromptClick}>
<button
type="button"
class={sheetItemClass}
onclick={() => attachmentMenu.callbacks[AttachmentAction.MCP_PROMPT_CLICK]()}
>
<Zap class="h-4 w-4 shrink-0" />
<span>MCP Prompt</span>
@@ -285,7 +294,11 @@
{/if}
{#if hasMcpResourcesSupport}
<button type="button" class={sheetItemClass} onclick={onMcpResourcesClick}>
<button
type="button"
class={sheetItemClass}
onclick={() => attachmentMenu.callbacks[AttachmentAction.MCP_RESOURCES_CLICK]()}
>
<FolderOpen class="h-4 w-4 shrink-0" />
<span>MCP Resources</span>
@@ -42,6 +42,7 @@
{hasMcpPromptsSupport}
{hasMcpResourcesSupport}
{onFileUpload}
{onSystemPromptClick}
{onMcpPromptClick}
{onMcpResourcesClick}
>
@@ -20,7 +20,7 @@
type="submit"
disabled={isDisabled}
class={[
'h-8 w-8 rounded-full p-0',
'md:h-8 md:w-8 h-9 w-9 rounded-full p-0',
showErrorState &&
'bg-red-400/10 text-red-400 hover:bg-red-400/20 hover:text-red-400 disabled:opacity-100'
]}
@@ -1,4 +1,5 @@
<script lang="ts">
import { isMobile } from '$lib/stores/viewport.svelte';
import { autoResizeTextarea } from '$lib/utils';
import { onMount } from 'svelte';
@@ -37,7 +38,9 @@
}
export function focus() {
textareaElement?.focus();
if (isMobile.current) return;
textareaElement?.focus({ preventScroll: true });
}
export function resetHeight() {
@@ -231,7 +231,7 @@
editedContent = message.content;
}
textareaElement?.focus();
textareaElement?.focus({ preventScroll: true });
editedExtras = message.extra ? [...message.extra] : [];
editedUploadedFiles = [];
@@ -324,7 +324,7 @@
}
</script>
<div use:fadeInView>
<div use:fadeInView class="chat-message">
{#if message.role === MessageRole.SYSTEM}
<ChatMessageSystem
bind:textareaElement
@@ -180,6 +180,9 @@
let displayedModel = $derived(message.model ?? null);
// model being switched to while it loads, so the selector bar tracks it
let pendingModel = $state<string | null>(null);
let isCurrentlyLoading = $derived(isLoading());
let isStreaming = $derived(isChatStreaming());
let hasNoContent = $derived(!message?.content?.trim());
@@ -207,6 +210,42 @@
isLastAssistantMessage
);
let assistantEl: HTMLDivElement | undefined = $state();
let lastUserMessageHeight = $state(0);
let assistantMarginTop = $state(0);
$effect(() => {
if (!assistantEl) return;
assistantMarginTop = Math.round(parseFloat(getComputedStyle(assistantEl).marginTop));
const chatMessageEl = assistantEl.closest('.chat-message');
const previousChatMessage = chatMessageEl?.previousElementSibling;
const userMessageEl = previousChatMessage?.querySelector(
'.chat-message-user'
) as HTMLElement | null;
if (!userMessageEl) {
lastUserMessageHeight = 0;
return;
}
const updateHeight = () => {
const rect = userMessageEl.getBoundingClientRect();
const marginTop = Math.round(parseFloat(getComputedStyle(userMessageEl).marginTop));
lastUserMessageHeight = Math.round(rect.height + marginTop);
};
updateHeight();
const resizeObserver = new ResizeObserver(updateHeight);
resizeObserver.observe(userMessageEl);
return () => {
resizeObserver.disconnect();
};
});
function handleCopyModel() {
void copyToClipboard(displayedModel ?? '');
}
@@ -219,12 +258,17 @@
</script>
<div
class="text-md group w-full leading-7.5 {className}"
bind:this={assistantEl}
class="chat-message-assistant text-md group w-full leading-7.5 {className}"
style:--last-user-message-height={lastUserMessageHeight > 0
? `${lastUserMessageHeight}px`
: undefined}
style:--assistant-margin-top={assistantMarginTop > 0 ? `${assistantMarginTop}px` : undefined}
role="group"
aria-label="Assistant message with actions"
>
{#if showProcessingInfoTop}
<div class="mt-6 w-full max-w-[48rem]" in:fade>
<div class="mt-6 w-full max-w-3xl" in:fade>
<div class="processing-container">
<span class="processing-text">
{modelLoadingText ??
@@ -257,7 +301,7 @@
{/if}
{#if showProcessingInfoBottom}
<div class="mt-4 w-full max-w-[48rem]" in:fade>
<div class="mt-4 w-full max-w-3xl" in:fade>
<div class="processing-container">
<span class="processing-text">
{modelLoadingText ??
@@ -277,13 +321,19 @@
>
{#if isRouter}
<ModelsSelectorDropdown
currentModel={displayedModel}
currentModel={pendingModel ?? displayedModel}
disabled={isLoading()}
onModelChange={async (modelId: string, modelName: string) => {
const status = modelsStore.getModelStatus(modelId);
if (status !== ServerModelStatus.LOADED) {
await modelsStore.loadModel(modelId);
pendingModel = modelId;
try {
await modelsStore.loadModel(modelId);
} finally {
pendingModel = null;
}
}
onRegenerate(modelName);
@@ -351,6 +401,23 @@
</div>
<style>
:global(.chat-message):last-child .chat-message-assistant {
--assistant-min-height-offset: calc(
var(--last-user-message-height, 19rem) + var(--chat-form-height, 6rem) +
var(--chat-form-bottom-position, 0.5rem) + var(--chat-form-padding-top, 6rem) +
var(--assistant-margin-top, 3rem)
);
min-height: calc(100dvh - var(--assistant-min-height-offset));
@media (width > 768px) {
--assistant-min-height-offset: calc(
var(--last-user-message-height, 18rem) + var(--chat-form-height, 6rem) +
var(--chat-form-bottom-position, 1rem) + var(--chat-form-padding-top, 6rem) +
var(--assistant-margin-top, 3rem)
);
}
}
.processing-container {
display: flex;
flex-direction: column;
@@ -48,7 +48,7 @@
<div
aria-label="User message with actions"
class="group flex flex-col items-end gap-3 md:gap-2 {className}"
class="chat-message-user group flex flex-col items-end gap-3 md:gap-2 {className}"
role="group"
>
{#if editCtx.isEditing}
@@ -19,7 +19,7 @@
renderMarkdown = false,
textColorClass = 'text-foreground',
cardBgClass = 'dark:bg-primary/15',
maxHeightStyle = 'max-height: var(--max-message-height);'
maxHeightStyle = ''
}: Props = $props();
let isMultiline = $state(false);
@@ -59,7 +59,7 @@
{#if content.trim()}
<Card
class="max-w-[80%] overflow-y-auto rounded-[1.125rem] border-none bg-primary/5 px-3.75 py-1.5 {textColorClass} backdrop-blur-md data-[multiline]:py-2.5 {cardBgClass}"
class="chat-message-user-bubble max-w-[80%] overflow-y-auto rounded-[1.125rem] border-none bg-primary/5 px-3.75 py-1.5 {textColorClass} backdrop-blur-md data-multiline:py-2.5 {cardBgClass}"
data-multiline={isMultiline ? '' : undefined}
style="{maxHeightStyle} overflow-wrap: anywhere; word-break: break-word;"
>
@@ -37,6 +37,7 @@
let allConversationMessages = $state<DatabaseMessage[]>([]);
let isVisible = $state(false);
let previousConversationId = $state<string | null>(null);
let previousRouteId = $state<string | null>(null);
const currentConfig = config();
@@ -157,8 +158,9 @@
});
});
beforeNavigate(() => {
beforeNavigate((navigation) => {
isVisible = false;
previousRouteId = navigation.from?.route.id ?? null;
});
afterNavigate(() => {
@@ -249,12 +251,13 @@
</script>
<div
class="transition-opacity delay-300 duration-500 ease-out
{isVisible ? 'opacity-100' : 'opacity-0'}"
class="transition-opacity duration-500 ease-out
{isVisible ? 'opacity-100' : 'opacity-0'}
{previousRouteId === '/(chat)/chat/[id]' ? '' : 'delay-300'}"
>
{#each displayMessages as { message, toolMessages, isLastAssistantMessage, siblingInfo } (message.id)}
<ChatMessage
class="mx-auto mt-12 w-full max-w-[48rem]"
class="mx-auto mt-12 w-full max-w-3xl"
{message}
{toolMessages}
{isLastAssistantMessage}
@@ -1,31 +1,28 @@
<script lang="ts">
import { Trash2 } from '@lucide/svelte';
import { afterNavigate } from '$app/navigation';
import { page } from '$app/state';
import {
ChatScreenForm,
ChatMessages,
ChatScreenDragOverlay,
ChatScreenProcessingInfo,
ChatScreenActionScrollDown,
DialogEmptyFileAlert,
DialogFileUploadError,
DialogChatError,
ServerLoadingSplash,
DialogConfirmation,
ChatScreenServerError
} from '$lib/components/app';
import { setProcessingInfoContext } from '$lib/contexts';
import { ErrorDialogType } from '$lib/enums';
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
import { useChatScreenActiveModel } from '$lib/hooks/use-chat-screen-active-model.svelte';
import { useChatScreenDragAndDrop } from '$lib/hooks/use-chat-screen-drag-and-drop.svelte';
import { useChatScreenFileUpload } from '$lib/hooks/use-chat-screen-file-upload.svelte';
import { useChatScreenScroll } from '$lib/hooks/use-chat-screen-scroll.svelte';
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
import { device } from '$lib/stores/device.svelte';
import { isMobile } from '$lib/stores/viewport.svelte';
import {
chatStore,
errorDialog,
isLoading,
isChatStreaming,
isEditing,
getAddFilesHandler,
activeProcessingState
} from '$lib/stores/chat.svelte';
import {
@@ -34,138 +31,81 @@
activeConversation
} from '$lib/stores/conversations.svelte';
import { config } from '$lib/stores/settings.svelte';
import { serverLoading, serverError, isRouterMode } from '$lib/stores/server.svelte';
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
import { isFileTypeSupported, filterFilesByModalities } from '$lib/utils';
import { parseFilesToMessageExtras, processFilesToChatUploaded } from '$lib/utils/browser-only';
import { onMount } from 'svelte';
import { serverLoading, serverError } from '$lib/stores/server.svelte';
import { parseFilesToMessageExtras } from '$lib/utils/browser-only';
import { onDestroy, onMount } from 'svelte';
import ChatScreenGreeting from './ChatScreenGreeting.svelte';
import ChatScreenActionScrollDown from './ChatScreenActionScrollDown.svelte';
import ChatScreenDialogsAndAlerts from './ChatScreenDialogsAndAlerts.svelte';
import { ROUTES } from '$lib/constants';
let { showCenteredEmpty = false } = $props();
const autoScroll = createAutoScrollController();
let disableAutoScroll = $derived(Boolean(config().disableAutoScroll));
let chatScrollContainer: HTMLDivElement | undefined = $state();
let dragCounter = $state(0);
let isDragOver = $state(false);
let showFileErrorDialog = $state(false);
let uploadedFiles = $state<ChatUploadedFile[]>([]);
let fileErrorData = $state<{
generallyUnsupported: File[];
modalityUnsupported: File[];
modalityReasons: Record<string, string>;
supportedTypes: string[];
}>({
generallyUnsupported: [],
modalityUnsupported: [],
modalityReasons: {},
supportedTypes: []
});
let showDeleteDialog = $state(false);
let showEmptyFileDialog = $state(false);
let processingInfoVisible = $state(false);
let emptyFileNames = $state<string[]>([]);
let initialMessage = $state('');
let isEmpty = $derived(
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
);
let activeErrorDialog = $derived(errorDialog());
let isServerLoading = $derived(serverLoading());
let hasPropsError = $derived(!!serverError());
let isCurrentConversationLoading = $derived(isLoading() || isChatStreaming());
let showProcessingInfo = $derived(
isCurrentConversationLoading ||
(config().keepStatsVisible && !!page.params.id) ||
activeProcessingState() !== null
);
let isRouter = $derived(isRouterMode());
let conversationModel = $derived(
chatStore.getConversationModel(activeMessages() as DatabaseMessage[])
);
let activeModelId = $derived.by(() => {
const options = modelOptions();
if (!isRouter) {
return options.length > 0 ? options[0].model : null;
}
const selectedId = selectedModelId();
if (selectedId) {
const model = options.find((m) => m.id === selectedId);
if (model) return model.model;
}
if (conversationModel) {
const model = options.find((m) => m.model === conversationModel);
if (model) return model.model;
}
return null;
});
let modelPropsVersion = $state(0);
setProcessingInfoContext({
get showProcessingInfo() {
return showProcessingInfo;
}
});
$effect(() => {
if (activeModelId) {
const cached = modelsStore.getModelProps(activeModelId);
let disableAutoScroll = $derived(Boolean(config().disableAutoScroll) || isMobile.current);
let isMobileUserScrolledUp = $state(false);
let mobileScrollDownHint = $state(false);
let mobileScrollDownHintLockedUntil = $state(0);
let emptyFileNames = $state<string[]>([]);
let initialMessage = $state('');
let showDeleteDialog = $state(false);
let showEmptyFileDialog = $state(false);
let isEmpty = $derived(
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
);
let activeErrorDialog = $derived(errorDialog());
let isServerLoading = $derived(serverLoading());
let hasPropsError = $derived(!!serverError());
let isCurrentConversationLoading = $derived(isLoading() || isChatStreaming());
let showProcessingInfo = $derived(
isCurrentConversationLoading ||
(config().keepStatsVisible && !!page.params.id) ||
activeProcessingState() !== null
);
let chatFormBottomPosition = $derived.by(() => {
if (!isMobile.current) return '1rem';
if (device.isStandalone) return '1.5rem';
if (device.isIOSSafari) return '0.25rem';
return '0.5rem';
});
if (!cached) {
modelsStore.fetchModelProps(activeModelId).then(() => {
modelPropsVersion++;
});
const autoScroll = createAutoScrollController();
const scroll = useChatScreenScroll(autoScroll);
const activeModel = useChatScreenActiveModel();
const fileUpload = useChatScreenFileUpload({
capabilities: () => ({
hasVision: activeModel.hasVisionModality,
hasAudio: activeModel.hasAudioModality,
hasVideo: activeModel.hasVideoModality
}),
activeModelId: () => activeModel.activeModelId
});
const dragAndDrop = useChatScreenDragAndDrop({
onDrop: fileUpload.handleFileUpload
});
const { handleKeydown } = useKeyboardShortcuts({
deleteActiveConversation: () => {
if (activeConversation()) {
showDeleteDialog = true;
}
}
});
let hasAudioModality = $derived.by(() => {
if (activeModelId) {
void modelPropsVersion;
function handleMobileScroll() {
if (!isMobile.current) return;
return modelsStore.modelSupportsAudio(activeModelId);
}
const container = scroll.chatScrollContainer;
if (!container) return;
return false;
});
let hasVideoModality = $derived.by(() => {
if (activeModelId) {
void modelPropsVersion;
return modelsStore.modelSupportsVideo(activeModelId);
}
return false;
});
let hasVisionModality = $derived.by(() => {
if (activeModelId) {
void modelPropsVersion;
return modelsStore.modelSupportsVision(activeModelId);
}
return false;
});
const distanceFromBottom =
container.scrollHeight - container.clientHeight - container.scrollTop;
isMobileUserScrolledUp = distanceFromBottom > 300;
}
async function handleDeleteConfirm() {
const conversation = activeConversation();
@@ -177,27 +117,69 @@
showDeleteDialog = false;
}
function handleProcessingInfoVisibility(visible: boolean) {
processingInfoVisible = visible;
}
async function handleSendMessage(message: string, files?: ChatUploadedFile[]): Promise<boolean> {
const plainFiles = files ? $state.snapshot(files) : undefined;
const result = plainFiles
? await parseFilesToMessageExtras(plainFiles, activeModel.activeModelId ?? undefined)
: undefined;
function handleDragEnter(event: DragEvent) {
event.preventDefault();
dragCounter++;
if (event.dataTransfer?.types.includes('Files')) {
isDragOver = true;
if (result?.emptyFiles && result.emptyFiles.length > 0) {
emptyFileNames = result.emptyFiles;
showEmptyFileDialog = true;
if (files) {
const emptyFileNamesSet = new Set(result.emptyFiles);
fileUpload.uploadedFiles = fileUpload.uploadedFiles.filter(
(file) => !emptyFileNamesSet.has(file.name)
);
}
return false;
}
handleSendLikeScroll();
await chatStore.sendMessage(message, result?.extras);
return true;
}
function handleDragLeave(event: DragEvent) {
event.preventDefault();
function handleSendLikeScroll() {
if (!isMobile.current) {
autoScroll.enable();
}
dragCounter--;
setTimeout(() => {
const container = scroll.chatScrollContainer;
if (!container) return;
if (dragCounter === 0) {
isDragOver = false;
const lastUserBubble = container.querySelector(
'.chat-message:nth-last-child(2) .chat-message-user .chat-message-user-bubble'
) as HTMLElement | null;
if (isMobile.current) {
// Keep the last user message bubble just above the input on mobile
const bubbleHeight = lastUserBubble?.scrollHeight ?? 0;
const baseHeight = container.scrollHeight - innerHeight;
container.scrollTo({
top: bubbleHeight > 0 ? baseHeight - bubbleHeight : baseHeight,
behavior: 'smooth'
});
} else if (lastUserBubble) {
// On desktop, place the last user message near the top of the viewport
const topPadding = 24;
const bubbleRect = lastUserBubble.getBoundingClientRect();
container.scrollTo({
top: Math.max(0, container.scrollTop + bubbleRect.top - topPadding),
behavior: 'smooth'
});
} else {
autoScroll.scrollToBottom();
}
}, 100);
if (isMobile.current) {
autoScroll.setDisabled(disableAutoScroll);
mobileScrollDownHint = true;
mobileScrollDownHintLockedUntil = Date.now() + 500;
}
}
@@ -207,273 +189,138 @@
}
}
function handleDragOver(event: DragEvent) {
event.preventDefault();
}
function handleDrop(event: DragEvent) {
event.preventDefault();
isDragOver = false;
dragCounter = 0;
if (event.dataTransfer?.files) {
const files = Array.from(event.dataTransfer.files);
if (isEditing()) {
const handler = getAddFilesHandler();
if (handler) {
handler(files);
return;
}
}
processFiles(files);
}
}
function handleFileRemove(fileId: string) {
uploadedFiles = uploadedFiles.filter((f) => f.id !== fileId);
}
function handleFileUpload(files: File[]) {
processFiles(files);
}
const { handleKeydown } = useKeyboardShortcuts({
deleteActiveConversation: () => {
if (activeConversation()) {
showDeleteDialog = true;
}
}
});
async function handleSystemPromptAdd(draft: { message: string; files: ChatUploadedFile[] }) {
if (draft.message || draft.files.length > 0) {
chatStore.savePendingDraft(draft.message, draft.files);
}
await chatStore.addSystemPrompt();
}
function handleScroll() {
autoScroll.handleScroll();
}
async function handleSendMessage(message: string, files?: ChatUploadedFile[]): Promise<boolean> {
const plainFiles = files ? $state.snapshot(files) : undefined;
const result = plainFiles
? await parseFilesToMessageExtras(plainFiles, activeModelId ?? undefined)
: undefined;
if (result?.emptyFiles && result.emptyFiles.length > 0) {
emptyFileNames = result.emptyFiles;
showEmptyFileDialog = true;
if (files) {
const emptyFileNamesSet = new Set(result.emptyFiles);
uploadedFiles = uploadedFiles.filter((file) => !emptyFileNamesSet.has(file.name));
}
return false;
}
const extras = result?.extras;
// Enable autoscroll for user-initiated message sending
autoScroll.enable();
await chatStore.sendMessage(message, extras);
autoScroll.scrollToBottom();
return true;
}
async function processFiles(files: File[]) {
const generallySupported: File[] = [];
const generallyUnsupported: File[] = [];
for (const file of files) {
if (isFileTypeSupported(file.name, file.type)) {
generallySupported.push(file);
} else {
generallyUnsupported.push(file);
}
}
// Use model-specific capabilities for file validation
const capabilities = {
hasVision: hasVisionModality,
hasAudio: hasAudioModality,
hasVideo: hasVideoModality
};
const { supportedFiles, unsupportedFiles, modalityReasons } = filterFilesByModalities(
generallySupported,
capabilities
);
const allUnsupportedFiles = [...generallyUnsupported, ...unsupportedFiles];
if (allUnsupportedFiles.length > 0) {
const supportedTypes: string[] = ['text files', 'PDFs'];
if (hasVisionModality) supportedTypes.push('images');
if (hasAudioModality) supportedTypes.push('audio files');
if (hasVideoModality) supportedTypes.push('video files');
fileErrorData = {
generallyUnsupported,
modalityUnsupported: unsupportedFiles,
modalityReasons,
supportedTypes
};
showFileErrorDialog = true;
}
if (supportedFiles.length > 0) {
const processed = await processFilesToChatUploaded(
supportedFiles,
activeModelId ?? undefined
);
uploadedFiles = [...uploadedFiles, ...processed];
}
}
afterNavigate(() => {
if (!disableAutoScroll) {
$effect(() => {
const shouldDisableAutoScroll =
config().disableAutoScroll || (isMobile.current && isCurrentConversationLoading);
autoScroll.setDisabled(shouldDisableAutoScroll);
if (!shouldDisableAutoScroll) {
autoScroll.enable();
}
});
function handleMessagesReady() {
if (disableAutoScroll) return;
if (!autoScroll.userScrolledUp) {
requestAnimationFrame(() => {
autoScroll.scrollToBottom('instant');
});
}
}
onMount(() => {
const pendingDraft = chatStore.consumePendingDraft();
if (pendingDraft) {
initialMessage = pendingDraft.message;
fileUpload.uploadedFiles = pendingDraft.files;
}
autoScroll.startObserving();
if (!disableAutoScroll) {
autoScroll.enable();
}
const pendingDraft = chatStore.consumePendingDraft();
if (pendingDraft) {
initialMessage = pendingDraft.message;
uploadedFiles = pendingDraft.files;
if (isMobile.current && isCurrentConversationLoading) {
mobileScrollDownHint = true;
mobileScrollDownHintLockedUntil = Date.now() + 500;
}
handleMobileScroll();
});
$effect(() => {
autoScroll.setContainer(chatScrollContainer);
});
$effect(() => {
autoScroll.setDisabled(disableAutoScroll);
});
onDestroy(() => autoScroll.destroy());
</script>
{#if isDragOver}
{#if dragAndDrop.isDragOver}
<ChatScreenDragOverlay />
{/if}
<svelte:window onkeydown={handleKeydown} />
<svelte:window
onkeydown={handleKeydown}
onscroll={(e) => {
scroll.handleScroll(e);
handleMobileScroll();
if (e.isTrusted && Date.now() > mobileScrollDownHintLockedUntil) {
mobileScrollDownHint = false;
}
}}
/>
{#if isServerLoading}
<ServerLoadingSplash />
{:else}
<div
bind:this={chatScrollContainer}
aria-label="Chat interface with file drop zone"
class="flex h-full flex-col overflow-y-auto px-4 md:px-6"
ondragenter={handleDragEnter}
ondragleave={handleDragLeave}
ondragover={handleDragOver}
ondrop={handleDrop}
onscroll={handleScroll}
class="chat-screen flex grow flex-col min-h-[calc(100dvh-1rem)] md:min-h-full px-4 md:py-0 pt-12 pb-48 md:pb-4"
style:--chat-form-bottom-position={chatFormBottomPosition}
ondragenter={dragAndDrop.dragHandlers.dragenter}
ondragleave={dragAndDrop.dragHandlers.dragleave}
ondragover={dragAndDrop.dragHandlers.dragover}
ondrop={dragAndDrop.dragHandlers.drop}
role="main"
>
<div class="flex grow flex-col pt-14">
{#if !isEmpty}
<ChatMessages
messages={activeMessages()}
onMessagesReady={handleMessagesReady}
onUserAction={() => {
autoScroll.enable();
if (!autoScroll.userScrolledUp) {
autoScroll.scrollToBottom();
}
}}
/>
{/if}
{#if !isEmpty}
<ChatMessages
messages={activeMessages()}
onUserAction={() => {
handleSendLikeScroll();
}}
/>
{/if}
<div
class={[
'pointer-events-none sticky right-4 left-4 mt-auto transition-all duration-200',
isEmpty ? 'bottom-[calc(50dvh-7rem)]' : 'bottom-4 pt-24 md:pt-32'
]}
>
<ChatScreenGreeting {isEmpty} />
<div
class={[
'pointer-events-none md:sticky fixed mt-auto transition-all duration-200',
device.isStandalone
? 'bottom-6 right-4 left-4'
: device.isIOSSafari
? 'bottom-1 left-2 right-2'
: 'bottom-2 right-2 left-2',
isEmpty ? 'md:bottom-[calc(50dvh-7rem)] 2xl:bottom-[calc(50dvh-4rem)]' : 'md:bottom-4'
]}
style:padding-top={!isEmpty ? 'var(--chat-form-padding-top)' : undefined}
>
<ChatScreenGreeting {isEmpty} />
<ChatScreenActionScrollDown
container={chatScrollContainer}
hasProcessingInfoVisible={processingInfoVisible}
/>
<ChatScreenServerError />
<ChatScreenProcessingInfo onVisibilityChange={handleProcessingInfoVisibility} />
<ChatScreenServerError />
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl">
<ChatScreenForm
disabled={hasPropsError || isEditing()}
{initialMessage}
isLoading={isCurrentConversationLoading}
onFileRemove={handleFileRemove}
onFileUpload={handleFileUpload}
onSend={handleSendMessage}
onStop={() => chatStore.stopGeneration()}
onSystemPromptAdd={handleSystemPromptAdd}
bind:uploadedFiles
<div class="pointer-events-none flex flex-col gap-6 items-center w-full">
{#if (isMobile.current ? mobileScrollDownHint || isMobileUserScrolledUp : autoScroll.userScrolledUp) && page.url.hash.includes(ROUTES.CHAT) && page.params.id}
<ChatScreenActionScrollDown
onclick={() => {
mobileScrollDownHint = false;
scroll.chatScrollContainer?.scrollTo({
top: scroll.chatScrollContainer.scrollHeight,
behavior: 'smooth'
});
}}
/>
</div>
{/if}
{#if showProcessingInfo}
<ChatScreenProcessingInfo />
{/if}
</div>
<ChatScreenForm
class="pointer-events-auto conversation-chat-form"
disabled={hasPropsError || isEditing()}
{initialMessage}
isLoading={isCurrentConversationLoading}
onFileRemove={fileUpload.handleFileRemove}
onFileUpload={fileUpload.handleFileUpload}
onSend={handleSendMessage}
onStop={() => chatStore.stopGeneration()}
onSystemPromptAdd={handleSystemPromptAdd}
bind:uploadedFiles={fileUpload.uploadedFiles}
/>
</div>
</div>
{/if}
<DialogFileUploadError bind:open={showFileErrorDialog} {fileErrorData} />
<DialogConfirmation
bind:open={showDeleteDialog}
title="Delete Conversation"
description="Are you sure you want to delete this conversation? This action cannot be undone and will permanently remove all messages in this conversation."
confirmText="Delete"
cancelText="Cancel"
variant="destructive"
icon={Trash2}
onConfirm={handleDeleteConfirm}
onCancel={() => (showDeleteDialog = false)}
/>
<DialogEmptyFileAlert
bind:open={showEmptyFileDialog}
emptyFiles={emptyFileNames}
onOpenChange={(open) => {
if (!open) {
emptyFileNames = [];
}
}}
/>
<DialogChatError
message={activeErrorDialog?.message ?? ''}
contextInfo={activeErrorDialog?.contextInfo}
onOpenChange={handleErrorDialogOpenChange}
open={Boolean(activeErrorDialog)}
type={activeErrorDialog?.type ?? ErrorDialogType.SERVER}
<ChatScreenDialogsAndAlerts
{showDeleteDialog}
{handleDeleteConfirm}
{showEmptyFileDialog}
{emptyFileNames}
{activeErrorDialog}
{handleErrorDialogOpenChange}
{fileUpload}
/>
@@ -1,58 +1,18 @@
<script lang="ts">
import { ArrowDown } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import ActionIcon from '$lib/components/app/actions/ActionIcon.svelte';
interface Props {
container: HTMLDivElement | undefined;
hasProcessingInfoVisible: boolean;
}
let { container, hasProcessingInfoVisible }: Props = $props();
let show = $state(false);
let buttonBottom = $derived(hasProcessingInfoVisible ? '2rem' : '0');
function checkVisibility() {
if (!container) return;
const { scrollTop, scrollHeight, clientHeight } = container;
const distanceFromBottom = scrollHeight - clientHeight - scrollTop;
show = distanceFromBottom > clientHeight * 0.5;
}
function scrollToBottom() {
if (container) {
container.scrollTo({
top: container.scrollHeight,
behavior: 'smooth'
});
}
}
$effect(() => {
const c = container;
if (c) {
c.addEventListener('scroll', checkVisibility);
checkVisibility();
return () => {
c.removeEventListener('scroll', checkVisibility);
};
}
});
let { onclick }: { onclick: (e?: MouseEvent) => void } = $props();
</script>
<div class="relative z-50 mx-auto mb-4 flex max-w-[48rem] justify-center">
<Button
onclick={scrollToBottom}
variant="secondary"
size="icon"
disabled={!show}
class="pointer-events-auto absolute h-10 w-10 rounded-full bg-background/80 shadow-lg backdrop-blur-sm transition-all duration-200 hover:bg-muted/80"
style="bottom: {buttonBottom}; transform: translateY({show ? '0' : '2rem'}); opacity: {show
? 1
: 0};"
aria-label="Scroll to bottom"
>
<ArrowDown class="h-4 w-4" />
</Button>
<div class="pointer-events-auto flex justify-center relative h-0">
<ActionIcon
icon={ArrowDown}
{onclick}
ariaLabel="Scroll to bottom"
tooltip="Scroll to bottom"
size="lg"
iconSize="h-4 w-4"
class="h-9 w-9 rounded-full bg-accent text-accent-foreground absolute bottom-4 shadow-md"
/>
</div>
@@ -0,0 +1,55 @@
<script lang="ts">
import { Trash2 } from '@lucide/svelte';
import { ErrorDialogType } from '$lib/enums';
import {
DialogChatError,
DialogConfirmation,
DialogEmptyFileAlert,
DialogFileUploadError
} from '$lib/components/app';
let {
showDeleteDialog,
handleDeleteConfirm,
showEmptyFileDialog,
emptyFileNames,
activeErrorDialog,
handleErrorDialogOpenChange,
fileUpload
} = $props();
</script>
<DialogFileUploadError
bind:open={fileUpload.showFileErrorDialog}
fileErrorData={fileUpload.fileErrorData}
/>
<DialogConfirmation
bind:open={showDeleteDialog}
title="Delete Conversation"
description="Are you sure you want to delete this conversation? This action cannot be undone and will permanently remove all messages in this conversation."
confirmText="Delete"
cancelText="Cancel"
variant="destructive"
icon={Trash2}
onConfirm={handleDeleteConfirm}
onCancel={() => (showDeleteDialog = false)}
/>
<DialogEmptyFileAlert
bind:open={showEmptyFileDialog}
emptyFiles={emptyFileNames}
onOpenChange={(open) => {
if (!open) {
emptyFileNames = [];
}
}}
/>
<DialogChatError
message={activeErrorDialog?.message ?? ''}
contextInfo={activeErrorDialog?.contextInfo}
onOpenChange={handleErrorDialogOpenChange}
open={Boolean(activeErrorDialog)}
type={activeErrorDialog?.type ?? ErrorDialogType.SERVER}
/>
@@ -2,6 +2,7 @@
import { afterNavigate } from '$app/navigation';
import { page } from '$app/state';
import { ChatForm } from '$lib/components/app';
import { isMobile } from '$lib/stores/viewport.svelte';
import { onMount } from 'svelte';
import { useDraftMessages } from '$lib/hooks/use-draft-messages.svelte';
@@ -32,7 +33,30 @@
}: Props = $props();
let chatFormRef: ChatForm | undefined = $state(undefined);
let formWrapperEl: HTMLDivElement | undefined = $state();
let chatId = $derived(page.params.id as string | undefined);
$effect(() => {
if (!formWrapperEl) return;
const formEl = formWrapperEl.querySelector('form') as HTMLElement | null;
if (!formEl) return;
const updateHeight = () => {
const height = Math.round(formEl.getBoundingClientRect().height);
document.documentElement.style.setProperty('--chat-form-height', `${height}px`);
};
updateHeight();
const resizeObserver = new ResizeObserver(updateHeight);
resizeObserver.observe(formEl);
return () => {
resizeObserver.disconnect();
document.documentElement.style.removeProperty('--chat-form-height');
};
});
let hasLoadingAttachments = $derived(uploadedFiles.some((f) => f.isLoading));
let message = $derived(initialMessage);
let previousIsLoading = $derived(isLoading);
@@ -83,12 +107,14 @@
}
onMount(() => {
setTimeout(() => chatFormRef?.focus(), 10);
if (!isMobile.current) {
setTimeout(() => chatFormRef?.focus(), 100);
}
});
afterNavigate((navigation) => {
if (navigation?.from != null) {
setTimeout(() => chatFormRef?.focus(), 10);
if (navigation?.from != null && !isMobile.current) {
setTimeout(() => chatFormRef?.focus(), 100);
}
});
@@ -108,12 +134,12 @@
});
</script>
<div class="relative mx-auto max-w-[48rem]">
<div class="chat-screen-form-wrapper" bind:this={formWrapperEl}>
<ChatForm
class="mx-auto max-w-3xl {className}"
bind:this={chatFormRef}
bind:value={message}
bind:uploadedFiles
class={className}
{disabled}
{isLoading}
showMcpPromptButton
@@ -1,5 +1,4 @@
<script lang="ts">
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
import { serverStore } from '$lib/stores/server.svelte';
interface Props {
@@ -11,10 +10,9 @@
<div
class={[
'pointer-events-none mb-4 hidden px-4 text-center',
isEmpty && 'pointer-events-auto block!'
'pointer-events-none mb-4 hidden px-4 text-center text-balance',
isEmpty && 'mb-[calc(50dvh-8rem)] md:mb-6 pointer-events-auto block!'
]}
use:fadeInView={{ duration: 300 }}
>
<h1 class="mb-2 text-2xl font-semibold tracking-tight md:text-3xl">Hello there</h1>
@@ -5,13 +5,8 @@
import { chatStore, isLoading, isChatStreaming } from '$lib/stores/chat.svelte';
import { activeMessages, activeConversation } from '$lib/stores/conversations.svelte';
import { config } from '$lib/stores/settings.svelte';
import { getProcessingInfoContext } from '$lib/contexts';
import { page } from '$app/state';
const processingState = useProcessingState();
const processingInfoCtx = getProcessingInfoContext();
let showProcessingInfo = $derived(processingInfoCtx.showProcessingInfo);
let isCurrentConversationLoading = $derived(isLoading());
let isStreaming = $derived(isChatStreaming());
@@ -70,8 +65,8 @@
<div
class={[
'chat-processing-info-container pointer-events-none relative',
page.params.id && showProcessingInfo && 'visible'
'chat-processing-info-container pointer-events-none relative w-full hidden md:block',
processingVisible && 'visible'
]}
>
<div class="chat-processing-info-content absolute bottom-4 left-1/2 -translate-x-1/2">
@@ -677,13 +677,6 @@ export { default as ChatScreenForm } from './ChatScreen/ChatScreenForm.svelte';
*/
export { default as ChatScreenProcessingInfo } from './ChatScreen/ChatScreenProcessingInfo.svelte';
/**
* Scroll-to-bottom action button. Displays a floating button when the user
* has scrolled up more than half a viewport height from the bottom.
* Takes the chat container element as a prop to manage scroll state internally.
*/
export { default as ChatScreenActionScrollDown } from './ChatScreen/ChatScreenActionScrollDown.svelte';
/**
* Server error alert displayed when the server is unreachable.
* Shows the error message with a retry button.
@@ -3,6 +3,7 @@
import { Search, X } from '@lucide/svelte';
interface Props {
autofocus?: boolean;
value?: string;
placeholder?: string;
onInput?: (value: string) => void;
@@ -15,6 +16,7 @@
}
let {
autofocus,
value = $bindable(''),
placeholder = 'Search...',
onInput,
@@ -39,7 +41,7 @@
if (value) {
value = '';
onInput?.('');
ref?.focus();
ref?.focus({ preventScroll: true });
} else {
onClose?.();
}
@@ -52,6 +54,7 @@
/>
<Input
{autofocus}
{id}
bind:value
bind:ref
@@ -0,0 +1,15 @@
<script>
import logoMark from '$lib/assets/logo.svg?raw';
let { class: className = '', style = '' } = $props();
</script>
<div class={className} {style}>
{@html logoMark}
</div>
<style>
div :global(svg) {
width: var(--size, 1rem);
height: var(--size, 1rem);
}
</style>
@@ -51,3 +51,11 @@ export { default as KeyboardShortcutInfo } from './KeyboardShortcutInfo.svelte';
* Preview button is shown only for HTML code blocks.
*/
export { default as CodeBlockActions } from './CodeBlockActions.svelte';
/**
* **Logo** - Application brand mark
*
* Inline SVG of the application logo. Accepts styling via the standard
* `class` and `style` props and inherits color via `currentColor`.
*/
export { default as Logo } from './Logo.svelte';
@@ -0,0 +1,11 @@
<script lang="ts">
let { percent }: { percent: number } = $props();
</script>
<!-- thin determinate load bar pinned to the bottom edge, pulsing while it fills -->
<div class="pointer-events-none absolute inset-x-0 bottom-0 h-0.5 overflow-hidden rounded-b-sm">
<div
class="h-full animate-pulse bg-primary transition-[width] duration-200 ease-out"
style="width: {percent}%"
></div>
</div>
@@ -2,8 +2,10 @@
import { ChevronDown, Loader2, Package } from '@lucide/svelte';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import * as Tooltip from '$lib/components/ui/tooltip';
import { KeyboardKey } from '$lib/enums';
import { KeyboardKey, ServerModelStatus } from '$lib/enums';
import { useModelsSelector } from '$lib/hooks/use-models-selector.svelte';
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
import { modelLoadFraction } from '$lib/utils';
import {
DialogModelInformation,
DropdownMenuSearchable,
@@ -11,6 +13,7 @@
ModelsSelectorList,
ModelsSelectorOption
} from '$lib/components/app';
import ModelLoadHighlight from './ModelLoadHighlight.svelte';
import type { ModelItem } from './utils';
interface Props {
@@ -113,6 +116,17 @@
{/if}
{:else}
{@const selectedOption = ms.getDisplayOption()}
{@const triggerModel = selectedOption?.model}
{@const triggerStatus = triggerModel
? routerModels().find((m) => m.id === triggerModel)?.status?.value
: undefined}
{@const triggerLoading =
!!triggerModel &&
(triggerStatus === ServerModelStatus.LOADING ||
modelsStore.isModelOperationInProgress(triggerModel))}
{@const triggerLoadPercent = triggerLoading
? Math.round(modelLoadFraction(modelsStore.getLoadProgress(triggerModel)) * 100)
: 0}
{#if ms.isRouter}
<DropdownMenu.Root bind:open={isOpen} onOpenChange={ms.handleOpenChange}>
@@ -123,7 +137,7 @@
<DropdownMenu.Trigger
{...props}
class={[
`inline-grid cursor-pointer grid-cols-[1fr_auto_1fr] items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
`relative inline-grid cursor-pointer grid-cols-[1fr_auto_1fr] items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
!ms.isCurrentModelInCache
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
: forceForegroundText
@@ -154,6 +168,10 @@
{:else}
<ChevronDown class="h-3 w-3.5 shrink-0" />
{/if}
{#if triggerLoading}
<ModelLoadHighlight percent={triggerLoadPercent} />
{/if}
</DropdownMenu.Trigger>
{/snippet}
</Tooltip.Trigger>
@@ -10,6 +10,7 @@
RotateCw
} from '@lucide/svelte';
import { ActionIcon, ModelId } from '$lib/components/app';
import ModelLoadHighlight from './ModelLoadHighlight.svelte';
import type { ModelOption } from '$lib/types/models';
import { ServerModelStatus } from '$lib/enums';
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
@@ -119,11 +120,11 @@
</div>
{#if isLoading}
<div class="flex w-4 [@media(pointer:coarse)]:w-5 items-center justify-center">
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-5">
<Loader2 class="h-4 w-4 animate-spin text-muted-foreground" />
</div>
{:else if isFailed}
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
<CircleAlert
class="h-3.5 w-3.5 text-red-500 group-hover:hidden [@media(pointer:coarse)]:hidden"
/>
@@ -140,7 +141,7 @@
</div>
</div>
{:else if isSleeping}
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
<span
class="h-2 w-2 rounded-full bg-orange-400 group-hover:hidden [@media(pointer:coarse)]:hidden"
></span>
@@ -159,7 +160,7 @@
</div>
</div>
{:else if isLoaded}
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
<span
class="h-2 w-2 rounded-full bg-green-500 group-hover:hidden [@media(pointer:coarse)]:hidden"
></span>
@@ -176,7 +177,7 @@
</div>
</div>
{:else}
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
<span
class="h-2 w-2 rounded-full bg-muted-foreground/50 group-hover:hidden [@media(pointer:coarse)]:hidden"
></span>
@@ -196,13 +197,6 @@
</div>
{#if isLoading}
<div
class="pointer-events-none absolute inset-x-0 bottom-0 h-0.5 overflow-hidden rounded-b-sm bg-muted"
>
<div
class="h-full bg-primary transition-[width] duration-200 ease-out"
style="width: {loadPercent}%"
></div>
</div>
<ModelLoadHighlight percent={loadPercent} />
{/if}
</div>
@@ -8,6 +8,10 @@
ModelsSelectorList,
SearchInput
} from '$lib/components/app';
import ModelLoadHighlight from './ModelLoadHighlight.svelte';
import { ServerModelStatus } from '$lib/enums';
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
import { modelLoadFraction } from '$lib/utils';
interface Props {
class?: string;
@@ -61,12 +65,23 @@
<p class="text-xs text-muted-foreground">No models available.</p>
{:else}
{@const selectedOption = ms.getDisplayOption()}
{@const triggerModel = selectedOption?.model}
{@const triggerStatus = triggerModel
? routerModels().find((m) => m.id === triggerModel)?.status?.value
: undefined}
{@const triggerLoading =
!!triggerModel &&
(triggerStatus === ServerModelStatus.LOADING ||
modelsStore.isModelOperationInProgress(triggerModel))}
{@const triggerLoadPercent = triggerLoading
? Math.round(modelLoadFraction(modelsStore.getLoadProgress(triggerModel)) * 100)
: 0}
{#if ms.isRouter}
<button
type="button"
class={[
`inline-flex cursor-pointer items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 max-sm:px-3 max-sm:py-2 text-xs max-sm:text-sm shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
`relative inline-flex cursor-pointer items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 max-sm:px-3 max-sm:py-2 max-sm:text-sm dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
!ms.isCurrentModelInCache
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
: forceForegroundText
@@ -99,6 +114,10 @@
{:else}
<ChevronDown class="h-3 w-3.5 shrink-0" />
{/if}
{#if triggerLoading}
<ModelLoadHighlight percent={triggerLoadPercent} />
{/if}
</button>
<Sheet.Root bind:open={sheetOpen} onOpenChange={handleSheetOpenChange}>
@@ -1,84 +0,0 @@
<script lang="ts">
import { goto } from '$app/navigation';
import { page } from '$app/state';
import { ActionIcon } from '$lib/components/app';
import {
ICON_STRIP_TRANSITION_DURATION,
ICON_STRIP_TRANSITION_DELAY_MULTIPLIER,
SIDEBAR_ACTIONS_ITEMS
} from '$lib/constants';
import { TooltipSide } from '$lib/enums';
import { fade } from 'svelte/transition';
import { circIn } from 'svelte/easing';
import { onMount } from 'svelte';
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
interface Props {
sidebarOpen: boolean;
onSearchClick: () => void;
}
let { sidebarOpen = false, onSearchClick }: Props = $props();
const { handleKeydown } = useKeyboardShortcuts({ activateSearchMode: () => onSearchClick() });
let initialized = $state(false);
let showIcons = $derived(!sidebarOpen);
showIcons = false;
onMount(() => {
showIcons = !sidebarOpen;
setTimeout(() => {
initialized = true;
}, ICON_STRIP_TRANSITION_DELAY_MULTIPLIER * SIDEBAR_ACTIONS_ITEMS.length);
});
</script>
<svelte:window onkeydown={handleKeydown} />
<div
class="hidden shrink-0 transition-[width] duration-200 ease-linear md:block {sidebarOpen
? 'w-0'
: 'w-[calc(var(--sidebar-width-icon)+1.5rem)]'}"
></div>
<aside
class="fixed top-0 bottom-0 left-0 z-10 hidden w-[calc(var(--sidebar-width-icon)+1.5rem)] flex-col items-center justify-between py-3 transition-opacity duration-200 ease-linear md:flex {sidebarOpen
? 'pointer-events-none opacity-0'
: 'opacity-100'}"
>
<div class="mt-12 flex flex-col items-center gap-1">
{#each SIDEBAR_ACTIONS_ITEMS as item, i (item.tooltip)}
{@const onclick = item.route ? () => goto(item.route!) : onSearchClick}
{@const isActive = item.activeRouteId
? page.route.id === item.activeRouteId
: item.activeRoutePrefix
? !!page.route.id?.startsWith(item.activeRoutePrefix)
: false}
{#if showIcons}
<div
in:fade={{
duration: ICON_STRIP_TRANSITION_DURATION,
delay: !initialized
? ICON_STRIP_TRANSITION_DELAY_MULTIPLIER + i * ICON_STRIP_TRANSITION_DELAY_MULTIPLIER
: 0,
easing: circIn
}}
>
<ActionIcon
icon={item.icon}
tooltip={item.tooltip}
tooltipSide={TooltipSide.RIGHT}
size="lg"
iconSize="h-4 w-4"
class="h-9 w-9 rounded-full hover:bg-accent! {isActive
? 'bg-accent text-accent-foreground'
: ''}"
{onclick}
/>
</div>
{/if}
{/each}
</div>
</aside>
@@ -1,40 +1,67 @@
<script lang="ts">
import { goto } from '$app/navigation';
import { page } from '$app/state';
import { Trash2, Pencil, Pin, X } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import { DialogConfirmation } from '$lib/components/app';
import SidebarNavigationActions from './SidebarNavigationActions.svelte';
import SidebarNavigationConversationItem from './SidebarNavigationConversationItem.svelte';
import { Checkbox } from '$lib/components/ui/checkbox';
import Label from '$lib/components/ui/label/label.svelte';
import ScrollArea from '$lib/components/ui/scroll-area/scroll-area.svelte';
import * as Sidebar from '$lib/components/ui/sidebar';
import Input from '$lib/components/ui/input/input.svelte';
import { ROUTES } from '$lib/constants/routes';
import { RouterService } from '$lib/services/router.service';
import { PanelLeftClose, PanelLeftOpen, X } from '@lucide/svelte';
import {
conversationsStore,
conversations,
buildConversationTree
} from '$lib/stores/conversations.svelte';
import { chatStore } from '$lib/stores/chat.svelte';
import { getPreviewText } from '$lib/utils';
import { APP_NAME } from '$lib/constants';
ActionIcon,
Logo,
SidebarNavigationConversationList,
SidebarNavigationActions
} from '$lib/components/app';
import { ROUTES } from '$lib/constants';
import { fade } from 'svelte/transition';
const sidebar = Sidebar.useSidebar();
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
import { chatStore } from '$lib/stores/chat.svelte';
import { RouterService } from '$lib/services/router.service';
import { isMobile } from '$lib/stores/viewport.svelte';
import { TooltipSide } from '$lib/enums';
import { device } from '$lib/stores/device.svelte';
import { circIn } from 'svelte/easing';
interface Props {
onSearchClick?: () => void;
}
let { onSearchClick = () => {} }: Props = $props();
const { handleKeydown } = useKeyboardShortcuts({ activateSearchMode: () => onSearchClick() });
let isExpandedMode = $state(false);
let hoveredTooltip = $state<string | null>(null);
let logoHovered = $state(false);
const isStripExpanded = $derived(isExpandedMode || hoveredTooltip !== null);
const isOnMobile = $derived(isMobile.current);
function toggleExpandedMode() {
isExpandedMode = !isExpandedMode;
if (!isExpandedMode) {
hoveredTooltip = null;
}
}
$effect(() => {
if (!isExpandedMode) {
isSearchModeActive = false;
searchQuery = '';
cancelMobileCollapse();
}
});
// On mobile the dedicated /search route hides the sidebar (see the aside
// render guard below). Collapse it as we enter /search so it doesn't
// reappear expanded when the user navigates back via the back button.
$effect(() => {
if (isMobile.current && page.url.hash.includes(ROUTES.SEARCH)) {
isExpandedMode = false;
}
});
let currentChatId = $derived(page.params.id);
let isSearchModeActive = $state(false);
let searchQuery = $state('');
let showDeleteDialog = $state(false);
let deleteWithForks = $state(false);
let showEditDialog = $state(false);
let selectedConversation = $state<DatabaseConversation | null>(null);
let editedName = $state('');
let selectedConversationNamePreview = $derived.by(() =>
selectedConversation ? getPreviewText(selectedConversation.name) : ''
);
let filteredConversations = $derived.by(() => {
if (isSearchModeActive) {
@@ -50,294 +77,206 @@
return conversations();
});
let conversationTree = $derived(buildConversationTree(filteredConversations));
let pinnedConversations = $derived.by(() => {
return conversationTree.filter(({ conversation }) => conversation.pinned);
});
let unpinnedConversations = $derived.by(() => {
return conversationTree.filter(({ conversation }) => !conversation.pinned);
});
let selectedConversationHasDescendants = $derived.by(() => {
if (!selectedConversation) return false;
const allConvs = conversations();
const queue = [selectedConversation.id];
while (queue.length > 0) {
const parentId = queue.pop()!;
for (const c of allConvs) {
if (c.forkedFromConversationId === parentId) return true;
}
}
return false;
});
async function handleDeleteConversation(id: string) {
const conversation = conversations().find((conv) => conv.id === id);
if (conversation) {
selectedConversation = conversation;
deleteWithForks = false;
showDeleteDialog = true;
async function selectConversation(id: string) {
if (isMobile.current) {
scheduleMobileCollapse();
}
await goto(RouterService.chat(id));
}
async function handleEditConversation(id: string) {
const conversation = conversations().find((conv) => conv.id === id);
if (conversation) {
selectedConversation = conversation;
editedName = conversation.name;
showEditDialog = true;
if (!conversation) return;
const newName = window.prompt('Rename conversation', conversation.name);
if (newName && newName.trim()) {
await conversationsStore.updateConversationName(id, newName.trim());
}
}
function handleConfirmDelete() {
if (selectedConversation) {
const convId = selectedConversation.id;
const withForks = deleteWithForks;
showDeleteDialog = false;
async function handleDeleteConversation(id: string) {
const conversation = conversations().find((conv) => conv.id === id);
if (!conversation) return;
setTimeout(() => {
conversationsStore.deleteConversation(convId, {
deleteWithForks: withForks
});
}, 100); // Wait for animation to finish
}
}
const confirmed = window.confirm(
`Delete "${conversation.name}"? This action cannot be undone.`
);
if (!confirmed) return;
function handleConfirmEdit() {
if (!editedName.trim() || !selectedConversation) return;
showEditDialog = false;
conversationsStore.updateConversationName(selectedConversation.id, editedName);
selectedConversation = null;
}
export function handleMobileSidebarItemClick() {
if (sidebar.isMobile) {
sidebar.toggle();
}
}
let chatSidebarActions: { activateSearch?: () => void } | undefined = $state();
let openedForSearch = $state(false);
export function activateSearchMode() {
if (!sidebar.open) {
openedForSearch = true;
}
chatSidebarActions?.activateSearch?.();
}
function handleSearchDeactivated() {
if (openedForSearch) {
openedForSearch = false;
sidebar.toggle();
}
}
$effect(() => {
if (!sidebar.open) {
isSearchModeActive = false;
searchQuery = '';
openedForSearch = false;
}
});
export function editActiveConversation() {
if (currentChatId) {
const activeConversation = filteredConversations.find((conv) => conv.id === currentChatId);
if (activeConversation) {
const event = new CustomEvent('edit-active-conversation', {
detail: { conversationId: currentChatId }
});
document.dispatchEvent(event);
}
}
}
async function selectConversation(id: string) {
if (isSearchModeActive) {
isSearchModeActive = false;
searchQuery = '';
}
handleMobileSidebarItemClick();
await goto(RouterService.chat(id));
await conversationsStore.deleteConversation(id, { deleteWithForks: false });
}
function handleStopGeneration(id: string) {
chatStore.stopGenerationForChat(id);
}
let innerWidth = $state(0);
let pendingCollapse = $state<ReturnType<typeof setTimeout> | null>(null);
function scheduleMobileCollapse() {
if (pendingCollapse) {
clearTimeout(pendingCollapse);
}
pendingCollapse = setTimeout(() => {
isExpandedMode = false;
pendingCollapse = null;
}, 100);
}
function cancelMobileCollapse() {
if (pendingCollapse) {
clearTimeout(pendingCollapse);
pendingCollapse = null;
}
}
</script>
<div class="flex h-full flex-col">
<ScrollArea class="h-full flex-1">
<Sidebar.Header class="gap-4 bg-sidebar/50 p-3 backdrop-blur-lg md:pt-4 md:pb-2">
<div class="flex items-center justify-between">
<a href={ROUTES.START} onclick={handleMobileSidebarItemClick}>
<h1 class="inline-flex items-center gap-1 px-2 text-xl font-semibold">
{APP_NAME}
</h1>
</a>
<svelte:window onkeydown={handleKeydown} bind:innerWidth />
<Button
class="rounded-full md:hidden"
variant="ghost"
size="icon"
onclick={() => sidebar.toggle()}
>
<X class="h-4 w-4" />
<span class="sr-only">Close sidebar</span>
</Button>
{#if innerWidth > 768 || (!page.url.hash.includes(ROUTES.SETTINGS) && !page.url.hash.includes(ROUTES.MCP_SERVERS) && !page.url.hash.includes(ROUTES.SEARCH))}
<aside
class={[
// Layout & positioning
'fixed md:sticky top-2 left-2 md:left-0 md:ml-2 md:mt-2 pt-2 z-10 w-[calc(100dvw-1rem)]',
// Dimensions & overflow
'md:h-[calc(100dvh-1.125rem)]',
isExpandedMode &&
(device.isStandalone
? 'h-[calc(100dvh-2rem)]'
: device.isIOSDevice
? 'h-[calc(100dvh-0.5rem)]'
: 'h-[calc(100dvh-1rem)]'),
// Shape & depth
'rounded-3xl md:rounded-2xl',
// Flex layout
'flex flex-col justify-between',
// Transition
'md:transition-[width,padding] duration-200 ease-out',
// Expanded state: width, surface, depth
isStripExpanded && 'md:w-72 md:bg-muted/60 md:backdrop-blur-xl border-border shadow-md',
// Collapsed state
!isStripExpanded && 'md:w-12',
// Expanded mode flag (for mobile ::before overlay)
isExpandedMode && 'is-expanded'
]}
>
<div class="px-2 flex items-center justify-between">
<div
role="button"
tabindex="0"
class="relative"
onmouseenter={() => (logoHovered = true)}
onmouseleave={() => (logoHovered = false)}
>
<ActionIcon
icon={!isExpandedMode && logoHovered && innerWidth > 768 ? PanelLeftOpen : Logo}
size="lg"
iconSize="h-4.5 w-4.5 md:h-4 md:w-4"
class="{isExpandedMode
? 'bg-muted! md:bg-foreground/5!'
: 'bg-transparent!'} md:h-9 md:w-9 h-10 w-10 rounded-full md:hover:bg-foreground/10! pointer-events-auto"
href={isExpandedMode ? ROUTES.START : undefined}
onclick={isExpandedMode ? undefined : toggleExpandedMode}
tooltip={isExpandedMode ? undefined : 'Open Sidebar'}
tooltipSide={TooltipSide.RIGHT}
ariaLabel={isExpandedMode ? 'Go to start' : 'Expand navigation'}
/>
</div>
<SidebarNavigationActions
bind:this={chatSidebarActions}
{handleMobileSidebarItemClick}
bind:isSearchModeActive
bind:searchQuery
onSearchDeactivated={handleSearchDeactivated}
/>
</Sidebar.Header>
{#if !isSearchModeActive && pinnedConversations.length > 0}
<Sidebar.Group class="p-0 px-4">
<Sidebar.GroupLabel>
<div class="flex items-center gap-1">
<Pin class="h-3.5 w-3.5" />
<span>Pinned</span>
</div>
</Sidebar.GroupLabel>
<Sidebar.GroupContent>
<Sidebar.Menu>
{#each pinnedConversations as { conversation, depth } (conversation.id)}
<Sidebar.MenuItem class="mb-1 p-0">
<SidebarNavigationConversationItem
conversation={{
id: conversation.id,
name: conversation.name,
lastModified: conversation.lastModified,
currNode: conversation.currNode,
forkedFromConversationId: conversation.forkedFromConversationId,
pinned: conversation.pinned
}}
{depth}
isActive={currentChatId === conversation.id}
onSelect={selectConversation}
onEdit={handleEditConversation}
onDelete={handleDeleteConversation}
onStop={handleStopGeneration}
/>
</Sidebar.MenuItem>
{/each}
</Sidebar.Menu>
</Sidebar.GroupContent>
</Sidebar.Group>
{/if}
<Sidebar.Group class="mt-2 h-[calc(100vh-21rem)] space-y-2 p-0 px-3">
{#if (filteredConversations.length > 0 && isSearchModeActive) || !isSearchModeActive}
<Sidebar.GroupLabel>
{isSearchModeActive ? 'Search results' : 'Recent conversations'}
</Sidebar.GroupLabel>
{#if isExpandedMode || isOnMobile}
<div
class="flex items-center transition-all duration-150 ease-out {isMobile.current &&
!isExpandedMode
? 'opacity-0 h-0!'
: ''}"
in:fade={{ duration: 150, easing: circIn, delay: 50 }}
out:fade={{ duration: 100 }}
>
<ActionIcon
icon={isMobile.current ? X : PanelLeftClose}
size="lg"
iconSize="h-4.5 w-4.5 md:h-4 md:w-4"
class="backdrop-blur-none md:h-9 md:w-9 h-10 w-10 rounded-full mr-1 hover:bg-accent!"
onclick={toggleExpandedMode}
tooltip="Close Sidebar"
tooltipSide={TooltipSide.LEFT}
ariaLabel="Collapse navigation"
/>
</div>
{/if}
<Sidebar.GroupContent>
<Sidebar.Menu>
{#each isSearchModeActive ? conversationTree : unpinnedConversations as { conversation, depth } (conversation.id)}
<Sidebar.MenuItem class="mb-1 p-0">
<SidebarNavigationConversationItem
conversation={{
id: conversation.id,
name: conversation.name,
lastModified: conversation.lastModified,
currNode: conversation.currNode,
forkedFromConversationId: conversation.forkedFromConversationId,
pinned: conversation.pinned
}}
{depth}
isActive={currentChatId === conversation.id}
onSelect={selectConversation}
onEdit={handleEditConversation}
onDelete={handleDeleteConversation}
onStop={handleStopGeneration}
/>
</Sidebar.MenuItem>
{/each}
{#if (isSearchModeActive ? conversationTree : unpinnedConversations).length === 0}
<div class="px-2 py-4 text-center">
<p class="mb-4 p-4 text-sm text-muted-foreground">
{searchQuery.length > 0
? 'No results found'
: isSearchModeActive
? 'Start typing to see results'
: 'No conversations yet'}
</p>
</div>
{/if}
</Sidebar.Menu>
</Sidebar.GroupContent>
</Sidebar.Group>
</ScrollArea>
</div>
<DialogConfirmation
bind:open={showDeleteDialog}
title="Delete Conversation"
description={selectedConversation
? `Are you sure you want to delete "${selectedConversationNamePreview}"? This action cannot be undone and will permanently remove all messages in this conversation.`
: ''}
confirmText="Delete"
cancelText="Cancel"
variant="destructive"
icon={Trash2}
onConfirm={handleConfirmDelete}
onCancel={() => {
showDeleteDialog = false;
selectedConversation = null;
}}
>
{#if selectedConversationHasDescendants}
<div class="flex items-center gap-2 py-2">
<Checkbox id="delete-with-forks" bind:checked={deleteWithForks} />
<Label for="delete-with-forks" class="text-sm">Also delete all forked conversations</Label>
</div>
{/if}
</DialogConfirmation>
<DialogConfirmation
bind:open={showEditDialog}
title="Edit Conversation Name"
description=""
confirmText="Save"
cancelText="Cancel"
icon={Pencil}
onConfirm={handleConfirmEdit}
onCancel={() => {
showEditDialog = false;
selectedConversation = null;
}}
onKeydown={(event) => {
if (event.key === 'Enter') {
event.preventDefault();
event.stopImmediatePropagation();
handleConfirmEdit();
<div class="mt-2 flex min-h-0 flex-1 flex-col gap-4 md:gap-1 overflow-y-auto">
<div
class="flex min-h-0 flex-1 flex-col gap-4 md:gap-1 {isMobile.current
? 'transition-[opacity,height] duration-200 ease-out'
: ''} {isMobile.current && !isExpandedMode ? 'opacity-0 !h-0' : ''}"
in:fade={{ duration: 200 }}
out:fade={{ duration: 200 }}
>
<SidebarNavigationActions
isExpandedMode={innerWidth > 768 ? isExpandedMode : true}
class="px-2"
bind:isSearchModeActive
bind:searchQuery
onSearchDeactivated={() => {
isSearchModeActive = false;
searchQuery = '';
}}
onSearchClick={() => {
isExpandedMode = true;
isSearchModeActive = true;
}}
onNewChat={() => {
if (isMobile.current) {
scheduleMobileCollapse();
}
}}
/>
{#if isExpandedMode || isOnMobile}
<SidebarNavigationConversationList
class="px-2"
{filteredConversations}
{currentChatId}
{isSearchModeActive}
{searchQuery}
onSelect={selectConversation}
onEdit={handleEditConversation}
onDelete={handleDeleteConversation}
onStop={handleStopGeneration}
/>
{/if}
</div>
</div>
</aside>
{/if}
<style>
aside {
@media (max-width: 768px) {
--size: 1.125rem;
}
}}
>
<Input
class="text-foreground"
placeholder="Enter a new name"
type="text"
bind:value={editedName}
/>
</DialogConfirmation>
}
@media (max-width: 768px) {
aside {
&:not(.is-expanded) {
pointer-events: none;
}
}
aside.is-expanded::before {
content: '';
position: fixed;
top: -0.5rem;
bottom: -0.25rem;
left: -0.5rem;
right: -0.5rem;
z-index: -1;
background: var(--background);
backdrop-filter: blur(1rem);
pointer-events: none;
}
}
</style>
@@ -1,39 +1,86 @@
<script lang="ts">
import { KeyboardShortcutInfo } from '$lib/components/app';
import { Button } from '$lib/components/ui/button';
import type { Component } from 'svelte';
import { SearchInput } from '$lib/components/app';
import { goto } from '$app/navigation';
import { page } from '$app/state';
import { SIDEBAR_ACTIONS_ITEMS } from '$lib/constants/ui';
import { Search } from '@lucide/svelte';
import { ActionIcon, KeyboardShortcutInfo, SearchInput } from '$lib/components/app';
import { Button } from '$lib/components/ui/button';
import {
ICON_STRIP_TRANSITION_DURATION,
ICON_STRIP_TRANSITION_DELAY_MULTIPLIER,
ROUTES,
SIDEBAR_ACTIONS_ITEMS
} from '$lib/constants';
import { isMobile } from '$lib/stores/viewport.svelte';
import { TooltipSide } from '$lib/enums';
import { fade } from 'svelte/transition';
import { circIn } from 'svelte/easing';
import { onMount } from 'svelte';
import type { Component } from 'svelte';
interface Props {
handleMobileSidebarItemClick: () => void;
class: string;
isExpandedMode: boolean;
isSearchModeActive: boolean;
searchQuery: string;
isCancelAlwaysVisible?: boolean;
onSearchDeactivated?: () => void;
onSearchClick?: () => void;
onNewChat?: () => void;
}
let {
handleMobileSidebarItemClick,
isSearchModeActive = $bindable(),
searchQuery = $bindable(),
isCancelAlwaysVisible = false,
onSearchDeactivated
class: className,
isExpandedMode = false,
isSearchModeActive = $bindable(false),
searchQuery = $bindable(''),
onSearchDeactivated,
onSearchClick,
onNewChat
}: Props = $props();
let initialized = $state(false);
let showIcons = $state(false);
let searchInputRef = $state<HTMLInputElement | null>(null);
const isOnMobile = $derived(isMobile.current);
$effect(() => {
if (isSearchModeActive && searchInputRef) {
searchInputRef.focus();
}
});
onMount(() => {
showIcons = true;
setTimeout(() => {
initialized = true;
}, ICON_STRIP_TRANSITION_DELAY_MULTIPLIER * SIDEBAR_ACTIONS_ITEMS.length);
});
function handleSearchModeDeactivate() {
isSearchModeActive = false;
searchQuery = '';
onSearchDeactivated?.();
}
export function activateSearch() {
isSearchModeActive = true;
// Focus after Svelte renders the input
queueMicrotask(() => searchInputRef?.focus());
function isItemActive(item: {
activeRouteId?: string;
activeRoutePrefix?: string;
activeUrlIncludes?: string;
}): boolean {
if (item.activeRouteId) {
return page.route.id === item.activeRouteId;
}
if (item.activeRoutePrefix) {
return !!page.route.id?.startsWith(item.activeRoutePrefix);
}
if (item.activeUrlIncludes) {
return page.url?.hash?.includes(item.activeUrlIncludes) ?? false;
}
return false;
}
</script>
@@ -41,56 +88,109 @@
<IconComponent class="h-4 w-4" />
{/snippet}
<div class="my-1 space-y-1">
{#if isSearchModeActive}
{#if isSearchModeActive}
<div class="px-4 my-2">
<SearchInput
bind:value={searchQuery}
bind:ref={searchInputRef}
onClose={handleSearchModeDeactivate}
onKeyDown={(e) => e.key === 'Escape' && handleSearchModeDeactivate()}
placeholder="Search conversations..."
{isCancelAlwaysVisible}
/>
{:else}
{#each SIDEBAR_ACTIONS_ITEMS as item (item.route)}
{#if !item.route}
<Button
class="w-full justify-between px-2 backdrop-blur-none! hover:[&>kbd]:opacity-100"
onclick={activateSearch}
variant="ghost"
>
<div class="flex items-center gap-2">
{@render itemIcon(item.icon)}
</div>
{:else if isExpandedMode || isOnMobile}
<div
class="{className} flex flex-col gap-5 md:gap-1 mt-2 md:mt-0 {!isExpandedMode && isOnMobile
? 'hidden pointer-events-none'
: ''}"
>
{#each SIDEBAR_ACTIONS_ITEMS as item, i (item.tooltip)}
{@const isActive = isItemActive(item)}
{@const isSearchOnMobile = item.icon === Search && isMobile.current}
{@const itemHref = isSearchOnMobile ? ROUTES.SEARCH : item.route}
{@const itemOnClick = item.route
? () => {
onNewChat?.();
goto(item.route!);
}
: isSearchOnMobile
? undefined
: onSearchClick}
{@const itemTransition = {
duration: ICON_STRIP_TRANSITION_DURATION,
delay: !initialized
? ICON_STRIP_TRANSITION_DELAY_MULTIPLIER + i * ICON_STRIP_TRANSITION_DELAY_MULTIPLIER
: 0,
easing: circIn
}}
{item.tooltip}
</div>
{#if showIcons}
<div transition:fade={itemTransition}>
<Button
class="w-full min-w-9 justify-between px-2 backdrop-blur-none! hover:[&>kbd]:opacity-100 {isActive
? 'bg-accent text-accent-foreground'
: ''}"
href={itemHref}
onclick={itemOnClick}
variant="ghost"
size="default"
>
<span class="flex min-w-0 items-center px-0.5 gap-2">
{@render itemIcon(item.icon)}
{#if item.keys}
<KeyboardShortcutInfo keys={item.keys} />
{/if}
</Button>
{:else}
<Button
class="w-full justify-between px-2 backdrop-blur-none! hover:[&>kbd]:opacity-100 {(item.activeRouteId &&
page.route.id === item.activeRouteId) ||
(item.activeRoutePrefix && page.route.id?.startsWith(item.activeRoutePrefix))
? 'bg-accent text-accent-foreground'
: ''}"
href={item.route}
onclick={handleMobileSidebarItemClick}
variant="ghost"
>
<div class="flex items-center gap-2">
{@render itemIcon(item.icon)}
{#if showIcons}
<span
in:fade={{ duration: 150, easing: circIn, delay: 50 }}
out:fade={{ duration: 100 }}
class="min-w-0 truncate">{item.tooltip}</span
>
{/if}
</span>
{item.tooltip}
</div>
{#if item.keys}
<KeyboardShortcutInfo keys={item.keys} />
{/if}
</Button>
{#if item.keys}
<KeyboardShortcutInfo keys={item.keys} />
{/if}
</Button>
</div>
{/if}
{/each}
{/if}
</div>
</div>
{:else}
<div class="{className} flex-col gap-1 hidden md:flex">
{#each SIDEBAR_ACTIONS_ITEMS as item, i (item.tooltip)}
{@const isActive = isItemActive(item)}
{@const isSearchOnMobile = item.icon === Search && isMobile.current}
{@const itemOnClick = item.route
? () => {
onNewChat?.();
goto(item.route!);
}
: isSearchOnMobile
? undefined
: onSearchClick}
{@const itemTransition = {
duration: ICON_STRIP_TRANSITION_DURATION,
delay: !initialized
? ICON_STRIP_TRANSITION_DELAY_MULTIPLIER + i * ICON_STRIP_TRANSITION_DELAY_MULTIPLIER
: 0,
easing: circIn
}}
{#if showIcons}
<div transition:fade={itemTransition}>
<ActionIcon
icon={item.icon}
tooltip={item.tooltip}
tooltipSide={TooltipSide.RIGHT}
size="lg"
iconSize="h-4 w-4"
class="h-9 w-9 rounded-full hover:bg-accent! {isActive
? 'bg-accent text-accent-foreground'
: ''}"
onclick={itemOnClick}
/>
</div>
{/if}
{/each}
</div>
{/if}
@@ -0,0 +1,135 @@
<script lang="ts">
import { Pin } from '@lucide/svelte';
import { buildConversationTree } from '$lib/stores/conversations.svelte';
import SidebarNavigationConversationItem from './SidebarNavigationConversationItem.svelte';
import SidebarNavigationSearchResults from './SidebarNavigationSearchResults.svelte';
interface Props {
class: string;
filteredConversations: DatabaseConversation[];
currentChatId: string | undefined;
isSearchModeActive: boolean;
searchQuery: string;
onSelect: (id: string) => void;
onEdit: (id: string) => void;
onDelete: (id: string) => void;
onStop: (id: string) => void;
}
let {
class: className,
filteredConversations,
currentChatId,
isSearchModeActive,
searchQuery,
onSelect,
onEdit,
onDelete,
onStop
}: Props = $props();
let conversationTree = $derived(buildConversationTree(filteredConversations));
let pinnedConversations = $derived(
conversationTree.filter(({ conversation }) => conversation.pinned)
);
let unpinnedConversations = $derived(
conversationTree.filter(({ conversation }) => !conversation.pinned)
);
const recentEmptyMessage = $derived(
searchQuery.length > 0 ? 'No results found' : 'No conversations yet'
);
</script>
{#if isSearchModeActive}
<SidebarNavigationSearchResults
class={className}
{searchQuery}
{filteredConversations}
{currentChatId}
{onSelect}
{onEdit}
{onDelete}
{onStop}
/>
{:else}
{#if pinnedConversations.length > 0}
<div class="py-2 flex whitespace-nowrap {className}">
<div
class="text-muted-foreground inline-flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium gap-1"
>
<Pin class="h-3.5 w-3.5" />
<span>Pinned</span>
</div>
</div>
<ul class="flex w-full min-w-0 flex-col gap-4 md:gap-1 {className}">
{#each pinnedConversations as { conversation, depth } (conversation.id)}
<li class="group/item relative mb-1 p-0">
<SidebarNavigationConversationItem
conversation={{
id: conversation.id,
name: conversation.name,
lastModified: conversation.lastModified,
currNode: conversation.currNode,
forkedFromConversationId: conversation.forkedFromConversationId,
pinned: conversation.pinned
}}
{depth}
isActive={currentChatId === conversation.id}
{onSelect}
{onEdit}
{onDelete}
{onStop}
/>
</li>
{/each}
</ul>
{/if}
<div class="mt-2 flex min-h-0 flex-1 flex-col gap-4 md:gap-2 whitespace-nowrap {className}">
{#if filteredConversations.length > 0}
<div
class="text-muted-foreground flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium"
>
Recent conversations
</div>
{/if}
<div class="min-h-0 flex-1 md:overflow-y-auto">
<ul class="flex w-full min-w-0 flex-col gap-4 md:gap-1">
{#each unpinnedConversations as { conversation, depth } (conversation.id)}
<li class="group/item relative mb-1 p-0">
<SidebarNavigationConversationItem
conversation={{
id: conversation.id,
name: conversation.name,
lastModified: conversation.lastModified,
currNode: conversation.currNode,
forkedFromConversationId: conversation.forkedFromConversationId,
pinned: conversation.pinned
}}
{depth}
isActive={currentChatId === conversation.id}
{onSelect}
{onEdit}
{onDelete}
{onStop}
/>
</li>
{/each}
{#if unpinnedConversations.length === 0}
<li class="px-2 py-4 text-center">
<p class="mb-4 p-4 text-sm text-muted-foreground">
{recentEmptyMessage}
</p>
</li>
{/if}
</ul>
</div>
</div>
{/if}
@@ -16,4 +16,6 @@
}: Props = $props();
</script>
<SearchInput bind:value {placeholder} {onInput} class="mb-4 {className}" />
<div class="mb-4 px-2 {className}">
<SearchInput bind:value {placeholder} {onInput} />
</div>
@@ -0,0 +1,76 @@
<script lang="ts">
import { buildConversationTree } from '$lib/stores/conversations.svelte';
import SidebarNavigationConversationItem from './SidebarNavigationConversationItem.svelte';
interface Props {
class?: string;
searchQuery: string;
filteredConversations: DatabaseConversation[];
currentChatId: string | undefined;
onSelect: (id: string) => void;
onEdit: (id: string) => void;
onDelete: (id: string) => void;
onStop: (id: string) => void;
}
let {
class: className = '',
searchQuery,
filteredConversations,
currentChatId,
onSelect,
onEdit,
onDelete,
onStop
}: Props = $props();
let tree = $derived(buildConversationTree(filteredConversations));
const hasQuery = $derived(searchQuery.trim().length > 0);
const showHeader = $derived(hasQuery && filteredConversations.length > 0);
const emptyMessage = $derived(hasQuery ? 'No results found' : 'Start typing to see results');
</script>
<div class="flex min-h-0 flex-1 flex-col gap-2 whitespace-nowrap {className}">
{#if showHeader}
<div
class="text-muted-foreground flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium"
>
Search results
</div>
{/if}
<div class="min-h-0 flex-1 overflow-y-auto">
<ul class="flex w-full min-w-0 flex-col gap-1">
{#each tree as { conversation, depth } (conversation.id)}
<li class="group/item relative mb-1 p-0">
<SidebarNavigationConversationItem
conversation={{
id: conversation.id,
name: conversation.name,
lastModified: conversation.lastModified,
currNode: conversation.currNode,
forkedFromConversationId: conversation.forkedFromConversationId,
pinned: conversation.pinned
}}
{depth}
isActive={currentChatId === conversation.id}
{onSelect}
{onEdit}
{onDelete}
{onStop}
/>
</li>
{/each}
{#if tree.length === 0}
<li class="px-2 py-4 text-center">
<p class="mb-4 p-4 text-sm text-muted-foreground">
{emptyMessage}
</p>
</li>
{/if}
</ul>
</div>
</div>
@@ -63,15 +63,6 @@ export { default as DropdownMenuSearchable } from './DropdownMenuSearchable.svel
* ```
*/
export { default as DropdownMenuActions } from './DropdownMenuActions.svelte';
/**
* **DesktopIconStrip** - Fixed icon strip for desktop sidebar
*
* Vertical icon strip shown on desktop when the sidebar is collapsed.
* Contains navigation shortcuts for new chat, search, MCP, import/export, and settings.
*/
export { default as DesktopIconStrip } from './DesktopIconStrip.svelte';
/**
* **SidebarNavigation** - Sidebar with actions menu and conversation list
*
@@ -115,13 +106,6 @@ export { default as DesktopIconStrip } from './DesktopIconStrip.svelte';
*/
export { default as SidebarNavigation } from './SidebarNavigation/SidebarNavigation.svelte';
/**
* Action buttons for sidebar header. Contains new chat button, settings button,
* and delete all conversations button. Manages dialog states for settings and
* delete confirmation.
*/
export { default as SidebarNavigationActions } from './SidebarNavigation/SidebarNavigationActions.svelte';
/**
* Single conversation item in sidebar. Displays conversation title (truncated),
* last message preview, and timestamp. Shows context menu on right-click with
@@ -130,6 +114,58 @@ export { default as SidebarNavigationActions } from './SidebarNavigation/Sidebar
*/
export { default as SidebarNavigationConversationItem } from './SidebarNavigation/SidebarNavigationConversationItem.svelte';
/**
* **SidebarNavigationConversationList** - Grouped conversation list
*
* Pure-presentational list of conversations. Splits items into a Pinned
* section (when not in search mode) and a Recent Conversations / Search
* Results section with the unpinned items. Item selection, edit, delete,
* and stop-generation are delegated to the caller via callbacks.
*
* @example
* ```svelte
* <SidebarNavigationConversationList
* {filteredConversations}
* {currentChatId}
* {isSearchModeActive}
* {searchQuery}
* onSelect={...}
* onEdit={...}
* onDelete={...}
* onStop={...}
* />
* ```
*/
export { default as SidebarNavigationConversationList } from './SidebarNavigation/SidebarNavigationConversationList.svelte';
export { default as SidebarNavigationActions } from './SidebarNavigation/SidebarNavigationActions.svelte';
/**
* **SidebarNavigationSearchResults** - Filtered conversation list for search.
*
* Pure-presentational rendering of the search-mode subtree: "Search results"
* header, the matching items rendered through {@link SidebarNavigationConversationItem},
* and contextual empty-state messages. Used both inline inside
* {@link SidebarNavigationConversationList} (when search mode is active in the
* sidebar) and as the body of the mobile `/search` route.
*
* The caller is expected to provide an already-filtered list via
* `filteredConversations` and a `searchQuery` for the empty-state messages.
*
* @example
* ```svelte
* <SidebarNavigationSearchResults
* {searchQuery}
* {filteredConversations}
* {currentChatId}
* onSelect={...}
* onEdit={...}
* onDelete={...}
* onStop={...}
* />
* ```
*/
export { default as SidebarNavigationSearchResults } from './SidebarNavigation/SidebarNavigationSearchResults.svelte';
/**
* Search input for filtering conversations in sidebar. Filters conversation
* list by title as user types. Shows clear button when query is not empty.
@@ -126,10 +126,7 @@
});
</script>
<div
class="mx-auto flex h-full max-h-[100dvh] w-full flex-col overflow-y-auto md:pl-8"
in:fade={{ duration: 150 }}
>
<div class="mx-auto flex h-full w-full flex-col md:pl-8" in:fade={{ duration: 150 }}>
<div class="flex flex-1 flex-col gap-4 md:flex-row">
<SettingsChatDesktopSidebar
sections={SETTINGS_CHAT_SECTIONS}
@@ -12,11 +12,13 @@
let { sections, isActive, getHref, onSectionChange }: Props = $props();
</script>
<div class="sticky top-0 hidden w-64 flex-col self-start bg-background pt-10 pb-4 md:flex">
<div class="flex items-center gap-2 pb-10">
<Settings class="h-6 w-6" />
<h1 class="text-2xl font-semibold">Settings</h1>
<div class="sticky top-2 hidden w-64 flex-col self-start bg-background py-4 md:flex gap-6">
<div class="flex items-center gap-2 py-2">
<Settings class="h-5 w-5 md:h-6 md:w-6" />
<h1 class="text-xl font-semibold md:text-2xl">Settings</h1>
</div>
<nav class="space-y-1">
{#each sections as section (section.title)}
{#if getHref}
@@ -1,17 +1,19 @@
<script lang="ts">
import { Plus } from '@lucide/svelte';
import { X, Plus } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { toolsStore } from '$lib/stores/tools.svelte';
import { McpServerCard, McpServerCardSkeleton } from '$lib/components/app/mcp';
import { ActionIcon, McpServerCard, McpServerCardSkeleton } from '$lib/components/app';
import { DialogMcpServerAddNew } from '$lib/components/app/dialogs';
import { HealthCheckStatus } from '$lib/enums';
import { ROUTES } from '$lib/constants';
import { fade } from 'svelte/transition';
import { onMount } from 'svelte';
import McpLogo from '../mcp/McpLogo.svelte';
import { browser } from '$app/environment';
import { page } from '$app/state';
import { replaceState } from '$app/navigation';
import { goto, replaceState } from '$app/navigation';
interface Props {
class?: string;
@@ -24,6 +26,24 @@
let initialLoadComplete = $state(false);
let isAddingServer = $state(false);
let previousRouteId = $state<string | null>(null);
$effect(() => {
const currentId = page.route.id;
return () => {
previousRouteId = currentId;
};
});
function handleClose() {
const prevIsMcpServers = previousRouteId === '/mcp-servers';
if (browser && window.history.length > 1 && !prevIsMcpServers) {
history.back();
} else {
goto(ROUTES.START);
}
}
onMount(() => {
if (page.url.searchParams.has('add')) {
isAddingServer = true;
@@ -54,15 +74,26 @@
});
</script>
<div in:fade={{ duration: 150 }} class="h-full max-h-[100dvh] overflow-y-auto">
<div class="flex items-center gap-2 p-4 md:absolute md:top-8 md:left-8 md:px-0 md:py-2">
<McpLogo class="h-5 w-5 md:h-6 md:w-6" />
<h1 class="text-xl font-semibold md:text-2xl">MCP Servers</h1>
<div in:fade={{ duration: 150 }}>
<div class="fixed top-4.5 right-4 z-50 md:hidden">
<ActionIcon icon={X} tooltip="Close" onclick={handleClose} />
</div>
<div class="sticky top-0 z-10 mt-4 flex items-start gap-4 p-4 md:justify-end md:px-8">
<Button variant="outline" size="sm" class="shrink-0" onclick={() => (isAddingServer = true)}>
<div
class="sticky top-0 z-10 mt-4 mb-2 flex items-start gap-4 md:p-4 p-0 px-4 md:justify-between md:px-8"
>
<div class="flex items-center gap-2">
<McpLogo class="h-5 w-5 md:h-6 md:w-6" />
<h1 class="text-lg font-semibold md:text-2xl">MCP Servers</h1>
</div>
<Button
variant="outline"
size="lg"
class="shrink-0 fixed md:static bottom-6 right-6"
onclick={() => (isAddingServer = true)}
>
<Plus class="h-4 w-4" />
Add New Server
@@ -20,7 +20,7 @@
size: {
default: 'h-9 px-4 py-2 has-[>svg]:px-3',
sm: 'h-8 gap-1.5 rounded-md px-3 has-[>svg]:px-2.5',
lg: 'h-10 rounded-md px-6 has-[>svg]:px-4',
lg: 'h-10 rounded-lg px-6 has-[>svg]:px-4',
'icon-lg': 'size-10',
icon: 'size-9',
'icon-sm': 'size-5 rounded-sm'
@@ -1,7 +0,0 @@
export const SIDEBAR_COOKIE_NAME = 'sidebar:state';
export const SIDEBAR_COOKIE_MAX_AGE = 60 * 60 * 24 * 7;
export const SIDEBAR_MIN_WIDTH = '18rem';
export const SIDEBAR_MAX_WIDTH = '32rem';
export const SIDEBAR_WIDTH_MOBILE = '18rem';
export const SIDEBAR_WIDTH_ICON = '3rem';
export const SIDEBAR_KEYBOARD_SHORTCUT = 'b';
@@ -1,79 +0,0 @@
import { isMobile } from '$lib/stores/viewport.svelte.js';
import { getContext, setContext } from 'svelte';
import { SIDEBAR_KEYBOARD_SHORTCUT, SIDEBAR_MIN_WIDTH } from './constants.js';
type Getter<T> = () => T;
export type SidebarStateProps = {
/**
* A getter function that returns the current open state of the sidebar.
* We use a getter function here to support `bind:open` on the `Sidebar.Provider`
* component.
*/
open: Getter<boolean>;
/**
* A function that sets the open state of the sidebar. To support `bind:open`, we need
* a source of truth for changing the open state to ensure it will be synced throughout
* the sub-components and any `bind:` references.
*/
setOpen: (open: boolean) => void;
};
class SidebarState {
readonly props: SidebarStateProps;
open = $derived.by(() => this.props.open());
openMobile = $state(false);
sidebarWidth = $state(SIDEBAR_MIN_WIDTH);
isResizing = $state(false);
setOpen: SidebarStateProps['setOpen'];
state = $derived.by(() => (this.open ? 'expanded' : 'collapsed'));
constructor(props: SidebarStateProps) {
this.setOpen = props.setOpen;
this.props = props;
}
// Convenience getter for checking if the sidebar is mobile
// without this, we would need to use `sidebar.isMobile.current` everywhere
get isMobile() {
return isMobile.current;
}
// Event handler to apply to the `<svelte:window>`
handleShortcutKeydown = (e: KeyboardEvent) => {
if (e.key === SIDEBAR_KEYBOARD_SHORTCUT && (e.metaKey || e.ctrlKey)) {
e.preventDefault();
this.toggle();
}
};
setOpenMobile = (value: boolean) => {
this.openMobile = value;
};
toggle = () => {
this.setOpen(!this.open);
};
}
const SYMBOL_KEY = 'scn-sidebar';
/**
* Instantiates a new `SidebarState` instance and sets it in the context.
*
* @param props The constructor props for the `SidebarState` class.
* @returns The `SidebarState` instance.
*/
export function setSidebar(props: SidebarStateProps): SidebarState {
return setContext(Symbol.for(SYMBOL_KEY), new SidebarState(props));
}
/**
* Retrieves the `SidebarState` instance from the context. This is a class instance,
* so you cannot destructure it.
* @returns The `SidebarState` instance.
*/
export function useSidebar(): SidebarState {
return getContext(Symbol.for(SYMBOL_KEY));
}
@@ -1,75 +0,0 @@
import { useSidebar } from './context.svelte.js';
import Content from './sidebar-content.svelte';
import Footer from './sidebar-footer.svelte';
import GroupAction from './sidebar-group-action.svelte';
import GroupContent from './sidebar-group-content.svelte';
import GroupLabel from './sidebar-group-label.svelte';
import Group from './sidebar-group.svelte';
import Header from './sidebar-header.svelte';
import Input from './sidebar-input.svelte';
import Inset from './sidebar-inset.svelte';
import MenuAction from './sidebar-menu-action.svelte';
import MenuBadge from './sidebar-menu-badge.svelte';
import MenuButton from './sidebar-menu-button.svelte';
import MenuItem from './sidebar-menu-item.svelte';
import MenuSkeleton from './sidebar-menu-skeleton.svelte';
import MenuSubButton from './sidebar-menu-sub-button.svelte';
import MenuSubItem from './sidebar-menu-sub-item.svelte';
import MenuSub from './sidebar-menu-sub.svelte';
import Menu from './sidebar-menu.svelte';
import Provider from './sidebar-provider.svelte';
import Rail from './sidebar-rail.svelte';
import Separator from './sidebar-separator.svelte';
import Trigger from './sidebar-trigger.svelte';
import Root from './sidebar.svelte';
export {
Content,
Footer,
Group,
GroupAction,
GroupContent,
GroupLabel,
Header,
Input,
Inset,
Menu,
MenuAction,
MenuBadge,
MenuButton,
MenuItem,
MenuSkeleton,
MenuSub,
MenuSubButton,
MenuSubItem,
Provider,
Rail,
Root,
Separator,
//
Root as Sidebar,
Content as SidebarContent,
Footer as SidebarFooter,
Group as SidebarGroup,
GroupAction as SidebarGroupAction,
GroupContent as SidebarGroupContent,
GroupLabel as SidebarGroupLabel,
Header as SidebarHeader,
Input as SidebarInput,
Inset as SidebarInset,
Menu as SidebarMenu,
MenuAction as SidebarMenuAction,
MenuBadge as SidebarMenuBadge,
MenuButton as SidebarMenuButton,
MenuItem as SidebarMenuItem,
MenuSkeleton as SidebarMenuSkeleton,
MenuSub as SidebarMenuSub,
MenuSubButton as SidebarMenuSubButton,
MenuSubItem as SidebarMenuSubItem,
Provider as SidebarProvider,
Rail as SidebarRail,
Separator as SidebarSeparator,
Trigger as SidebarTrigger,
Trigger,
useSidebar
};
@@ -1,24 +0,0 @@
<script lang="ts">
import type { HTMLAttributes } from 'svelte/elements';
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
let {
ref = $bindable(null),
class: className,
children,
...restProps
}: WithElementRef<HTMLAttributes<HTMLElement>> = $props();
</script>
<div
bind:this={ref}
data-slot="sidebar-content"
data-sidebar="content"
class={cn(
'flex min-h-0 flex-1 flex-col gap-2 overflow-auto group-data-[collapsible=icon]:overflow-hidden',
className
)}
{...restProps}
>
{@render children?.()}
</div>
@@ -1,21 +0,0 @@
<script lang="ts">
import type { HTMLAttributes } from 'svelte/elements';
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
let {
ref = $bindable(null),
class: className,
children,
...restProps
}: WithElementRef<HTMLAttributes<HTMLElement>> = $props();
</script>
<div
bind:this={ref}
data-slot="sidebar-footer"
data-sidebar="footer"
class={cn('flex flex-col gap-2 p-3', className)}
{...restProps}
>
{@render children?.()}
</div>
@@ -1,36 +0,0 @@
<script lang="ts">
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
import type { Snippet } from 'svelte';
import type { HTMLButtonAttributes } from 'svelte/elements';
let {
ref = $bindable(null),
class: className,
children,
child,
...restProps
}: WithElementRef<HTMLButtonAttributes> & {
child?: Snippet<[{ props: Record<string, unknown> }]>;
} = $props();
const mergedProps = $derived({
class: cn(
'text-sidebar-foreground ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground outline-hidden absolute right-3 top-3.5 flex aspect-square w-5 items-center justify-center rounded-md p-0 transition-transform focus-visible:ring-2 [&>svg]:size-4 [&>svg]:shrink-0',
// Increases the hit area of the button on mobile.
'after:absolute after:-inset-2 md:after:hidden',
'group-data-[collapsible=icon]:hidden',
className
),
'data-slot': 'sidebar-group-action',
'data-sidebar': 'group-action',
...restProps
});
</script>
{#if child}
{@render child({ props: mergedProps })}
{:else}
<button bind:this={ref} {...mergedProps}>
{@render children?.()}
</button>
{/if}

Some files were not shown because too many files have changed in this diff Show More