mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-25 07:07:39 +02:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| fdb2c11c70 | |||
| 09cedfd699 | |||
| 8be759e6f7 | |||
| 894bb27af3 | |||
| fb401045cc | |||
| 51eae8cfca | |||
| 1191758c5d | |||
| 00139b660b | |||
| ef9c13d4c2 | |||
| 88636e178f | |||
| ac4105d68b |
@@ -142,7 +142,9 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
|
||||
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
|
||||
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
|
||||
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
|
||||
- [x] [Liquid LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2)
|
||||
- [x] [Liquid LFM2.5 models](https://huggingface.co/collections/LiquidAI/lfm25)
|
||||
- [x] [Liquid Nanos](https://huggingface.co/collections/LiquidAI/liquid-nanos)
|
||||
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
|
||||
- [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86)
|
||||
- [x] [Mellum models](https://huggingface.co/JetBrains/models?search=mellum)
|
||||
|
||||
@@ -80,8 +80,6 @@ add_library(${TARGET}
|
||||
http.h
|
||||
imatrix-loader.cpp
|
||||
imatrix-loader.h
|
||||
json-partial.cpp
|
||||
json-partial.h
|
||||
json-schema-to-grammar.cpp
|
||||
llguidance.cpp
|
||||
log.cpp
|
||||
|
||||
@@ -2758,5 +2758,9 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
||||
std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_templates * chat_templates) {
|
||||
GGML_ASSERT(chat_templates != nullptr);
|
||||
GGML_ASSERT(chat_templates->template_default != nullptr);
|
||||
if (chat_templates->template_tool_use != nullptr) {
|
||||
// take the more expressive template when available
|
||||
return chat_templates->template_tool_use->caps.to_map();
|
||||
}
|
||||
return chat_templates->template_default->caps.to_map();
|
||||
}
|
||||
|
||||
@@ -1,324 +0,0 @@
|
||||
#include "json-partial.h"
|
||||
|
||||
#include "log.h"
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <string>
|
||||
#include <regex>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
enum common_json_stack_element_type {
|
||||
COMMON_JSON_STACK_ELEMENT_OBJECT,
|
||||
COMMON_JSON_STACK_ELEMENT_KEY,
|
||||
COMMON_JSON_STACK_ELEMENT_ARRAY,
|
||||
};
|
||||
|
||||
struct common_json_stack_element {
|
||||
common_json_stack_element_type type;
|
||||
std::string key;
|
||||
};
|
||||
|
||||
bool common_json_parse(
|
||||
const std::string & input,
|
||||
const std::string & healing_marker,
|
||||
common_json & out)
|
||||
{
|
||||
std::string::const_iterator it = input.begin();
|
||||
const auto end = input.end();
|
||||
return common_json_parse(it, end, healing_marker, out);
|
||||
}
|
||||
|
||||
bool common_json_parse(
|
||||
std::string::const_iterator & it,
|
||||
const std::string::const_iterator & end,
|
||||
const std::string & healing_marker,
|
||||
common_json & out)
|
||||
{
|
||||
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
||||
struct json_error_locator : public nlohmann::json_sax<json> {
|
||||
std::size_t position;
|
||||
bool found_error;
|
||||
std::string last_token;
|
||||
std::string exception_message;
|
||||
std::vector<common_json_stack_element> stack;
|
||||
|
||||
json_error_locator() : position(0), found_error(false) {}
|
||||
|
||||
bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
|
||||
this->position = position - 1;
|
||||
this->found_error = true;
|
||||
this->last_token = last_token;
|
||||
this->exception_message = ex.what();
|
||||
return false;
|
||||
}
|
||||
void close_value() {
|
||||
if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
|
||||
stack.pop_back();
|
||||
}
|
||||
}
|
||||
bool null() override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool boolean(bool) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool number_integer(number_integer_t) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool number_unsigned(number_unsigned_t) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool number_float(number_float_t, const string_t &) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool string(string_t &) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool binary(binary_t &) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool start_object(std::size_t) override { // NOLINT
|
||||
stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
|
||||
return true;
|
||||
}
|
||||
bool end_object() override {
|
||||
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
|
||||
stack.pop_back();
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool key(string_t & key) override { // NOLINT
|
||||
stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
|
||||
return true;
|
||||
}
|
||||
bool start_array(std::size_t) override { // NOLINT
|
||||
stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
|
||||
return true;
|
||||
}
|
||||
bool end_array() override {
|
||||
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
|
||||
stack.pop_back();
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
};
|
||||
json_error_locator err_loc;
|
||||
auto start = it;
|
||||
json::sax_parse(it, end, &err_loc);
|
||||
|
||||
if (err_loc.found_error) {
|
||||
it = start;
|
||||
auto temptative_end = it + err_loc.position;
|
||||
// LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
|
||||
|
||||
auto input = std::string(it, temptative_end);
|
||||
try {
|
||||
out.json = json::parse(input);
|
||||
// out.json = json::parse(it, temptative_end);
|
||||
it = temptative_end;
|
||||
return true;
|
||||
} catch (const std::exception & ex) {
|
||||
// No, needs healing.
|
||||
LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
|
||||
}
|
||||
auto can_parse = [](const std::string & str) {
|
||||
try {
|
||||
auto _ = json::parse(str); // NOLINT
|
||||
return true;
|
||||
} catch (const std::exception &) {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
if (!healing_marker.empty() && !err_loc.stack.empty()) {
|
||||
std::string str(it, temptative_end);
|
||||
auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
|
||||
if (last_non_sp_pos == std::string::npos) {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
|
||||
}
|
||||
auto last_non_sp_char = str[last_non_sp_pos];
|
||||
// Used to detect stops on a number, which may not be complete.
|
||||
auto was_maybe_number = [&]() {
|
||||
if (!str.empty() && std::isspace(str.back())) {
|
||||
return false;
|
||||
}
|
||||
return std::isdigit(last_non_sp_char) ||
|
||||
last_non_sp_char == '.' ||
|
||||
last_non_sp_char == 'e' ||
|
||||
last_non_sp_char == 'E' ||
|
||||
last_non_sp_char == '-';
|
||||
};
|
||||
|
||||
std::string closing;
|
||||
for (size_t i = err_loc.stack.size(); i > 0; i--) {
|
||||
auto & el = err_loc.stack[i - 1];
|
||||
if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
|
||||
closing += "}";
|
||||
} else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
|
||||
closing += "]";
|
||||
} else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
|
||||
throw std::runtime_error("Unexpected stack element type");
|
||||
}
|
||||
}
|
||||
|
||||
// Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX
|
||||
static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)");
|
||||
|
||||
auto is_high_surrogate = [&](const std::string & s) {
|
||||
// Check if a partial of a high surrogate (U+D800-U+DBFF)
|
||||
return s.length() >= 4 &&
|
||||
s[0] == '\\' && s[1] == 'u' &&
|
||||
std::tolower(s[2]) == 'd' &&
|
||||
(s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b');
|
||||
};
|
||||
|
||||
// Initialize the unicode marker to a low surrogate to handle the edge case
|
||||
// where a high surrogate (U+D800-U+DBFF) is immediately followed by a
|
||||
// backslash (\)
|
||||
std::string unicode_marker_padding = "udc00";
|
||||
std::smatch last_unicode_seq;
|
||||
|
||||
if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) {
|
||||
std::smatch second_last_seq;
|
||||
std::string prelude = str.substr(0, last_unicode_seq.position());
|
||||
|
||||
// Pad the escape sequence with 0s until it forms a complete sequence of 6 characters
|
||||
unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0');
|
||||
|
||||
if (is_high_surrogate(last_unicode_seq.str())) {
|
||||
// If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF)
|
||||
unicode_marker_padding += "\\udc00";
|
||||
} else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) {
|
||||
if (is_high_surrogate(second_last_seq.str())) {
|
||||
// If this follows a high surrogate, pad it to be a low surrogate
|
||||
if (last_unicode_seq.length() == 2) {
|
||||
unicode_marker_padding = "dc00";
|
||||
} else if (last_unicode_seq.length() == 3) {
|
||||
unicode_marker_padding = "c00";
|
||||
} else {
|
||||
// The original unicode_marker_padding is already padded with 0s
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
|
||||
|
||||
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
|
||||
// We're inside an object value
|
||||
if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
|
||||
// Was about to create an object value
|
||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + ": 1" + closing)) {
|
||||
str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
|
||||
} else if (last_non_sp_char == '{' && can_parse(str + closing)) {
|
||||
// Was about to create an object
|
||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
|
||||
} else if (can_parse(str + "\"" + closing)) {
|
||||
// Was inside an object value string
|
||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||
// Was inside an object value string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
|
||||
// Was inside an object value string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
|
||||
} else {
|
||||
// find last :
|
||||
auto last_pos = str.find_last_of(':');
|
||||
if (last_pos == std::string::npos) {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
|
||||
}
|
||||
// Cutting back to opening : for object value
|
||||
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
}
|
||||
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
|
||||
if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
|
||||
// Was about to create an array value
|
||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + "\"" + closing)) {
|
||||
// Was inside an array value string
|
||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||
// Was inside an array value string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
|
||||
// Was inside an array value string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
|
||||
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
|
||||
// Had just finished a value
|
||||
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
|
||||
} else {
|
||||
auto last_pos = str.find_last_of("[,");
|
||||
if (last_pos == std::string::npos) {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
|
||||
}
|
||||
// Cutting back to last [ or , for array value
|
||||
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
}
|
||||
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
|
||||
if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
|
||||
(last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
|
||||
// Was about to create an object key+value
|
||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
|
||||
} else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
|
||||
// Was about to create an object key+value
|
||||
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
|
||||
} else if (can_parse(str + "\": 1" + closing)) {
|
||||
// Was inside an object key string
|
||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
|
||||
// Was inside an object key string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) {
|
||||
// Was inside an object key string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing;
|
||||
} else {
|
||||
auto last_pos = str.find_last_of(':');
|
||||
if (last_pos == std::string::npos) {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
|
||||
}
|
||||
// fprintf(stderr, "Cutting back to last : for object key+value\n");
|
||||
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
|
||||
}
|
||||
// fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
|
||||
out.json = json::parse(str);
|
||||
it = temptative_end;
|
||||
return true;
|
||||
}
|
||||
// handle unclosed top-level primitive
|
||||
if (err_loc.position != 0 && !healing_marker.empty() && err_loc.stack.empty()) {
|
||||
std::string str(it, temptative_end);
|
||||
const auto & magic_seed = out.healing_marker.marker = healing_marker;
|
||||
if (can_parse(str + "\"")) {
|
||||
// Was inside an string
|
||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"";
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"")) {
|
||||
// Was inside an string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"";
|
||||
} else {
|
||||
// TODO: handle more unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
|
||||
// fprintf(stderr, "Closing: TODO\n");
|
||||
return false;
|
||||
}
|
||||
out.json = json::parse(str);
|
||||
it = temptative_end;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
out.json = json::parse(it, end);
|
||||
it = end;
|
||||
return true;
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
// TODO: use json_fwd.hpp when possible
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
// Healing marker (empty if the JSON was fully parsed / wasn't healed).
|
||||
struct common_healing_marker {
|
||||
// Raw marker.
|
||||
std::string marker;
|
||||
|
||||
// Cutting the `common_json.json.dump()` string at the (only) occurrence of this marker should yield the original partial JSON string (modulo spaces / if it had the same dump format).
|
||||
std::string json_dump_marker;
|
||||
};
|
||||
|
||||
// Represents a parsed JSON object, with its optional healing marker (a JSON dump fragment that can be used to find the position of healing in the JSON dump string)
|
||||
struct common_json {
|
||||
nlohmann::ordered_json json;
|
||||
|
||||
common_healing_marker healing_marker;
|
||||
};
|
||||
|
||||
// Parse the JSON string, healing (closing) any partial JSON if `healing_marker` is not empty.
|
||||
//
|
||||
// Healing completes partial JSON strings by adding a (possibly modified) healing marker, then whatever is needed to close the JSON.
|
||||
// This allows to parse the resulting healed JSON string, yet be able to cut it again if needed at the healing marker.
|
||||
// (this is used when parsing JSON outputs from the models, then crafting partial JSONs for the partial tool calls in OAI format).
|
||||
//
|
||||
// For instance, parsing `{` with a healing marker `foo` will produce a healed JSON `{"foo":1}`, w/ json_dump_marker = `"foo"` (which can be used to break the JSON again).
|
||||
bool common_json_parse(
|
||||
const std::string & input,
|
||||
const std::string & healing_marker,
|
||||
common_json & out);
|
||||
|
||||
// Parse the JSON string (see overload above), but advancing an iterator to the end of the input when the (potentially partial) parsing succeeds.
|
||||
bool common_json_parse(
|
||||
std::string::const_iterator & it,
|
||||
const std::string::const_iterator & end,
|
||||
const std::string & healing_marker,
|
||||
common_json & out);
|
||||
@@ -46,6 +46,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"DbrxForCausalLM": "dbrx",
|
||||
"DeciLMForCausalLM": "deci",
|
||||
"DeepseekForCausalLM": "deepseek",
|
||||
"DeepseekOCRForCausalLM": "deepseek",
|
||||
"DeepseekV2ForCausalLM": "deepseek",
|
||||
"DeepseekV3ForCausalLM": "deepseek",
|
||||
"DeepseekV32ForCausalLM": "deepseek",
|
||||
@@ -124,6 +125,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"LLaDAModelLM": "llada",
|
||||
"LLaMAForCausalLM": "llama",
|
||||
"Lfm25AudioTokenizer": "lfm2",
|
||||
"Lfm2BidirectionalModel": "lfm2",
|
||||
"Lfm2ForCausalLM": "lfm2",
|
||||
"Lfm2Model": "lfm2",
|
||||
"Lfm2MoeForCausalLM": "lfm2",
|
||||
@@ -232,6 +234,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"UMT5ForConditionalGeneration": "t5",
|
||||
"UMT5Model": "t5",
|
||||
"UltravoxModel": "ultravox",
|
||||
"UnlimitedOCRForCausalLM": "deepseek",
|
||||
"VLlama3ForCausalLM": "llama",
|
||||
"VoxtralForConditionalGeneration": "llama",
|
||||
"WavTokenizerDec": "wavtokenizer",
|
||||
@@ -298,6 +301,7 @@ MMPROJ_MODEL_MAP: dict[str, str] = {
|
||||
"StepVLForConditionalGeneration": "step3",
|
||||
"Step3p7ForConditionalGeneration": "step3",
|
||||
"UltravoxModel": "ultravox",
|
||||
"UnlimitedOCRForCausalLM": "deepseek",
|
||||
"VoxtralForConditionalGeneration": "ultravox",
|
||||
"YoutuVLForConditionalGeneration": "youtuvl",
|
||||
}
|
||||
|
||||
+10
-2
@@ -14,7 +14,7 @@ from .base import MmprojModel, ModelBase, TextModel, gguf, logger
|
||||
from .qwen import QwenModel
|
||||
|
||||
|
||||
@ModelBase.register("DeepseekOCRForCausalLM")
|
||||
@ModelBase.register("DeepseekOCRForCausalLM", "UnlimitedOCRForCausalLM")
|
||||
class DeepseekOCRVisionModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -205,6 +205,8 @@ class DeepseekModel(TextModel):
|
||||
@ModelBase.register(
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekOCRForCausalLM",
|
||||
"UnlimitedOCRForCausalLM",
|
||||
"KimiVLForConditionalGeneration",
|
||||
"KimiK25ForConditionalGeneration",
|
||||
"YoutuForCausalLM",
|
||||
@@ -224,7 +226,7 @@ class DeepseekV2Model(TextModel):
|
||||
self.origin_hf_arch = hparams.get('architectures', [None])[0]
|
||||
|
||||
# special handling for Deepseek OCR
|
||||
if self.origin_hf_arch in ("DeepseekOCRForCausalLM", "DeepseekOCR2ForCausalLM"):
|
||||
if self.origin_hf_arch in ("DeepseekOCRForCausalLM", "DeepseekOCR2ForCausalLM", "UnlimitedOCRForCausalLM"):
|
||||
self.model_arch = gguf.MODEL_ARCH.DEEPSEEK2OCR
|
||||
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
|
||||
self.gguf_writer.add_architecture()
|
||||
@@ -350,6 +352,12 @@ class DeepseekV2Model(TextModel):
|
||||
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||
|
||||
# Unlimited-OCR sliding window; written for metadata, the decoder ignores it (full MHA)
|
||||
if is_ocr:
|
||||
sliding_window = hparams.get("sliding_window_size") or hparams.get("sliding_window")
|
||||
if sliding_window:
|
||||
self.gguf_writer.add_sliding_window(sliding_window)
|
||||
|
||||
if (rope_mscale_all := self.rope_parameters.get("mscale_all_dim")) is not None:
|
||||
# [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
|
||||
# note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul
|
||||
|
||||
+10
-3
@@ -64,11 +64,17 @@ class LFM2Model(TextModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Lfm2Model")
|
||||
@ModelBase.register("Lfm2Model", "Lfm2BidirectionalModel")
|
||||
class LFM2ColBertModel(LFM2Model):
|
||||
model_arch = gguf.MODEL_ARCH.LFM2
|
||||
dense_tensor_name = "dense_2"
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if self.hf_arch == "Lfm2BidirectionalModel":
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
self._try_set_pooling_type()
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if not name.startswith(self.dense_tensor_name):
|
||||
name = "model." + name
|
||||
@@ -76,10 +82,11 @@ class LFM2ColBertModel(LFM2Model):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
# dense tensor is stored in a separate safetensors file
|
||||
# optional dense tensor is stored in a separate safetensors file
|
||||
from safetensors.torch import load_file
|
||||
tensors_file = self.dir_model / "1_Dense" / "model.safetensors"
|
||||
assert tensors_file.is_file()
|
||||
if not tensors_file.is_file():
|
||||
return
|
||||
tensor = load_file(tensors_file)["linear.weight"]
|
||||
self.gguf_writer.add_embedding_length_out(tensor.shape[0])
|
||||
yield f"{self.dense_tensor_name}.weight", tensor.clone()
|
||||
|
||||
@@ -24,7 +24,6 @@
|
||||
"GGML_LLAMAFILE": "OFF",
|
||||
"GGML_OPENCL": "ON",
|
||||
"GGML_HEXAGON": "ON",
|
||||
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
|
||||
"LLAMA_OPENSSL": "OFF"
|
||||
}
|
||||
},
|
||||
@@ -47,7 +46,6 @@
|
||||
"GGML_LLAMAFILE": "OFF",
|
||||
"GGML_OPENCL": "ON",
|
||||
"GGML_HEXAGON": "ON",
|
||||
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
|
||||
"LLAMA_OPENSSL": "OFF"
|
||||
}
|
||||
},
|
||||
@@ -73,7 +71,6 @@
|
||||
"GGML_LLAMAFILE": "OFF",
|
||||
"GGML_OPENCL": "OFF",
|
||||
"GGML_HEXAGON": "ON",
|
||||
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
|
||||
"LLAMA_OPENSSL": "OFF"
|
||||
}
|
||||
},
|
||||
|
||||
@@ -266,7 +266,6 @@ set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING
|
||||
"ggml: OpenCL API version to target")
|
||||
|
||||
option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF)
|
||||
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml: quantize group size (32, 64, or 128)")
|
||||
|
||||
# toolchain for vulkan-shaders-gen
|
||||
set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")
|
||||
|
||||
@@ -25,7 +25,6 @@ include(ExternalProject)
|
||||
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
|
||||
option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF)
|
||||
set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
|
||||
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)")
|
||||
|
||||
add_library(htp_iface OBJECT
|
||||
${CMAKE_CURRENT_BINARY_DIR}/htp_iface_stub.c)
|
||||
@@ -72,15 +71,12 @@ function(build_htp_skel V)
|
||||
-DHEXAGON_SDK_ROOT=${HEXAGON_SDK_ROOT}
|
||||
-DHEXAGON_TOOLS_ROOT=${HEXAGON_TOOLS_ROOT}
|
||||
-DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}
|
||||
-DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}
|
||||
-DDSP_VERSION=${V}
|
||||
-DPREBUILT_LIB_DIR="toolv19_${V}")
|
||||
list(APPEND HTP_SKELS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so)
|
||||
set(HTP_SKELS ${HTP_SKELS} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
build_htp_skel(v68)
|
||||
build_htp_skel(v69)
|
||||
build_htp_skel(v73)
|
||||
build_htp_skel(v75)
|
||||
build_htp_skel(v79)
|
||||
|
||||
+1359
-1274
File diff suppressed because it is too large
Load Diff
@@ -5,10 +5,12 @@
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-common.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <stdio.h>
|
||||
#include "htp-ops.h"
|
||||
#include "htp/matmul-ops.h"
|
||||
|
||||
struct htp_opnode {
|
||||
ggml_tensor * node = nullptr;
|
||||
@@ -17,6 +19,13 @@ struct htp_opnode {
|
||||
|
||||
htp_op_code opcode = HTP_OP_INVALID;
|
||||
|
||||
std::vector<ggml_tensor *> extra_dsts;
|
||||
|
||||
int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS] = {0};
|
||||
|
||||
htp_opnode(ggml_tensor * node = nullptr, std::vector<ggml_tensor *> fused = {}, htp_op_code opcode = HTP_OP_INVALID, std::vector<ggml_tensor *> extra_dsts = {})
|
||||
: node(node), fused(std::move(fused)), opcode(opcode), extra_dsts(std::move(extra_dsts)) {}
|
||||
|
||||
ggml_op op() const {
|
||||
return node->op;
|
||||
}
|
||||
@@ -25,6 +34,26 @@ struct htp_opnode {
|
||||
return fused.empty() ? node : fused.back();
|
||||
}
|
||||
|
||||
void add_fused(ggml_tensor * t, bool extra_dst = false) {
|
||||
fused.push_back(t);
|
||||
if (extra_dst) {
|
||||
extra_dsts.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const ggml_tensor *> get_outputs() const {
|
||||
std::vector<const ggml_tensor *> res;
|
||||
if (extra_dsts.empty()) {
|
||||
res.push_back(dst());
|
||||
} else {
|
||||
res.push_back(node);
|
||||
for (const auto * x : extra_dsts) {
|
||||
res.push_back(x);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
const ggml_tensor * src0() const {
|
||||
return node->src[0];
|
||||
}
|
||||
@@ -37,10 +66,6 @@ struct htp_opnode {
|
||||
return ggml_op_is_empty(node->op);
|
||||
}
|
||||
|
||||
void add_fused(ggml_tensor * t) {
|
||||
fused.push_back(t);
|
||||
}
|
||||
|
||||
bool stackable() const {
|
||||
switch (this->op()) {
|
||||
case GGML_OP_MUL_MAT:
|
||||
@@ -131,87 +156,117 @@ struct htp_opformat {
|
||||
char types[16 * GGML_MAX_SRC];
|
||||
char buffs[64 * GGML_MAX_SRC];
|
||||
char names[64 * GGML_MAX_SRC];
|
||||
char kparams[128];
|
||||
|
||||
int format_tensor_dims(char * str, const struct ggml_tensor * t) {
|
||||
int format_tensor_dims(char * str, size_t max_size, const struct ggml_tensor * t) {
|
||||
if (!t) {
|
||||
return sprintf(str, "NONE");
|
||||
return snprintf(str, max_size, "NONE");
|
||||
}
|
||||
if (t->ne[2] == 1 && t->ne[3] == 1) {
|
||||
return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]);
|
||||
return snprintf(str, max_size, "%d:%d", (int) t->ne[0], (int) t->ne[1]);
|
||||
} else {
|
||||
return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
|
||||
return snprintf(str, max_size, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
|
||||
}
|
||||
}
|
||||
|
||||
void format_op_dims(char * str, const htp_opnode & node) {
|
||||
void format_op_dims(char * str, size_t max_size, const htp_opnode & node) {
|
||||
char * p = str;
|
||||
char * p_end = str + max_size;
|
||||
auto inputs = node.get_inputs();
|
||||
|
||||
if (!inputs.empty()) {
|
||||
p += format_tensor_dims(p, inputs[0]);
|
||||
p += std::min((size_t)format_tensor_dims(p, p_end - p, inputs[0]), (size_t)(p_end - p));
|
||||
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
p += sprintf(p, " x ");
|
||||
p += format_tensor_dims(p, inputs[i]);
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
|
||||
}
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)format_tensor_dims(p, p_end - p, inputs[i]), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
p += sprintf(p, " -> ");
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
char self[64];
|
||||
format_tensor_dims(self, node.dst());
|
||||
p += sprintf(p, "%s", self);
|
||||
format_tensor_dims(self, sizeof(self), node.dst());
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", self), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
int format_tensor_strides(char * str, const struct ggml_tensor * t) {
|
||||
int format_tensor_strides(char * str, size_t max_size, const struct ggml_tensor * t) {
|
||||
if (!t) {
|
||||
return sprintf(str, "NONE");
|
||||
return snprintf(str, max_size, "NONE");
|
||||
}
|
||||
const char * c = ggml_is_contiguous(t) ? "" : "!";
|
||||
|
||||
if (t->ne[2] == 1 && t->ne[3] == 1) {
|
||||
return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c);
|
||||
return snprintf(str, max_size, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c);
|
||||
} else {
|
||||
return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c);
|
||||
return snprintf(str, max_size, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c);
|
||||
}
|
||||
}
|
||||
|
||||
void format_op_strides(char * str, const htp_opnode & node) {
|
||||
void format_op_strides(char * str, size_t max_size, const htp_opnode & node) {
|
||||
char * p = str;
|
||||
char * p_end = str + max_size;
|
||||
auto inputs = node.get_inputs();
|
||||
|
||||
if (!inputs.empty()) {
|
||||
p += format_tensor_strides(p, inputs[0]);
|
||||
p += std::min((size_t)format_tensor_strides(p, p_end - p, inputs[0]), (size_t)(p_end - p));
|
||||
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
p += sprintf(p, " x ");
|
||||
p += format_tensor_strides(p, inputs[i]);
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
|
||||
}
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)format_tensor_strides(p, p_end - p, inputs[i]), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
p += sprintf(p, " -> ");
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
char self[64];
|
||||
format_tensor_strides(self, node.dst());
|
||||
p += sprintf(p, "%s", self);
|
||||
format_tensor_strides(self, sizeof(self), node.dst());
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", self), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
void format_op_types(char * str, const htp_opnode & node) {
|
||||
void format_op_types(char * str, size_t max_size, const htp_opnode & node) {
|
||||
char * p = str;
|
||||
char * p_end = str + max_size;
|
||||
auto inputs = node.get_inputs();
|
||||
|
||||
if (!inputs.empty()) {
|
||||
p += sprintf(p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE");
|
||||
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
p += sprintf(p, " x ");
|
||||
p += sprintf(p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE");
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE"), (size_t)(p_end - p));
|
||||
}
|
||||
|
||||
p += sprintf(p, " -> ");
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
|
||||
}
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE"), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
p += sprintf(p, "%s", ggml_type_name(node.dst()->type));
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", ggml_type_name(node.dst()->type)), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
const char * tensor_buff_name(const struct ggml_tensor * t) {
|
||||
@@ -221,51 +276,102 @@ struct htp_opformat {
|
||||
return "NONE";
|
||||
}
|
||||
|
||||
void format_op_buffs(char * str, const htp_opnode & node) {
|
||||
void format_op_buffs(char * str, size_t max_size, const htp_opnode & node) {
|
||||
char * p = str;
|
||||
char * p_end = str + max_size;
|
||||
auto inputs = node.get_inputs();
|
||||
|
||||
if (!inputs.empty()) {
|
||||
p += sprintf(p, "%s", tensor_buff_name(inputs[0]));
|
||||
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
p += sprintf(p, " x ");
|
||||
p += sprintf(p, "%s", tensor_buff_name(inputs[i]));
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(inputs[0])), (size_t)(p_end - p));
|
||||
}
|
||||
|
||||
p += sprintf(p, " -> ");
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
|
||||
}
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(inputs[i])), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
p += sprintf(p, "%s", tensor_buff_name(node.dst()));
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(node.dst())), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
void format_op_names(char * str, const htp_opnode & node) {
|
||||
void format_op_names(char * str, size_t max_size, const htp_opnode & node) {
|
||||
char * p = str;
|
||||
char * p_end = str + max_size;
|
||||
auto inputs = node.get_inputs();
|
||||
|
||||
if (!inputs.empty()) {
|
||||
p += sprintf(p, "%s", inputs[0] ? inputs[0]->name : "NONE");
|
||||
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
p += sprintf(p, " x ");
|
||||
p += sprintf(p, "%s", inputs[i] ? inputs[i]->name : "NONE");
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[0] ? inputs[0]->name : "NONE"), (size_t)(p_end - p));
|
||||
}
|
||||
|
||||
p += sprintf(p, " -> ");
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
|
||||
}
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[i] ? inputs[i]->name : "NONE"), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
p += sprintf(p, "%s", node.dst()->name);
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", node.dst()->name), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
void format_kernel_params(char * str, size_t max_size, const htp_opnode & node) {
|
||||
if (node.opcode == HTP_OP_MUL_MAT || node.opcode == HTP_OP_MUL_MAT_ID ||
|
||||
node.opcode == HTP_OP_MUL_MAT_QKV || node.opcode == HTP_OP_MUL_MAT_FFN) {
|
||||
const auto * kparams = (const struct htp_mm_kernel_params *) node.kernel_params;
|
||||
const char * path = "unknown";
|
||||
int32_t type = kparams->kernel_type;
|
||||
if (type == HTP_MM_KERNEL_HMX_2D || type == HTP_MM_KERNEL_HMX_F16_BATCHED) {
|
||||
path = "hmx-tiled";
|
||||
} else if (type == HTP_MM_KERNEL_HVX_F16_F16_VTCM || type == HTP_MM_KERNEL_HVX_F32_F32_VTCM ||
|
||||
type == HTP_MM_KERNEL_HVX_QUANT_ROW || type == HTP_MM_KERNEL_HVX_QUANT_BLOCK) {
|
||||
path = "hvx-tiled";
|
||||
} else if (type == HTP_MM_KERNEL_HVX_F16_F16_DDR || type == HTP_MM_KERNEL_HVX_F16_F32_DDR ||
|
||||
type == HTP_MM_KERNEL_HVX_F32_F32_DDR || type == HTP_MM_KERNEL_HVX_F32_F16_DDR ||
|
||||
type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) {
|
||||
path = "hvx-flat";
|
||||
}
|
||||
snprintf(str, max_size, "%s vtcm %d", path, (int) kparams->vtcm_size);
|
||||
} else {
|
||||
snprintf(str, max_size, "----");
|
||||
}
|
||||
}
|
||||
|
||||
void format(const htp_opnode & node) {
|
||||
format_op_dims(dims, node);
|
||||
format_op_strides(strides, node);
|
||||
format_op_types(types, node);
|
||||
format_op_buffs(buffs, node);
|
||||
format_op_names(names, node);
|
||||
format_op_dims(dims, sizeof(dims), node);
|
||||
format_op_strides(strides, sizeof(strides), node);
|
||||
format_op_types(types, sizeof(types), node);
|
||||
format_op_buffs(buffs, sizeof(buffs), node);
|
||||
format_op_names(names, sizeof(names), node);
|
||||
format_kernel_params(kparams, sizeof(kparams), node);
|
||||
}
|
||||
|
||||
htp_opformat() {}
|
||||
htp_opformat() {
|
||||
strides[0] = '\0';
|
||||
dims[0] = '\0';
|
||||
types[0] = '\0';
|
||||
buffs[0] = '\0';
|
||||
names[0] = '\0';
|
||||
kparams[0] = '\0';
|
||||
}
|
||||
htp_opformat(const htp_opnode & node) { format(node); }
|
||||
};
|
||||
|
||||
|
||||
@@ -19,43 +19,9 @@ add_library(${HTP_LIB} SHARED
|
||||
htp_iface_skel.c
|
||||
worker-pool.c
|
||||
hex-dma.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
|
||||
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
|
||||
|
||||
if (GGML_HEXAGON_FA_EXP2_HF)
|
||||
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
|
||||
endif()
|
||||
|
||||
# HMX acceleration: available on v73+ architectures
|
||||
set(HTP_HMX_VERSIONS v73 v75 v79 v81)
|
||||
list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
|
||||
|
||||
if (_hmx_idx GREATER_EQUAL 0)
|
||||
target_sources(${HTP_LIB} PRIVATE
|
||||
hmx-flash-attn-ops.c
|
||||
hmx-matmul-ops.c
|
||||
hmx-queue.c
|
||||
)
|
||||
|
||||
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
|
||||
set_source_files_properties(
|
||||
hmx-flash-attn-ops.c
|
||||
hmx-matmul-ops.c
|
||||
hmx-queue.c
|
||||
PROPERTIES COMPILE_OPTIONS "-mhmx"
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1)
|
||||
endif()
|
||||
|
||||
build_idl(htp_iface.idl ${HTP_LIB})
|
||||
|
||||
target_sources(${HTP_LIB} PRIVATE
|
||||
hmx-queue.c
|
||||
flash-attn-ops.c
|
||||
hmx-flash-attn-ops.c
|
||||
matmul-ops.c
|
||||
binary-ops.c
|
||||
unary-ops.c
|
||||
@@ -63,7 +29,6 @@ target_sources(${HTP_LIB} PRIVATE
|
||||
softmax-ops.c
|
||||
act-ops.c
|
||||
rope-ops.c
|
||||
flash-attn-ops.c
|
||||
set-rows-ops.c
|
||||
get-rows-ops.c
|
||||
cpy-ops.c
|
||||
@@ -79,6 +44,17 @@ target_sources(${HTP_LIB} PRIVATE
|
||||
pad-ops.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>)
|
||||
|
||||
if (GGML_HEXAGON_FA_EXP2_HF)
|
||||
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
|
||||
endif()
|
||||
|
||||
build_idl(htp_iface.idl ${HTP_LIB})
|
||||
|
||||
set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
install(TARGETS ${HTP_LIB})
|
||||
|
||||
@@ -3,7 +3,7 @@ if (HEXAGON_TOOLCHAIN_INCLUDED)
|
||||
endif()
|
||||
set(HEXAGON_TOOLCHAIN_INCLUDED true)
|
||||
|
||||
#Cross Compiling for Hexagon
|
||||
# Cross Compiling for Hexagon
|
||||
set(HEXAGON TRUE)
|
||||
set(CMAKE_SYSTEM_NAME QURT)
|
||||
set(CMAKE_SYSTEM_PROCESSOR Hexagon)
|
||||
@@ -14,7 +14,6 @@ set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
|
||||
set(CUSTOM_RUNELF_PATH "")
|
||||
|
||||
#To fix backward compatibility with EAI addon.
|
||||
if (NOT HEXAGON_SDK_ROOT)
|
||||
set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT})
|
||||
endif()
|
||||
@@ -31,7 +30,6 @@ endif()
|
||||
file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
|
||||
file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT)
|
||||
|
||||
#Get the Binary extension of the Hexagon Toolchain
|
||||
if(CMAKE_HOST_SYSTEM_NAME STREQUAL Windows)
|
||||
set(HEXAGON_TOOLCHAIN_SUFFIX .exe)
|
||||
endif()
|
||||
@@ -48,12 +46,12 @@ set(CMAKE_TRY_COMPILE_PLATFORM_VARIABLES
|
||||
HEXAGON_TOOLS_ROOT
|
||||
)
|
||||
|
||||
#QURT Related includes and linker flags
|
||||
# QURT Related includes and linker flags
|
||||
set(V_ARCH ${HEXAGON_ARCH})
|
||||
set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/ADSP${V_ARCH}MP${V_ARCH_EXTN}")
|
||||
set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/compute${V_ARCH}${V_ARCH_EXTN}")
|
||||
|
||||
if( ${TREE} MATCHES PAKMAN )
|
||||
if (${TREE} MATCHES PAKMAN)
|
||||
set(_QURT_INSTALL_DIR "${QURT_IMAGE_DIR}/compute${V_ARCH}${V_ARCH_EXTN}")
|
||||
endif()
|
||||
message(DEBUG "_QURT_INSTALL_DIR:${_QURT_INSTALL_DIR}")
|
||||
@@ -83,11 +81,9 @@ set(QURT_START_LINK_LIBS
|
||||
)
|
||||
STRING(REPLACE ";" " " QURT_START_LINK_LIBS "${QURT_START_LINK_LIBS}")
|
||||
|
||||
set(QURT_END_LINK_LIBS
|
||||
${TARGET_DIR}/fini.o
|
||||
)
|
||||
set(QURT_END_LINK_LIBS ${TARGET_DIR}/fini.o)
|
||||
|
||||
#Non QURT related includes and linker flags
|
||||
# Non QURT related includes and linker flags
|
||||
|
||||
set(TARGET_DIR_NOOS "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/${HEXAGON_ARCH}")
|
||||
|
||||
@@ -99,8 +95,10 @@ if (NOT NO_WRAP_MEM_API)
|
||||
set(WRAP_MEMALIGN -Wl,--wrap=memalign)
|
||||
endif()
|
||||
|
||||
set(ARCH_FLAGS "-mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -mhmx")
|
||||
|
||||
set(PIC_SHARED_LD_FLAGS
|
||||
-mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH}
|
||||
${ARCH_FLAGS}
|
||||
-G0
|
||||
-fpic
|
||||
-Wl,-Bsymbolic
|
||||
@@ -120,13 +118,13 @@ STRING(REPLACE ";" " " PIC_SHARED_LD_FLAGS "${PIC_SHARED_LD_FLAGS}")
|
||||
|
||||
set(HEXAGON_PIC_SHARED_LINK_OPTIONS "${PIC_SHARED_LD_FLAGS}")
|
||||
|
||||
#System include paths
|
||||
# System include paths
|
||||
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs)
|
||||
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs/stddef)
|
||||
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/ipc/fastrpc/incs)
|
||||
|
||||
#LLVM toolchain setup
|
||||
#Compiler paths, options and architecture
|
||||
# LLVM toolchain setup
|
||||
# Compiler paths, options and architecture
|
||||
set(CMAKE_C_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang${HEXAGON_TOOLCHAIN_SUFFIX})
|
||||
set(CMAKE_CXX_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX})
|
||||
set(CMAKE_AR ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-ar${HEXAGON_TOOLCHAIN_SUFFIX})
|
||||
@@ -137,8 +135,8 @@ set(CMAKE_PREFIX_PATH ${HEXAGON_TOOLCHAIN}/Tools/target/hexagon)
|
||||
set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,")
|
||||
set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,")
|
||||
|
||||
#Compiler Options
|
||||
set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
|
||||
# Compiler Options
|
||||
set(COMMON_FLAGS "${ARCH_FLAGS} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
|
||||
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g")
|
||||
|
||||
@@ -18,7 +18,8 @@
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hmx-ops.h"
|
||||
|
||||
int hmx_flash_attn_ext(struct htp_ops_context * octx);
|
||||
|
||||
// Must be multiple of 32
|
||||
#define FLASH_ATTN_BLOCK_SIZE (32 * 2)
|
||||
@@ -633,7 +634,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
// HMX path: head_dim multiple of 64, F16 KV, and no sinks
|
||||
if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 64 == 0 && v->ne[0] % 64 == 0 && octx->src[4] == NULL) {
|
||||
int ret = hmx_flash_attn_ext(octx);
|
||||
@@ -642,7 +642,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
}
|
||||
// VTCM too small or other failure -> fall through to HVX path
|
||||
}
|
||||
#endif
|
||||
|
||||
struct htp_fa_context factx;
|
||||
factx.octx = octx;
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
#ifndef HEX_COMMON_H
|
||||
#define HEX_COMMON_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifndef SIZE_MAX
|
||||
#define SIZE_MAX ((size_t)-1)
|
||||
#endif
|
||||
|
||||
#ifndef MAX
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
#ifndef MIN
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
static inline uint32_t hex_ceil_pow2(uint32_t x) {
|
||||
if (x <= 1) { return 1; }
|
||||
int p = 2;
|
||||
x--;
|
||||
while (x >>= 1) { p <<= 1; }
|
||||
return p;
|
||||
}
|
||||
|
||||
static inline size_t hmx_ceil_div(size_t num, size_t den) {
|
||||
return (num + den - 1) / den;
|
||||
}
|
||||
|
||||
static inline int32_t hex_is_aligned(const void * addr, uint32_t align) {
|
||||
return ((size_t) addr & (align - 1)) == 0;
|
||||
}
|
||||
|
||||
static inline size_t hex_align_up(size_t v, size_t align) {
|
||||
return hmx_ceil_div(v, align) * align;
|
||||
}
|
||||
|
||||
static inline size_t hex_align_down(size_t v, size_t align) {
|
||||
return (v / align) * align;
|
||||
}
|
||||
|
||||
static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
|
||||
uint32_t left_off = (size_t) addr & (chunk_size - 1);
|
||||
uint32_t right_off = left_off + n;
|
||||
return right_off <= chunk_size;
|
||||
}
|
||||
|
||||
static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
|
||||
return m * ((n + m - 1) / m);
|
||||
}
|
||||
|
||||
static inline size_t hex_smin(size_t a, size_t b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
static inline size_t hex_smax(size_t a, size_t b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
static inline void hex_swap_ptr(void ** p1, void ** p2) {
|
||||
void * t = *p1;
|
||||
*p1 = *p2;
|
||||
*p2 = t;
|
||||
}
|
||||
|
||||
static inline bool hex_mul_overflow(size_t a, size_t b, size_t *out) {
|
||||
if (a != 0 && b > SIZE_MAX / a) return true;
|
||||
*out = a * b;
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline bool hex_add_overflow(size_t a, size_t b, size_t *out) {
|
||||
if (a > SIZE_MAX - b) return true;
|
||||
*out = a + b;
|
||||
return false;
|
||||
}
|
||||
|
||||
#endif // HEX_COMMON_H
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <hexagon_types.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include "hex-utils.h"
|
||||
|
||||
#include "hex-profile.h"
|
||||
|
||||
@@ -127,13 +128,8 @@ static inline dma_ptr dma_make_ptr(void *dst, const void *src)
|
||||
return p;
|
||||
}
|
||||
|
||||
#if __HVX_ARCH__ < 73
|
||||
static const uint32_t dma_src_l2_bypass_on = 1;
|
||||
static const uint32_t dma_dst_l2_bypass_on = 0;
|
||||
#else
|
||||
static const uint32_t dma_src_l2_bypass_on = 1;
|
||||
static const uint32_t dma_dst_l2_bypass_on = 1;
|
||||
#endif
|
||||
|
||||
static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t size) {
|
||||
if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) {
|
||||
|
||||
@@ -11,14 +11,7 @@
|
||||
|
||||
#include "hex-fastdiv.h"
|
||||
#include "hex-dump.h"
|
||||
|
||||
#ifndef MAX
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
#ifndef MIN
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#endif
|
||||
#include "hex-common.h"
|
||||
|
||||
static inline uint64_t hex_get_cycles() {
|
||||
uint64_t cycles = 0;
|
||||
@@ -32,54 +25,6 @@ static inline uint64_t hex_get_pktcnt() {
|
||||
return pktcnt;
|
||||
}
|
||||
|
||||
static inline uint32_t hex_ceil_pow2(uint32_t x) {
|
||||
if (x <= 1) { return 1; }
|
||||
int p = 2;
|
||||
x--;
|
||||
while (x >>= 1) { p <<= 1; }
|
||||
return p;
|
||||
}
|
||||
|
||||
static inline size_t hmx_ceil_div(size_t num, size_t den) {
|
||||
return (num + den - 1) / den;
|
||||
}
|
||||
|
||||
static inline int32_t hex_is_aligned(const void * addr, uint32_t align) {
|
||||
return ((size_t) addr & (align - 1)) == 0;
|
||||
}
|
||||
|
||||
static inline size_t hex_align_up(size_t v, size_t align) {
|
||||
return hmx_ceil_div(v, align) * align;
|
||||
}
|
||||
|
||||
static inline size_t hex_align_down(size_t v, size_t align) {
|
||||
return (v / align) * align;
|
||||
}
|
||||
|
||||
static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
|
||||
uint32_t left_off = (size_t) addr & (chunk_size - 1);
|
||||
uint32_t right_off = left_off + n;
|
||||
return right_off <= chunk_size;
|
||||
}
|
||||
|
||||
static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
|
||||
return m * ((n + m - 1) / m);
|
||||
}
|
||||
|
||||
static inline size_t hex_smin(size_t a, size_t b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
static inline size_t hex_smax(size_t a, size_t b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
static inline void hex_swap_ptr(void ** p1, void ** p2) {
|
||||
void * t = *p1;
|
||||
*p1 = *p2;
|
||||
*p2 = t;
|
||||
}
|
||||
|
||||
static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {
|
||||
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
|
||||
Q6_l2fetch_AP((void *) p, control);
|
||||
|
||||
@@ -49,7 +49,7 @@
|
||||
// g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions.
|
||||
// Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales
|
||||
// Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax.
|
||||
static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool use_pipeline) {
|
||||
static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool pipeline) {
|
||||
const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS);
|
||||
const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK]
|
||||
const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong
|
||||
@@ -70,7 +70,7 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV,
|
||||
+ k_dma_size * 2 // K DMA x2
|
||||
+ v_dma_size * 2 // V DMA x2
|
||||
+ k_tile_size * 1 // K tiles
|
||||
+ v_tile_size * (use_pipeline ? 2 : 1) // V tiles (double-buffered if pipelining)
|
||||
+ v_tile_size * (pipeline ? 2 : 1) // V tiles (double-buffered if pipelining)
|
||||
+ s_tile_size * 2 // S + P
|
||||
+ d_tile_size * 1 // D (diagonal matrix)
|
||||
+ col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum
|
||||
@@ -290,7 +290,7 @@ static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) =
|
||||
|
||||
struct hmx_fa_context {
|
||||
const struct htp_ops_context * octx;
|
||||
bool use_pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2
|
||||
bool pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2
|
||||
uint32_t n_threads;
|
||||
|
||||
// Op parameters
|
||||
@@ -409,7 +409,7 @@ static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data)
|
||||
return;
|
||||
}
|
||||
|
||||
__fp16 * v_tiles_dest = factx->use_pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0];
|
||||
__fp16 * v_tiles_dest = factx->pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0];
|
||||
|
||||
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
|
||||
@@ -1312,13 +1312,13 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS);
|
||||
|
||||
const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc;
|
||||
const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2);
|
||||
const bool pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2);
|
||||
|
||||
// Bypass thread pool dispatch for small prompts/non-pipelined prefill by setting n_threads = 1
|
||||
const uint32_t n_threads = use_pipeline ? n_threads_init : 1;
|
||||
const uint32_t n_threads = pipeline ? n_threads_init : 1;
|
||||
|
||||
FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu",
|
||||
neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget);
|
||||
neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, pipeline, vtcm_budget);
|
||||
|
||||
// ======== Build context ========
|
||||
struct hmx_fa_context factx;
|
||||
@@ -1339,7 +1339,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
factx.n_kv_blocks = n_kv_blocks;
|
||||
factx.is_q_fp32 = (q->type == HTP_TYPE_F32);
|
||||
factx.is_dst_fp32 = (dst->type == HTP_TYPE_F32);
|
||||
factx.use_pipeline = use_pipeline;
|
||||
factx.pipeline = pipeline;
|
||||
factx.mask_broadcast = (mask != NULL && mask->ne[2] == 1);
|
||||
|
||||
// Extract op parameters (mutable during softcap adjustment, then stored as const in factx)
|
||||
@@ -1405,7 +1405,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes);
|
||||
factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes);
|
||||
factx.vtcm_v_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes);
|
||||
if (use_pipeline) {
|
||||
if (pipeline) {
|
||||
factx.vtcm_v_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes);
|
||||
} else {
|
||||
factx.vtcm_v_tiles[1] = NULL;
|
||||
@@ -1456,7 +1456,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
// ======== HMX lock strategy ========
|
||||
// Pipeline: queue thread auto-acquires HMX lock on first push; released by suspend.
|
||||
// Fallback: main thread holds the lock (original behavior).
|
||||
if (!factx.use_pipeline) {
|
||||
if (!factx.pipeline) {
|
||||
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
|
||||
}
|
||||
|
||||
@@ -1550,7 +1550,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
const size_t k_src_stride = size_k_row_padded / sizeof(__fp16);
|
||||
const size_t v_src_stride = size_v_row_padded / sizeof(__fp16);
|
||||
|
||||
if (factx.use_pipeline) {
|
||||
if (factx.pipeline) {
|
||||
// ==================================================================
|
||||
// Pipeline path: HVX phases ‖ HMX queue worker
|
||||
// ==================================================================
|
||||
@@ -1780,7 +1780,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
fa_build_d_diag_inv_l(&factx, n_row_tiles, n_row_tiles_g_br);
|
||||
|
||||
// HMX: O_final = diag(1/l) @ O_prev
|
||||
if (factx.use_pipeline) {
|
||||
if (factx.pipeline) {
|
||||
on_job.o_curr = o_tile_curr;
|
||||
on_job.o_prev = o_tile_prev;
|
||||
on_job.d_tiles = factx.vtcm_d_tiles;
|
||||
@@ -1826,7 +1826,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
} // end KV head loop
|
||||
} // end batch loop
|
||||
|
||||
if (factx.use_pipeline) {
|
||||
if (factx.pipeline) {
|
||||
hmx_queue_suspend(ctx->hmx_queue);
|
||||
} else {
|
||||
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,6 +0,0 @@
|
||||
// HMX operations compiled as a single translation unit.
|
||||
// This allows interprocedural optimizations within HMX ops without requiring global HTP LTO.
|
||||
|
||||
#include "hmx-queue.c"
|
||||
#include "hmx-matmul-ops.c"
|
||||
#include "hmx-flash-attn-ops.c"
|
||||
@@ -1,88 +0,0 @@
|
||||
// HMX operation entry-point declarations.
|
||||
// Ported from htp-ops-lib/include/dsp/ops.h (renamed, benchmark kernels removed). (https://github.com/haozixu/htp-ops-lib)
|
||||
|
||||
#ifndef HMX_OPS_H
|
||||
#define HMX_OPS_H
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "htp-ops.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef struct {
|
||||
float *dst;
|
||||
const float *activation;
|
||||
const __fp16 *permuted_weight;
|
||||
int m;
|
||||
int k;
|
||||
int n;
|
||||
int act_stride;
|
||||
int weight_stride;
|
||||
int dst_stride;
|
||||
int ne02;
|
||||
int ne03;
|
||||
int ne12;
|
||||
int ne13;
|
||||
size_t src0_nb2;
|
||||
size_t src0_nb3;
|
||||
size_t src1_nb2;
|
||||
size_t src1_nb3;
|
||||
size_t dst_nb2;
|
||||
size_t dst_nb3;
|
||||
} hmx_matmul_f16_f32_batched_params_t;
|
||||
|
||||
// HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output
|
||||
// act_stride: activation row stride in elements (= k for contiguous, or
|
||||
// nb[1]/sizeof(float) for permuted tensors like attention Q).
|
||||
// weight_stride: weight row stride in elements (= k for compact weights, or
|
||||
// nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK).
|
||||
int hmx_matmul_f16_f32(struct htp_context *ctx,
|
||||
float *restrict dst,
|
||||
const float *activation,
|
||||
const __fp16 *permuted_weight,
|
||||
int m, int k, int n,
|
||||
int act_stride,
|
||||
int weight_stride);
|
||||
|
||||
// Batched F16 wrapper over hmx_mat_mul_f16_f32.
|
||||
// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3.
|
||||
int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params);
|
||||
|
||||
// HMX matrix multiplication — all supported weight types (F16/F32/Q4_0/Q4_1/Q8_0/IQ4_NL/MXFP4)
|
||||
int hmx_matmul_2d_f32(struct htp_context *ctx,
|
||||
float *restrict dst,
|
||||
const float *activation,
|
||||
const uint8_t *permuted_weight,
|
||||
int m, int k, int n,
|
||||
int act_stride,
|
||||
int weight_stride,
|
||||
int weight_type);
|
||||
|
||||
struct mmid_row_mapping;
|
||||
|
||||
int hmx_matmul_id_2d_f32(struct htp_context *ctx,
|
||||
float *restrict dst,
|
||||
const float *activation,
|
||||
const uint8_t *permuted_weight,
|
||||
int m, int k, int n,
|
||||
int ne11,
|
||||
size_t act_nb1, size_t act_nb2,
|
||||
size_t dst_nb1, size_t dst_nb2,
|
||||
int weight_stride,
|
||||
int weight_type,
|
||||
const struct mmid_row_mapping *matrix_rows,
|
||||
int cur_a,
|
||||
int mapping_stride);
|
||||
|
||||
// HMX flash attention
|
||||
int hmx_flash_attn_ext(struct htp_ops_context * octx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // HMX_OPS_H
|
||||
@@ -13,7 +13,9 @@
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifndef HTP_MAX_NTHREADS
|
||||
#define HTP_MAX_NTHREADS 10
|
||||
#endif
|
||||
#define HTP_MAX_MMAPS 16
|
||||
|
||||
// Memory mapping
|
||||
@@ -42,9 +44,13 @@ struct htp_ops_context {
|
||||
|
||||
enum htp_op_code op; // FIXME: rename to opcode
|
||||
int32_t op_params[HTP_OP_MAX_PARAMS];
|
||||
int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS];
|
||||
|
||||
const struct htp_tensor * src[HTP_OP_MAX_INPUTS];
|
||||
const struct htp_tensor * dst;
|
||||
union {
|
||||
const struct htp_tensor * dst;
|
||||
const struct htp_tensor * dsts[HTP_OP_MAX_OUTPUTS];
|
||||
};
|
||||
|
||||
// TODO convert these to an array
|
||||
struct htp_spad src0_spad;
|
||||
@@ -87,13 +93,13 @@ struct htp_context {
|
||||
|
||||
struct htp_ops_context octx;
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
struct hmx_queue * hmx_queue; // Async HMX queue for pipeline overlap
|
||||
#endif
|
||||
};
|
||||
|
||||
int op_matmul(struct htp_ops_context * octx);
|
||||
int op_matmul_id(struct htp_ops_context * octx);
|
||||
int op_matmul_qkv(struct htp_ops_context * octx);
|
||||
int op_matmul_ffn(struct htp_ops_context * octx);
|
||||
int op_binary(struct htp_ops_context * octx);
|
||||
int op_unary(struct htp_ops_context * octx);
|
||||
int op_sum_rows(struct htp_ops_context * octx);
|
||||
|
||||
@@ -28,18 +28,19 @@ enum htp_data_type {
|
||||
HTP_TYPE_MXFP4 = 39,
|
||||
|
||||
// types used internally for repack, dyn.quant, etc
|
||||
HTP_TYPE_Q4_0x4x2 = 200,
|
||||
HTP_TYPE_Q4_1x4x2,
|
||||
HTP_TYPE_Q8_0x4x2,
|
||||
HTP_TYPE_MXFP4x4x2,
|
||||
HTP_TYPE_Q4_0_TILED = 200,
|
||||
HTP_TYPE_Q4_1_TILED,
|
||||
HTP_TYPE_Q8_0_TILED,
|
||||
HTP_TYPE_MXFP4_TILED,
|
||||
|
||||
HTP_TYPE_INVALID
|
||||
};
|
||||
|
||||
// Constats for internal types
|
||||
#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
|
||||
#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
|
||||
#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks
|
||||
#define QK_Q4_0_TILED 256 // 32x32 Q4_0 tiled layout
|
||||
#define QK_Q8_0_TILED 128 // 32x32 Q8_0 tiled layout
|
||||
#define QK_MXFP4_TILED 256 // 32x32 MXFP4 tiled layout
|
||||
|
||||
|
||||
|
||||
// Mask to enable various stages of the Ops.
|
||||
@@ -57,6 +58,8 @@ enum htp_op_code {
|
||||
HTP_OP_DIV = 3,
|
||||
HTP_OP_MUL_MAT,
|
||||
HTP_OP_MUL_MAT_ID,
|
||||
HTP_OP_MUL_MAT_QKV,
|
||||
HTP_OP_MUL_MAT_FFN,
|
||||
HTP_OP_RMS_NORM,
|
||||
HTP_OP_RMS_NORM_MUL,
|
||||
HTP_OP_UNARY_SILU,
|
||||
@@ -99,7 +102,9 @@ enum htp_op_code {
|
||||
|
||||
#define HTP_OP_MAX_DIMS 4 // aka GGML_MAX_DIMS
|
||||
#define HTP_OP_MAX_INPUTS 6 // aka GGML_MAX_SRCS
|
||||
#define HTP_OP_MAX_OUTPUTS 4
|
||||
#define HTP_OP_MAX_PARAMS 16 // aka GGML_MAX_OP_PARAMS
|
||||
#define HTP_OP_MAX_KERN_PARAMS 32
|
||||
|
||||
#define HTP_OP_MAX_BUFS 16
|
||||
#define HTP_OP_MAX_REQS 256
|
||||
@@ -142,8 +147,10 @@ struct htp_op_desc {
|
||||
uint32_t opcode; // GGML/HTP Op
|
||||
uint32_t flags; // Op flags
|
||||
int32_t params[HTP_OP_MAX_PARAMS]; // Params for the op, e.g. epsilon of RMS norm
|
||||
int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS]; // generic blob for host-precomputed parameters
|
||||
uint16_t src[HTP_OP_MAX_INPUTS]; // Input tensors indices
|
||||
uint16_t dst; // Output tensor index
|
||||
uint16_t dst[HTP_OP_MAX_OUTPUTS]; // Output tensor indices
|
||||
uint16_t pad[2]; // padding to align to 64 bits
|
||||
};
|
||||
|
||||
#ifndef HTP_MAX_NTHREADS
|
||||
|
||||
@@ -11,12 +11,13 @@ struct htp_iface_pmu_conf {
|
||||
};
|
||||
|
||||
interface htp_iface : remote_handle64 {
|
||||
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx, in uint64 max_vmem);
|
||||
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 n_hmx, in uint64 max_vmem);
|
||||
AEEResult stop();
|
||||
AEEResult mmap(in uint32 fd, in uint32 size);
|
||||
AEEResult munmap(in uint32 fd);
|
||||
AEEResult profiler(in uint32 mode, in htp_iface_pmu_conf pmu);
|
||||
AEEResult etm(in uint32 enable);
|
||||
AEEResult hwinfo(rout uint32 n_threads, rout uint32 n_hvx, rout uint32 n_hmx, rout uint64 vtcm_size);
|
||||
};
|
||||
|
||||
#endif /* HTP_IDL */
|
||||
|
||||
@@ -170,25 +170,7 @@ static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
|
||||
}
|
||||
#endif
|
||||
|
||||
/* Q6_Vsf_equals_Vw is only available on v73+.*/
|
||||
#if __HVX_ARCH__ < 73
|
||||
static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in)
|
||||
{
|
||||
HVX_Vector const vzero = Q6_V_vzero();
|
||||
HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero);
|
||||
HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in);
|
||||
HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift);
|
||||
HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift);
|
||||
HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized);
|
||||
HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp));
|
||||
return ret;
|
||||
}
|
||||
|
||||
static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in)
|
||||
{
|
||||
return Q6_Vsf_equals_Vqf32(hvx_vec_i32_to_qf32(in));
|
||||
}
|
||||
#endif
|
||||
|
||||
static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) {
|
||||
// This looks complicated.
|
||||
@@ -305,4 +287,17 @@ static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) {
|
||||
|
||||
#endif // __HVX_ARCH__ < 79
|
||||
|
||||
static inline HVX_Vector hvx_vec_load_act_tile(const uint8_t * y_q, uint32_t kt, HVX_Vector * v_act_all) {
|
||||
if (kt % 4 == 0) {
|
||||
*v_act_all = hvx_vmem(y_q + kt * 32);
|
||||
return *v_act_all;
|
||||
} else if (kt % 4 == 1) {
|
||||
return Q6_V_vror_VR(*v_act_all, 32);
|
||||
} else if (kt % 4 == 2) {
|
||||
return Q6_V_vror_VR(*v_act_all, 64);
|
||||
} else {
|
||||
return Q6_V_vror_VR(*v_act_all, 96);
|
||||
}
|
||||
}
|
||||
|
||||
#endif /* HVX_BASE_H */
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -361,7 +361,7 @@ static void vtcm_free(struct htp_context * ctx) {
|
||||
static void htp_packet_callback(dspqueue_t queue, int error, void * context);
|
||||
static void htp_error_callback(dspqueue_t queue, int error, void * context);
|
||||
|
||||
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx, uint64_t max_vmem) {
|
||||
AEEResult htp_iface_start(remote_handle64 handle, uint32_t sess_id, uint64_t dsp_queue_id, uint32_t n_hvx, uint32_t n_hmx, uint64_t max_vmem) {
|
||||
struct htp_context * ctx = (struct htp_context *) handle;
|
||||
|
||||
if (!ctx) {
|
||||
@@ -395,10 +395,9 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
|
||||
return AEE_ENOMEMORY;
|
||||
}
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
ctx->hmx_enabled = use_hmx;
|
||||
ctx->hmx_enabled = n_hmx;
|
||||
ctx->hmx_queue = NULL;
|
||||
if (use_hmx) {
|
||||
if (n_hmx) {
|
||||
ctx->hmx_queue = hmx_queue_create(16, ctx->vtcm_rctx);
|
||||
if (ctx->hmx_queue) {
|
||||
ctx->hmx_queue->trace = &ctx->trace[HTP_MAX_NTHREADS];
|
||||
@@ -407,8 +406,7 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
|
||||
ctx->hmx_enabled = false;
|
||||
}
|
||||
}
|
||||
FARF(HIGH, "HMX %s (use_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", use_hmx);
|
||||
#endif
|
||||
FARF(HIGH, "HMX %s (n_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", n_hmx);
|
||||
|
||||
qurt_sysenv_max_hthreads_t hw_threads;
|
||||
qurt_sysenv_get_max_hw_threads(&hw_threads);
|
||||
@@ -481,13 +479,11 @@ AEEResult htp_iface_stop(remote_handle64 handle) {
|
||||
dma_queue_delete(ctx->dma[i]);
|
||||
}
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
if (ctx->hmx_queue) {
|
||||
hmx_queue_delete(ctx->hmx_queue);
|
||||
ctx->hmx_queue = NULL;
|
||||
}
|
||||
ctx->hmx_enabled = false;
|
||||
#endif
|
||||
|
||||
vtcm_free(ctx);
|
||||
|
||||
@@ -500,6 +496,36 @@ AEEResult htp_iface_stop(remote_handle64 handle) {
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
AEEResult htp_iface_hwinfo(remote_handle64 handle, uint32_t * n_threads, uint32_t * n_hvx, uint32_t * n_hmx, uint64_t * vtcm_size) {
|
||||
(void)handle;
|
||||
if (!n_threads || !n_hvx || !n_hmx || !vtcm_size) {
|
||||
return AEE_EBADPARM;
|
||||
}
|
||||
|
||||
qurt_sysenv_max_hthreads_t hw_threads;
|
||||
qurt_sysenv_get_max_hw_threads(&hw_threads);
|
||||
uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF;
|
||||
|
||||
uint32_t n_hvx_val = hw_nhvx;
|
||||
if (n_hvx_val > hw_threads.max_hthreads) {
|
||||
n_hvx_val = hw_threads.max_hthreads;
|
||||
}
|
||||
if (n_hvx_val > HTP_MAX_NTHREADS) {
|
||||
n_hvx_val = HTP_MAX_NTHREADS;
|
||||
}
|
||||
|
||||
// for now we force n_threads == n_hvx
|
||||
*n_threads = n_hvx_val;
|
||||
*n_hvx = n_hvx_val;
|
||||
*n_hmx = 1;
|
||||
|
||||
uint32_t vtcm_sz = 8 * 1024 * 1024; // 8MB default fallback
|
||||
HAP_compute_res_query_VTCM(0, (unsigned int *)&vtcm_sz, NULL, NULL, NULL);
|
||||
*vtcm_size = vtcm_sz;
|
||||
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
static void htp_error_callback(dspqueue_t queue, int error, void * context) {
|
||||
// No errors expected on the DSP.
|
||||
FARF(ERROR, "Error callback: 0x%08x", (unsigned) error);
|
||||
@@ -554,6 +580,12 @@ static int execute_op(struct htp_ops_context * octx) {
|
||||
case HTP_OP_MUL_MAT_ID:
|
||||
return op_matmul_id(octx);
|
||||
|
||||
case HTP_OP_MUL_MAT_QKV:
|
||||
return op_matmul_qkv(octx);
|
||||
|
||||
case HTP_OP_MUL_MAT_FFN:
|
||||
return op_matmul_ffn(octx);
|
||||
|
||||
case HTP_OP_MUL:
|
||||
case HTP_OP_ADD:
|
||||
case HTP_OP_SUB:
|
||||
@@ -762,8 +794,9 @@ static void prep_tensors(struct htp_context *ctx, struct htp_buf_desc *bufs, str
|
||||
}
|
||||
}
|
||||
|
||||
static void proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) {
|
||||
static int proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) {
|
||||
memcpy(octx->op_params, op->params, sizeof(octx->op_params));
|
||||
memcpy(octx->kernel_params, op->kernel_params, sizeof(octx->kernel_params));
|
||||
octx->flags = op->flags;
|
||||
octx->op = op->opcode;
|
||||
|
||||
@@ -785,22 +818,41 @@ static void proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens,
|
||||
src->ne[0], src->ne[1], src->ne[3], src->ne[3]);
|
||||
}
|
||||
|
||||
// Prep output tensor
|
||||
struct htp_tensor *dst = tens + op->dst;
|
||||
// Prep output tensors
|
||||
for (uint32_t i = 0; i < HTP_OP_MAX_OUTPUTS; i++) {
|
||||
uint16_t dst_idx = op->dst[i];
|
||||
if (dst_idx == 0xffff) {
|
||||
octx->dsts[i] = NULL;
|
||||
continue;
|
||||
}
|
||||
struct htp_tensor *dst = tens + dst_idx;
|
||||
octx->dsts[i] = dst;
|
||||
|
||||
octx->dst = dst;
|
||||
FARF(HIGH, "prep-dst[%u] #%u: data %p size %u : %u:%u:%u:%u", i, dst_idx, (void*) dst->data, dst->size,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]);
|
||||
}
|
||||
|
||||
FARF(HIGH, "prep-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size,
|
||||
dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]);
|
||||
int status = execute_op(octx);
|
||||
|
||||
(void) execute_op(octx);
|
||||
octx->src0_spad.src = NULL;
|
||||
octx->src1_spad.src = NULL;
|
||||
octx->src2_spad.src = NULL;
|
||||
octx->src3_spad.src = NULL;
|
||||
octx->dst_spad.src = NULL;
|
||||
|
||||
// flush buffers on output
|
||||
hex_l2flush((void *) dst->data, dst->size);
|
||||
dst->flags |= HTP_TENSOR_FLUSHED;
|
||||
for (uint32_t i = 0; i < HTP_OP_MAX_OUTPUTS; i++) {
|
||||
if (octx->dsts[i]) {
|
||||
struct htp_tensor *dst = (struct htp_tensor *)octx->dsts[i];
|
||||
hex_l2flush((void *) dst->data, dst->size);
|
||||
dst->flags |= HTP_TENSOR_FLUSHED;
|
||||
|
||||
FARF(HIGH, "post-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size,
|
||||
dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]);
|
||||
FARF(HIGH, "post-dst[%u] #%u: data %p size %u : %u:%u:%u:%u", i, op->dst[i], (void*) dst->data, dst->size,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]);
|
||||
}
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
#define DSPQUEUE_POLL_TIMEOUT_USEC 100
|
||||
@@ -892,20 +944,26 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
}
|
||||
}
|
||||
|
||||
int op_status = HTP_STATUS_OK;
|
||||
uint32_t op_wakeup = n_ops / 2; // half-way throgh the batch
|
||||
|
||||
for (uint32_t i=0; i < n_ops; i++) {
|
||||
struct profile_data prof;
|
||||
|
||||
if (i == (n_ops-1)) {
|
||||
// wake up the host before starting the last op
|
||||
if (i == op_wakeup) {
|
||||
dspqueue_write_early_wakeup_noblock(queue, 0, 0);
|
||||
}
|
||||
|
||||
profile_start(ctx->profiler, &prof);
|
||||
|
||||
proc_op_req(octx, tens, i, &ops[i]);
|
||||
op_status = proc_op_req(octx, tens, i, &ops[i]);
|
||||
|
||||
profile_stop(ctx->profiler, &prof);
|
||||
|
||||
if (op_status != HTP_STATUS_OK) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (ctx->profiler) {
|
||||
pds[i].opcode = ops[i].opcode;
|
||||
pds[i].usecs = prof.usecs;
|
||||
@@ -919,7 +977,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
|
||||
struct htp_opbatch_rsp rsp;
|
||||
rsp.id = req.id;
|
||||
rsp.status = HTP_STATUS_OK;
|
||||
rsp.status = op_status;
|
||||
rsp.n_bufs = n_bufs;
|
||||
rsp.n_tensors = n_tens;
|
||||
rsp.n_ops = n_ops;
|
||||
|
||||
+2729
-4117
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,508 @@
|
||||
#ifndef HTP_MATMUL_OPS_H
|
||||
#define HTP_MATMUL_OPS_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include "htp-ops.h"
|
||||
#include "hex-fastdiv.h"
|
||||
#include "hex-common.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// --- HMX Tile Constraints ---
|
||||
#define HTP_MM_HMX_TILE_N_COLS 32
|
||||
#define HTP_MM_HMX_TILE_N_ROWS 32
|
||||
#define HTP_MM_HMX_TILE_SIZE (32 * 32 * sizeof(__fp16)) // 2048 bytes
|
||||
#define HTP_MM_HMX_TILE_N_ELMS 1024
|
||||
#define HTP_MM_HMX_MIN_NROWS 4
|
||||
|
||||
// --- Weight Repacked Tile Sizes ---
|
||||
#define HTP_MM_WEIGHT_TILE_SIZE_Q4_0 576
|
||||
#define HTP_MM_WEIGHT_TILE_SIZE_Q4_1 640
|
||||
#define HTP_MM_WEIGHT_TILE_SIZE_Q8_0 1088
|
||||
#define HTP_MM_WEIGHT_TILE_SIZE_IQ4_NL 576
|
||||
#define HTP_MM_WEIGHT_TILE_SIZE_MXFP4 544
|
||||
|
||||
// --- Weight Repacked Aligned Tile Sizes ---
|
||||
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_0 640
|
||||
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_1 640
|
||||
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q8_0 1152
|
||||
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_IQ4_NL 640
|
||||
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_MXFP4 640
|
||||
|
||||
// --- Activation Tiled Block Sizes (including padding) ---
|
||||
#define HTP_MM_ACT_TILE_SIZE_Q8_0 1152
|
||||
#define HTP_MM_ACT_TILE_SIZE_Q8_1 1280
|
||||
|
||||
#define HTP_MM_MAX_PREFETCH 16
|
||||
|
||||
// --- Solver Cost Model Penalty Weights (HMX-specific) ---
|
||||
#define HTP_MM_HMX_COST_W_DEQUANT 3 // cost penalty for quantized weight loading/dequantization
|
||||
#define HTP_MM_HMX_COST_A_CONVERT 2 // cost penalty for activation loading/conversion
|
||||
|
||||
// --- DMA Activation Transfer Configuration ---
|
||||
#define HTP_MM_DMA_ACT_ROWS_PER_STEP 2
|
||||
#define HTP_MM_DMA_ACT_MULTIPLIER 4
|
||||
|
||||
enum htp_mm_kernel_type {
|
||||
HTP_MM_KERNEL_UNSUPPORTED = 0,
|
||||
|
||||
// HMX paths
|
||||
HTP_MM_KERNEL_HMX_2D,
|
||||
HTP_MM_KERNEL_HMX_F16_BATCHED,
|
||||
|
||||
// HVX floating-point paths
|
||||
HTP_MM_KERNEL_HVX_F16_F16_VTCM,
|
||||
HTP_MM_KERNEL_HVX_F16_F16_DDR,
|
||||
HTP_MM_KERNEL_HVX_F16_F32_DDR,
|
||||
|
||||
HTP_MM_KERNEL_HVX_F32_F32_VTCM,
|
||||
HTP_MM_KERNEL_HVX_F32_F32_DDR,
|
||||
HTP_MM_KERNEL_HVX_F32_F16_DDR,
|
||||
|
||||
// HVX quantized paths
|
||||
HTP_MM_KERNEL_HVX_QUANT_ROW, // standard row-wise parallel quantization
|
||||
HTP_MM_KERNEL_HVX_QUANT_BLOCK, // parallel block-wise quantization
|
||||
HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT, // row-wise fallback flat quantization
|
||||
};
|
||||
|
||||
// Op-specific struct for precomputed matmul params
|
||||
struct htp_mm_kernel_params {
|
||||
int32_t kernel_type; // enum htp_mm_kernel_type
|
||||
int32_t pipeline; // 1 = pipelined execution, 0 = standard
|
||||
int32_t m_chunk; // Row chunk size (M chunk)
|
||||
int32_t n_chunk; // Col chunk size (N chunk)
|
||||
int32_t n_threads; // Number of threads to spawn
|
||||
int32_t n_act_threads; // Number of threads for activation preparation
|
||||
int32_t n_hmx; // 1 = use HMX, 0 = use HVX
|
||||
int32_t n_prefetch; // Prefetch lookahead buffers/rows in VTCM
|
||||
int32_t tile_size; // Weight tile size
|
||||
int32_t aligned_tile_size; // Aligned weight tile size (padded to 128)
|
||||
int32_t src1_row_size; // Row size for quantized activation
|
||||
int32_t vtcm_size; // Total required scratchpad size in VTCM
|
||||
int32_t vtcm_src0_size; // src0 scratchpad size in VTCM
|
||||
int32_t vtcm_src1_size; // src1 scratchpad size in VTCM
|
||||
int32_t vtcm_src2_size; // src2 scratchpad size in VTCM (fused only)
|
||||
int32_t vtcm_src3_size; // src3 scratchpad size in VTCM (fused only)
|
||||
int32_t vtcm_dst_size; // dst scratchpad size in VTCM
|
||||
|
||||
// Precomputed division values
|
||||
struct fastdiv_values div_ne12_ne1;
|
||||
struct fastdiv_values div_ne1;
|
||||
struct fastdiv_values div_r2;
|
||||
struct fastdiv_values div_r3;
|
||||
struct fastdiv_values div_ne11;
|
||||
};
|
||||
|
||||
#if defined(__cplusplus)
|
||||
static_assert(sizeof(struct htp_mm_kernel_params) <= 128, "htp_matmul_kernel_params is too large for kernel_params blob");
|
||||
#else
|
||||
_Static_assert(sizeof(struct htp_mm_kernel_params) <= 128, "htp_matmul_kernel_params is too large for kernel_params blob");
|
||||
#endif
|
||||
|
||||
struct mmid_row_mapping {
|
||||
uint32_t i1;
|
||||
uint32_t i2;
|
||||
};
|
||||
|
||||
// Search for optimal (mc, nc) chunk sizes within VTCM budget.
|
||||
static inline int htp_mm_hmx_compute_chunks(size_t vtcm_total,
|
||||
size_t overhead,
|
||||
size_t per_n_cost,
|
||||
size_t per_m_cost,
|
||||
size_t per_mn_cost,
|
||||
size_t m,
|
||||
size_t n,
|
||||
size_t m_block_cost,
|
||||
size_t n_block_cost,
|
||||
size_t * m_chunk_out,
|
||||
size_t * n_chunk_out,
|
||||
size_t * total_out) {
|
||||
if (m == 0 || n == 0) return -1;
|
||||
if (vtcm_total <= overhead) return -1;
|
||||
if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1;
|
||||
|
||||
const size_t usable = vtcm_total - overhead;
|
||||
|
||||
size_t best_cost = SIZE_MAX;
|
||||
size_t best_mn = 0;
|
||||
size_t best_m = 0, best_n = 0;
|
||||
|
||||
const size_t n_max = hex_align_down((size_t)n, HTP_MM_HMX_TILE_N_COLS);
|
||||
for (size_t nc = n_max; nc >= HTP_MM_HMX_TILE_N_COLS; nc -= HTP_MM_HMX_TILE_N_COLS) {
|
||||
size_t n_fixed = 0, ncmn = 0, mc_denom = 0;
|
||||
if (hex_mul_overflow(nc, per_n_cost, &n_fixed)) continue;
|
||||
if (n_fixed >= usable) goto next_nc;
|
||||
|
||||
if (hex_mul_overflow(nc, per_mn_cost, &ncmn)) goto next_nc;
|
||||
if (hex_add_overflow(per_m_cost, ncmn, &mc_denom) || mc_denom == 0) goto next_nc;
|
||||
|
||||
{
|
||||
size_t remain = usable - n_fixed;
|
||||
size_t mc = remain / mc_denom;
|
||||
mc = hex_align_down(mc, HTP_MM_HMX_TILE_N_ROWS);
|
||||
mc = hex_smin(mc, m);
|
||||
|
||||
if (mc == 0) {
|
||||
goto next_nc;
|
||||
}
|
||||
|
||||
size_t mblocks = ((size_t) m + mc - 1) / mc;
|
||||
size_t nblocks = ((size_t) n + nc - 1) / nc;
|
||||
size_t cost = mblocks * m_block_cost + nblocks * n_block_cost;
|
||||
size_t mn = mc * nc;
|
||||
if (cost < best_cost || (cost == best_cost && mn > best_mn)) {
|
||||
best_cost = cost;
|
||||
best_mn = mn;
|
||||
best_m = mc;
|
||||
best_n = nc;
|
||||
}
|
||||
}
|
||||
|
||||
next_nc:
|
||||
if (nc == HTP_MM_HMX_TILE_N_COLS) break; // avoid size_t underflow
|
||||
}
|
||||
|
||||
if (best_m == 0 || best_n == 0) return -1;
|
||||
|
||||
// Compute exact total (with overflow checks)
|
||||
size_t t0 = 0, t1 = 0, t2 = 0, mn = 0, total = 0;
|
||||
if (hex_mul_overflow(best_n, per_n_cost, &t0)) return -1;
|
||||
if (hex_mul_overflow(best_m, per_m_cost, &t1)) return -1;
|
||||
if (hex_mul_overflow(best_m, best_n, &mn)) return -1;
|
||||
if (hex_mul_overflow(mn, per_mn_cost, &t2)) return -1;
|
||||
if (hex_add_overflow(t0, t1, &total)) return -1;
|
||||
if (hex_add_overflow(total, t2, &total)) return -1;
|
||||
if (hex_add_overflow(total, overhead, &total)) return -1;
|
||||
|
||||
*m_chunk_out = best_m;
|
||||
*n_chunk_out = best_n;
|
||||
*total_out = total;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// --- Tile Size Helpers ---
|
||||
static inline uint32_t htp_mm_get_weight_tile_size(int weight_type) {
|
||||
switch (weight_type) {
|
||||
case HTP_TYPE_Q4_0:
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
return HTP_MM_WEIGHT_TILE_SIZE_Q4_0;
|
||||
case HTP_TYPE_Q4_1:
|
||||
return HTP_MM_WEIGHT_TILE_SIZE_Q4_1;
|
||||
case HTP_TYPE_Q8_0:
|
||||
return HTP_MM_WEIGHT_TILE_SIZE_Q8_0;
|
||||
case HTP_TYPE_MXFP4:
|
||||
return HTP_MM_WEIGHT_TILE_SIZE_MXFP4;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
static inline uint32_t htp_mm_get_weight_aligned_tile_size(int weight_type) {
|
||||
switch (weight_type) {
|
||||
case HTP_TYPE_Q4_0:
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_0;
|
||||
case HTP_TYPE_Q4_1:
|
||||
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_1;
|
||||
case HTP_TYPE_Q8_0:
|
||||
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q8_0;
|
||||
case HTP_TYPE_MXFP4:
|
||||
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_MXFP4;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// --- Activation/Row Size Helpers ---
|
||||
static inline size_t htp_mm_q8_0_tiled_row_size(uint32_t ne) {
|
||||
const uint32_t ne_padded = ((ne + 127) / 128) * 128;
|
||||
const uint32_t nb_32 = ne_padded / 32;
|
||||
return nb_32 * HTP_MM_ACT_TILE_SIZE_Q8_0;
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_q8_1_tiled_row_size(uint32_t ne) {
|
||||
const uint32_t ne_padded = ((ne + 127) / 128) * 128;
|
||||
const uint32_t nb_32 = ne_padded / 32;
|
||||
return nb_32 * HTP_MM_ACT_TILE_SIZE_Q8_1;
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_q8_0_flat_row_size(uint32_t ne) {
|
||||
const uint32_t quants_size = hex_align_up(ne, 128);
|
||||
const uint32_t num_scales = (ne + 31) / 32;
|
||||
const uint32_t scales_size = hex_align_up(num_scales * 2, 128);
|
||||
return quants_size + scales_size;
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_q8_1_flat_row_size(uint32_t ne) {
|
||||
const uint32_t quants_size = hex_align_up(ne, 128);
|
||||
const uint32_t num_scales = (ne + 31) / 32;
|
||||
const uint32_t scales_size = hex_align_up(num_scales * 4, 128);
|
||||
return quants_size + scales_size;
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_get_tiled_row_stride(int weight_type, uint32_t k) {
|
||||
uint32_t nb = (k + QK_Q4_0_TILED - 1) / QK_Q4_0_TILED;
|
||||
switch (weight_type) {
|
||||
case HTP_TYPE_Q4_0:
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
case HTP_TYPE_Q4_1:
|
||||
case HTP_TYPE_Q8_0:
|
||||
case HTP_TYPE_MXFP4:
|
||||
return (size_t) nb * htp_mm_get_weight_tile_size(weight_type);
|
||||
case HTP_TYPE_F16:
|
||||
return (size_t) k * sizeof(__fp16);
|
||||
case HTP_TYPE_F32:
|
||||
return (size_t) k * sizeof(float);
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_round_up(size_t n, size_t m) {
|
||||
return ((n + m - 1) / m) * m;
|
||||
}
|
||||
|
||||
static inline bool htp_mm_hmx_pipeline(uint32_t m) {
|
||||
return m > 32;
|
||||
}
|
||||
|
||||
static inline void htp_mm_hmx_get_2d_chunk_costs(
|
||||
int wtype, uint32_t k, bool pipeline, uint32_t aligned_tile_size,
|
||||
size_t * size_per_n_out, size_t * size_per_m_out, size_t * size_per_mn_out
|
||||
) {
|
||||
const bool is_quant = (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32);
|
||||
const size_t row_stride = htp_mm_get_tiled_row_stride(wtype, k);
|
||||
const size_t vec_dot_size = k * sizeof(uint16_t);
|
||||
const uint32_t n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS;
|
||||
const size_t qweight_row_stride = is_quant ? (size_t)(n_k_tiles * aligned_tile_size) / 32 : 0;
|
||||
|
||||
*size_per_n_out = (pipeline ? 2 : 1) * (is_quant ? qweight_row_stride : row_stride) +
|
||||
(pipeline ? 2 * vec_dot_size : vec_dot_size);
|
||||
*size_per_m_out = vec_dot_size;
|
||||
*size_per_mn_out = (pipeline ? 2 : 1) * sizeof(uint16_t);
|
||||
}
|
||||
|
||||
static inline void htp_mm_hmx_get_batched_chunk_costs(
|
||||
uint32_t k, uint32_t group_size,
|
||||
size_t * size_per_n_out, size_t * size_per_m_out, size_t * size_per_mn_out
|
||||
) {
|
||||
const size_t vec_dot_size = k * sizeof(uint16_t);
|
||||
*size_per_n_out = 3 * vec_dot_size;
|
||||
*size_per_m_out = group_size * vec_dot_size;
|
||||
*size_per_mn_out = sizeof(uint16_t);
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_hmx_get_2d_vtcm_size(
|
||||
int wtype, uint32_t k, size_t mc, size_t nc, bool pipeline, uint32_t act_threads, uint32_t aligned_tile_size
|
||||
) {
|
||||
const uint32_t n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS;
|
||||
const bool is_quant = (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32);
|
||||
const size_t row_stride = htp_mm_get_tiled_row_stride(wtype, k);
|
||||
const size_t vec_dot_size = k * sizeof(uint16_t);
|
||||
|
||||
const size_t act_f32_size = htp_mm_round_up(act_threads * 4 * k * sizeof(float), HTP_MM_HMX_TILE_SIZE);
|
||||
size_t weight_area_size = is_quant
|
||||
? htp_mm_round_up((nc / 32) * n_k_tiles * aligned_tile_size, HTP_MM_HMX_TILE_SIZE)
|
||||
: htp_mm_round_up(nc * row_stride, HTP_MM_HMX_TILE_SIZE);
|
||||
if (pipeline) {
|
||||
weight_area_size *= 2;
|
||||
}
|
||||
const size_t act_area_size = htp_mm_round_up(mc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
|
||||
const size_t output_area_size = htp_mm_round_up(mc * nc * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE);
|
||||
|
||||
size_t scratch0_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
|
||||
size_t scratch1_size = pipeline ? scratch0_size : 0;
|
||||
size_t scratch2_size = pipeline ? output_area_size : 0;
|
||||
|
||||
return weight_area_size + act_area_size + act_f32_size + output_area_size +
|
||||
scratch0_size + scratch1_size + scratch2_size + 256;
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_hmx_get_batched_vtcm_size(
|
||||
int wtype, uint32_t k, size_t mc, size_t nc, uint32_t group_size, bool use_dma_activation, bool pipeline, uint32_t act_threads) {
|
||||
(void)wtype;
|
||||
(void)pipeline;
|
||||
const size_t vec_dot_size = k * sizeof(uint16_t);
|
||||
const size_t f32_scratch_size = use_dma_activation
|
||||
? htp_mm_round_up(act_threads * 4 * k * sizeof(float), HTP_MM_HMX_TILE_SIZE) : 0;
|
||||
|
||||
const size_t act_head_stride = mc * k;
|
||||
const size_t weight_area_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
|
||||
const size_t act_area_size = htp_mm_round_up(group_size * act_head_stride * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE);
|
||||
const size_t output_area_size = htp_mm_round_up(group_size * mc * nc * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE);
|
||||
const size_t scratch_area_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
|
||||
|
||||
return weight_area_size + act_area_size + output_area_size +
|
||||
2 * scratch_area_size + 256 + f32_scratch_size;
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_hvx_get_vtcm_sizes(
|
||||
int kernel_type,
|
||||
int wtype,
|
||||
uint32_t ne10, // k
|
||||
uint32_t src1_nrows, // m_total (or act_nrows)
|
||||
uint32_t n_threads,
|
||||
size_t dst_row_size,
|
||||
size_t src0_row_size,
|
||||
size_t src1_row_size,
|
||||
uint32_t n_prefetch,
|
||||
size_t * vtcm_src0_size_out,
|
||||
size_t * vtcm_src1_size_out,
|
||||
size_t * vtcm_dst_size_out
|
||||
) {
|
||||
size_t vtcm_src0_size = 0;
|
||||
size_t vtcm_src1_size = 0;
|
||||
size_t vtcm_dst_size = 0;
|
||||
|
||||
const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 ||
|
||||
wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL ||
|
||||
wtype == HTP_TYPE_MXFP4);
|
||||
|
||||
const size_t src0_row_size_padded = htp_mm_round_up(src0_row_size, 128);
|
||||
const size_t dst_nrows = (src1_nrows > 1) ? 0 : 1;
|
||||
|
||||
switch (kernel_type) {
|
||||
case HTP_MM_KERNEL_HVX_F16_F16_VTCM: {
|
||||
size_t f16_src1_row_size = htp_mm_round_up(ne10 * 2, 128);
|
||||
vtcm_src1_size = htp_mm_round_up(f16_src1_row_size * src1_nrows, 256);
|
||||
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256) * n_threads;
|
||||
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0;
|
||||
break;
|
||||
}
|
||||
case HTP_MM_KERNEL_HVX_F16_F32_DDR:
|
||||
case HTP_MM_KERNEL_HVX_F16_F16_DDR:
|
||||
case HTP_MM_KERNEL_HVX_F32_F32_DDR:
|
||||
case HTP_MM_KERNEL_HVX_F32_F16_DDR: {
|
||||
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size, 256) * n_threads;
|
||||
vtcm_src1_size = htp_mm_round_up(n_prefetch * src1_row_size, 256) * n_threads;
|
||||
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0;
|
||||
break;
|
||||
}
|
||||
case HTP_MM_KERNEL_HVX_F32_F32_VTCM: {
|
||||
size_t f32_src1_row_size = htp_mm_round_up(ne10 * 4, 128);
|
||||
vtcm_src1_size = htp_mm_round_up(f32_src1_row_size * src1_nrows, 256);
|
||||
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256) * n_threads;
|
||||
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0;
|
||||
break;
|
||||
}
|
||||
case HTP_MM_KERNEL_HVX_QUANT_BLOCK:
|
||||
case HTP_MM_KERNEL_HVX_QUANT_ROW: {
|
||||
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10);
|
||||
|
||||
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
|
||||
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
|
||||
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
|
||||
|
||||
// src0 spad is also used in dynamic quantizer to store padded src1 rows
|
||||
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, QK_Q8_0_TILED * sizeof(float));
|
||||
if (vtcm_src0_size < src1_row_size_padded) {
|
||||
vtcm_src0_size = src1_row_size_padded;
|
||||
}
|
||||
|
||||
vtcm_src0_size = vtcm_src0_size * n_threads;
|
||||
vtcm_dst_size = vtcm_dst_size * n_threads;
|
||||
|
||||
if (is_repack) {
|
||||
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
|
||||
uint32_t n_k_tiles = ne10 / 32;
|
||||
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
|
||||
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
vtcm_src0_size = repacked_vtcm_size * n_threads;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT: {
|
||||
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10);
|
||||
|
||||
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
|
||||
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
|
||||
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
|
||||
|
||||
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, 256);
|
||||
if (vtcm_src0_size < src1_row_size_padded) {
|
||||
vtcm_src0_size = src1_row_size_padded;
|
||||
}
|
||||
|
||||
vtcm_src0_size = vtcm_src0_size * n_threads;
|
||||
vtcm_dst_size = vtcm_dst_size * n_threads;
|
||||
|
||||
if (is_repack) {
|
||||
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
|
||||
uint32_t n_k_tiles = ne10 / 32;
|
||||
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
|
||||
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
vtcm_src0_size = repacked_vtcm_size * n_threads;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
*vtcm_src0_size_out = vtcm_src0_size;
|
||||
*vtcm_src1_size_out = vtcm_src1_size;
|
||||
*vtcm_dst_size_out = vtcm_dst_size;
|
||||
|
||||
return vtcm_src0_size + vtcm_src1_size + vtcm_dst_size;
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_hvx_id_get_vtcm_sizes(
|
||||
int wtype,
|
||||
uint32_t ne10, // k
|
||||
uint32_t src1_nrows,
|
||||
uint32_t n_threads,
|
||||
size_t src0_row_size, // nb01
|
||||
uint32_t n_prefetch,
|
||||
size_t * vtcm_src0_size_out,
|
||||
size_t * vtcm_src1_size_out
|
||||
) {
|
||||
const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 ||
|
||||
wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL ||
|
||||
wtype == HTP_TYPE_MXFP4);
|
||||
|
||||
const size_t src0_row_size_padded = htp_mm_round_up(src0_row_size, 128);
|
||||
const size_t src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10)
|
||||
: htp_mm_q8_0_tiled_row_size(ne10);
|
||||
|
||||
size_t src0_sz_per_thread = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
|
||||
size_t src1_sz = htp_mm_round_up(src1_row_size * src1_nrows, 256);
|
||||
|
||||
// src0 spad also holds temporary transposed src1 columns during dynamic quantization.
|
||||
const size_t src1_row_size_padded = htp_mm_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
|
||||
if (src0_sz_per_thread < src1_row_size_padded) {
|
||||
src0_sz_per_thread = src1_row_size_padded;
|
||||
}
|
||||
|
||||
if (is_repack) {
|
||||
const uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
|
||||
const uint32_t n_k_tiles = ne10 / 32;
|
||||
const uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
|
||||
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
src0_sz_per_thread = repacked_vtcm_size;
|
||||
}
|
||||
|
||||
const size_t vtcm_src0_size = src0_sz_per_thread * n_threads;
|
||||
|
||||
*vtcm_src0_size_out = vtcm_src0_size;
|
||||
*vtcm_src1_size_out = src1_sz;
|
||||
|
||||
return vtcm_src0_size + src1_sz;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // HTP_MATMUL_OPS_H
|
||||
@@ -14,8 +14,6 @@ Drivers_Dir = 13
|
||||
1 = %DiskId%
|
||||
|
||||
[SourceDisksFiles]
|
||||
libggml-htp-v68.so = 1
|
||||
libggml-htp-v69.so = 1
|
||||
libggml-htp-v73.so = 1
|
||||
libggml-htp-v75.so = 1
|
||||
libggml-htp-v79.so = 1
|
||||
@@ -28,8 +26,6 @@ ExcludeFromSelect = *
|
||||
CopyFiles=Drivers_Dir
|
||||
|
||||
[Drivers_Dir]
|
||||
libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
libggml-htp-v79.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
|
||||
@@ -10152,14 +10152,8 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
const int ne00 = src0 ? src0->ne[0] : 0;
|
||||
const int ne01 = src0 ? src0->ne[1] : 0;
|
||||
const int ne02 = src0 ? src0->ne[2] : 0;
|
||||
const int ne03 = src0 ? src0->ne[3] : 0;
|
||||
|
||||
const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
|
||||
const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
|
||||
const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
|
||||
GGML_TENSOR_LOCALS(int, ne0, src0, ne);
|
||||
GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);
|
||||
|
||||
const int nth = MIN(64, ne00);
|
||||
|
||||
@@ -10173,11 +10167,12 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth, NULL));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &eps));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float)*nth, NULL));
|
||||
|
||||
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
||||
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
||||
|
||||
@@ -24,6 +24,7 @@ kernel void kernel_norm(
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
@@ -43,7 +44,8 @@ kernel void kernel_norm(
|
||||
// parallel sum
|
||||
sum[get_local_id(0)] = 0.0f;
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
sum[get_local_id(0)] += x[i00];
|
||||
// this kernel handles float, nb00/4 translates byte offset to element offset
|
||||
sum[get_local_id(0)] += x[i00*nb00/4];
|
||||
}
|
||||
// reduce
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -60,7 +62,8 @@ kernel void kernel_norm(
|
||||
global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
sum[get_local_id(0)] = 0.0f;
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
y[i00] = x[i00] - mean;
|
||||
// this kernel handles float, nb00/4 translates byte offset to element offset
|
||||
y[i00] = x[i00*nb00/4] - mean;
|
||||
sum[get_local_id(0)] += y[i00] * y[i00];
|
||||
}
|
||||
|
||||
|
||||
@@ -699,6 +699,7 @@ struct vk_device_struct {
|
||||
|
||||
bool add_rms_fusion;
|
||||
uint32_t partials_binding_alignment;
|
||||
uint32_t max_nodes_per_submit;
|
||||
|
||||
bool shader_64b_indexing;
|
||||
|
||||
@@ -5878,6 +5879,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
device->subgroup_vote = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
||||
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eVote);
|
||||
|
||||
// Submit at least every 100 nodes, in case there are workloads without as much matmul.
|
||||
device->max_nodes_per_submit = 100;
|
||||
const char* GGML_VK_MAX_NODES_PER_SUBMIT = getenv("GGML_VK_MAX_NODES_PER_SUBMIT");
|
||||
if (GGML_VK_MAX_NODES_PER_SUBMIT != nullptr) {
|
||||
uint32_t max_nodes_per_submit = std::stoul(GGML_VK_MAX_NODES_PER_SUBMIT);
|
||||
device->max_nodes_per_submit = std::max(max_nodes_per_submit, 1u);
|
||||
}
|
||||
|
||||
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
|
||||
|
||||
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
||||
@@ -16173,8 +16182,6 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
|
||||
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
|
||||
// (and scaled down based on model size, so smaller models submit earlier).
|
||||
// Also submit at least every 100 nodes, in case there are workloads without as much matmul.
|
||||
int nodes_per_submit = 100;
|
||||
int submitted_nodes = 0;
|
||||
int submit_count = 0;
|
||||
uint64_t mul_mat_bytes = 0;
|
||||
@@ -16400,7 +16407,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
|
||||
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
|
||||
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
|
||||
bool submit = (submitted_nodes >= nodes_per_submit) ||
|
||||
bool submit = ((uint32_t)submitted_nodes >= ctx->device->max_nodes_per_submit) ||
|
||||
(mul_mat_bytes_per_submit != 0 && mul_mat_bytes >= mul_mat_bytes_per_submit) ||
|
||||
(i + ctx->num_additional_fused_ops >= last_node) ||
|
||||
(almost_ready && !ctx->almost_ready_fence_pending);
|
||||
|
||||
@@ -463,6 +463,7 @@ void main() {
|
||||
}
|
||||
rowmaxf = max(rowmaxf, float(Sf[r][c]));
|
||||
}
|
||||
rowmaxf += FATTN_KQ_MAX_OFFSET;
|
||||
float Moldf = Mf[r];
|
||||
|
||||
// M = max(rowmax, Mold)
|
||||
|
||||
@@ -352,6 +352,7 @@ void main() {
|
||||
}
|
||||
rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp]));
|
||||
}
|
||||
rowmaxf += FATTN_KQ_MAX_OFFSET;
|
||||
float Moldf = Mf[r];
|
||||
|
||||
// Compute max across the row
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <future>
|
||||
#include <queue>
|
||||
#include <condition_variable>
|
||||
#include <atomic>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <cstdlib>
|
||||
@@ -34,6 +35,9 @@
|
||||
|
||||
std::mutex lock;
|
||||
std::vector<std::pair<std::string, std::string>> shader_fnames;
|
||||
// Set when any shader subprocess fails (non-zero exit / stderr / launch failure) so the
|
||||
// build is stopped instead of silently producing a broken libggml-vulkan. (issue #24393)
|
||||
static std::atomic<bool> compile_failed{false};
|
||||
std::locale c_locale("C");
|
||||
|
||||
std::string GLSLC = "glslc";
|
||||
@@ -78,7 +82,7 @@ enum MatMulIdType {
|
||||
|
||||
namespace {
|
||||
|
||||
void execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
|
||||
int execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
|
||||
#ifdef _WIN32
|
||||
HANDLE stdout_read, stdout_write;
|
||||
HANDLE stderr_read, stderr_write;
|
||||
@@ -127,8 +131,11 @@ void execute_command(std::vector<std::string>& command, std::string& stdout_str,
|
||||
CloseHandle(stdout_read);
|
||||
CloseHandle(stderr_read);
|
||||
WaitForSingleObject(pi.hProcess, INFINITE);
|
||||
DWORD exit_code = 1;
|
||||
GetExitCodeProcess(pi.hProcess, &exit_code);
|
||||
CloseHandle(pi.hProcess);
|
||||
CloseHandle(pi.hThread);
|
||||
return (int)exit_code;
|
||||
#else
|
||||
int stdout_pipe[2];
|
||||
int stderr_pipe[2];
|
||||
@@ -175,7 +182,9 @@ void execute_command(std::vector<std::string>& command, std::string& stdout_str,
|
||||
|
||||
close(stdout_pipe[0]);
|
||||
close(stderr_pipe[0]);
|
||||
waitpid(pid, nullptr, 0);
|
||||
int status = 0;
|
||||
waitpid(pid, &status, 0);
|
||||
return WIFEXITED(status) ? WEXITSTATUS(status) : -1;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -372,13 +381,14 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
|
||||
// }
|
||||
// std::cout << std::endl;
|
||||
|
||||
execute_command(cmd, stdout_str, stderr_str);
|
||||
if (!stderr_str.empty()) {
|
||||
std::cerr << "cannot compile " << name << "\n\n";
|
||||
int exit_code = execute_command(cmd, stdout_str, stderr_str);
|
||||
if (exit_code != 0 || !stderr_str.empty()) {
|
||||
std::cerr << "cannot compile " << name << " (exit code " << exit_code << ")\n\n";
|
||||
for (const auto& part : cmd) {
|
||||
std::cerr << part << " ";
|
||||
}
|
||||
std::cerr << "\n\n" << stderr_str << std::endl;
|
||||
compile_failed = true;
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -398,6 +408,7 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
|
||||
shader_fnames.push_back(std::make_pair(name, out_path));
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
|
||||
compile_failed = true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1271,6 +1282,11 @@ int main(int argc, char** argv) {
|
||||
|
||||
process_shaders();
|
||||
|
||||
if (compile_failed) {
|
||||
std::cerr << "vulkan-shaders-gen: one or more shaders failed to compile" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
write_output_files();
|
||||
|
||||
return EXIT_SUCCESS;
|
||||
|
||||
@@ -57,19 +57,25 @@ oppoll=
|
||||
opflt=
|
||||
[ "$OF" != "" ] && opflt="GGML_HEXAGON_OPFILTER=$OF"
|
||||
|
||||
opfuse=
|
||||
[ "$OC" != "" ] && opfuse="GGML_HEXAGON_OPFUSION=$OC"
|
||||
|
||||
vmem=
|
||||
[ "$VM" != "" ] && vmem="GGML_HEXAGON_VMEM=$VM"
|
||||
|
||||
mbuf=
|
||||
[ "$MB" != "" ] && mbuf="GGML_HEXAGON_MBUF=$MB"
|
||||
|
||||
mmsel=
|
||||
[ "$MM" != "" ] && mmsel="GGML_HEXAGON_MM_SELECT=$MM"
|
||||
|
||||
set -x
|
||||
|
||||
adb $adbserial $adbhost shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $vmem $mbuf \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $opfuse $vmem $mbuf $mmsel \
|
||||
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --ubatch-size 1024 -fa on \
|
||||
|
||||
@@ -51,6 +51,12 @@ opqueue=
|
||||
oppoll=
|
||||
[ "$OP" != "" ] && oppoll="GGML_HEXAGON_OPPOLL=$OP"
|
||||
|
||||
opfuse=
|
||||
[ "$OC" != "" ] && opfuse="GGML_HEXAGON_OPFUSION=$OC"
|
||||
|
||||
mmsel=
|
||||
[ "$MM" != "" ] && mmsel="GGML_HEXAGON_MM_SELECT=$MM"
|
||||
|
||||
set -x
|
||||
|
||||
tool=$1; shift
|
||||
@@ -59,5 +65,5 @@ adb $adbserial $adbhost shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll ./$branch/bin/$tool $@ \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opfuse $mmsel ./$branch/bin/$tool $@ \
|
||||
"
|
||||
|
||||
@@ -26,7 +26,7 @@ COL_MAP = {
|
||||
}
|
||||
|
||||
op_pattern = re.compile(
|
||||
r"profile-op\s+(?P<op_name>[A-Z_0-9+]+):\s+.*?\s+:\s+(?P<dims>[\d:x\s\->!]+)\s+:\s+(?P<types>[a-z\d_\s\->x]+)\s+:\s+.*?\s+(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?"
|
||||
r"profile-op\s+(?P<op_name>[A-Z_0-9+]+):\s+.*?\s+:\s+(?P<dims>[\d:x\s\->!]+)\s+:\s+(?P<types>[a-z\d_\s\->x]+)\s+:\s+.*?\s+:\s+(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?"
|
||||
)
|
||||
|
||||
trace_pattern = re.compile(
|
||||
@@ -93,9 +93,40 @@ def parse_log(file_path, pmu_index=None):
|
||||
+ int(ts_match.group('us'))
|
||||
)
|
||||
|
||||
op_match = op_pattern.search(line)
|
||||
if "|" in line and "profile-op" in line:
|
||||
parts = [p.strip() for p in line.split("|")]
|
||||
prefix = parts[0]
|
||||
prefix_match = re.search(r"profile-op\s+(?P<op_name>[A-Z_0-9+]+)", prefix)
|
||||
if not prefix_match:
|
||||
continue
|
||||
|
||||
if len(parts) == 7:
|
||||
dims, types, timings = parts[2], parts[3], parts[6]
|
||||
elif len(parts) == 6:
|
||||
dims, types, timings = parts[2], parts[3], parts[5]
|
||||
else:
|
||||
continue
|
||||
|
||||
timing_match = re.search(
|
||||
r"(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?",
|
||||
timings
|
||||
)
|
||||
if not timing_match:
|
||||
continue
|
||||
|
||||
op_match = timing_match
|
||||
op_name = prefix_match.group("op_name")
|
||||
else:
|
||||
op_match = op_pattern.search(line)
|
||||
if op_match:
|
||||
op_name = op_match.group('op_name')
|
||||
dims = op_match.group('dims').strip()
|
||||
types = op_match.group('types').strip()
|
||||
else:
|
||||
op_match = None
|
||||
|
||||
if op_match:
|
||||
pmu_raw = op_match.group('pmu')
|
||||
pmu_raw = op_match.group('pmu') if 'pmu' in op_match.groupdict() else None
|
||||
pmu_val = None
|
||||
if pmu_raw and pmu_index is not None:
|
||||
try:
|
||||
@@ -105,7 +136,7 @@ def parse_log(file_path, pmu_index=None):
|
||||
except (ValueError, IndexError):
|
||||
pmu_val = None
|
||||
|
||||
evt_raw = op_match.group('evt')
|
||||
evt_raw = op_match.group('evt') if 'evt' in op_match.groupdict() else None
|
||||
evt_val = None
|
||||
if evt_raw:
|
||||
try:
|
||||
@@ -122,9 +153,9 @@ def parse_log(file_path, pmu_index=None):
|
||||
op_text = line[idx + 11:].strip() if idx != -1 else line.strip()
|
||||
|
||||
current_op = {
|
||||
'name': op_match.group('op_name'),
|
||||
'dims': op_match.group('dims').strip(),
|
||||
'types': op_match.group('types').strip(),
|
||||
'name': op_name,
|
||||
'dims': dims,
|
||||
'types': types,
|
||||
'op_text': op_text,
|
||||
'usec': int(op_match.group('usec')),
|
||||
'cycles': int(op_match.group('cycles')),
|
||||
|
||||
@@ -12,7 +12,7 @@ from collections import defaultdict
|
||||
logger = logging.getLogger("ggml-hexagon-trace")
|
||||
|
||||
op_pattern = re.compile(
|
||||
r"profile-op\s+(?P<op_name>[A-Z_0-9+]+):\s+.*?\s+:\s+(?P<dims>[\d:x\s\->!]+)\s+:\s+(?P<types>[a-z\d_\s\->x]+)\s+:\s+(?P<strides>[\d:x\s\->!]+)\s+:\s+(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?"
|
||||
r"profile-op\s+(?P<op_name>[A-Z_0-9+]+):\s+.*?\s+:\s+(?P<dims>[\d:x\s\->!]+)\s+:\s+(?P<types>[a-z\d_\s\->x]+)\s+:\s+(?P<strides>[\d:x\s\->!]+?)\s+:\s+(?:(?P<params>.*?)\s+:\s+)?(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?"
|
||||
)
|
||||
|
||||
trace_pattern = re.compile(
|
||||
@@ -66,7 +66,40 @@ def parse_log(file_path):
|
||||
|
||||
for line in f:
|
||||
line_idx += 1
|
||||
op_match = op_pattern.search(line)
|
||||
if "|" in line and "profile-op" in line:
|
||||
parts = [p.strip() for p in line.split("|")]
|
||||
prefix = parts[0]
|
||||
prefix_match = re.search(r"profile-op\s+(?P<op_name>[A-Z_0-9+]+)", prefix)
|
||||
if not prefix_match:
|
||||
continue
|
||||
|
||||
if len(parts) == 7:
|
||||
dims, types, strides, params, timings = parts[2], parts[3], parts[4], parts[5], parts[6]
|
||||
elif len(parts) == 6:
|
||||
dims, types, strides, params, timings = parts[2], parts[3], parts[4], "", parts[5]
|
||||
else:
|
||||
continue
|
||||
|
||||
timing_match = re.search(
|
||||
r"(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?",
|
||||
timings
|
||||
)
|
||||
if not timing_match:
|
||||
continue
|
||||
|
||||
op_match = timing_match
|
||||
op_name = prefix_match.group("op_name")
|
||||
else:
|
||||
op_match = op_pattern.search(line)
|
||||
if op_match:
|
||||
op_name = op_match.group('op_name')
|
||||
dims = op_match.group('dims').strip() if op_match.group('dims') else ''
|
||||
types = op_match.group('types').strip() if op_match.group('types') else ''
|
||||
strides = op_match.group('strides').strip() if op_match.group('strides') else ''
|
||||
params = op_match.group('params').strip() if ('params' in op_match.groupdict() and op_match.group('params')) else ''
|
||||
else:
|
||||
op_match = None
|
||||
|
||||
if op_match:
|
||||
cycles_start_raw = op_match.group('start')
|
||||
unwrapped_cycles_start = None
|
||||
@@ -77,10 +110,11 @@ def parse_log(file_path):
|
||||
op_text = line[idx + 11:].strip() if idx != -1 else line.strip()
|
||||
|
||||
current_op = {
|
||||
'name': op_match.group('op_name'),
|
||||
'dims': op_match.group('dims').strip() if op_match.group('dims') else '',
|
||||
'types': op_match.group('types').strip() if op_match.group('types') else '',
|
||||
'strides': op_match.group('strides').strip() if op_match.group('strides') else '',
|
||||
'name': op_name,
|
||||
'dims': dims,
|
||||
'types': types,
|
||||
'strides': strides,
|
||||
'params': params,
|
||||
'op_text': op_text,
|
||||
'usec': int(op_match.group('usec')),
|
||||
'cycles': int(op_match.group('cycles')),
|
||||
@@ -397,6 +431,8 @@ def generate_perfetto_trace(filtered_ops, output_path):
|
||||
debug_annots.append(make_debug_annotation("line", int_val=op['line_num']))
|
||||
if 'strides' in op and op['strides']:
|
||||
debug_annots.append(make_debug_annotation("strides", string_val=op['strides']))
|
||||
if 'params' in op and op['params'] and op['params'] != '----':
|
||||
debug_annots.append(make_debug_annotation("params", string_val=op['params']))
|
||||
|
||||
# Slice Begin
|
||||
evt_begin = make_track_event(1, 2, name=f"{op['name']} ({op['dims']})", category="operator", debug_annotations=debug_annots)
|
||||
|
||||
+14
-4
@@ -190,7 +190,15 @@ llama_model_lfm2::graph<iswa>::graph(const llama_model & model, const llm_graph_
|
||||
auto * conv_rs = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs);
|
||||
auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);
|
||||
|
||||
bx = ggml_concat(ctx0, conv, bx, 0);
|
||||
// causal prepends the state, non-causal pads symmetrically for a centered window
|
||||
if (hparams.causal_attn) {
|
||||
bx = ggml_concat(ctx0, conv, bx, 0);
|
||||
} else {
|
||||
const int64_t pad = (hparams.n_shortconv_l_cache - 1) / 2;
|
||||
auto * left = ggml_cont(ctx0,
|
||||
ggml_view_3d(ctx0, conv, pad, hparams.n_embd, n_seqs, conv->nb[1], conv->nb[2], (d_conv - pad) * conv->nb[0]));
|
||||
bx = ggml_pad_ext(ctx0, ggml_concat(ctx0, left, bx, 0), 0, pad, 0, 0, 0, 0, 0, 0);
|
||||
}
|
||||
GGML_ASSERT(bx->ne[0] > conv->ne[0]);
|
||||
|
||||
// last d_conv columns is a new conv state
|
||||
@@ -266,10 +274,12 @@ llama_model_lfm2::graph<iswa>::graph(const llama_model & model, const llm_graph_
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
cur = build_lora_mm(model.output, cur, model.output_s);
|
||||
cb(cur, "result_output", -1);
|
||||
if (!cparams.embeddings) {
|
||||
cur = build_lora_mm(model.output, cur, model.output_s);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
res->t_logits = cur;
|
||||
res->t_logits = cur;
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
@@ -199,7 +199,6 @@ llama_build_and_test(test-jinja.cpp)
|
||||
llama_test(test-jinja NAME test-jinja-py ARGS -py LABEL python)
|
||||
llama_build_and_test(test-chat-auto-parser.cpp WORKING_DIRECTORY ${PROJECT_SOURCE_DIR})
|
||||
llama_build_and_test(test-chat-template.cpp)
|
||||
llama_build_and_test(test-json-partial.cpp)
|
||||
llama_build_and_test(test-log.cpp)
|
||||
llama_build_and_test(
|
||||
test-peg-parser.cpp
|
||||
|
||||
@@ -8420,6 +8420,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 2880, 32, 2880, {1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 2880, 32, 2880, {1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_MXFP4, GGML_TYPE_F32, 2880, 32, 2880, {1, 1}, {1, 1}));
|
||||
|
||||
|
||||
#if 0
|
||||
{
|
||||
// Test paths in OpenCL
|
||||
@@ -8594,6 +8599,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
|
||||
// gpt-oss issue with Vulkan mmq_id
|
||||
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
|
||||
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_Q4_0, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
|
||||
|
||||
for (ggml_type type_a : all_types) {
|
||||
test_cases.emplace_back(new test_mul_mat_id(type_a, GGML_TYPE_F32, 4, 2, false, 64, 16, 3*ggml_blck_size(type_a)));
|
||||
|
||||
@@ -1,287 +0,0 @@
|
||||
#include "common.h"
|
||||
#include "json-partial.h"
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
|
||||
template <class T> static void assert_equals(const T & expected, const T & actual) {
|
||||
if (expected != actual) {
|
||||
std::cerr << "Expected: " << expected << std::endl;
|
||||
std::cerr << "Actual: " << actual << std::endl;
|
||||
std::cerr << std::flush;
|
||||
throw std::runtime_error("Test failed");
|
||||
}
|
||||
}
|
||||
|
||||
static void test_json_healing() {
|
||||
auto parse = [](const std::string & str) {
|
||||
std::cerr << "# Parsing: " << str << '\n';
|
||||
std::string::const_iterator it = str.begin();
|
||||
const auto end = str.end();
|
||||
common_json out;
|
||||
std::string healing_marker = "$llama.cpp.json$";
|
||||
if (common_json_parse(it, end, healing_marker, out)) {
|
||||
auto dump = out.json.dump();
|
||||
std::cerr << "Parsed: " << dump << '\n';
|
||||
std::cerr << "Magic: " << out.healing_marker.json_dump_marker << '\n';
|
||||
std::string result;
|
||||
if (!out.healing_marker.json_dump_marker.empty()) {
|
||||
auto i = dump.find(out.healing_marker.json_dump_marker);
|
||||
if (i == std::string::npos) {
|
||||
throw std::runtime_error("Failed to find magic in dump " + dump + " (magic: " + out.healing_marker.json_dump_marker + ")");
|
||||
}
|
||||
result = dump.substr(0, i);
|
||||
} else {
|
||||
result = dump;
|
||||
}
|
||||
std::cerr << "Result: " << result << '\n';
|
||||
if (string_starts_with(str, result)) {
|
||||
std::cerr << "Failure!\n";
|
||||
}
|
||||
// return dump;
|
||||
} else {
|
||||
throw std::runtime_error("Failed to parse: " + str);
|
||||
}
|
||||
|
||||
};
|
||||
auto parse_all = [&](const std::string & str) {
|
||||
for (size_t i = 1; i < str.size(); i++) {
|
||||
parse(str.substr(0, i));
|
||||
}
|
||||
};
|
||||
parse_all("{\"a\": \"b\"}");
|
||||
parse_all("{\"hey\": 1, \"ho\\\"ha\": [1]}");
|
||||
|
||||
parse_all("[{\"a\": \"b\"}]");
|
||||
|
||||
auto test = [&](const std::vector<std::string> & inputs, const std::string & expected, const std::string & expected_marker) {
|
||||
for (const auto & input : inputs) {
|
||||
common_json out;
|
||||
assert_equals(true, common_json_parse(input, "$foo", out));
|
||||
assert_equals<std::string>(expected, out.json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true));
|
||||
assert_equals<std::string>(expected_marker, out.healing_marker.json_dump_marker);
|
||||
}
|
||||
};
|
||||
// No healing needed:
|
||||
test(
|
||||
{
|
||||
R"([{"a":"b"}, "y"])",
|
||||
},
|
||||
R"([{"a":"b"},"y"])",
|
||||
""
|
||||
);
|
||||
// Partial literals can't be healed:
|
||||
test(
|
||||
{
|
||||
R"([1)",
|
||||
R"([tru)",
|
||||
R"([n)",
|
||||
R"([nul)",
|
||||
R"([23.2)",
|
||||
},
|
||||
R"(["$foo"])",
|
||||
R"("$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a": 1)",
|
||||
R"({"a": tru)",
|
||||
R"({"a": n)",
|
||||
R"({"a": nul)",
|
||||
R"({"a": 23.2)",
|
||||
},
|
||||
R"({"a":"$foo"})",
|
||||
R"("$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({)",
|
||||
},
|
||||
R"({"$foo":1})",
|
||||
R"("$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([)",
|
||||
},
|
||||
R"(["$foo"])",
|
||||
R"("$foo)"
|
||||
);
|
||||
// Healing right after a full literal
|
||||
test(
|
||||
{
|
||||
R"(1 )",
|
||||
},
|
||||
R"(1)",
|
||||
""
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"(true)",
|
||||
R"(true )",
|
||||
},
|
||||
R"(true)",
|
||||
""
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"(null)",
|
||||
R"(null )",
|
||||
},
|
||||
R"(null)",
|
||||
""
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([1 )",
|
||||
},
|
||||
R"([1,"$foo"])",
|
||||
R"(,"$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([{})",
|
||||
R"([{} )",
|
||||
},
|
||||
R"([{},"$foo"])",
|
||||
R"(,"$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([true)",
|
||||
},
|
||||
// TODO: detect the true/false/null literal was complete
|
||||
R"(["$foo"])",
|
||||
R"("$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([true )",
|
||||
},
|
||||
R"([true,"$foo"])",
|
||||
R"(,"$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([true,)",
|
||||
},
|
||||
R"([true,"$foo"])",
|
||||
R"("$foo)"
|
||||
);
|
||||
// Test nesting
|
||||
test(
|
||||
{
|
||||
R"([{"a": [{"b": [{)",
|
||||
},
|
||||
R"([{"a":[{"b":[{"$foo":1}]}]}])",
|
||||
R"("$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([{"a": [{"b": [)",
|
||||
},
|
||||
R"([{"a":[{"b":["$foo"]}]}])",
|
||||
R"("$foo)"
|
||||
);
|
||||
|
||||
test(
|
||||
{
|
||||
R"([{"a": "b"})",
|
||||
R"([{"a": "b"} )",
|
||||
},
|
||||
R"([{"a":"b"},"$foo"])",
|
||||
R"(,"$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([{"a": "b"},)",
|
||||
R"([{"a": "b"}, )",
|
||||
},
|
||||
R"([{"a":"b"},"$foo"])",
|
||||
R"("$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({ "code)",
|
||||
},
|
||||
R"({"code$foo":1})",
|
||||
R"($foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({ "code\)",
|
||||
},
|
||||
R"({"code\\$foo":1})",
|
||||
R"(\$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({ "code")",
|
||||
},
|
||||
R"({"code":"$foo"})",
|
||||
R"(:"$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({ "key")",
|
||||
},
|
||||
R"({"key":"$foo"})",
|
||||
R"(:"$foo)"
|
||||
);
|
||||
// Test unicode escape sequences
|
||||
test(
|
||||
{
|
||||
R"({"a":"\u)",
|
||||
},
|
||||
R"({"a":"\u0000$foo"})",
|
||||
R"(0000$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\u00)",
|
||||
},
|
||||
R"({"a":"\u0000$foo"})",
|
||||
R"(00$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\ud300)",
|
||||
},
|
||||
R"({"a":"\ud300$foo"})",
|
||||
R"($foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\ud800)",
|
||||
},
|
||||
R"({"a":"\ud800\udc00$foo"})",
|
||||
R"(\udc00$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\ud800\)",
|
||||
},
|
||||
R"({"a":"\ud800\udc00$foo"})",
|
||||
R"(udc00$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\ud800\u)",
|
||||
},
|
||||
R"({"a":"\ud800\udc00$foo"})",
|
||||
R"(dc00$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\ud800\udc00)",
|
||||
},
|
||||
R"({"a":"\ud800\udc00$foo"})",
|
||||
R"($foo)"
|
||||
);
|
||||
}
|
||||
|
||||
int main() {
|
||||
test_json_healing();
|
||||
std::cerr << "All tests passed.\n";
|
||||
return 0;
|
||||
}
|
||||
@@ -9,6 +9,7 @@ its output, and holds them against the HF model's scores.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import unicodedata
|
||||
@@ -28,6 +29,12 @@ class ModelSpec:
|
||||
mmproj_arg: str
|
||||
model_default: str
|
||||
mmproj_default: str
|
||||
prompt: str = "Free OCR. "
|
||||
n_predict: int = 512
|
||||
n_ctx: int | None = None
|
||||
# Unlimited-OCR's "document parsing" prompt emits <|det|> grounding markup that
|
||||
# the HF reference strips in result.md; drop it before scoring to match.
|
||||
strip_grounding: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -63,6 +70,20 @@ MODELS = {
|
||||
model_default="gguf_models/deepseek-ai/deepseek-ocr-2-bf16.gguf",
|
||||
mmproj_default="gguf_models/deepseek-ai/mmproj-deepseek-ocr-2-bf16.gguf",
|
||||
),
|
||||
"unlimited": ModelSpec(
|
||||
key="unlimited", label="Unlimited-OCR",
|
||||
model_arg="--llama-model-unlimited", mmproj_arg="--mmproj-unlimited",
|
||||
model_default="gguf_models/baidu/unlimited-ocr-bf16.gguf",
|
||||
mmproj_default="gguf_models/baidu/mmproj-unlimited-ocr-bf16.gguf",
|
||||
# "Free OCR." immediately emits EOS on this checkpoint; the HF reference
|
||||
# (demo/unlimited_ocr_scores.py) uses "document parsing.", which grounds.
|
||||
prompt="document parsing.",
|
||||
# Grounding emits ~3x the tokens of plain OCR, so it needs a larger budget
|
||||
# and context to reach the article body the ground truth covers.
|
||||
n_predict=4096,
|
||||
n_ctx=16384,
|
||||
strip_grounding=True,
|
||||
),
|
||||
}
|
||||
|
||||
CASES = [
|
||||
@@ -82,9 +103,26 @@ CASES = [
|
||||
# is one pixel off and lands at ~0.69 instead.
|
||||
hf_cer=0.7761, hf_chrf=28.70, cer_tol=0.12, chrf_tol=8.0,
|
||||
),
|
||||
TestCase(
|
||||
model_key="unlimited", label="single-view scan",
|
||||
image="tools/mtmd/test-1.jpeg",
|
||||
ground_truth="tools/mtmd/tests/test-1-ground-truth.txt",
|
||||
# HF reference: Unlimited-OCR scoring (gundam, bf16) on this image/ground-truth.
|
||||
# Decoder runs full MHA, not R-SWA; the band absorbs that gap + bf16 variance.
|
||||
hf_cer=0.1869, hf_chrf=75.23, cer_tol=0.06, chrf_tol=6.0,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
GROUNDING_TAG_RE = re.compile(r"<\|(ref|det)\|>.*?<\|/\1\|>", re.DOTALL)
|
||||
|
||||
|
||||
def strip_grounding(text: str) -> str:
|
||||
"""Drop <|ref|>..<|/ref|> / <|det|>..<|/det|> grounding markup, matching the
|
||||
cleaned result.md the HF reference scores against."""
|
||||
return GROUNDING_TAG_RE.sub("", text)
|
||||
|
||||
|
||||
def arg_dest(flag: str) -> str:
|
||||
return flag.lstrip("-").replace("-", "_")
|
||||
|
||||
@@ -129,19 +167,19 @@ def compute_chrf(expected: str, ocr_out: str) -> float:
|
||||
return CHRF().sentence_score(ocr_out, [expected]).score
|
||||
|
||||
|
||||
def run_mtmd_cli(model_path, mmproj_path, image_path, bin_path) -> str:
|
||||
def run_mtmd_cli(spec: "ModelSpec", model_path, mmproj_path, image_path, bin_path) -> str:
|
||||
"""Run mtmd-cli on the image and return its output."""
|
||||
cmd = [
|
||||
str(bin_path),
|
||||
"-m", str(model_path),
|
||||
"--mmproj", str(mmproj_path),
|
||||
"--image", str(image_path),
|
||||
"-p", "Free OCR. ",
|
||||
"-p", spec.prompt,
|
||||
"--chat-template", "deepseek-ocr",
|
||||
"--temp", "0",
|
||||
"--flash-attn", "off", # match the HF "eager" attention reference
|
||||
"--no-warmup",
|
||||
"-n", "512", # cap loops on hard images (KV would otherwise fill)
|
||||
"-n", str(spec.n_predict), # cap loops on hard images (KV would otherwise fill)
|
||||
# HF decodes with no_repeat_ngram_size; llama.cpp's analog is DRY.
|
||||
# Default DRY breakers include "\n", so they are cleared below.
|
||||
"--dry-multiplier", "0.8",
|
||||
@@ -150,6 +188,8 @@ def run_mtmd_cli(model_path, mmproj_path, image_path, bin_path) -> str:
|
||||
"--dry-penalty-last-n", "-1",
|
||||
"--dry-sequence-breaker", "none",
|
||||
]
|
||||
if spec.n_ctx is not None:
|
||||
cmd += ["-c", str(spec.n_ctx)]
|
||||
logger.debug(f" command: {' '.join(cmd)}")
|
||||
|
||||
try:
|
||||
@@ -164,6 +204,8 @@ def run_mtmd_cli(model_path, mmproj_path, image_path, bin_path) -> str:
|
||||
raise RuntimeError(f"llama-mtmd-cli failed with code {result.returncode}")
|
||||
|
||||
output = result.stdout.decode("utf-8", errors="replace").strip()
|
||||
if spec.strip_grounding:
|
||||
output = strip_grounding(output)
|
||||
if not output:
|
||||
raise RuntimeError("llama-mtmd-cli produced no output on stdout")
|
||||
logger.info(f" output: {len(output)} chars")
|
||||
@@ -193,7 +235,7 @@ def evaluate(case: "TestCase", expected: str, ocr_out: str) -> bool:
|
||||
|
||||
logger.info("")
|
||||
logger.info("=" * 60)
|
||||
logger.info("Free OCR evaluation:")
|
||||
logger.info("OCR evaluation:")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f" CER {cer:>7.4f} (HF {case.hf_cer:.4f}, <= {case.cer_max:>7.4f} -> {verdict(cer_pass)})")
|
||||
logger.info(f" chrF (0-100) {chrf:>7.2f} (HF {case.hf_chrf:.2f}, >= {case.chrf_min:>7.2f} -> {verdict(chrf_pass)})")
|
||||
@@ -269,9 +311,9 @@ def main() -> int:
|
||||
expected = read_expected_text(ground_truth)
|
||||
logger.info(f" Image: {case.image}")
|
||||
logger.info(f" Expected text: {len(expected)} chars")
|
||||
logger.info(" Running llama.cpp 'Free OCR'")
|
||||
logger.info(f" Running llama.cpp prompt {model_spec.prompt!r}")
|
||||
try:
|
||||
ocr_out = run_mtmd_cli(model, mmproj, image, binary)
|
||||
ocr_out = run_mtmd_cli(model_spec, model, mmproj, image, binary)
|
||||
except RuntimeError as e:
|
||||
logger.error(f" Error: {e}")
|
||||
results[title] = False
|
||||
|
||||
@@ -40,6 +40,7 @@ struct debug_options {
|
||||
bool enable_reasoning = true;
|
||||
bool debug_jinja = false;
|
||||
bool force_tool_call = false;
|
||||
bool parallel_tool_calls = true;
|
||||
output_mode mode = output_mode::BOTH;
|
||||
input_message_type input_message = input_message_type::NONE;
|
||||
};
|
||||
@@ -87,6 +88,7 @@ static void print_usage(const char * program_name) {
|
||||
LOG_ERR("\nOptions:\n");
|
||||
LOG_ERR(" --no-tools Disable tool definitions\n");
|
||||
LOG_ERR(" --force-tool-call Set tool calls to forced\n");
|
||||
LOG_ERR(" --parallel-tool-calls=0|1 Set parallel_tool_calls (default: 1)\n");
|
||||
LOG_ERR(" --generation-prompt=0|1 Set add_generation_prompt (default: 1)\n");
|
||||
LOG_ERR(" --enable-reasoning=0|1 Enable reasoning parsing (default: 1)\n");
|
||||
LOG_ERR(" --output=MODE Output mode: analysis, template, both (default: both)\n");
|
||||
@@ -121,6 +123,8 @@ static bool parse_options(int argc, char ** argv, debug_options & opts) {
|
||||
opts.debug_jinja = true;
|
||||
} else if (arg == "--no-tools") {
|
||||
opts.with_tools = false;
|
||||
} else if (arg.rfind("--parallel-tool-calls=", 0) == 0) {
|
||||
opts.parallel_tool_calls = parse_bool_option(arg.substr(22));
|
||||
} else if (arg.rfind("--generation-prompt=", 0) == 0) {
|
||||
opts.generation_prompt = parse_bool_option(arg.substr(20));
|
||||
} else if (arg.rfind("--enable-reasoning=", 0) == 0) {
|
||||
@@ -349,7 +353,7 @@ static autoparser::generation_params prepare_params(const debug_options & opts,
|
||||
params.tools = json();
|
||||
params.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
}
|
||||
params.parallel_tool_calls = false;
|
||||
params.parallel_tool_calls = opts.parallel_tool_calls;
|
||||
return params;
|
||||
}
|
||||
|
||||
|
||||
+1
-2
@@ -28,10 +28,9 @@ vite.config.ts.timestamp-*
|
||||
# PWA Artifacts
|
||||
apple-splash-*.png
|
||||
apple-touch-icon-*.png
|
||||
favicon.ico
|
||||
favicon-dark.ico
|
||||
maskable-icon-*.png
|
||||
pwa-*.png
|
||||
static/favicon*
|
||||
|
||||
# Storybook
|
||||
*storybook.log
|
||||
|
||||
Generated
+7
-7
@@ -35,7 +35,7 @@
|
||||
"bits-ui": "2.18.1",
|
||||
"clsx": "2.1.1",
|
||||
"dexie": "4.4.3",
|
||||
"dompurify": "3.4.5",
|
||||
"dompurify": "3.4.11",
|
||||
"eslint": "9.39.4",
|
||||
"eslint-config-prettier": "10.1.8",
|
||||
"eslint-plugin-storybook": "10.4.2",
|
||||
@@ -8653,9 +8653,9 @@
|
||||
"peer": true
|
||||
},
|
||||
"node_modules/dompurify": {
|
||||
"version": "3.4.5",
|
||||
"resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.4.5.tgz",
|
||||
"integrity": "sha512-OrwIBKsdNSVEeubdJ1HBv/wNENRM9ytAVCv7YXt//A3vPdVMNuACRqK9mXCGCBW2ln7BT/A4X0jXHo2Gu89miA==",
|
||||
"version": "3.4.11",
|
||||
"resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.4.11.tgz",
|
||||
"integrity": "sha512-zhlUV12GsaRzMsf9q5M254YhA4+VuF0fG+QFqu6aYpoGlKtz+w8//jBcGVYBgQkR5GHjUomejY84AV+/uPbWdw==",
|
||||
"dev": true,
|
||||
"license": "(MPL-2.0 OR Apache-2.0)",
|
||||
"optionalDependencies": {
|
||||
@@ -10226,9 +10226,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/hono": {
|
||||
"version": "4.12.23",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.23.tgz",
|
||||
"integrity": "sha512-eIaZ9qDgu7XV0pxOCrg7/WhnQ6Ivm22UcxhXx/A3dcbqbbYgBEkc6e/J/s7j2tS96zoB0S9VBdLwQNCWwUo4LA==",
|
||||
"version": "4.12.26",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.26.tgz",
|
||||
"integrity": "sha512-uyZtpnYxM9CmQ7QsQknM4zN8EftNqhON1qYeIKM0Se67CCEe2c44xyGURwB0axX2fBDu1dqHrHAc1hmNT8ITkw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
|
||||
@@ -54,7 +54,7 @@
|
||||
"bits-ui": "2.18.1",
|
||||
"clsx": "2.1.1",
|
||||
"dexie": "4.4.3",
|
||||
"dompurify": "3.4.5",
|
||||
"dompurify": "3.4.11",
|
||||
"eslint": "9.39.4",
|
||||
"eslint-config-prettier": "10.1.8",
|
||||
"eslint-plugin-storybook": "10.4.2",
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
import { defineConfig } from '@vite-pwa/assets-generator/config';
|
||||
import { FAVICON_COLORS, PWA_ASSET_GENERATOR } from './src/lib/constants/pwa';
|
||||
import { writeThemeFavicons } from './scripts/favicon-colorize';
|
||||
|
||||
writeThemeFavicons(FAVICON_COLORS.LIGHT, FAVICON_COLORS.DARK, {
|
||||
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
|
||||
});
|
||||
|
||||
export default defineConfig({
|
||||
headLinkOptions: {
|
||||
@@ -7,7 +13,8 @@ export default defineConfig({
|
||||
preset: {
|
||||
transparent: {
|
||||
sizes: [],
|
||||
favicons: [[48, 'favicon-dark.ico']]
|
||||
favicons: [[48, 'favicon-dark.ico']],
|
||||
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
|
||||
},
|
||||
maskable: {
|
||||
sizes: []
|
||||
|
||||
@@ -5,15 +5,32 @@ import {
|
||||
} from '@vite-pwa/assets-generator/config';
|
||||
import { readFileSync } from 'node:fs';
|
||||
import { resolve } from 'node:path';
|
||||
import { THEME_COLORS, PWA_GENERATOR_DEVICES, PWA_ASSET_GENERATOR } from './src/lib/constants/pwa';
|
||||
import {
|
||||
THEME_COLORS,
|
||||
PWA_GENERATOR_DEVICES,
|
||||
PWA_ASSET_GENERATOR,
|
||||
FAVICON_COLORS
|
||||
} from './src/lib/constants/pwa';
|
||||
import { SplashOrientation } from './src/lib/enums/splash.enums';
|
||||
import { writeThemeFavicons } from './scripts/favicon-colorize';
|
||||
|
||||
writeThemeFavicons(FAVICON_COLORS.LIGHT, FAVICON_COLORS.DARK, {
|
||||
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
|
||||
});
|
||||
|
||||
export default defineConfig({
|
||||
headLinkOptions: {
|
||||
preset: PWA_ASSET_GENERATOR.LINK_PRESET
|
||||
},
|
||||
preset: combinePresetAndAppleSplashScreens(
|
||||
minimal2023Preset,
|
||||
{
|
||||
...minimal2023Preset,
|
||||
// tiny margin so favicon.ico / pwa-*.png breathe inside the canvas
|
||||
transparent: {
|
||||
...minimal2023Preset.transparent,
|
||||
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
|
||||
}
|
||||
},
|
||||
{
|
||||
padding: PWA_ASSET_GENERATOR.SPLASH_PADDING,
|
||||
resizeOptions: {
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
import { mkdirSync, readFileSync, writeFileSync } from 'node:fs';
|
||||
import { dirname, resolve } from 'node:path';
|
||||
import { fileURLToPath } from 'node:url';
|
||||
|
||||
const HERE = dirname(fileURLToPath(import.meta.url));
|
||||
const PROJECT_ROOT = resolve(HERE, '..');
|
||||
|
||||
const DEFAULT_LOGO = resolve(PROJECT_ROOT, 'src/lib/assets/logo.svg');
|
||||
const DEFAULT_OUT_DIR = resolve(PROJECT_ROOT, 'static');
|
||||
const DEFAULT_OUT_LIGHT = resolve(DEFAULT_OUT_DIR, 'favicon.svg');
|
||||
const DEFAULT_OUT_DARK = resolve(DEFAULT_OUT_DIR, 'favicon-dark.svg');
|
||||
|
||||
const CURRENT_COLOR = 'currentColor';
|
||||
|
||||
export interface ColorizedFavicon {
|
||||
light: string;
|
||||
dark: string;
|
||||
}
|
||||
|
||||
export interface WriteThemeFaviconsOptions {
|
||||
sourcePath?: string;
|
||||
lightOutPath?: string;
|
||||
darkOutPath?: string;
|
||||
/**
|
||||
* Fraction of the icon (0..1) to leave as an even margin on each side.
|
||||
* Applied by wrapping the inner content in a `<g transform="...">` so the
|
||||
* source `src/lib/assets/logo.svg` is not modified. Pass 0 to disable.
|
||||
*/
|
||||
padding?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Replace every `currentColor` occurrence in the SVG with the given color.
|
||||
* Pure: no filesystem access, so it is straightforward to unit-test.
|
||||
*/
|
||||
export function colorizeFaviconSvg(
|
||||
svg: string,
|
||||
lightColor: string,
|
||||
darkColor: string
|
||||
): ColorizedFavicon {
|
||||
return {
|
||||
light: svg.replaceAll(CURRENT_COLOR, lightColor),
|
||||
dark: svg.replaceAll(CURRENT_COLOR, darkColor)
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Shrink the inner SVG content uniformly and re-center it so `padding` (a
|
||||
* 0..1 fraction) is reserved as equal margin on each side. Returns the input
|
||||
* unchanged for non-positive padding, missing/invalid `viewBox`, or unexpected
|
||||
* markup so the caller always gets a renderable SVG.
|
||||
*/
|
||||
export function padFaviconSvg(svg: string, padding: number): string {
|
||||
if (!(padding > 0) || padding >= 1) return svg;
|
||||
|
||||
const viewBoxMatch = svg.match(/viewBox\s*=\s*["']([^"']+)["']/i);
|
||||
if (!viewBoxMatch) return svg;
|
||||
|
||||
const parts = viewBoxMatch[1]
|
||||
.trim()
|
||||
.split(/[\s,]+/)
|
||||
.map(Number);
|
||||
if (parts.length !== 4 || parts.some((n) => !Number.isFinite(n))) return svg;
|
||||
|
||||
const [, , width, height] = parts;
|
||||
if (width <= 0 || height <= 0) return svg;
|
||||
|
||||
const scale = 1 - padding;
|
||||
const translateX = (padding * width) / 2;
|
||||
const translateY = (padding * height) / 2;
|
||||
|
||||
const openTagStart = svg.search(/<svg\b/i);
|
||||
if (openTagStart === -1) return svg;
|
||||
const openTagEnd = svg.indexOf('>', openTagStart);
|
||||
if (openTagEnd === -1) return svg;
|
||||
const closeStart = svg.lastIndexOf('</svg');
|
||||
if (closeStart === -1 || closeStart <= openTagEnd) return svg;
|
||||
|
||||
const openTag = svg.slice(0, openTagEnd + 1);
|
||||
const inner = svg.slice(openTagEnd + 1, closeStart);
|
||||
const closeTag = svg.slice(closeStart);
|
||||
|
||||
const group = `<g transform="translate(${translateX} ${translateY}) scale(${scale})">`;
|
||||
return `${openTag}${group}${inner}</g>${closeTag}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Read `src/lib/assets/logo.svg`, colorize it for both themes, and write
|
||||
* the results to the static directory so the PWA asset generator can consume
|
||||
* them. Paths can be overridden for tests.
|
||||
*/
|
||||
export function writeThemeFavicons(
|
||||
lightColor: string,
|
||||
darkColor: string,
|
||||
{
|
||||
sourcePath = DEFAULT_LOGO,
|
||||
lightOutPath = DEFAULT_OUT_LIGHT,
|
||||
darkOutPath = DEFAULT_OUT_DARK,
|
||||
padding = 0
|
||||
}: WriteThemeFaviconsOptions = {}
|
||||
): void {
|
||||
const source = readFileSync(sourcePath, 'utf-8');
|
||||
const { light, dark } = colorizeFaviconSvg(source, lightColor, darkColor);
|
||||
mkdirSync(dirname(lightOutPath), { recursive: true });
|
||||
writeFileSync(lightOutPath, padFaviconSvg(light, padding));
|
||||
writeFileSync(darkOutPath, padFaviconSvg(dark, padding));
|
||||
}
|
||||
@@ -48,6 +48,7 @@
|
||||
|
||||
--chat-form-area-height: 8rem;
|
||||
--chat-form-area-offset: 2rem;
|
||||
--chat-form-padding-top: 6rem;
|
||||
--max-message-height: max(24rem, min(80dvh, calc(100dvh - var(--chat-form-area-height) - 12rem)));
|
||||
}
|
||||
|
||||
@@ -55,6 +56,7 @@
|
||||
:root {
|
||||
--chat-form-area-height: 24rem;
|
||||
--chat-form-area-offset: 12rem;
|
||||
--chat-form-padding-top: 6rem;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,7 +143,6 @@
|
||||
@apply bg-background text-foreground;
|
||||
scrollbar-width: thin;
|
||||
scrollbar-gutter: stable;
|
||||
overflow: hidden; /* Added due to Mermaid rendering somehow causing the double scrollbar */
|
||||
}
|
||||
|
||||
/* Global scrollbar styling - visible only on hover */
|
||||
@@ -193,3 +194,7 @@
|
||||
scrollbar-width: none;
|
||||
}
|
||||
}
|
||||
|
||||
.mermaidTooltip {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
@@ -10,9 +10,9 @@ import { isElementInViewport } from '$lib/utils/viewport';
|
||||
*/
|
||||
export function fadeInView(
|
||||
node: HTMLElement,
|
||||
options: { duration?: number; y?: number; skipIfVisible?: boolean } = {}
|
||||
options: { duration?: number; y?: number; delay?: number; skipIfVisible?: boolean } = {}
|
||||
) {
|
||||
const { duration = 300, y = 0, skipIfVisible = false } = options;
|
||||
const { duration = 300, y = 0, delay = 0, skipIfVisible = false } = options;
|
||||
|
||||
if (skipIfVisible && isElementInViewport(node)) {
|
||||
return;
|
||||
@@ -27,10 +27,12 @@ export function fadeInView(
|
||||
(entries) => {
|
||||
for (const entry of entries) {
|
||||
if (entry.isIntersecting) {
|
||||
requestAnimationFrame(() => {
|
||||
node.style.opacity = '1';
|
||||
node.style.transform = 'translateY(0)';
|
||||
});
|
||||
setTimeout(() => {
|
||||
requestAnimationFrame(() => {
|
||||
node.style.opacity = '1';
|
||||
node.style.transform = 'translateY(0)';
|
||||
});
|
||||
}, delay);
|
||||
observer.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
<svg width="512" height="512" viewBox="0 0 512 512" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M244.95 8C215.233 8 187.774 23.8591 172.923 49.5999L95.6009 183.625C60.2162 244.959 104.481 321.6 175.29 321.6H208L316.977 132.708C348.959 77.2719 308.95 8 244.95 8ZM208 321.6H351.947C415.982 321.6 456.013 390.91 424.013 446.377C409.155 472.132 381.681 488 351.947 488H271.29C200.481 488 156.216 411.359 191.601 350.026L208 321.6Z" fill="currentColor"/>
|
||||
<path d="M208 321.6H16L106.462 164.8L208 321.6Z" fill="currentColor"/>
|
||||
<path d="M388.923 8L208 321.6L253.6 8H388.923Z" fill="currentColor"/>
|
||||
<path d="M304 488H112L202.462 331.2L304 488Z" fill="currentColor"/>
|
||||
<path d="M496 321.6H208L419.399 454.4L496 321.6Z" fill="currentColor"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 771 B |
@@ -8,12 +8,13 @@
|
||||
ariaLabel?: string;
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
href?: string;
|
||||
icon: Component;
|
||||
iconSize?: string;
|
||||
onclick: (e?: MouseEvent) => void;
|
||||
onclick?: (e?: MouseEvent) => void;
|
||||
size?: ButtonSize;
|
||||
stopPropagationOnClick?: boolean;
|
||||
tooltip: string;
|
||||
tooltip?: string;
|
||||
variant?: ButtonVariant;
|
||||
tooltipSide?: TooltipSide;
|
||||
}
|
||||
@@ -22,6 +23,7 @@
|
||||
icon,
|
||||
tooltip,
|
||||
variant = 'ghost',
|
||||
href = '',
|
||||
size = 'sm',
|
||||
class: className = '',
|
||||
disabled = false,
|
||||
@@ -31,34 +33,49 @@
|
||||
onclick,
|
||||
ariaLabel
|
||||
}: Props = $props();
|
||||
|
||||
let innerWidth = $state(0);
|
||||
const showTooltip = $derived(!!tooltip && innerWidth > 768);
|
||||
</script>
|
||||
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<!-- prevent another nested button element -->
|
||||
{#snippet child({ props })}
|
||||
<Button
|
||||
{...props}
|
||||
{variant}
|
||||
{size}
|
||||
{disabled}
|
||||
onclick={(e: MouseEvent) => {
|
||||
if (stopPropagationOnClick) e.stopPropagation();
|
||||
{#snippet button(props = {})}
|
||||
<Button
|
||||
{...props}
|
||||
{href}
|
||||
{variant}
|
||||
{size}
|
||||
{disabled}
|
||||
onclick={(e: MouseEvent) => {
|
||||
if (stopPropagationOnClick) e.stopPropagation();
|
||||
|
||||
onclick?.(e);
|
||||
}}
|
||||
class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!"
|
||||
aria-label={ariaLabel || tooltip}
|
||||
>
|
||||
{#if icon}
|
||||
{@const IconComponent = icon}
|
||||
<IconComponent class={iconSize} />
|
||||
{/if}
|
||||
</Button>
|
||||
{/snippet}
|
||||
</Tooltip.Trigger>
|
||||
onclick?.(e);
|
||||
}}
|
||||
class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!"
|
||||
aria-label={ariaLabel || tooltip}
|
||||
>
|
||||
{#if icon}
|
||||
{@const IconComponent = icon}
|
||||
|
||||
<Tooltip.Content side={tooltipSide}>
|
||||
<p>{tooltip}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
<IconComponent class={iconSize} />
|
||||
{/if}
|
||||
</Button>
|
||||
{/snippet}
|
||||
|
||||
{#if showTooltip}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<!-- prevent another nested button element -->
|
||||
{#snippet child({ props })}
|
||||
{@render button(props)}
|
||||
{/snippet}
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content side={tooltipSide}>
|
||||
<p>{tooltip}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{:else}
|
||||
{@render button({ href })}
|
||||
{/if}
|
||||
|
||||
<svelte:window bind:innerWidth />
|
||||
|
||||
@@ -494,7 +494,7 @@
|
||||
/>
|
||||
|
||||
<div
|
||||
class="{INPUT_CLASSES} overflow-hidden rounded-3xl backdrop-blur-md {disabled
|
||||
class="{INPUT_CLASSES} overflow-hidden rounded-4xl md:rounded-3xl backdrop-blur-md {disabled
|
||||
? 'cursor-not-allowed opacity-60'
|
||||
: ''}"
|
||||
data-slot="input-area"
|
||||
@@ -510,7 +510,7 @@
|
||||
/>
|
||||
|
||||
<div
|
||||
class="flex-column relative min-h-[48px] items-center rounded-3xl py-2 pb-2.25 shadow-sm transition-all focus-within:shadow-md md:!py-3"
|
||||
class="flex-column relative min-h-12 items-center rounded-4xl md:rounded-3xl py-2 pb-2.25 shadow-sm transition-all focus-within:shadow-md md:py-3!"
|
||||
onpaste={handlePaste}
|
||||
>
|
||||
<ChatFormTextarea
|
||||
|
||||
+1
-1
@@ -15,7 +15,7 @@
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<Button
|
||||
class="file-upload-button h-8 w-8 rounded-full p-0"
|
||||
class="file-upload-button md:h-8 md:w-8 h-9 w-9 rounded-full p-0"
|
||||
{disabled}
|
||||
{onclick}
|
||||
variant="secondary"
|
||||
|
||||
+16
-3
@@ -15,6 +15,7 @@
|
||||
import { McpLogo } from '$lib/components/app';
|
||||
import { PencilRuler, ChevronDown, ChevronRight } from '@lucide/svelte';
|
||||
import { HealthCheckStatus } from '$lib/enums';
|
||||
import { AttachmentAction } from '$lib/enums/attachment.enums';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -270,14 +271,22 @@
|
||||
</Collapsible.Root>
|
||||
{/if}
|
||||
|
||||
<button type="button" class={sheetItemClass} onclick={onSystemPromptClick}>
|
||||
<button
|
||||
type="button"
|
||||
class={sheetItemClass}
|
||||
onclick={() => attachmentMenu.callbacks[AttachmentAction.SYSTEM_PROMPT_CLICK]()}
|
||||
>
|
||||
<MessageSquare class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>System Message</span>
|
||||
</button>
|
||||
|
||||
{#if hasMcpPromptsSupport}
|
||||
<button type="button" class={sheetItemClass} onclick={onMcpPromptClick}>
|
||||
<button
|
||||
type="button"
|
||||
class={sheetItemClass}
|
||||
onclick={() => attachmentMenu.callbacks[AttachmentAction.MCP_PROMPT_CLICK]()}
|
||||
>
|
||||
<Zap class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>MCP Prompt</span>
|
||||
@@ -285,7 +294,11 @@
|
||||
{/if}
|
||||
|
||||
{#if hasMcpResourcesSupport}
|
||||
<button type="button" class={sheetItemClass} onclick={onMcpResourcesClick}>
|
||||
<button
|
||||
type="button"
|
||||
class={sheetItemClass}
|
||||
onclick={() => attachmentMenu.callbacks[AttachmentAction.MCP_RESOURCES_CLICK]()}
|
||||
>
|
||||
<FolderOpen class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>MCP Resources</span>
|
||||
|
||||
+1
@@ -42,6 +42,7 @@
|
||||
{hasMcpPromptsSupport}
|
||||
{hasMcpResourcesSupport}
|
||||
{onFileUpload}
|
||||
{onSystemPromptClick}
|
||||
{onMcpPromptClick}
|
||||
{onMcpResourcesClick}
|
||||
>
|
||||
|
||||
+1
-1
@@ -20,7 +20,7 @@
|
||||
type="submit"
|
||||
disabled={isDisabled}
|
||||
class={[
|
||||
'h-8 w-8 rounded-full p-0',
|
||||
'md:h-8 md:w-8 h-9 w-9 rounded-full p-0',
|
||||
showErrorState &&
|
||||
'bg-red-400/10 text-red-400 hover:bg-red-400/20 hover:text-red-400 disabled:opacity-100'
|
||||
]}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
<script lang="ts">
|
||||
import { isMobile } from '$lib/stores/viewport.svelte';
|
||||
import { autoResizeTextarea } from '$lib/utils';
|
||||
import { onMount } from 'svelte';
|
||||
|
||||
@@ -37,7 +38,9 @@
|
||||
}
|
||||
|
||||
export function focus() {
|
||||
textareaElement?.focus();
|
||||
if (isMobile.current) return;
|
||||
|
||||
textareaElement?.focus({ preventScroll: true });
|
||||
}
|
||||
|
||||
export function resetHeight() {
|
||||
|
||||
@@ -231,7 +231,7 @@
|
||||
editedContent = message.content;
|
||||
}
|
||||
|
||||
textareaElement?.focus();
|
||||
textareaElement?.focus({ preventScroll: true });
|
||||
editedExtras = message.extra ? [...message.extra] : [];
|
||||
editedUploadedFiles = [];
|
||||
|
||||
@@ -324,7 +324,7 @@
|
||||
}
|
||||
</script>
|
||||
|
||||
<div use:fadeInView>
|
||||
<div use:fadeInView class="chat-message">
|
||||
{#if message.role === MessageRole.SYSTEM}
|
||||
<ChatMessageSystem
|
||||
bind:textareaElement
|
||||
|
||||
+72
-5
@@ -180,6 +180,9 @@
|
||||
|
||||
let displayedModel = $derived(message.model ?? null);
|
||||
|
||||
// model being switched to while it loads, so the selector bar tracks it
|
||||
let pendingModel = $state<string | null>(null);
|
||||
|
||||
let isCurrentlyLoading = $derived(isLoading());
|
||||
let isStreaming = $derived(isChatStreaming());
|
||||
let hasNoContent = $derived(!message?.content?.trim());
|
||||
@@ -207,6 +210,42 @@
|
||||
isLastAssistantMessage
|
||||
);
|
||||
|
||||
let assistantEl: HTMLDivElement | undefined = $state();
|
||||
let lastUserMessageHeight = $state(0);
|
||||
let assistantMarginTop = $state(0);
|
||||
|
||||
$effect(() => {
|
||||
if (!assistantEl) return;
|
||||
|
||||
assistantMarginTop = Math.round(parseFloat(getComputedStyle(assistantEl).marginTop));
|
||||
|
||||
const chatMessageEl = assistantEl.closest('.chat-message');
|
||||
const previousChatMessage = chatMessageEl?.previousElementSibling;
|
||||
const userMessageEl = previousChatMessage?.querySelector(
|
||||
'.chat-message-user'
|
||||
) as HTMLElement | null;
|
||||
|
||||
if (!userMessageEl) {
|
||||
lastUserMessageHeight = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
const updateHeight = () => {
|
||||
const rect = userMessageEl.getBoundingClientRect();
|
||||
const marginTop = Math.round(parseFloat(getComputedStyle(userMessageEl).marginTop));
|
||||
lastUserMessageHeight = Math.round(rect.height + marginTop);
|
||||
};
|
||||
|
||||
updateHeight();
|
||||
|
||||
const resizeObserver = new ResizeObserver(updateHeight);
|
||||
resizeObserver.observe(userMessageEl);
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
};
|
||||
});
|
||||
|
||||
function handleCopyModel() {
|
||||
void copyToClipboard(displayedModel ?? '');
|
||||
}
|
||||
@@ -219,12 +258,17 @@
|
||||
</script>
|
||||
|
||||
<div
|
||||
class="text-md group w-full leading-7.5 {className}"
|
||||
bind:this={assistantEl}
|
||||
class="chat-message-assistant text-md group w-full leading-7.5 {className}"
|
||||
style:--last-user-message-height={lastUserMessageHeight > 0
|
||||
? `${lastUserMessageHeight}px`
|
||||
: undefined}
|
||||
style:--assistant-margin-top={assistantMarginTop > 0 ? `${assistantMarginTop}px` : undefined}
|
||||
role="group"
|
||||
aria-label="Assistant message with actions"
|
||||
>
|
||||
{#if showProcessingInfoTop}
|
||||
<div class="mt-6 w-full max-w-[48rem]" in:fade>
|
||||
<div class="mt-6 w-full max-w-3xl" in:fade>
|
||||
<div class="processing-container">
|
||||
<span class="processing-text">
|
||||
{modelLoadingText ??
|
||||
@@ -257,7 +301,7 @@
|
||||
{/if}
|
||||
|
||||
{#if showProcessingInfoBottom}
|
||||
<div class="mt-4 w-full max-w-[48rem]" in:fade>
|
||||
<div class="mt-4 w-full max-w-3xl" in:fade>
|
||||
<div class="processing-container">
|
||||
<span class="processing-text">
|
||||
{modelLoadingText ??
|
||||
@@ -277,13 +321,19 @@
|
||||
>
|
||||
{#if isRouter}
|
||||
<ModelsSelectorDropdown
|
||||
currentModel={displayedModel}
|
||||
currentModel={pendingModel ?? displayedModel}
|
||||
disabled={isLoading()}
|
||||
onModelChange={async (modelId: string, modelName: string) => {
|
||||
const status = modelsStore.getModelStatus(modelId);
|
||||
|
||||
if (status !== ServerModelStatus.LOADED) {
|
||||
await modelsStore.loadModel(modelId);
|
||||
pendingModel = modelId;
|
||||
|
||||
try {
|
||||
await modelsStore.loadModel(modelId);
|
||||
} finally {
|
||||
pendingModel = null;
|
||||
}
|
||||
}
|
||||
|
||||
onRegenerate(modelName);
|
||||
@@ -351,6 +401,23 @@
|
||||
</div>
|
||||
|
||||
<style>
|
||||
:global(.chat-message):last-child .chat-message-assistant {
|
||||
--assistant-min-height-offset: calc(
|
||||
var(--last-user-message-height, 19rem) + var(--chat-form-height, 6rem) +
|
||||
var(--chat-form-bottom-position, 0.5rem) + var(--chat-form-padding-top, 6rem) +
|
||||
var(--assistant-margin-top, 3rem)
|
||||
);
|
||||
min-height: calc(100dvh - var(--assistant-min-height-offset));
|
||||
|
||||
@media (width > 768px) {
|
||||
--assistant-min-height-offset: calc(
|
||||
var(--last-user-message-height, 18rem) + var(--chat-form-height, 6rem) +
|
||||
var(--chat-form-bottom-position, 1rem) + var(--chat-form-padding-top, 6rem) +
|
||||
var(--assistant-margin-top, 3rem)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
.processing-container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
|
||||
+1
-1
@@ -48,7 +48,7 @@
|
||||
|
||||
<div
|
||||
aria-label="User message with actions"
|
||||
class="group flex flex-col items-end gap-3 md:gap-2 {className}"
|
||||
class="chat-message-user group flex flex-col items-end gap-3 md:gap-2 {className}"
|
||||
role="group"
|
||||
>
|
||||
{#if editCtx.isEditing}
|
||||
|
||||
+2
-2
@@ -19,7 +19,7 @@
|
||||
renderMarkdown = false,
|
||||
textColorClass = 'text-foreground',
|
||||
cardBgClass = 'dark:bg-primary/15',
|
||||
maxHeightStyle = 'max-height: var(--max-message-height);'
|
||||
maxHeightStyle = ''
|
||||
}: Props = $props();
|
||||
|
||||
let isMultiline = $state(false);
|
||||
@@ -59,7 +59,7 @@
|
||||
|
||||
{#if content.trim()}
|
||||
<Card
|
||||
class="max-w-[80%] overflow-y-auto rounded-[1.125rem] border-none bg-primary/5 px-3.75 py-1.5 {textColorClass} backdrop-blur-md data-[multiline]:py-2.5 {cardBgClass}"
|
||||
class="chat-message-user-bubble max-w-[80%] overflow-y-auto rounded-[1.125rem] border-none bg-primary/5 px-3.75 py-1.5 {textColorClass} backdrop-blur-md data-multiline:py-2.5 {cardBgClass}"
|
||||
data-multiline={isMultiline ? '' : undefined}
|
||||
style="{maxHeightStyle} overflow-wrap: anywhere; word-break: break-word;"
|
||||
>
|
||||
|
||||
@@ -37,6 +37,7 @@
|
||||
let allConversationMessages = $state<DatabaseMessage[]>([]);
|
||||
let isVisible = $state(false);
|
||||
let previousConversationId = $state<string | null>(null);
|
||||
let previousRouteId = $state<string | null>(null);
|
||||
|
||||
const currentConfig = config();
|
||||
|
||||
@@ -157,8 +158,9 @@
|
||||
});
|
||||
});
|
||||
|
||||
beforeNavigate(() => {
|
||||
beforeNavigate((navigation) => {
|
||||
isVisible = false;
|
||||
previousRouteId = navigation.from?.route.id ?? null;
|
||||
});
|
||||
|
||||
afterNavigate(() => {
|
||||
@@ -249,12 +251,13 @@
|
||||
</script>
|
||||
|
||||
<div
|
||||
class="transition-opacity delay-300 duration-500 ease-out
|
||||
{isVisible ? 'opacity-100' : 'opacity-0'}"
|
||||
class="transition-opacity duration-500 ease-out
|
||||
{isVisible ? 'opacity-100' : 'opacity-0'}
|
||||
{previousRouteId === '/(chat)/chat/[id]' ? '' : 'delay-300'}"
|
||||
>
|
||||
{#each displayMessages as { message, toolMessages, isLastAssistantMessage, siblingInfo } (message.id)}
|
||||
<ChatMessage
|
||||
class="mx-auto mt-12 w-full max-w-[48rem]"
|
||||
class="mx-auto mt-12 w-full max-w-3xl"
|
||||
{message}
|
||||
{toolMessages}
|
||||
{isLastAssistantMessage}
|
||||
|
||||
@@ -1,31 +1,28 @@
|
||||
<script lang="ts">
|
||||
import { Trash2 } from '@lucide/svelte';
|
||||
import { afterNavigate } from '$app/navigation';
|
||||
import { page } from '$app/state';
|
||||
import {
|
||||
ChatScreenForm,
|
||||
ChatMessages,
|
||||
ChatScreenDragOverlay,
|
||||
ChatScreenProcessingInfo,
|
||||
ChatScreenActionScrollDown,
|
||||
DialogEmptyFileAlert,
|
||||
DialogFileUploadError,
|
||||
DialogChatError,
|
||||
ServerLoadingSplash,
|
||||
DialogConfirmation,
|
||||
ChatScreenServerError
|
||||
} from '$lib/components/app';
|
||||
import { setProcessingInfoContext } from '$lib/contexts';
|
||||
import { ErrorDialogType } from '$lib/enums';
|
||||
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
|
||||
import { useChatScreenActiveModel } from '$lib/hooks/use-chat-screen-active-model.svelte';
|
||||
import { useChatScreenDragAndDrop } from '$lib/hooks/use-chat-screen-drag-and-drop.svelte';
|
||||
import { useChatScreenFileUpload } from '$lib/hooks/use-chat-screen-file-upload.svelte';
|
||||
import { useChatScreenScroll } from '$lib/hooks/use-chat-screen-scroll.svelte';
|
||||
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
|
||||
import { device } from '$lib/stores/device.svelte';
|
||||
import { isMobile } from '$lib/stores/viewport.svelte';
|
||||
import {
|
||||
chatStore,
|
||||
errorDialog,
|
||||
isLoading,
|
||||
isChatStreaming,
|
||||
isEditing,
|
||||
getAddFilesHandler,
|
||||
activeProcessingState
|
||||
} from '$lib/stores/chat.svelte';
|
||||
import {
|
||||
@@ -34,138 +31,81 @@
|
||||
activeConversation
|
||||
} from '$lib/stores/conversations.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { serverLoading, serverError, isRouterMode } from '$lib/stores/server.svelte';
|
||||
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
|
||||
import { isFileTypeSupported, filterFilesByModalities } from '$lib/utils';
|
||||
import { parseFilesToMessageExtras, processFilesToChatUploaded } from '$lib/utils/browser-only';
|
||||
import { onMount } from 'svelte';
|
||||
import { serverLoading, serverError } from '$lib/stores/server.svelte';
|
||||
import { parseFilesToMessageExtras } from '$lib/utils/browser-only';
|
||||
import { onDestroy, onMount } from 'svelte';
|
||||
import ChatScreenGreeting from './ChatScreenGreeting.svelte';
|
||||
import ChatScreenActionScrollDown from './ChatScreenActionScrollDown.svelte';
|
||||
import ChatScreenDialogsAndAlerts from './ChatScreenDialogsAndAlerts.svelte';
|
||||
import { ROUTES } from '$lib/constants';
|
||||
|
||||
let { showCenteredEmpty = false } = $props();
|
||||
|
||||
const autoScroll = createAutoScrollController();
|
||||
|
||||
let disableAutoScroll = $derived(Boolean(config().disableAutoScroll));
|
||||
let chatScrollContainer: HTMLDivElement | undefined = $state();
|
||||
let dragCounter = $state(0);
|
||||
let isDragOver = $state(false);
|
||||
let showFileErrorDialog = $state(false);
|
||||
let uploadedFiles = $state<ChatUploadedFile[]>([]);
|
||||
|
||||
let fileErrorData = $state<{
|
||||
generallyUnsupported: File[];
|
||||
modalityUnsupported: File[];
|
||||
modalityReasons: Record<string, string>;
|
||||
supportedTypes: string[];
|
||||
}>({
|
||||
generallyUnsupported: [],
|
||||
modalityUnsupported: [],
|
||||
modalityReasons: {},
|
||||
supportedTypes: []
|
||||
});
|
||||
|
||||
let showDeleteDialog = $state(false);
|
||||
|
||||
let showEmptyFileDialog = $state(false);
|
||||
|
||||
let processingInfoVisible = $state(false);
|
||||
|
||||
let emptyFileNames = $state<string[]>([]);
|
||||
|
||||
let initialMessage = $state('');
|
||||
|
||||
let isEmpty = $derived(
|
||||
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
|
||||
);
|
||||
|
||||
let activeErrorDialog = $derived(errorDialog());
|
||||
let isServerLoading = $derived(serverLoading());
|
||||
let hasPropsError = $derived(!!serverError());
|
||||
|
||||
let isCurrentConversationLoading = $derived(isLoading() || isChatStreaming());
|
||||
|
||||
let showProcessingInfo = $derived(
|
||||
isCurrentConversationLoading ||
|
||||
(config().keepStatsVisible && !!page.params.id) ||
|
||||
activeProcessingState() !== null
|
||||
);
|
||||
|
||||
let isRouter = $derived(isRouterMode());
|
||||
|
||||
let conversationModel = $derived(
|
||||
chatStore.getConversationModel(activeMessages() as DatabaseMessage[])
|
||||
);
|
||||
|
||||
let activeModelId = $derived.by(() => {
|
||||
const options = modelOptions();
|
||||
|
||||
if (!isRouter) {
|
||||
return options.length > 0 ? options[0].model : null;
|
||||
}
|
||||
|
||||
const selectedId = selectedModelId();
|
||||
if (selectedId) {
|
||||
const model = options.find((m) => m.id === selectedId);
|
||||
if (model) return model.model;
|
||||
}
|
||||
|
||||
if (conversationModel) {
|
||||
const model = options.find((m) => m.model === conversationModel);
|
||||
if (model) return model.model;
|
||||
}
|
||||
|
||||
return null;
|
||||
});
|
||||
|
||||
let modelPropsVersion = $state(0);
|
||||
|
||||
setProcessingInfoContext({
|
||||
get showProcessingInfo() {
|
||||
return showProcessingInfo;
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (activeModelId) {
|
||||
const cached = modelsStore.getModelProps(activeModelId);
|
||||
let disableAutoScroll = $derived(Boolean(config().disableAutoScroll) || isMobile.current);
|
||||
let isMobileUserScrolledUp = $state(false);
|
||||
let mobileScrollDownHint = $state(false);
|
||||
let mobileScrollDownHintLockedUntil = $state(0);
|
||||
let emptyFileNames = $state<string[]>([]);
|
||||
let initialMessage = $state('');
|
||||
let showDeleteDialog = $state(false);
|
||||
let showEmptyFileDialog = $state(false);
|
||||
let isEmpty = $derived(
|
||||
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
|
||||
);
|
||||
let activeErrorDialog = $derived(errorDialog());
|
||||
let isServerLoading = $derived(serverLoading());
|
||||
let hasPropsError = $derived(!!serverError());
|
||||
let isCurrentConversationLoading = $derived(isLoading() || isChatStreaming());
|
||||
let showProcessingInfo = $derived(
|
||||
isCurrentConversationLoading ||
|
||||
(config().keepStatsVisible && !!page.params.id) ||
|
||||
activeProcessingState() !== null
|
||||
);
|
||||
let chatFormBottomPosition = $derived.by(() => {
|
||||
if (!isMobile.current) return '1rem';
|
||||
if (device.isStandalone) return '1.5rem';
|
||||
if (device.isIOSSafari) return '0.25rem';
|
||||
return '0.5rem';
|
||||
});
|
||||
|
||||
if (!cached) {
|
||||
modelsStore.fetchModelProps(activeModelId).then(() => {
|
||||
modelPropsVersion++;
|
||||
});
|
||||
const autoScroll = createAutoScrollController();
|
||||
const scroll = useChatScreenScroll(autoScroll);
|
||||
const activeModel = useChatScreenActiveModel();
|
||||
const fileUpload = useChatScreenFileUpload({
|
||||
capabilities: () => ({
|
||||
hasVision: activeModel.hasVisionModality,
|
||||
hasAudio: activeModel.hasAudioModality,
|
||||
hasVideo: activeModel.hasVideoModality
|
||||
}),
|
||||
activeModelId: () => activeModel.activeModelId
|
||||
});
|
||||
const dragAndDrop = useChatScreenDragAndDrop({
|
||||
onDrop: fileUpload.handleFileUpload
|
||||
});
|
||||
const { handleKeydown } = useKeyboardShortcuts({
|
||||
deleteActiveConversation: () => {
|
||||
if (activeConversation()) {
|
||||
showDeleteDialog = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let hasAudioModality = $derived.by(() => {
|
||||
if (activeModelId) {
|
||||
void modelPropsVersion;
|
||||
function handleMobileScroll() {
|
||||
if (!isMobile.current) return;
|
||||
|
||||
return modelsStore.modelSupportsAudio(activeModelId);
|
||||
}
|
||||
const container = scroll.chatScrollContainer;
|
||||
if (!container) return;
|
||||
|
||||
return false;
|
||||
});
|
||||
|
||||
let hasVideoModality = $derived.by(() => {
|
||||
if (activeModelId) {
|
||||
void modelPropsVersion;
|
||||
|
||||
return modelsStore.modelSupportsVideo(activeModelId);
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
|
||||
let hasVisionModality = $derived.by(() => {
|
||||
if (activeModelId) {
|
||||
void modelPropsVersion;
|
||||
|
||||
return modelsStore.modelSupportsVision(activeModelId);
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
const distanceFromBottom =
|
||||
container.scrollHeight - container.clientHeight - container.scrollTop;
|
||||
isMobileUserScrolledUp = distanceFromBottom > 300;
|
||||
}
|
||||
|
||||
async function handleDeleteConfirm() {
|
||||
const conversation = activeConversation();
|
||||
@@ -177,27 +117,69 @@
|
||||
showDeleteDialog = false;
|
||||
}
|
||||
|
||||
function handleProcessingInfoVisibility(visible: boolean) {
|
||||
processingInfoVisible = visible;
|
||||
}
|
||||
async function handleSendMessage(message: string, files?: ChatUploadedFile[]): Promise<boolean> {
|
||||
const plainFiles = files ? $state.snapshot(files) : undefined;
|
||||
const result = plainFiles
|
||||
? await parseFilesToMessageExtras(plainFiles, activeModel.activeModelId ?? undefined)
|
||||
: undefined;
|
||||
|
||||
function handleDragEnter(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
|
||||
dragCounter++;
|
||||
|
||||
if (event.dataTransfer?.types.includes('Files')) {
|
||||
isDragOver = true;
|
||||
if (result?.emptyFiles && result.emptyFiles.length > 0) {
|
||||
emptyFileNames = result.emptyFiles;
|
||||
showEmptyFileDialog = true;
|
||||
if (files) {
|
||||
const emptyFileNamesSet = new Set(result.emptyFiles);
|
||||
fileUpload.uploadedFiles = fileUpload.uploadedFiles.filter(
|
||||
(file) => !emptyFileNamesSet.has(file.name)
|
||||
);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
handleSendLikeScroll();
|
||||
|
||||
await chatStore.sendMessage(message, result?.extras);
|
||||
return true;
|
||||
}
|
||||
|
||||
function handleDragLeave(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
function handleSendLikeScroll() {
|
||||
if (!isMobile.current) {
|
||||
autoScroll.enable();
|
||||
}
|
||||
|
||||
dragCounter--;
|
||||
setTimeout(() => {
|
||||
const container = scroll.chatScrollContainer;
|
||||
if (!container) return;
|
||||
|
||||
if (dragCounter === 0) {
|
||||
isDragOver = false;
|
||||
const lastUserBubble = container.querySelector(
|
||||
'.chat-message:nth-last-child(2) .chat-message-user .chat-message-user-bubble'
|
||||
) as HTMLElement | null;
|
||||
|
||||
if (isMobile.current) {
|
||||
// Keep the last user message bubble just above the input on mobile
|
||||
const bubbleHeight = lastUserBubble?.scrollHeight ?? 0;
|
||||
const baseHeight = container.scrollHeight - innerHeight;
|
||||
|
||||
container.scrollTo({
|
||||
top: bubbleHeight > 0 ? baseHeight - bubbleHeight : baseHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
} else if (lastUserBubble) {
|
||||
// On desktop, place the last user message near the top of the viewport
|
||||
const topPadding = 24;
|
||||
const bubbleRect = lastUserBubble.getBoundingClientRect();
|
||||
container.scrollTo({
|
||||
top: Math.max(0, container.scrollTop + bubbleRect.top - topPadding),
|
||||
behavior: 'smooth'
|
||||
});
|
||||
} else {
|
||||
autoScroll.scrollToBottom();
|
||||
}
|
||||
}, 100);
|
||||
|
||||
if (isMobile.current) {
|
||||
autoScroll.setDisabled(disableAutoScroll);
|
||||
mobileScrollDownHint = true;
|
||||
mobileScrollDownHintLockedUntil = Date.now() + 500;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -207,273 +189,138 @@
|
||||
}
|
||||
}
|
||||
|
||||
function handleDragOver(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
}
|
||||
|
||||
function handleDrop(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
|
||||
isDragOver = false;
|
||||
dragCounter = 0;
|
||||
|
||||
if (event.dataTransfer?.files) {
|
||||
const files = Array.from(event.dataTransfer.files);
|
||||
|
||||
if (isEditing()) {
|
||||
const handler = getAddFilesHandler();
|
||||
|
||||
if (handler) {
|
||||
handler(files);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
processFiles(files);
|
||||
}
|
||||
}
|
||||
|
||||
function handleFileRemove(fileId: string) {
|
||||
uploadedFiles = uploadedFiles.filter((f) => f.id !== fileId);
|
||||
}
|
||||
|
||||
function handleFileUpload(files: File[]) {
|
||||
processFiles(files);
|
||||
}
|
||||
|
||||
const { handleKeydown } = useKeyboardShortcuts({
|
||||
deleteActiveConversation: () => {
|
||||
if (activeConversation()) {
|
||||
showDeleteDialog = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
async function handleSystemPromptAdd(draft: { message: string; files: ChatUploadedFile[] }) {
|
||||
if (draft.message || draft.files.length > 0) {
|
||||
chatStore.savePendingDraft(draft.message, draft.files);
|
||||
}
|
||||
|
||||
await chatStore.addSystemPrompt();
|
||||
}
|
||||
|
||||
function handleScroll() {
|
||||
autoScroll.handleScroll();
|
||||
}
|
||||
|
||||
async function handleSendMessage(message: string, files?: ChatUploadedFile[]): Promise<boolean> {
|
||||
const plainFiles = files ? $state.snapshot(files) : undefined;
|
||||
const result = plainFiles
|
||||
? await parseFilesToMessageExtras(plainFiles, activeModelId ?? undefined)
|
||||
: undefined;
|
||||
|
||||
if (result?.emptyFiles && result.emptyFiles.length > 0) {
|
||||
emptyFileNames = result.emptyFiles;
|
||||
showEmptyFileDialog = true;
|
||||
|
||||
if (files) {
|
||||
const emptyFileNamesSet = new Set(result.emptyFiles);
|
||||
uploadedFiles = uploadedFiles.filter((file) => !emptyFileNamesSet.has(file.name));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const extras = result?.extras;
|
||||
|
||||
// Enable autoscroll for user-initiated message sending
|
||||
autoScroll.enable();
|
||||
await chatStore.sendMessage(message, extras);
|
||||
autoScroll.scrollToBottom();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
async function processFiles(files: File[]) {
|
||||
const generallySupported: File[] = [];
|
||||
const generallyUnsupported: File[] = [];
|
||||
|
||||
for (const file of files) {
|
||||
if (isFileTypeSupported(file.name, file.type)) {
|
||||
generallySupported.push(file);
|
||||
} else {
|
||||
generallyUnsupported.push(file);
|
||||
}
|
||||
}
|
||||
|
||||
// Use model-specific capabilities for file validation
|
||||
const capabilities = {
|
||||
hasVision: hasVisionModality,
|
||||
hasAudio: hasAudioModality,
|
||||
hasVideo: hasVideoModality
|
||||
};
|
||||
const { supportedFiles, unsupportedFiles, modalityReasons } = filterFilesByModalities(
|
||||
generallySupported,
|
||||
capabilities
|
||||
);
|
||||
|
||||
const allUnsupportedFiles = [...generallyUnsupported, ...unsupportedFiles];
|
||||
|
||||
if (allUnsupportedFiles.length > 0) {
|
||||
const supportedTypes: string[] = ['text files', 'PDFs'];
|
||||
|
||||
if (hasVisionModality) supportedTypes.push('images');
|
||||
if (hasAudioModality) supportedTypes.push('audio files');
|
||||
if (hasVideoModality) supportedTypes.push('video files');
|
||||
|
||||
fileErrorData = {
|
||||
generallyUnsupported,
|
||||
modalityUnsupported: unsupportedFiles,
|
||||
modalityReasons,
|
||||
supportedTypes
|
||||
};
|
||||
showFileErrorDialog = true;
|
||||
}
|
||||
|
||||
if (supportedFiles.length > 0) {
|
||||
const processed = await processFilesToChatUploaded(
|
||||
supportedFiles,
|
||||
activeModelId ?? undefined
|
||||
);
|
||||
uploadedFiles = [...uploadedFiles, ...processed];
|
||||
}
|
||||
}
|
||||
|
||||
afterNavigate(() => {
|
||||
if (!disableAutoScroll) {
|
||||
$effect(() => {
|
||||
const shouldDisableAutoScroll =
|
||||
config().disableAutoScroll || (isMobile.current && isCurrentConversationLoading);
|
||||
autoScroll.setDisabled(shouldDisableAutoScroll);
|
||||
if (!shouldDisableAutoScroll) {
|
||||
autoScroll.enable();
|
||||
}
|
||||
});
|
||||
|
||||
function handleMessagesReady() {
|
||||
if (disableAutoScroll) return;
|
||||
|
||||
if (!autoScroll.userScrolledUp) {
|
||||
requestAnimationFrame(() => {
|
||||
autoScroll.scrollToBottom('instant');
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
onMount(() => {
|
||||
const pendingDraft = chatStore.consumePendingDraft();
|
||||
if (pendingDraft) {
|
||||
initialMessage = pendingDraft.message;
|
||||
fileUpload.uploadedFiles = pendingDraft.files;
|
||||
}
|
||||
|
||||
autoScroll.startObserving();
|
||||
|
||||
if (!disableAutoScroll) {
|
||||
autoScroll.enable();
|
||||
}
|
||||
|
||||
const pendingDraft = chatStore.consumePendingDraft();
|
||||
if (pendingDraft) {
|
||||
initialMessage = pendingDraft.message;
|
||||
uploadedFiles = pendingDraft.files;
|
||||
if (isMobile.current && isCurrentConversationLoading) {
|
||||
mobileScrollDownHint = true;
|
||||
mobileScrollDownHintLockedUntil = Date.now() + 500;
|
||||
}
|
||||
|
||||
handleMobileScroll();
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
autoScroll.setContainer(chatScrollContainer);
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
autoScroll.setDisabled(disableAutoScroll);
|
||||
});
|
||||
onDestroy(() => autoScroll.destroy());
|
||||
</script>
|
||||
|
||||
{#if isDragOver}
|
||||
{#if dragAndDrop.isDragOver}
|
||||
<ChatScreenDragOverlay />
|
||||
{/if}
|
||||
|
||||
<svelte:window onkeydown={handleKeydown} />
|
||||
<svelte:window
|
||||
onkeydown={handleKeydown}
|
||||
onscroll={(e) => {
|
||||
scroll.handleScroll(e);
|
||||
handleMobileScroll();
|
||||
if (e.isTrusted && Date.now() > mobileScrollDownHintLockedUntil) {
|
||||
mobileScrollDownHint = false;
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
{#if isServerLoading}
|
||||
<ServerLoadingSplash />
|
||||
{:else}
|
||||
<div
|
||||
bind:this={chatScrollContainer}
|
||||
aria-label="Chat interface with file drop zone"
|
||||
class="flex h-full flex-col overflow-y-auto px-4 md:px-6"
|
||||
ondragenter={handleDragEnter}
|
||||
ondragleave={handleDragLeave}
|
||||
ondragover={handleDragOver}
|
||||
ondrop={handleDrop}
|
||||
onscroll={handleScroll}
|
||||
class="chat-screen flex grow flex-col min-h-[calc(100dvh-1rem)] md:min-h-full px-4 md:py-0 pt-12 pb-48 md:pb-4"
|
||||
style:--chat-form-bottom-position={chatFormBottomPosition}
|
||||
ondragenter={dragAndDrop.dragHandlers.dragenter}
|
||||
ondragleave={dragAndDrop.dragHandlers.dragleave}
|
||||
ondragover={dragAndDrop.dragHandlers.dragover}
|
||||
ondrop={dragAndDrop.dragHandlers.drop}
|
||||
role="main"
|
||||
>
|
||||
<div class="flex grow flex-col pt-14">
|
||||
{#if !isEmpty}
|
||||
<ChatMessages
|
||||
messages={activeMessages()}
|
||||
onMessagesReady={handleMessagesReady}
|
||||
onUserAction={() => {
|
||||
autoScroll.enable();
|
||||
if (!autoScroll.userScrolledUp) {
|
||||
autoScroll.scrollToBottom();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
{/if}
|
||||
{#if !isEmpty}
|
||||
<ChatMessages
|
||||
messages={activeMessages()}
|
||||
onUserAction={() => {
|
||||
handleSendLikeScroll();
|
||||
}}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
<div
|
||||
class={[
|
||||
'pointer-events-none sticky right-4 left-4 mt-auto transition-all duration-200',
|
||||
isEmpty ? 'bottom-[calc(50dvh-7rem)]' : 'bottom-4 pt-24 md:pt-32'
|
||||
]}
|
||||
>
|
||||
<ChatScreenGreeting {isEmpty} />
|
||||
<div
|
||||
class={[
|
||||
'pointer-events-none md:sticky fixed mt-auto transition-all duration-200',
|
||||
device.isStandalone
|
||||
? 'bottom-6 right-4 left-4'
|
||||
: device.isIOSSafari
|
||||
? 'bottom-1 left-2 right-2'
|
||||
: 'bottom-2 right-2 left-2',
|
||||
isEmpty ? 'md:bottom-[calc(50dvh-7rem)] 2xl:bottom-[calc(50dvh-4rem)]' : 'md:bottom-4'
|
||||
]}
|
||||
style:padding-top={!isEmpty ? 'var(--chat-form-padding-top)' : undefined}
|
||||
>
|
||||
<ChatScreenGreeting {isEmpty} />
|
||||
|
||||
<ChatScreenActionScrollDown
|
||||
container={chatScrollContainer}
|
||||
hasProcessingInfoVisible={processingInfoVisible}
|
||||
/>
|
||||
<ChatScreenServerError />
|
||||
|
||||
<ChatScreenProcessingInfo onVisibilityChange={handleProcessingInfoVisibility} />
|
||||
|
||||
<ChatScreenServerError />
|
||||
|
||||
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl">
|
||||
<ChatScreenForm
|
||||
disabled={hasPropsError || isEditing()}
|
||||
{initialMessage}
|
||||
isLoading={isCurrentConversationLoading}
|
||||
onFileRemove={handleFileRemove}
|
||||
onFileUpload={handleFileUpload}
|
||||
onSend={handleSendMessage}
|
||||
onStop={() => chatStore.stopGeneration()}
|
||||
onSystemPromptAdd={handleSystemPromptAdd}
|
||||
bind:uploadedFiles
|
||||
<div class="pointer-events-none flex flex-col gap-6 items-center w-full">
|
||||
{#if (isMobile.current ? mobileScrollDownHint || isMobileUserScrolledUp : autoScroll.userScrolledUp) && page.url.hash.includes(ROUTES.CHAT) && page.params.id}
|
||||
<ChatScreenActionScrollDown
|
||||
onclick={() => {
|
||||
mobileScrollDownHint = false;
|
||||
scroll.chatScrollContainer?.scrollTo({
|
||||
top: scroll.chatScrollContainer.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if showProcessingInfo}
|
||||
<ChatScreenProcessingInfo />
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<ChatScreenForm
|
||||
class="pointer-events-auto conversation-chat-form"
|
||||
disabled={hasPropsError || isEditing()}
|
||||
{initialMessage}
|
||||
isLoading={isCurrentConversationLoading}
|
||||
onFileRemove={fileUpload.handleFileRemove}
|
||||
onFileUpload={fileUpload.handleFileUpload}
|
||||
onSend={handleSendMessage}
|
||||
onStop={() => chatStore.stopGeneration()}
|
||||
onSystemPromptAdd={handleSystemPromptAdd}
|
||||
bind:uploadedFiles={fileUpload.uploadedFiles}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<DialogFileUploadError bind:open={showFileErrorDialog} {fileErrorData} />
|
||||
|
||||
<DialogConfirmation
|
||||
bind:open={showDeleteDialog}
|
||||
title="Delete Conversation"
|
||||
description="Are you sure you want to delete this conversation? This action cannot be undone and will permanently remove all messages in this conversation."
|
||||
confirmText="Delete"
|
||||
cancelText="Cancel"
|
||||
variant="destructive"
|
||||
icon={Trash2}
|
||||
onConfirm={handleDeleteConfirm}
|
||||
onCancel={() => (showDeleteDialog = false)}
|
||||
/>
|
||||
|
||||
<DialogEmptyFileAlert
|
||||
bind:open={showEmptyFileDialog}
|
||||
emptyFiles={emptyFileNames}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) {
|
||||
emptyFileNames = [];
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
<DialogChatError
|
||||
message={activeErrorDialog?.message ?? ''}
|
||||
contextInfo={activeErrorDialog?.contextInfo}
|
||||
onOpenChange={handleErrorDialogOpenChange}
|
||||
open={Boolean(activeErrorDialog)}
|
||||
type={activeErrorDialog?.type ?? ErrorDialogType.SERVER}
|
||||
<ChatScreenDialogsAndAlerts
|
||||
{showDeleteDialog}
|
||||
{handleDeleteConfirm}
|
||||
{showEmptyFileDialog}
|
||||
{emptyFileNames}
|
||||
{activeErrorDialog}
|
||||
{handleErrorDialogOpenChange}
|
||||
{fileUpload}
|
||||
/>
|
||||
|
||||
@@ -1,58 +1,18 @@
|
||||
<script lang="ts">
|
||||
import { ArrowDown } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import ActionIcon from '$lib/components/app/actions/ActionIcon.svelte';
|
||||
|
||||
interface Props {
|
||||
container: HTMLDivElement | undefined;
|
||||
hasProcessingInfoVisible: boolean;
|
||||
}
|
||||
|
||||
let { container, hasProcessingInfoVisible }: Props = $props();
|
||||
|
||||
let show = $state(false);
|
||||
|
||||
let buttonBottom = $derived(hasProcessingInfoVisible ? '2rem' : '0');
|
||||
|
||||
function checkVisibility() {
|
||||
if (!container) return;
|
||||
const { scrollTop, scrollHeight, clientHeight } = container;
|
||||
const distanceFromBottom = scrollHeight - clientHeight - scrollTop;
|
||||
show = distanceFromBottom > clientHeight * 0.5;
|
||||
}
|
||||
|
||||
function scrollToBottom() {
|
||||
if (container) {
|
||||
container.scrollTo({
|
||||
top: container.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
const c = container;
|
||||
if (c) {
|
||||
c.addEventListener('scroll', checkVisibility);
|
||||
checkVisibility();
|
||||
return () => {
|
||||
c.removeEventListener('scroll', checkVisibility);
|
||||
};
|
||||
}
|
||||
});
|
||||
let { onclick }: { onclick: (e?: MouseEvent) => void } = $props();
|
||||
</script>
|
||||
|
||||
<div class="relative z-50 mx-auto mb-4 flex max-w-[48rem] justify-center">
|
||||
<Button
|
||||
onclick={scrollToBottom}
|
||||
variant="secondary"
|
||||
size="icon"
|
||||
disabled={!show}
|
||||
class="pointer-events-auto absolute h-10 w-10 rounded-full bg-background/80 shadow-lg backdrop-blur-sm transition-all duration-200 hover:bg-muted/80"
|
||||
style="bottom: {buttonBottom}; transform: translateY({show ? '0' : '2rem'}); opacity: {show
|
||||
? 1
|
||||
: 0};"
|
||||
aria-label="Scroll to bottom"
|
||||
>
|
||||
<ArrowDown class="h-4 w-4" />
|
||||
</Button>
|
||||
<div class="pointer-events-auto flex justify-center relative h-0">
|
||||
<ActionIcon
|
||||
icon={ArrowDown}
|
||||
{onclick}
|
||||
ariaLabel="Scroll to bottom"
|
||||
tooltip="Scroll to bottom"
|
||||
size="lg"
|
||||
iconSize="h-4 w-4"
|
||||
class="h-9 w-9 rounded-full bg-accent text-accent-foreground absolute bottom-4 shadow-md"
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
<script lang="ts">
|
||||
import { Trash2 } from '@lucide/svelte';
|
||||
import { ErrorDialogType } from '$lib/enums';
|
||||
import {
|
||||
DialogChatError,
|
||||
DialogConfirmation,
|
||||
DialogEmptyFileAlert,
|
||||
DialogFileUploadError
|
||||
} from '$lib/components/app';
|
||||
|
||||
let {
|
||||
showDeleteDialog,
|
||||
handleDeleteConfirm,
|
||||
showEmptyFileDialog,
|
||||
emptyFileNames,
|
||||
activeErrorDialog,
|
||||
handleErrorDialogOpenChange,
|
||||
fileUpload
|
||||
} = $props();
|
||||
</script>
|
||||
|
||||
<DialogFileUploadError
|
||||
bind:open={fileUpload.showFileErrorDialog}
|
||||
fileErrorData={fileUpload.fileErrorData}
|
||||
/>
|
||||
|
||||
<DialogConfirmation
|
||||
bind:open={showDeleteDialog}
|
||||
title="Delete Conversation"
|
||||
description="Are you sure you want to delete this conversation? This action cannot be undone and will permanently remove all messages in this conversation."
|
||||
confirmText="Delete"
|
||||
cancelText="Cancel"
|
||||
variant="destructive"
|
||||
icon={Trash2}
|
||||
onConfirm={handleDeleteConfirm}
|
||||
onCancel={() => (showDeleteDialog = false)}
|
||||
/>
|
||||
|
||||
<DialogEmptyFileAlert
|
||||
bind:open={showEmptyFileDialog}
|
||||
emptyFiles={emptyFileNames}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) {
|
||||
emptyFileNames = [];
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
<DialogChatError
|
||||
message={activeErrorDialog?.message ?? ''}
|
||||
contextInfo={activeErrorDialog?.contextInfo}
|
||||
onOpenChange={handleErrorDialogOpenChange}
|
||||
open={Boolean(activeErrorDialog)}
|
||||
type={activeErrorDialog?.type ?? ErrorDialogType.SERVER}
|
||||
/>
|
||||
@@ -2,6 +2,7 @@
|
||||
import { afterNavigate } from '$app/navigation';
|
||||
import { page } from '$app/state';
|
||||
import { ChatForm } from '$lib/components/app';
|
||||
import { isMobile } from '$lib/stores/viewport.svelte';
|
||||
import { onMount } from 'svelte';
|
||||
import { useDraftMessages } from '$lib/hooks/use-draft-messages.svelte';
|
||||
|
||||
@@ -32,7 +33,30 @@
|
||||
}: Props = $props();
|
||||
|
||||
let chatFormRef: ChatForm | undefined = $state(undefined);
|
||||
let formWrapperEl: HTMLDivElement | undefined = $state();
|
||||
let chatId = $derived(page.params.id as string | undefined);
|
||||
|
||||
$effect(() => {
|
||||
if (!formWrapperEl) return;
|
||||
|
||||
const formEl = formWrapperEl.querySelector('form') as HTMLElement | null;
|
||||
if (!formEl) return;
|
||||
|
||||
const updateHeight = () => {
|
||||
const height = Math.round(formEl.getBoundingClientRect().height);
|
||||
document.documentElement.style.setProperty('--chat-form-height', `${height}px`);
|
||||
};
|
||||
|
||||
updateHeight();
|
||||
|
||||
const resizeObserver = new ResizeObserver(updateHeight);
|
||||
resizeObserver.observe(formEl);
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
document.documentElement.style.removeProperty('--chat-form-height');
|
||||
};
|
||||
});
|
||||
let hasLoadingAttachments = $derived(uploadedFiles.some((f) => f.isLoading));
|
||||
let message = $derived(initialMessage);
|
||||
let previousIsLoading = $derived(isLoading);
|
||||
@@ -83,12 +107,14 @@
|
||||
}
|
||||
|
||||
onMount(() => {
|
||||
setTimeout(() => chatFormRef?.focus(), 10);
|
||||
if (!isMobile.current) {
|
||||
setTimeout(() => chatFormRef?.focus(), 100);
|
||||
}
|
||||
});
|
||||
|
||||
afterNavigate((navigation) => {
|
||||
if (navigation?.from != null) {
|
||||
setTimeout(() => chatFormRef?.focus(), 10);
|
||||
if (navigation?.from != null && !isMobile.current) {
|
||||
setTimeout(() => chatFormRef?.focus(), 100);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -108,12 +134,12 @@
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="relative mx-auto max-w-[48rem]">
|
||||
<div class="chat-screen-form-wrapper" bind:this={formWrapperEl}>
|
||||
<ChatForm
|
||||
class="mx-auto max-w-3xl {className}"
|
||||
bind:this={chatFormRef}
|
||||
bind:value={message}
|
||||
bind:uploadedFiles
|
||||
class={className}
|
||||
{disabled}
|
||||
{isLoading}
|
||||
showMcpPromptButton
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
<script lang="ts">
|
||||
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
|
||||
import { serverStore } from '$lib/stores/server.svelte';
|
||||
|
||||
interface Props {
|
||||
@@ -11,10 +10,9 @@
|
||||
|
||||
<div
|
||||
class={[
|
||||
'pointer-events-none mb-4 hidden px-4 text-center',
|
||||
isEmpty && 'pointer-events-auto block!'
|
||||
'pointer-events-none mb-4 hidden px-4 text-center text-balance',
|
||||
isEmpty && 'mb-[calc(50dvh-8rem)] md:mb-6 pointer-events-auto block!'
|
||||
]}
|
||||
use:fadeInView={{ duration: 300 }}
|
||||
>
|
||||
<h1 class="mb-2 text-2xl font-semibold tracking-tight md:text-3xl">Hello there</h1>
|
||||
|
||||
|
||||
@@ -5,13 +5,8 @@
|
||||
import { chatStore, isLoading, isChatStreaming } from '$lib/stores/chat.svelte';
|
||||
import { activeMessages, activeConversation } from '$lib/stores/conversations.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { getProcessingInfoContext } from '$lib/contexts';
|
||||
import { page } from '$app/state';
|
||||
|
||||
const processingState = useProcessingState();
|
||||
const processingInfoCtx = getProcessingInfoContext();
|
||||
|
||||
let showProcessingInfo = $derived(processingInfoCtx.showProcessingInfo);
|
||||
|
||||
let isCurrentConversationLoading = $derived(isLoading());
|
||||
let isStreaming = $derived(isChatStreaming());
|
||||
@@ -70,8 +65,8 @@
|
||||
|
||||
<div
|
||||
class={[
|
||||
'chat-processing-info-container pointer-events-none relative',
|
||||
page.params.id && showProcessingInfo && 'visible'
|
||||
'chat-processing-info-container pointer-events-none relative w-full hidden md:block',
|
||||
processingVisible && 'visible'
|
||||
]}
|
||||
>
|
||||
<div class="chat-processing-info-content absolute bottom-4 left-1/2 -translate-x-1/2">
|
||||
|
||||
@@ -677,13 +677,6 @@ export { default as ChatScreenForm } from './ChatScreen/ChatScreenForm.svelte';
|
||||
*/
|
||||
export { default as ChatScreenProcessingInfo } from './ChatScreen/ChatScreenProcessingInfo.svelte';
|
||||
|
||||
/**
|
||||
* Scroll-to-bottom action button. Displays a floating button when the user
|
||||
* has scrolled up more than half a viewport height from the bottom.
|
||||
* Takes the chat container element as a prop to manage scroll state internally.
|
||||
*/
|
||||
export { default as ChatScreenActionScrollDown } from './ChatScreen/ChatScreenActionScrollDown.svelte';
|
||||
|
||||
/**
|
||||
* Server error alert displayed when the server is unreachable.
|
||||
* Shows the error message with a retry button.
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import { Search, X } from '@lucide/svelte';
|
||||
|
||||
interface Props {
|
||||
autofocus?: boolean;
|
||||
value?: string;
|
||||
placeholder?: string;
|
||||
onInput?: (value: string) => void;
|
||||
@@ -15,6 +16,7 @@
|
||||
}
|
||||
|
||||
let {
|
||||
autofocus,
|
||||
value = $bindable(''),
|
||||
placeholder = 'Search...',
|
||||
onInput,
|
||||
@@ -39,7 +41,7 @@
|
||||
if (value) {
|
||||
value = '';
|
||||
onInput?.('');
|
||||
ref?.focus();
|
||||
ref?.focus({ preventScroll: true });
|
||||
} else {
|
||||
onClose?.();
|
||||
}
|
||||
@@ -52,6 +54,7 @@
|
||||
/>
|
||||
|
||||
<Input
|
||||
{autofocus}
|
||||
{id}
|
||||
bind:value
|
||||
bind:ref
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
<script>
|
||||
import logoMark from '$lib/assets/logo.svg?raw';
|
||||
let { class: className = '', style = '' } = $props();
|
||||
</script>
|
||||
|
||||
<div class={className} {style}>
|
||||
{@html logoMark}
|
||||
</div>
|
||||
|
||||
<style>
|
||||
div :global(svg) {
|
||||
width: var(--size, 1rem);
|
||||
height: var(--size, 1rem);
|
||||
}
|
||||
</style>
|
||||
@@ -51,3 +51,11 @@ export { default as KeyboardShortcutInfo } from './KeyboardShortcutInfo.svelte';
|
||||
* Preview button is shown only for HTML code blocks.
|
||||
*/
|
||||
export { default as CodeBlockActions } from './CodeBlockActions.svelte';
|
||||
|
||||
/**
|
||||
* **Logo** - Application brand mark
|
||||
*
|
||||
* Inline SVG of the application logo. Accepts styling via the standard
|
||||
* `class` and `style` props and inherits color via `currentColor`.
|
||||
*/
|
||||
export { default as Logo } from './Logo.svelte';
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
<script lang="ts">
|
||||
let { percent }: { percent: number } = $props();
|
||||
</script>
|
||||
|
||||
<!-- thin determinate load bar pinned to the bottom edge, pulsing while it fills -->
|
||||
<div class="pointer-events-none absolute inset-x-0 bottom-0 h-0.5 overflow-hidden rounded-b-sm">
|
||||
<div
|
||||
class="h-full animate-pulse bg-primary transition-[width] duration-200 ease-out"
|
||||
style="width: {percent}%"
|
||||
></div>
|
||||
</div>
|
||||
@@ -2,8 +2,10 @@
|
||||
import { ChevronDown, Loader2, Package } from '@lucide/svelte';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { KeyboardKey } from '$lib/enums';
|
||||
import { KeyboardKey, ServerModelStatus } from '$lib/enums';
|
||||
import { useModelsSelector } from '$lib/hooks/use-models-selector.svelte';
|
||||
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
|
||||
import { modelLoadFraction } from '$lib/utils';
|
||||
import {
|
||||
DialogModelInformation,
|
||||
DropdownMenuSearchable,
|
||||
@@ -11,6 +13,7 @@
|
||||
ModelsSelectorList,
|
||||
ModelsSelectorOption
|
||||
} from '$lib/components/app';
|
||||
import ModelLoadHighlight from './ModelLoadHighlight.svelte';
|
||||
import type { ModelItem } from './utils';
|
||||
|
||||
interface Props {
|
||||
@@ -113,6 +116,17 @@
|
||||
{/if}
|
||||
{:else}
|
||||
{@const selectedOption = ms.getDisplayOption()}
|
||||
{@const triggerModel = selectedOption?.model}
|
||||
{@const triggerStatus = triggerModel
|
||||
? routerModels().find((m) => m.id === triggerModel)?.status?.value
|
||||
: undefined}
|
||||
{@const triggerLoading =
|
||||
!!triggerModel &&
|
||||
(triggerStatus === ServerModelStatus.LOADING ||
|
||||
modelsStore.isModelOperationInProgress(triggerModel))}
|
||||
{@const triggerLoadPercent = triggerLoading
|
||||
? Math.round(modelLoadFraction(modelsStore.getLoadProgress(triggerModel)) * 100)
|
||||
: 0}
|
||||
|
||||
{#if ms.isRouter}
|
||||
<DropdownMenu.Root bind:open={isOpen} onOpenChange={ms.handleOpenChange}>
|
||||
@@ -123,7 +137,7 @@
|
||||
<DropdownMenu.Trigger
|
||||
{...props}
|
||||
class={[
|
||||
`inline-grid cursor-pointer grid-cols-[1fr_auto_1fr] items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
|
||||
`relative inline-grid cursor-pointer grid-cols-[1fr_auto_1fr] items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
|
||||
!ms.isCurrentModelInCache
|
||||
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
|
||||
: forceForegroundText
|
||||
@@ -154,6 +168,10 @@
|
||||
{:else}
|
||||
<ChevronDown class="h-3 w-3.5 shrink-0" />
|
||||
{/if}
|
||||
|
||||
{#if triggerLoading}
|
||||
<ModelLoadHighlight percent={triggerLoadPercent} />
|
||||
{/if}
|
||||
</DropdownMenu.Trigger>
|
||||
{/snippet}
|
||||
</Tooltip.Trigger>
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
RotateCw
|
||||
} from '@lucide/svelte';
|
||||
import { ActionIcon, ModelId } from '$lib/components/app';
|
||||
import ModelLoadHighlight from './ModelLoadHighlight.svelte';
|
||||
import type { ModelOption } from '$lib/types/models';
|
||||
import { ServerModelStatus } from '$lib/enums';
|
||||
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
|
||||
@@ -119,11 +120,11 @@
|
||||
</div>
|
||||
|
||||
{#if isLoading}
|
||||
<div class="flex w-4 [@media(pointer:coarse)]:w-5 items-center justify-center">
|
||||
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-5">
|
||||
<Loader2 class="h-4 w-4 animate-spin text-muted-foreground" />
|
||||
</div>
|
||||
{:else if isFailed}
|
||||
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
|
||||
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
|
||||
<CircleAlert
|
||||
class="h-3.5 w-3.5 text-red-500 group-hover:hidden [@media(pointer:coarse)]:hidden"
|
||||
/>
|
||||
@@ -140,7 +141,7 @@
|
||||
</div>
|
||||
</div>
|
||||
{:else if isSleeping}
|
||||
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
|
||||
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
|
||||
<span
|
||||
class="h-2 w-2 rounded-full bg-orange-400 group-hover:hidden [@media(pointer:coarse)]:hidden"
|
||||
></span>
|
||||
@@ -159,7 +160,7 @@
|
||||
</div>
|
||||
</div>
|
||||
{:else if isLoaded}
|
||||
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
|
||||
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
|
||||
<span
|
||||
class="h-2 w-2 rounded-full bg-green-500 group-hover:hidden [@media(pointer:coarse)]:hidden"
|
||||
></span>
|
||||
@@ -176,7 +177,7 @@
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
|
||||
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
|
||||
<span
|
||||
class="h-2 w-2 rounded-full bg-muted-foreground/50 group-hover:hidden [@media(pointer:coarse)]:hidden"
|
||||
></span>
|
||||
@@ -196,13 +197,6 @@
|
||||
</div>
|
||||
|
||||
{#if isLoading}
|
||||
<div
|
||||
class="pointer-events-none absolute inset-x-0 bottom-0 h-0.5 overflow-hidden rounded-b-sm bg-muted"
|
||||
>
|
||||
<div
|
||||
class="h-full bg-primary transition-[width] duration-200 ease-out"
|
||||
style="width: {loadPercent}%"
|
||||
></div>
|
||||
</div>
|
||||
<ModelLoadHighlight percent={loadPercent} />
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -8,6 +8,10 @@
|
||||
ModelsSelectorList,
|
||||
SearchInput
|
||||
} from '$lib/components/app';
|
||||
import ModelLoadHighlight from './ModelLoadHighlight.svelte';
|
||||
import { ServerModelStatus } from '$lib/enums';
|
||||
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
|
||||
import { modelLoadFraction } from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -61,12 +65,23 @@
|
||||
<p class="text-xs text-muted-foreground">No models available.</p>
|
||||
{:else}
|
||||
{@const selectedOption = ms.getDisplayOption()}
|
||||
{@const triggerModel = selectedOption?.model}
|
||||
{@const triggerStatus = triggerModel
|
||||
? routerModels().find((m) => m.id === triggerModel)?.status?.value
|
||||
: undefined}
|
||||
{@const triggerLoading =
|
||||
!!triggerModel &&
|
||||
(triggerStatus === ServerModelStatus.LOADING ||
|
||||
modelsStore.isModelOperationInProgress(triggerModel))}
|
||||
{@const triggerLoadPercent = triggerLoading
|
||||
? Math.round(modelLoadFraction(modelsStore.getLoadProgress(triggerModel)) * 100)
|
||||
: 0}
|
||||
|
||||
{#if ms.isRouter}
|
||||
<button
|
||||
type="button"
|
||||
class={[
|
||||
`inline-flex cursor-pointer items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 max-sm:px-3 max-sm:py-2 text-xs max-sm:text-sm shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
|
||||
`relative inline-flex cursor-pointer items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 max-sm:px-3 max-sm:py-2 max-sm:text-sm dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
|
||||
!ms.isCurrentModelInCache
|
||||
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
|
||||
: forceForegroundText
|
||||
@@ -99,6 +114,10 @@
|
||||
{:else}
|
||||
<ChevronDown class="h-3 w-3.5 shrink-0" />
|
||||
{/if}
|
||||
|
||||
{#if triggerLoading}
|
||||
<ModelLoadHighlight percent={triggerLoadPercent} />
|
||||
{/if}
|
||||
</button>
|
||||
|
||||
<Sheet.Root bind:open={sheetOpen} onOpenChange={handleSheetOpenChange}>
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { goto } from '$app/navigation';
|
||||
import { page } from '$app/state';
|
||||
import { ActionIcon } from '$lib/components/app';
|
||||
import {
|
||||
ICON_STRIP_TRANSITION_DURATION,
|
||||
ICON_STRIP_TRANSITION_DELAY_MULTIPLIER,
|
||||
SIDEBAR_ACTIONS_ITEMS
|
||||
} from '$lib/constants';
|
||||
import { TooltipSide } from '$lib/enums';
|
||||
import { fade } from 'svelte/transition';
|
||||
import { circIn } from 'svelte/easing';
|
||||
import { onMount } from 'svelte';
|
||||
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
|
||||
|
||||
interface Props {
|
||||
sidebarOpen: boolean;
|
||||
onSearchClick: () => void;
|
||||
}
|
||||
|
||||
let { sidebarOpen = false, onSearchClick }: Props = $props();
|
||||
|
||||
const { handleKeydown } = useKeyboardShortcuts({ activateSearchMode: () => onSearchClick() });
|
||||
|
||||
let initialized = $state(false);
|
||||
let showIcons = $derived(!sidebarOpen);
|
||||
|
||||
showIcons = false;
|
||||
|
||||
onMount(() => {
|
||||
showIcons = !sidebarOpen;
|
||||
|
||||
setTimeout(() => {
|
||||
initialized = true;
|
||||
}, ICON_STRIP_TRANSITION_DELAY_MULTIPLIER * SIDEBAR_ACTIONS_ITEMS.length);
|
||||
});
|
||||
</script>
|
||||
|
||||
<svelte:window onkeydown={handleKeydown} />
|
||||
|
||||
<div
|
||||
class="hidden shrink-0 transition-[width] duration-200 ease-linear md:block {sidebarOpen
|
||||
? 'w-0'
|
||||
: 'w-[calc(var(--sidebar-width-icon)+1.5rem)]'}"
|
||||
></div>
|
||||
<aside
|
||||
class="fixed top-0 bottom-0 left-0 z-10 hidden w-[calc(var(--sidebar-width-icon)+1.5rem)] flex-col items-center justify-between py-3 transition-opacity duration-200 ease-linear md:flex {sidebarOpen
|
||||
? 'pointer-events-none opacity-0'
|
||||
: 'opacity-100'}"
|
||||
>
|
||||
<div class="mt-12 flex flex-col items-center gap-1">
|
||||
{#each SIDEBAR_ACTIONS_ITEMS as item, i (item.tooltip)}
|
||||
{@const onclick = item.route ? () => goto(item.route!) : onSearchClick}
|
||||
{@const isActive = item.activeRouteId
|
||||
? page.route.id === item.activeRouteId
|
||||
: item.activeRoutePrefix
|
||||
? !!page.route.id?.startsWith(item.activeRoutePrefix)
|
||||
: false}
|
||||
{#if showIcons}
|
||||
<div
|
||||
in:fade={{
|
||||
duration: ICON_STRIP_TRANSITION_DURATION,
|
||||
delay: !initialized
|
||||
? ICON_STRIP_TRANSITION_DELAY_MULTIPLIER + i * ICON_STRIP_TRANSITION_DELAY_MULTIPLIER
|
||||
: 0,
|
||||
easing: circIn
|
||||
}}
|
||||
>
|
||||
<ActionIcon
|
||||
icon={item.icon}
|
||||
tooltip={item.tooltip}
|
||||
tooltipSide={TooltipSide.RIGHT}
|
||||
size="lg"
|
||||
iconSize="h-4 w-4"
|
||||
class="h-9 w-9 rounded-full hover:bg-accent! {isActive
|
||||
? 'bg-accent text-accent-foreground'
|
||||
: ''}"
|
||||
{onclick}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
||||
</aside>
|
||||
+234
-295
@@ -1,40 +1,67 @@
|
||||
<script lang="ts">
|
||||
import { goto } from '$app/navigation';
|
||||
import { page } from '$app/state';
|
||||
import { Trash2, Pencil, Pin, X } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { DialogConfirmation } from '$lib/components/app';
|
||||
import SidebarNavigationActions from './SidebarNavigationActions.svelte';
|
||||
import SidebarNavigationConversationItem from './SidebarNavigationConversationItem.svelte';
|
||||
import { Checkbox } from '$lib/components/ui/checkbox';
|
||||
import Label from '$lib/components/ui/label/label.svelte';
|
||||
import ScrollArea from '$lib/components/ui/scroll-area/scroll-area.svelte';
|
||||
import * as Sidebar from '$lib/components/ui/sidebar';
|
||||
import Input from '$lib/components/ui/input/input.svelte';
|
||||
import { ROUTES } from '$lib/constants/routes';
|
||||
import { RouterService } from '$lib/services/router.service';
|
||||
import { PanelLeftClose, PanelLeftOpen, X } from '@lucide/svelte';
|
||||
import {
|
||||
conversationsStore,
|
||||
conversations,
|
||||
buildConversationTree
|
||||
} from '$lib/stores/conversations.svelte';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { getPreviewText } from '$lib/utils';
|
||||
import { APP_NAME } from '$lib/constants';
|
||||
ActionIcon,
|
||||
Logo,
|
||||
SidebarNavigationConversationList,
|
||||
SidebarNavigationActions
|
||||
} from '$lib/components/app';
|
||||
import { ROUTES } from '$lib/constants';
|
||||
import { fade } from 'svelte/transition';
|
||||
|
||||
const sidebar = Sidebar.useSidebar();
|
||||
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
|
||||
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { RouterService } from '$lib/services/router.service';
|
||||
import { isMobile } from '$lib/stores/viewport.svelte';
|
||||
import { TooltipSide } from '$lib/enums';
|
||||
import { device } from '$lib/stores/device.svelte';
|
||||
import { circIn } from 'svelte/easing';
|
||||
|
||||
interface Props {
|
||||
onSearchClick?: () => void;
|
||||
}
|
||||
|
||||
let { onSearchClick = () => {} }: Props = $props();
|
||||
|
||||
const { handleKeydown } = useKeyboardShortcuts({ activateSearchMode: () => onSearchClick() });
|
||||
|
||||
let isExpandedMode = $state(false);
|
||||
let hoveredTooltip = $state<string | null>(null);
|
||||
let logoHovered = $state(false);
|
||||
|
||||
const isStripExpanded = $derived(isExpandedMode || hoveredTooltip !== null);
|
||||
const isOnMobile = $derived(isMobile.current);
|
||||
|
||||
function toggleExpandedMode() {
|
||||
isExpandedMode = !isExpandedMode;
|
||||
if (!isExpandedMode) {
|
||||
hoveredTooltip = null;
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (!isExpandedMode) {
|
||||
isSearchModeActive = false;
|
||||
searchQuery = '';
|
||||
cancelMobileCollapse();
|
||||
}
|
||||
});
|
||||
|
||||
// On mobile the dedicated /search route hides the sidebar (see the aside
|
||||
// render guard below). Collapse it as we enter /search so it doesn't
|
||||
// reappear expanded when the user navigates back via the back button.
|
||||
$effect(() => {
|
||||
if (isMobile.current && page.url.hash.includes(ROUTES.SEARCH)) {
|
||||
isExpandedMode = false;
|
||||
}
|
||||
});
|
||||
|
||||
let currentChatId = $derived(page.params.id);
|
||||
let isSearchModeActive = $state(false);
|
||||
let searchQuery = $state('');
|
||||
let showDeleteDialog = $state(false);
|
||||
let deleteWithForks = $state(false);
|
||||
let showEditDialog = $state(false);
|
||||
let selectedConversation = $state<DatabaseConversation | null>(null);
|
||||
let editedName = $state('');
|
||||
let selectedConversationNamePreview = $derived.by(() =>
|
||||
selectedConversation ? getPreviewText(selectedConversation.name) : ''
|
||||
);
|
||||
|
||||
let filteredConversations = $derived.by(() => {
|
||||
if (isSearchModeActive) {
|
||||
@@ -50,294 +77,206 @@
|
||||
return conversations();
|
||||
});
|
||||
|
||||
let conversationTree = $derived(buildConversationTree(filteredConversations));
|
||||
|
||||
let pinnedConversations = $derived.by(() => {
|
||||
return conversationTree.filter(({ conversation }) => conversation.pinned);
|
||||
});
|
||||
|
||||
let unpinnedConversations = $derived.by(() => {
|
||||
return conversationTree.filter(({ conversation }) => !conversation.pinned);
|
||||
});
|
||||
|
||||
let selectedConversationHasDescendants = $derived.by(() => {
|
||||
if (!selectedConversation) return false;
|
||||
|
||||
const allConvs = conversations();
|
||||
const queue = [selectedConversation.id];
|
||||
|
||||
while (queue.length > 0) {
|
||||
const parentId = queue.pop()!;
|
||||
|
||||
for (const c of allConvs) {
|
||||
if (c.forkedFromConversationId === parentId) return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
|
||||
async function handleDeleteConversation(id: string) {
|
||||
const conversation = conversations().find((conv) => conv.id === id);
|
||||
if (conversation) {
|
||||
selectedConversation = conversation;
|
||||
deleteWithForks = false;
|
||||
showDeleteDialog = true;
|
||||
async function selectConversation(id: string) {
|
||||
if (isMobile.current) {
|
||||
scheduleMobileCollapse();
|
||||
}
|
||||
await goto(RouterService.chat(id));
|
||||
}
|
||||
|
||||
async function handleEditConversation(id: string) {
|
||||
const conversation = conversations().find((conv) => conv.id === id);
|
||||
if (conversation) {
|
||||
selectedConversation = conversation;
|
||||
editedName = conversation.name;
|
||||
showEditDialog = true;
|
||||
if (!conversation) return;
|
||||
|
||||
const newName = window.prompt('Rename conversation', conversation.name);
|
||||
if (newName && newName.trim()) {
|
||||
await conversationsStore.updateConversationName(id, newName.trim());
|
||||
}
|
||||
}
|
||||
|
||||
function handleConfirmDelete() {
|
||||
if (selectedConversation) {
|
||||
const convId = selectedConversation.id;
|
||||
const withForks = deleteWithForks;
|
||||
showDeleteDialog = false;
|
||||
async function handleDeleteConversation(id: string) {
|
||||
const conversation = conversations().find((conv) => conv.id === id);
|
||||
if (!conversation) return;
|
||||
|
||||
setTimeout(() => {
|
||||
conversationsStore.deleteConversation(convId, {
|
||||
deleteWithForks: withForks
|
||||
});
|
||||
}, 100); // Wait for animation to finish
|
||||
}
|
||||
}
|
||||
const confirmed = window.confirm(
|
||||
`Delete "${conversation.name}"? This action cannot be undone.`
|
||||
);
|
||||
if (!confirmed) return;
|
||||
|
||||
function handleConfirmEdit() {
|
||||
if (!editedName.trim() || !selectedConversation) return;
|
||||
|
||||
showEditDialog = false;
|
||||
|
||||
conversationsStore.updateConversationName(selectedConversation.id, editedName);
|
||||
selectedConversation = null;
|
||||
}
|
||||
|
||||
export function handleMobileSidebarItemClick() {
|
||||
if (sidebar.isMobile) {
|
||||
sidebar.toggle();
|
||||
}
|
||||
}
|
||||
|
||||
let chatSidebarActions: { activateSearch?: () => void } | undefined = $state();
|
||||
let openedForSearch = $state(false);
|
||||
|
||||
export function activateSearchMode() {
|
||||
if (!sidebar.open) {
|
||||
openedForSearch = true;
|
||||
}
|
||||
chatSidebarActions?.activateSearch?.();
|
||||
}
|
||||
|
||||
function handleSearchDeactivated() {
|
||||
if (openedForSearch) {
|
||||
openedForSearch = false;
|
||||
sidebar.toggle();
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (!sidebar.open) {
|
||||
isSearchModeActive = false;
|
||||
searchQuery = '';
|
||||
openedForSearch = false;
|
||||
}
|
||||
});
|
||||
|
||||
export function editActiveConversation() {
|
||||
if (currentChatId) {
|
||||
const activeConversation = filteredConversations.find((conv) => conv.id === currentChatId);
|
||||
|
||||
if (activeConversation) {
|
||||
const event = new CustomEvent('edit-active-conversation', {
|
||||
detail: { conversationId: currentChatId }
|
||||
});
|
||||
document.dispatchEvent(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function selectConversation(id: string) {
|
||||
if (isSearchModeActive) {
|
||||
isSearchModeActive = false;
|
||||
searchQuery = '';
|
||||
}
|
||||
|
||||
handleMobileSidebarItemClick();
|
||||
await goto(RouterService.chat(id));
|
||||
await conversationsStore.deleteConversation(id, { deleteWithForks: false });
|
||||
}
|
||||
|
||||
function handleStopGeneration(id: string) {
|
||||
chatStore.stopGenerationForChat(id);
|
||||
}
|
||||
|
||||
let innerWidth = $state(0);
|
||||
let pendingCollapse = $state<ReturnType<typeof setTimeout> | null>(null);
|
||||
|
||||
function scheduleMobileCollapse() {
|
||||
if (pendingCollapse) {
|
||||
clearTimeout(pendingCollapse);
|
||||
}
|
||||
pendingCollapse = setTimeout(() => {
|
||||
isExpandedMode = false;
|
||||
pendingCollapse = null;
|
||||
}, 100);
|
||||
}
|
||||
|
||||
function cancelMobileCollapse() {
|
||||
if (pendingCollapse) {
|
||||
clearTimeout(pendingCollapse);
|
||||
pendingCollapse = null;
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex h-full flex-col">
|
||||
<ScrollArea class="h-full flex-1">
|
||||
<Sidebar.Header class="gap-4 bg-sidebar/50 p-3 backdrop-blur-lg md:pt-4 md:pb-2">
|
||||
<div class="flex items-center justify-between">
|
||||
<a href={ROUTES.START} onclick={handleMobileSidebarItemClick}>
|
||||
<h1 class="inline-flex items-center gap-1 px-2 text-xl font-semibold">
|
||||
{APP_NAME}
|
||||
</h1>
|
||||
</a>
|
||||
<svelte:window onkeydown={handleKeydown} bind:innerWidth />
|
||||
|
||||
<Button
|
||||
class="rounded-full md:hidden"
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onclick={() => sidebar.toggle()}
|
||||
>
|
||||
<X class="h-4 w-4" />
|
||||
<span class="sr-only">Close sidebar</span>
|
||||
</Button>
|
||||
{#if innerWidth > 768 || (!page.url.hash.includes(ROUTES.SETTINGS) && !page.url.hash.includes(ROUTES.MCP_SERVERS) && !page.url.hash.includes(ROUTES.SEARCH))}
|
||||
<aside
|
||||
class={[
|
||||
// Layout & positioning
|
||||
'fixed md:sticky top-2 left-2 md:left-0 md:ml-2 md:mt-2 pt-2 z-10 w-[calc(100dvw-1rem)]',
|
||||
// Dimensions & overflow
|
||||
'md:h-[calc(100dvh-1.125rem)]',
|
||||
isExpandedMode &&
|
||||
(device.isStandalone
|
||||
? 'h-[calc(100dvh-2rem)]'
|
||||
: device.isIOSDevice
|
||||
? 'h-[calc(100dvh-0.5rem)]'
|
||||
: 'h-[calc(100dvh-1rem)]'),
|
||||
// Shape & depth
|
||||
'rounded-3xl md:rounded-2xl',
|
||||
// Flex layout
|
||||
'flex flex-col justify-between',
|
||||
// Transition
|
||||
'md:transition-[width,padding] duration-200 ease-out',
|
||||
// Expanded state: width, surface, depth
|
||||
isStripExpanded && 'md:w-72 md:bg-muted/60 md:backdrop-blur-xl border-border shadow-md',
|
||||
// Collapsed state
|
||||
!isStripExpanded && 'md:w-12',
|
||||
// Expanded mode flag (for mobile ::before overlay)
|
||||
isExpandedMode && 'is-expanded'
|
||||
]}
|
||||
>
|
||||
<div class="px-2 flex items-center justify-between">
|
||||
<div
|
||||
role="button"
|
||||
tabindex="0"
|
||||
class="relative"
|
||||
onmouseenter={() => (logoHovered = true)}
|
||||
onmouseleave={() => (logoHovered = false)}
|
||||
>
|
||||
<ActionIcon
|
||||
icon={!isExpandedMode && logoHovered && innerWidth > 768 ? PanelLeftOpen : Logo}
|
||||
size="lg"
|
||||
iconSize="h-4.5 w-4.5 md:h-4 md:w-4"
|
||||
class="{isExpandedMode
|
||||
? 'bg-muted! md:bg-foreground/5!'
|
||||
: 'bg-transparent!'} md:h-9 md:w-9 h-10 w-10 rounded-full md:hover:bg-foreground/10! pointer-events-auto"
|
||||
href={isExpandedMode ? ROUTES.START : undefined}
|
||||
onclick={isExpandedMode ? undefined : toggleExpandedMode}
|
||||
tooltip={isExpandedMode ? undefined : 'Open Sidebar'}
|
||||
tooltipSide={TooltipSide.RIGHT}
|
||||
ariaLabel={isExpandedMode ? 'Go to start' : 'Expand navigation'}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<SidebarNavigationActions
|
||||
bind:this={chatSidebarActions}
|
||||
{handleMobileSidebarItemClick}
|
||||
bind:isSearchModeActive
|
||||
bind:searchQuery
|
||||
onSearchDeactivated={handleSearchDeactivated}
|
||||
/>
|
||||
</Sidebar.Header>
|
||||
|
||||
{#if !isSearchModeActive && pinnedConversations.length > 0}
|
||||
<Sidebar.Group class="p-0 px-4">
|
||||
<Sidebar.GroupLabel>
|
||||
<div class="flex items-center gap-1">
|
||||
<Pin class="h-3.5 w-3.5" />
|
||||
<span>Pinned</span>
|
||||
</div>
|
||||
</Sidebar.GroupLabel>
|
||||
<Sidebar.GroupContent>
|
||||
<Sidebar.Menu>
|
||||
{#each pinnedConversations as { conversation, depth } (conversation.id)}
|
||||
<Sidebar.MenuItem class="mb-1 p-0">
|
||||
<SidebarNavigationConversationItem
|
||||
conversation={{
|
||||
id: conversation.id,
|
||||
name: conversation.name,
|
||||
lastModified: conversation.lastModified,
|
||||
currNode: conversation.currNode,
|
||||
forkedFromConversationId: conversation.forkedFromConversationId,
|
||||
pinned: conversation.pinned
|
||||
}}
|
||||
{depth}
|
||||
isActive={currentChatId === conversation.id}
|
||||
onSelect={selectConversation}
|
||||
onEdit={handleEditConversation}
|
||||
onDelete={handleDeleteConversation}
|
||||
onStop={handleStopGeneration}
|
||||
/>
|
||||
</Sidebar.MenuItem>
|
||||
{/each}
|
||||
</Sidebar.Menu>
|
||||
</Sidebar.GroupContent>
|
||||
</Sidebar.Group>
|
||||
{/if}
|
||||
|
||||
<Sidebar.Group class="mt-2 h-[calc(100vh-21rem)] space-y-2 p-0 px-3">
|
||||
{#if (filteredConversations.length > 0 && isSearchModeActive) || !isSearchModeActive}
|
||||
<Sidebar.GroupLabel>
|
||||
{isSearchModeActive ? 'Search results' : 'Recent conversations'}
|
||||
</Sidebar.GroupLabel>
|
||||
{#if isExpandedMode || isOnMobile}
|
||||
<div
|
||||
class="flex items-center transition-all duration-150 ease-out {isMobile.current &&
|
||||
!isExpandedMode
|
||||
? 'opacity-0 h-0!'
|
||||
: ''}"
|
||||
in:fade={{ duration: 150, easing: circIn, delay: 50 }}
|
||||
out:fade={{ duration: 100 }}
|
||||
>
|
||||
<ActionIcon
|
||||
icon={isMobile.current ? X : PanelLeftClose}
|
||||
size="lg"
|
||||
iconSize="h-4.5 w-4.5 md:h-4 md:w-4"
|
||||
class="backdrop-blur-none md:h-9 md:w-9 h-10 w-10 rounded-full mr-1 hover:bg-accent!"
|
||||
onclick={toggleExpandedMode}
|
||||
tooltip="Close Sidebar"
|
||||
tooltipSide={TooltipSide.LEFT}
|
||||
ariaLabel="Collapse navigation"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<Sidebar.GroupContent>
|
||||
<Sidebar.Menu>
|
||||
{#each isSearchModeActive ? conversationTree : unpinnedConversations as { conversation, depth } (conversation.id)}
|
||||
<Sidebar.MenuItem class="mb-1 p-0">
|
||||
<SidebarNavigationConversationItem
|
||||
conversation={{
|
||||
id: conversation.id,
|
||||
name: conversation.name,
|
||||
lastModified: conversation.lastModified,
|
||||
currNode: conversation.currNode,
|
||||
forkedFromConversationId: conversation.forkedFromConversationId,
|
||||
pinned: conversation.pinned
|
||||
}}
|
||||
{depth}
|
||||
isActive={currentChatId === conversation.id}
|
||||
onSelect={selectConversation}
|
||||
onEdit={handleEditConversation}
|
||||
onDelete={handleDeleteConversation}
|
||||
onStop={handleStopGeneration}
|
||||
/>
|
||||
</Sidebar.MenuItem>
|
||||
{/each}
|
||||
|
||||
{#if (isSearchModeActive ? conversationTree : unpinnedConversations).length === 0}
|
||||
<div class="px-2 py-4 text-center">
|
||||
<p class="mb-4 p-4 text-sm text-muted-foreground">
|
||||
{searchQuery.length > 0
|
||||
? 'No results found'
|
||||
: isSearchModeActive
|
||||
? 'Start typing to see results'
|
||||
: 'No conversations yet'}
|
||||
</p>
|
||||
</div>
|
||||
{/if}
|
||||
</Sidebar.Menu>
|
||||
</Sidebar.GroupContent>
|
||||
</Sidebar.Group>
|
||||
</ScrollArea>
|
||||
</div>
|
||||
|
||||
<DialogConfirmation
|
||||
bind:open={showDeleteDialog}
|
||||
title="Delete Conversation"
|
||||
description={selectedConversation
|
||||
? `Are you sure you want to delete "${selectedConversationNamePreview}"? This action cannot be undone and will permanently remove all messages in this conversation.`
|
||||
: ''}
|
||||
confirmText="Delete"
|
||||
cancelText="Cancel"
|
||||
variant="destructive"
|
||||
icon={Trash2}
|
||||
onConfirm={handleConfirmDelete}
|
||||
onCancel={() => {
|
||||
showDeleteDialog = false;
|
||||
selectedConversation = null;
|
||||
}}
|
||||
>
|
||||
{#if selectedConversationHasDescendants}
|
||||
<div class="flex items-center gap-2 py-2">
|
||||
<Checkbox id="delete-with-forks" bind:checked={deleteWithForks} />
|
||||
|
||||
<Label for="delete-with-forks" class="text-sm">Also delete all forked conversations</Label>
|
||||
</div>
|
||||
{/if}
|
||||
</DialogConfirmation>
|
||||
|
||||
<DialogConfirmation
|
||||
bind:open={showEditDialog}
|
||||
title="Edit Conversation Name"
|
||||
description=""
|
||||
confirmText="Save"
|
||||
cancelText="Cancel"
|
||||
icon={Pencil}
|
||||
onConfirm={handleConfirmEdit}
|
||||
onCancel={() => {
|
||||
showEditDialog = false;
|
||||
selectedConversation = null;
|
||||
}}
|
||||
onKeydown={(event) => {
|
||||
if (event.key === 'Enter') {
|
||||
event.preventDefault();
|
||||
event.stopImmediatePropagation();
|
||||
handleConfirmEdit();
|
||||
<div class="mt-2 flex min-h-0 flex-1 flex-col gap-4 md:gap-1 overflow-y-auto">
|
||||
<div
|
||||
class="flex min-h-0 flex-1 flex-col gap-4 md:gap-1 {isMobile.current
|
||||
? 'transition-[opacity,height] duration-200 ease-out'
|
||||
: ''} {isMobile.current && !isExpandedMode ? 'opacity-0 !h-0' : ''}"
|
||||
in:fade={{ duration: 200 }}
|
||||
out:fade={{ duration: 200 }}
|
||||
>
|
||||
<SidebarNavigationActions
|
||||
isExpandedMode={innerWidth > 768 ? isExpandedMode : true}
|
||||
class="px-2"
|
||||
bind:isSearchModeActive
|
||||
bind:searchQuery
|
||||
onSearchDeactivated={() => {
|
||||
isSearchModeActive = false;
|
||||
searchQuery = '';
|
||||
}}
|
||||
onSearchClick={() => {
|
||||
isExpandedMode = true;
|
||||
isSearchModeActive = true;
|
||||
}}
|
||||
onNewChat={() => {
|
||||
if (isMobile.current) {
|
||||
scheduleMobileCollapse();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
{#if isExpandedMode || isOnMobile}
|
||||
<SidebarNavigationConversationList
|
||||
class="px-2"
|
||||
{filteredConversations}
|
||||
{currentChatId}
|
||||
{isSearchModeActive}
|
||||
{searchQuery}
|
||||
onSelect={selectConversation}
|
||||
onEdit={handleEditConversation}
|
||||
onDelete={handleDeleteConversation}
|
||||
onStop={handleStopGeneration}
|
||||
/>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</aside>
|
||||
{/if}
|
||||
|
||||
<style>
|
||||
aside {
|
||||
@media (max-width: 768px) {
|
||||
--size: 1.125rem;
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Input
|
||||
class="text-foreground"
|
||||
placeholder="Enter a new name"
|
||||
type="text"
|
||||
bind:value={editedName}
|
||||
/>
|
||||
</DialogConfirmation>
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
aside {
|
||||
&:not(.is-expanded) {
|
||||
pointer-events: none;
|
||||
}
|
||||
}
|
||||
|
||||
aside.is-expanded::before {
|
||||
content: '';
|
||||
position: fixed;
|
||||
top: -0.5rem;
|
||||
bottom: -0.25rem;
|
||||
left: -0.5rem;
|
||||
right: -0.5rem;
|
||||
z-index: -1;
|
||||
background: var(--background);
|
||||
backdrop-filter: blur(1rem);
|
||||
pointer-events: none;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
||||
+157
-57
@@ -1,39 +1,86 @@
|
||||
<script lang="ts">
|
||||
import { KeyboardShortcutInfo } from '$lib/components/app';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import type { Component } from 'svelte';
|
||||
import { SearchInput } from '$lib/components/app';
|
||||
import { goto } from '$app/navigation';
|
||||
import { page } from '$app/state';
|
||||
import { SIDEBAR_ACTIONS_ITEMS } from '$lib/constants/ui';
|
||||
import { Search } from '@lucide/svelte';
|
||||
import { ActionIcon, KeyboardShortcutInfo, SearchInput } from '$lib/components/app';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import {
|
||||
ICON_STRIP_TRANSITION_DURATION,
|
||||
ICON_STRIP_TRANSITION_DELAY_MULTIPLIER,
|
||||
ROUTES,
|
||||
SIDEBAR_ACTIONS_ITEMS
|
||||
} from '$lib/constants';
|
||||
import { isMobile } from '$lib/stores/viewport.svelte';
|
||||
import { TooltipSide } from '$lib/enums';
|
||||
import { fade } from 'svelte/transition';
|
||||
import { circIn } from 'svelte/easing';
|
||||
import { onMount } from 'svelte';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
handleMobileSidebarItemClick: () => void;
|
||||
class: string;
|
||||
isExpandedMode: boolean;
|
||||
isSearchModeActive: boolean;
|
||||
searchQuery: string;
|
||||
isCancelAlwaysVisible?: boolean;
|
||||
onSearchDeactivated?: () => void;
|
||||
onSearchClick?: () => void;
|
||||
onNewChat?: () => void;
|
||||
}
|
||||
|
||||
let {
|
||||
handleMobileSidebarItemClick,
|
||||
isSearchModeActive = $bindable(),
|
||||
searchQuery = $bindable(),
|
||||
isCancelAlwaysVisible = false,
|
||||
onSearchDeactivated
|
||||
class: className,
|
||||
isExpandedMode = false,
|
||||
isSearchModeActive = $bindable(false),
|
||||
searchQuery = $bindable(''),
|
||||
onSearchDeactivated,
|
||||
onSearchClick,
|
||||
onNewChat
|
||||
}: Props = $props();
|
||||
|
||||
let initialized = $state(false);
|
||||
let showIcons = $state(false);
|
||||
let searchInputRef = $state<HTMLInputElement | null>(null);
|
||||
|
||||
const isOnMobile = $derived(isMobile.current);
|
||||
|
||||
$effect(() => {
|
||||
if (isSearchModeActive && searchInputRef) {
|
||||
searchInputRef.focus();
|
||||
}
|
||||
});
|
||||
|
||||
onMount(() => {
|
||||
showIcons = true;
|
||||
|
||||
setTimeout(() => {
|
||||
initialized = true;
|
||||
}, ICON_STRIP_TRANSITION_DELAY_MULTIPLIER * SIDEBAR_ACTIONS_ITEMS.length);
|
||||
});
|
||||
|
||||
function handleSearchModeDeactivate() {
|
||||
isSearchModeActive = false;
|
||||
searchQuery = '';
|
||||
onSearchDeactivated?.();
|
||||
}
|
||||
|
||||
export function activateSearch() {
|
||||
isSearchModeActive = true;
|
||||
// Focus after Svelte renders the input
|
||||
queueMicrotask(() => searchInputRef?.focus());
|
||||
function isItemActive(item: {
|
||||
activeRouteId?: string;
|
||||
activeRoutePrefix?: string;
|
||||
activeUrlIncludes?: string;
|
||||
}): boolean {
|
||||
if (item.activeRouteId) {
|
||||
return page.route.id === item.activeRouteId;
|
||||
}
|
||||
|
||||
if (item.activeRoutePrefix) {
|
||||
return !!page.route.id?.startsWith(item.activeRoutePrefix);
|
||||
}
|
||||
|
||||
if (item.activeUrlIncludes) {
|
||||
return page.url?.hash?.includes(item.activeUrlIncludes) ?? false;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -41,56 +88,109 @@
|
||||
<IconComponent class="h-4 w-4" />
|
||||
{/snippet}
|
||||
|
||||
<div class="my-1 space-y-1">
|
||||
{#if isSearchModeActive}
|
||||
{#if isSearchModeActive}
|
||||
<div class="px-4 my-2">
|
||||
<SearchInput
|
||||
bind:value={searchQuery}
|
||||
bind:ref={searchInputRef}
|
||||
onClose={handleSearchModeDeactivate}
|
||||
onKeyDown={(e) => e.key === 'Escape' && handleSearchModeDeactivate()}
|
||||
placeholder="Search conversations..."
|
||||
{isCancelAlwaysVisible}
|
||||
/>
|
||||
{:else}
|
||||
{#each SIDEBAR_ACTIONS_ITEMS as item (item.route)}
|
||||
{#if !item.route}
|
||||
<Button
|
||||
class="w-full justify-between px-2 backdrop-blur-none! hover:[&>kbd]:opacity-100"
|
||||
onclick={activateSearch}
|
||||
variant="ghost"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
{@render itemIcon(item.icon)}
|
||||
</div>
|
||||
{:else if isExpandedMode || isOnMobile}
|
||||
<div
|
||||
class="{className} flex flex-col gap-5 md:gap-1 mt-2 md:mt-0 {!isExpandedMode && isOnMobile
|
||||
? 'hidden pointer-events-none'
|
||||
: ''}"
|
||||
>
|
||||
{#each SIDEBAR_ACTIONS_ITEMS as item, i (item.tooltip)}
|
||||
{@const isActive = isItemActive(item)}
|
||||
{@const isSearchOnMobile = item.icon === Search && isMobile.current}
|
||||
{@const itemHref = isSearchOnMobile ? ROUTES.SEARCH : item.route}
|
||||
{@const itemOnClick = item.route
|
||||
? () => {
|
||||
onNewChat?.();
|
||||
goto(item.route!);
|
||||
}
|
||||
: isSearchOnMobile
|
||||
? undefined
|
||||
: onSearchClick}
|
||||
{@const itemTransition = {
|
||||
duration: ICON_STRIP_TRANSITION_DURATION,
|
||||
delay: !initialized
|
||||
? ICON_STRIP_TRANSITION_DELAY_MULTIPLIER + i * ICON_STRIP_TRANSITION_DELAY_MULTIPLIER
|
||||
: 0,
|
||||
easing: circIn
|
||||
}}
|
||||
|
||||
{item.tooltip}
|
||||
</div>
|
||||
{#if showIcons}
|
||||
<div transition:fade={itemTransition}>
|
||||
<Button
|
||||
class="w-full min-w-9 justify-between px-2 backdrop-blur-none! hover:[&>kbd]:opacity-100 {isActive
|
||||
? 'bg-accent text-accent-foreground'
|
||||
: ''}"
|
||||
href={itemHref}
|
||||
onclick={itemOnClick}
|
||||
variant="ghost"
|
||||
size="default"
|
||||
>
|
||||
<span class="flex min-w-0 items-center px-0.5 gap-2">
|
||||
{@render itemIcon(item.icon)}
|
||||
|
||||
{#if item.keys}
|
||||
<KeyboardShortcutInfo keys={item.keys} />
|
||||
{/if}
|
||||
</Button>
|
||||
{:else}
|
||||
<Button
|
||||
class="w-full justify-between px-2 backdrop-blur-none! hover:[&>kbd]:opacity-100 {(item.activeRouteId &&
|
||||
page.route.id === item.activeRouteId) ||
|
||||
(item.activeRoutePrefix && page.route.id?.startsWith(item.activeRoutePrefix))
|
||||
? 'bg-accent text-accent-foreground'
|
||||
: ''}"
|
||||
href={item.route}
|
||||
onclick={handleMobileSidebarItemClick}
|
||||
variant="ghost"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
{@render itemIcon(item.icon)}
|
||||
{#if showIcons}
|
||||
<span
|
||||
in:fade={{ duration: 150, easing: circIn, delay: 50 }}
|
||||
out:fade={{ duration: 100 }}
|
||||
class="min-w-0 truncate">{item.tooltip}</span
|
||||
>
|
||||
{/if}
|
||||
</span>
|
||||
|
||||
{item.tooltip}
|
||||
</div>
|
||||
|
||||
{#if item.keys}
|
||||
<KeyboardShortcutInfo keys={item.keys} />
|
||||
{/if}
|
||||
</Button>
|
||||
{#if item.keys}
|
||||
<KeyboardShortcutInfo keys={item.keys} />
|
||||
{/if}
|
||||
</Button>
|
||||
</div>
|
||||
{/if}
|
||||
{/each}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="{className} flex-col gap-1 hidden md:flex">
|
||||
{#each SIDEBAR_ACTIONS_ITEMS as item, i (item.tooltip)}
|
||||
{@const isActive = isItemActive(item)}
|
||||
{@const isSearchOnMobile = item.icon === Search && isMobile.current}
|
||||
{@const itemOnClick = item.route
|
||||
? () => {
|
||||
onNewChat?.();
|
||||
goto(item.route!);
|
||||
}
|
||||
: isSearchOnMobile
|
||||
? undefined
|
||||
: onSearchClick}
|
||||
{@const itemTransition = {
|
||||
duration: ICON_STRIP_TRANSITION_DURATION,
|
||||
delay: !initialized
|
||||
? ICON_STRIP_TRANSITION_DELAY_MULTIPLIER + i * ICON_STRIP_TRANSITION_DELAY_MULTIPLIER
|
||||
: 0,
|
||||
easing: circIn
|
||||
}}
|
||||
|
||||
{#if showIcons}
|
||||
<div transition:fade={itemTransition}>
|
||||
<ActionIcon
|
||||
icon={item.icon}
|
||||
tooltip={item.tooltip}
|
||||
tooltipSide={TooltipSide.RIGHT}
|
||||
size="lg"
|
||||
iconSize="h-4 w-4"
|
||||
class="h-9 w-9 rounded-full hover:bg-accent! {isActive
|
||||
? 'bg-accent text-accent-foreground'
|
||||
: ''}"
|
||||
onclick={itemOnClick}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
+135
@@ -0,0 +1,135 @@
|
||||
<script lang="ts">
|
||||
import { Pin } from '@lucide/svelte';
|
||||
import { buildConversationTree } from '$lib/stores/conversations.svelte';
|
||||
import SidebarNavigationConversationItem from './SidebarNavigationConversationItem.svelte';
|
||||
import SidebarNavigationSearchResults from './SidebarNavigationSearchResults.svelte';
|
||||
|
||||
interface Props {
|
||||
class: string;
|
||||
filteredConversations: DatabaseConversation[];
|
||||
currentChatId: string | undefined;
|
||||
isSearchModeActive: boolean;
|
||||
searchQuery: string;
|
||||
onSelect: (id: string) => void;
|
||||
onEdit: (id: string) => void;
|
||||
onDelete: (id: string) => void;
|
||||
onStop: (id: string) => void;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className,
|
||||
filteredConversations,
|
||||
currentChatId,
|
||||
isSearchModeActive,
|
||||
searchQuery,
|
||||
onSelect,
|
||||
onEdit,
|
||||
onDelete,
|
||||
onStop
|
||||
}: Props = $props();
|
||||
|
||||
let conversationTree = $derived(buildConversationTree(filteredConversations));
|
||||
|
||||
let pinnedConversations = $derived(
|
||||
conversationTree.filter(({ conversation }) => conversation.pinned)
|
||||
);
|
||||
|
||||
let unpinnedConversations = $derived(
|
||||
conversationTree.filter(({ conversation }) => !conversation.pinned)
|
||||
);
|
||||
|
||||
const recentEmptyMessage = $derived(
|
||||
searchQuery.length > 0 ? 'No results found' : 'No conversations yet'
|
||||
);
|
||||
</script>
|
||||
|
||||
{#if isSearchModeActive}
|
||||
<SidebarNavigationSearchResults
|
||||
class={className}
|
||||
{searchQuery}
|
||||
{filteredConversations}
|
||||
{currentChatId}
|
||||
{onSelect}
|
||||
{onEdit}
|
||||
{onDelete}
|
||||
{onStop}
|
||||
/>
|
||||
{:else}
|
||||
{#if pinnedConversations.length > 0}
|
||||
<div class="py-2 flex whitespace-nowrap {className}">
|
||||
<div
|
||||
class="text-muted-foreground inline-flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium gap-1"
|
||||
>
|
||||
<Pin class="h-3.5 w-3.5" />
|
||||
|
||||
<span>Pinned</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<ul class="flex w-full min-w-0 flex-col gap-4 md:gap-1 {className}">
|
||||
{#each pinnedConversations as { conversation, depth } (conversation.id)}
|
||||
<li class="group/item relative mb-1 p-0">
|
||||
<SidebarNavigationConversationItem
|
||||
conversation={{
|
||||
id: conversation.id,
|
||||
name: conversation.name,
|
||||
lastModified: conversation.lastModified,
|
||||
currNode: conversation.currNode,
|
||||
forkedFromConversationId: conversation.forkedFromConversationId,
|
||||
pinned: conversation.pinned
|
||||
}}
|
||||
{depth}
|
||||
isActive={currentChatId === conversation.id}
|
||||
{onSelect}
|
||||
{onEdit}
|
||||
{onDelete}
|
||||
{onStop}
|
||||
/>
|
||||
</li>
|
||||
{/each}
|
||||
</ul>
|
||||
{/if}
|
||||
|
||||
<div class="mt-2 flex min-h-0 flex-1 flex-col gap-4 md:gap-2 whitespace-nowrap {className}">
|
||||
{#if filteredConversations.length > 0}
|
||||
<div
|
||||
class="text-muted-foreground flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium"
|
||||
>
|
||||
Recent conversations
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="min-h-0 flex-1 md:overflow-y-auto">
|
||||
<ul class="flex w-full min-w-0 flex-col gap-4 md:gap-1">
|
||||
{#each unpinnedConversations as { conversation, depth } (conversation.id)}
|
||||
<li class="group/item relative mb-1 p-0">
|
||||
<SidebarNavigationConversationItem
|
||||
conversation={{
|
||||
id: conversation.id,
|
||||
name: conversation.name,
|
||||
lastModified: conversation.lastModified,
|
||||
currNode: conversation.currNode,
|
||||
forkedFromConversationId: conversation.forkedFromConversationId,
|
||||
pinned: conversation.pinned
|
||||
}}
|
||||
{depth}
|
||||
isActive={currentChatId === conversation.id}
|
||||
{onSelect}
|
||||
{onEdit}
|
||||
{onDelete}
|
||||
{onStop}
|
||||
/>
|
||||
</li>
|
||||
{/each}
|
||||
|
||||
{#if unpinnedConversations.length === 0}
|
||||
<li class="px-2 py-4 text-center">
|
||||
<p class="mb-4 p-4 text-sm text-muted-foreground">
|
||||
{recentEmptyMessage}
|
||||
</p>
|
||||
</li>
|
||||
{/if}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
+3
-1
@@ -16,4 +16,6 @@
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
<SearchInput bind:value {placeholder} {onInput} class="mb-4 {className}" />
|
||||
<div class="mb-4 px-2 {className}">
|
||||
<SearchInput bind:value {placeholder} {onInput} />
|
||||
</div>
|
||||
|
||||
+76
@@ -0,0 +1,76 @@
|
||||
<script lang="ts">
|
||||
import { buildConversationTree } from '$lib/stores/conversations.svelte';
|
||||
import SidebarNavigationConversationItem from './SidebarNavigationConversationItem.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
searchQuery: string;
|
||||
filteredConversations: DatabaseConversation[];
|
||||
currentChatId: string | undefined;
|
||||
onSelect: (id: string) => void;
|
||||
onEdit: (id: string) => void;
|
||||
onDelete: (id: string) => void;
|
||||
onStop: (id: string) => void;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
searchQuery,
|
||||
filteredConversations,
|
||||
currentChatId,
|
||||
onSelect,
|
||||
onEdit,
|
||||
onDelete,
|
||||
onStop
|
||||
}: Props = $props();
|
||||
|
||||
let tree = $derived(buildConversationTree(filteredConversations));
|
||||
|
||||
const hasQuery = $derived(searchQuery.trim().length > 0);
|
||||
const showHeader = $derived(hasQuery && filteredConversations.length > 0);
|
||||
|
||||
const emptyMessage = $derived(hasQuery ? 'No results found' : 'Start typing to see results');
|
||||
</script>
|
||||
|
||||
<div class="flex min-h-0 flex-1 flex-col gap-2 whitespace-nowrap {className}">
|
||||
{#if showHeader}
|
||||
<div
|
||||
class="text-muted-foreground flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium"
|
||||
>
|
||||
Search results
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="min-h-0 flex-1 overflow-y-auto">
|
||||
<ul class="flex w-full min-w-0 flex-col gap-1">
|
||||
{#each tree as { conversation, depth } (conversation.id)}
|
||||
<li class="group/item relative mb-1 p-0">
|
||||
<SidebarNavigationConversationItem
|
||||
conversation={{
|
||||
id: conversation.id,
|
||||
name: conversation.name,
|
||||
lastModified: conversation.lastModified,
|
||||
currNode: conversation.currNode,
|
||||
forkedFromConversationId: conversation.forkedFromConversationId,
|
||||
pinned: conversation.pinned
|
||||
}}
|
||||
{depth}
|
||||
isActive={currentChatId === conversation.id}
|
||||
{onSelect}
|
||||
{onEdit}
|
||||
{onDelete}
|
||||
{onStop}
|
||||
/>
|
||||
</li>
|
||||
{/each}
|
||||
|
||||
{#if tree.length === 0}
|
||||
<li class="px-2 py-4 text-center">
|
||||
<p class="mb-4 p-4 text-sm text-muted-foreground">
|
||||
{emptyMessage}
|
||||
</p>
|
||||
</li>
|
||||
{/if}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
@@ -63,15 +63,6 @@ export { default as DropdownMenuSearchable } from './DropdownMenuSearchable.svel
|
||||
* ```
|
||||
*/
|
||||
export { default as DropdownMenuActions } from './DropdownMenuActions.svelte';
|
||||
|
||||
/**
|
||||
* **DesktopIconStrip** - Fixed icon strip for desktop sidebar
|
||||
*
|
||||
* Vertical icon strip shown on desktop when the sidebar is collapsed.
|
||||
* Contains navigation shortcuts for new chat, search, MCP, import/export, and settings.
|
||||
*/
|
||||
export { default as DesktopIconStrip } from './DesktopIconStrip.svelte';
|
||||
|
||||
/**
|
||||
* **SidebarNavigation** - Sidebar with actions menu and conversation list
|
||||
*
|
||||
@@ -115,13 +106,6 @@ export { default as DesktopIconStrip } from './DesktopIconStrip.svelte';
|
||||
*/
|
||||
export { default as SidebarNavigation } from './SidebarNavigation/SidebarNavigation.svelte';
|
||||
|
||||
/**
|
||||
* Action buttons for sidebar header. Contains new chat button, settings button,
|
||||
* and delete all conversations button. Manages dialog states for settings and
|
||||
* delete confirmation.
|
||||
*/
|
||||
export { default as SidebarNavigationActions } from './SidebarNavigation/SidebarNavigationActions.svelte';
|
||||
|
||||
/**
|
||||
* Single conversation item in sidebar. Displays conversation title (truncated),
|
||||
* last message preview, and timestamp. Shows context menu on right-click with
|
||||
@@ -130,6 +114,58 @@ export { default as SidebarNavigationActions } from './SidebarNavigation/Sidebar
|
||||
*/
|
||||
export { default as SidebarNavigationConversationItem } from './SidebarNavigation/SidebarNavigationConversationItem.svelte';
|
||||
|
||||
/**
|
||||
* **SidebarNavigationConversationList** - Grouped conversation list
|
||||
*
|
||||
* Pure-presentational list of conversations. Splits items into a Pinned
|
||||
* section (when not in search mode) and a Recent Conversations / Search
|
||||
* Results section with the unpinned items. Item selection, edit, delete,
|
||||
* and stop-generation are delegated to the caller via callbacks.
|
||||
*
|
||||
* @example
|
||||
* ```svelte
|
||||
* <SidebarNavigationConversationList
|
||||
* {filteredConversations}
|
||||
* {currentChatId}
|
||||
* {isSearchModeActive}
|
||||
* {searchQuery}
|
||||
* onSelect={...}
|
||||
* onEdit={...}
|
||||
* onDelete={...}
|
||||
* onStop={...}
|
||||
* />
|
||||
* ```
|
||||
*/
|
||||
export { default as SidebarNavigationConversationList } from './SidebarNavigation/SidebarNavigationConversationList.svelte';
|
||||
export { default as SidebarNavigationActions } from './SidebarNavigation/SidebarNavigationActions.svelte';
|
||||
|
||||
/**
|
||||
* **SidebarNavigationSearchResults** - Filtered conversation list for search.
|
||||
*
|
||||
* Pure-presentational rendering of the search-mode subtree: "Search results"
|
||||
* header, the matching items rendered through {@link SidebarNavigationConversationItem},
|
||||
* and contextual empty-state messages. Used both inline inside
|
||||
* {@link SidebarNavigationConversationList} (when search mode is active in the
|
||||
* sidebar) and as the body of the mobile `/search` route.
|
||||
*
|
||||
* The caller is expected to provide an already-filtered list via
|
||||
* `filteredConversations` and a `searchQuery` for the empty-state messages.
|
||||
*
|
||||
* @example
|
||||
* ```svelte
|
||||
* <SidebarNavigationSearchResults
|
||||
* {searchQuery}
|
||||
* {filteredConversations}
|
||||
* {currentChatId}
|
||||
* onSelect={...}
|
||||
* onEdit={...}
|
||||
* onDelete={...}
|
||||
* onStop={...}
|
||||
* />
|
||||
* ```
|
||||
*/
|
||||
export { default as SidebarNavigationSearchResults } from './SidebarNavigation/SidebarNavigationSearchResults.svelte';
|
||||
|
||||
/**
|
||||
* Search input for filtering conversations in sidebar. Filters conversation
|
||||
* list by title as user types. Shows clear button when query is not empty.
|
||||
|
||||
@@ -126,10 +126,7 @@
|
||||
});
|
||||
</script>
|
||||
|
||||
<div
|
||||
class="mx-auto flex h-full max-h-[100dvh] w-full flex-col overflow-y-auto md:pl-8"
|
||||
in:fade={{ duration: 150 }}
|
||||
>
|
||||
<div class="mx-auto flex h-full w-full flex-col md:pl-8" in:fade={{ duration: 150 }}>
|
||||
<div class="flex flex-1 flex-col gap-4 md:flex-row">
|
||||
<SettingsChatDesktopSidebar
|
||||
sections={SETTINGS_CHAT_SECTIONS}
|
||||
|
||||
@@ -12,11 +12,13 @@
|
||||
let { sections, isActive, getHref, onSectionChange }: Props = $props();
|
||||
</script>
|
||||
|
||||
<div class="sticky top-0 hidden w-64 flex-col self-start bg-background pt-10 pb-4 md:flex">
|
||||
<div class="flex items-center gap-2 pb-10">
|
||||
<Settings class="h-6 w-6" />
|
||||
<h1 class="text-2xl font-semibold">Settings</h1>
|
||||
<div class="sticky top-2 hidden w-64 flex-col self-start bg-background py-4 md:flex gap-6">
|
||||
<div class="flex items-center gap-2 py-2">
|
||||
<Settings class="h-5 w-5 md:h-6 md:w-6" />
|
||||
|
||||
<h1 class="text-xl font-semibold md:text-2xl">Settings</h1>
|
||||
</div>
|
||||
|
||||
<nav class="space-y-1">
|
||||
{#each sections as section (section.title)}
|
||||
{#if getHref}
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
<script lang="ts">
|
||||
import { Plus } from '@lucide/svelte';
|
||||
import { X, Plus } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { mcpStore } from '$lib/stores/mcp.svelte';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { toolsStore } from '$lib/stores/tools.svelte';
|
||||
import { McpServerCard, McpServerCardSkeleton } from '$lib/components/app/mcp';
|
||||
import { ActionIcon, McpServerCard, McpServerCardSkeleton } from '$lib/components/app';
|
||||
import { DialogMcpServerAddNew } from '$lib/components/app/dialogs';
|
||||
import { HealthCheckStatus } from '$lib/enums';
|
||||
import { ROUTES } from '$lib/constants';
|
||||
import { fade } from 'svelte/transition';
|
||||
import { onMount } from 'svelte';
|
||||
import McpLogo from '../mcp/McpLogo.svelte';
|
||||
import { browser } from '$app/environment';
|
||||
import { page } from '$app/state';
|
||||
import { replaceState } from '$app/navigation';
|
||||
import { goto, replaceState } from '$app/navigation';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -24,6 +26,24 @@
|
||||
let initialLoadComplete = $state(false);
|
||||
let isAddingServer = $state(false);
|
||||
|
||||
let previousRouteId = $state<string | null>(null);
|
||||
|
||||
$effect(() => {
|
||||
const currentId = page.route.id;
|
||||
return () => {
|
||||
previousRouteId = currentId;
|
||||
};
|
||||
});
|
||||
|
||||
function handleClose() {
|
||||
const prevIsMcpServers = previousRouteId === '/mcp-servers';
|
||||
if (browser && window.history.length > 1 && !prevIsMcpServers) {
|
||||
history.back();
|
||||
} else {
|
||||
goto(ROUTES.START);
|
||||
}
|
||||
}
|
||||
|
||||
onMount(() => {
|
||||
if (page.url.searchParams.has('add')) {
|
||||
isAddingServer = true;
|
||||
@@ -54,15 +74,26 @@
|
||||
});
|
||||
</script>
|
||||
|
||||
<div in:fade={{ duration: 150 }} class="h-full max-h-[100dvh] overflow-y-auto">
|
||||
<div class="flex items-center gap-2 p-4 md:absolute md:top-8 md:left-8 md:px-0 md:py-2">
|
||||
<McpLogo class="h-5 w-5 md:h-6 md:w-6" />
|
||||
|
||||
<h1 class="text-xl font-semibold md:text-2xl">MCP Servers</h1>
|
||||
<div in:fade={{ duration: 150 }}>
|
||||
<div class="fixed top-4.5 right-4 z-50 md:hidden">
|
||||
<ActionIcon icon={X} tooltip="Close" onclick={handleClose} />
|
||||
</div>
|
||||
|
||||
<div class="sticky top-0 z-10 mt-4 flex items-start gap-4 p-4 md:justify-end md:px-8">
|
||||
<Button variant="outline" size="sm" class="shrink-0" onclick={() => (isAddingServer = true)}>
|
||||
<div
|
||||
class="sticky top-0 z-10 mt-4 mb-2 flex items-start gap-4 md:p-4 p-0 px-4 md:justify-between md:px-8"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
<McpLogo class="h-5 w-5 md:h-6 md:w-6" />
|
||||
|
||||
<h1 class="text-lg font-semibold md:text-2xl">MCP Servers</h1>
|
||||
</div>
|
||||
|
||||
<Button
|
||||
variant="outline"
|
||||
size="lg"
|
||||
class="shrink-0 fixed md:static bottom-6 right-6"
|
||||
onclick={() => (isAddingServer = true)}
|
||||
>
|
||||
<Plus class="h-4 w-4" />
|
||||
|
||||
Add New Server
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
size: {
|
||||
default: 'h-9 px-4 py-2 has-[>svg]:px-3',
|
||||
sm: 'h-8 gap-1.5 rounded-md px-3 has-[>svg]:px-2.5',
|
||||
lg: 'h-10 rounded-md px-6 has-[>svg]:px-4',
|
||||
lg: 'h-10 rounded-lg px-6 has-[>svg]:px-4',
|
||||
'icon-lg': 'size-10',
|
||||
icon: 'size-9',
|
||||
'icon-sm': 'size-5 rounded-sm'
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
export const SIDEBAR_COOKIE_NAME = 'sidebar:state';
|
||||
export const SIDEBAR_COOKIE_MAX_AGE = 60 * 60 * 24 * 7;
|
||||
export const SIDEBAR_MIN_WIDTH = '18rem';
|
||||
export const SIDEBAR_MAX_WIDTH = '32rem';
|
||||
export const SIDEBAR_WIDTH_MOBILE = '18rem';
|
||||
export const SIDEBAR_WIDTH_ICON = '3rem';
|
||||
export const SIDEBAR_KEYBOARD_SHORTCUT = 'b';
|
||||
@@ -1,79 +0,0 @@
|
||||
import { isMobile } from '$lib/stores/viewport.svelte.js';
|
||||
import { getContext, setContext } from 'svelte';
|
||||
import { SIDEBAR_KEYBOARD_SHORTCUT, SIDEBAR_MIN_WIDTH } from './constants.js';
|
||||
|
||||
type Getter<T> = () => T;
|
||||
|
||||
export type SidebarStateProps = {
|
||||
/**
|
||||
* A getter function that returns the current open state of the sidebar.
|
||||
* We use a getter function here to support `bind:open` on the `Sidebar.Provider`
|
||||
* component.
|
||||
*/
|
||||
open: Getter<boolean>;
|
||||
|
||||
/**
|
||||
* A function that sets the open state of the sidebar. To support `bind:open`, we need
|
||||
* a source of truth for changing the open state to ensure it will be synced throughout
|
||||
* the sub-components and any `bind:` references.
|
||||
*/
|
||||
setOpen: (open: boolean) => void;
|
||||
};
|
||||
|
||||
class SidebarState {
|
||||
readonly props: SidebarStateProps;
|
||||
open = $derived.by(() => this.props.open());
|
||||
openMobile = $state(false);
|
||||
sidebarWidth = $state(SIDEBAR_MIN_WIDTH);
|
||||
isResizing = $state(false);
|
||||
setOpen: SidebarStateProps['setOpen'];
|
||||
state = $derived.by(() => (this.open ? 'expanded' : 'collapsed'));
|
||||
|
||||
constructor(props: SidebarStateProps) {
|
||||
this.setOpen = props.setOpen;
|
||||
this.props = props;
|
||||
}
|
||||
|
||||
// Convenience getter for checking if the sidebar is mobile
|
||||
// without this, we would need to use `sidebar.isMobile.current` everywhere
|
||||
get isMobile() {
|
||||
return isMobile.current;
|
||||
}
|
||||
|
||||
// Event handler to apply to the `<svelte:window>`
|
||||
handleShortcutKeydown = (e: KeyboardEvent) => {
|
||||
if (e.key === SIDEBAR_KEYBOARD_SHORTCUT && (e.metaKey || e.ctrlKey)) {
|
||||
e.preventDefault();
|
||||
this.toggle();
|
||||
}
|
||||
};
|
||||
|
||||
setOpenMobile = (value: boolean) => {
|
||||
this.openMobile = value;
|
||||
};
|
||||
|
||||
toggle = () => {
|
||||
this.setOpen(!this.open);
|
||||
};
|
||||
}
|
||||
|
||||
const SYMBOL_KEY = 'scn-sidebar';
|
||||
|
||||
/**
|
||||
* Instantiates a new `SidebarState` instance and sets it in the context.
|
||||
*
|
||||
* @param props The constructor props for the `SidebarState` class.
|
||||
* @returns The `SidebarState` instance.
|
||||
*/
|
||||
export function setSidebar(props: SidebarStateProps): SidebarState {
|
||||
return setContext(Symbol.for(SYMBOL_KEY), new SidebarState(props));
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the `SidebarState` instance from the context. This is a class instance,
|
||||
* so you cannot destructure it.
|
||||
* @returns The `SidebarState` instance.
|
||||
*/
|
||||
export function useSidebar(): SidebarState {
|
||||
return getContext(Symbol.for(SYMBOL_KEY));
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
import { useSidebar } from './context.svelte.js';
|
||||
import Content from './sidebar-content.svelte';
|
||||
import Footer from './sidebar-footer.svelte';
|
||||
import GroupAction from './sidebar-group-action.svelte';
|
||||
import GroupContent from './sidebar-group-content.svelte';
|
||||
import GroupLabel from './sidebar-group-label.svelte';
|
||||
import Group from './sidebar-group.svelte';
|
||||
import Header from './sidebar-header.svelte';
|
||||
import Input from './sidebar-input.svelte';
|
||||
import Inset from './sidebar-inset.svelte';
|
||||
import MenuAction from './sidebar-menu-action.svelte';
|
||||
import MenuBadge from './sidebar-menu-badge.svelte';
|
||||
import MenuButton from './sidebar-menu-button.svelte';
|
||||
import MenuItem from './sidebar-menu-item.svelte';
|
||||
import MenuSkeleton from './sidebar-menu-skeleton.svelte';
|
||||
import MenuSubButton from './sidebar-menu-sub-button.svelte';
|
||||
import MenuSubItem from './sidebar-menu-sub-item.svelte';
|
||||
import MenuSub from './sidebar-menu-sub.svelte';
|
||||
import Menu from './sidebar-menu.svelte';
|
||||
import Provider from './sidebar-provider.svelte';
|
||||
import Rail from './sidebar-rail.svelte';
|
||||
import Separator from './sidebar-separator.svelte';
|
||||
import Trigger from './sidebar-trigger.svelte';
|
||||
import Root from './sidebar.svelte';
|
||||
|
||||
export {
|
||||
Content,
|
||||
Footer,
|
||||
Group,
|
||||
GroupAction,
|
||||
GroupContent,
|
||||
GroupLabel,
|
||||
Header,
|
||||
Input,
|
||||
Inset,
|
||||
Menu,
|
||||
MenuAction,
|
||||
MenuBadge,
|
||||
MenuButton,
|
||||
MenuItem,
|
||||
MenuSkeleton,
|
||||
MenuSub,
|
||||
MenuSubButton,
|
||||
MenuSubItem,
|
||||
Provider,
|
||||
Rail,
|
||||
Root,
|
||||
Separator,
|
||||
//
|
||||
Root as Sidebar,
|
||||
Content as SidebarContent,
|
||||
Footer as SidebarFooter,
|
||||
Group as SidebarGroup,
|
||||
GroupAction as SidebarGroupAction,
|
||||
GroupContent as SidebarGroupContent,
|
||||
GroupLabel as SidebarGroupLabel,
|
||||
Header as SidebarHeader,
|
||||
Input as SidebarInput,
|
||||
Inset as SidebarInset,
|
||||
Menu as SidebarMenu,
|
||||
MenuAction as SidebarMenuAction,
|
||||
MenuBadge as SidebarMenuBadge,
|
||||
MenuButton as SidebarMenuButton,
|
||||
MenuItem as SidebarMenuItem,
|
||||
MenuSkeleton as SidebarMenuSkeleton,
|
||||
MenuSub as SidebarMenuSub,
|
||||
MenuSubButton as SidebarMenuSubButton,
|
||||
MenuSubItem as SidebarMenuSubItem,
|
||||
Provider as SidebarProvider,
|
||||
Rail as SidebarRail,
|
||||
Separator as SidebarSeparator,
|
||||
Trigger as SidebarTrigger,
|
||||
Trigger,
|
||||
useSidebar
|
||||
};
|
||||
@@ -1,24 +0,0 @@
|
||||
<script lang="ts">
|
||||
import type { HTMLAttributes } from 'svelte/elements';
|
||||
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
children,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLAttributes<HTMLElement>> = $props();
|
||||
</script>
|
||||
|
||||
<div
|
||||
bind:this={ref}
|
||||
data-slot="sidebar-content"
|
||||
data-sidebar="content"
|
||||
class={cn(
|
||||
'flex min-h-0 flex-1 flex-col gap-2 overflow-auto group-data-[collapsible=icon]:overflow-hidden',
|
||||
className
|
||||
)}
|
||||
{...restProps}
|
||||
>
|
||||
{@render children?.()}
|
||||
</div>
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user