Compare commits

...

20 Commits

Author SHA1 Message Date
matteo b55dcdef5d server: save generated text for the /slots endpoint (for LLAMA_SERVER_SLOTS_DEBUG=1) (#19622)
* save generated text for the /slots endpoint

* update debug_generated_text only when LLAMA_SERVER_SLOTS_DEBUG > 0

* Apply suggestions from code review

---------

Co-authored-by: Matteo <matteo@matteo>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2026-02-18 18:53:37 +01:00
Xuan-Son Nguyen eeef3cfced model: support GLM-OCR (#19677)
* model: support GLM-OCR

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-18 17:51:40 +01:00
Maciej Lisowski e99f1083a0 docs: Fix broken links for preparing models in Backends (#19684) 2026-02-18 23:50:23 +08:00
Reese Levine 238856ec8f ggml webgpu: shader library organization (#19530)
* Basic JIT compilation for mul_mat, get_rows, and scale (#17)

* scale jit working

* preliminary working jit for getrows and mulmat, needs refining

* simplified mul_mat preprocessing switch statement

* get_rows fixes, mul_mat refinement

* formatted + last edits

* removed some extraneous prints

* fixed get_rows, fixed workgroup dispatch in mul_mat. no gibberish

* small fix

* some changes, working

* get_rows and mul_mat jit fixed and working

* Update formatting

* formatting

* Add header

---------

Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>

* Start work on all-encompassing shader library

* refactor argmax, set_rows

* Refactor all but flashattention, mat mul

* flashattention and matrix multiplication moved to new format

* clean up preprocessing

* Formatting

* remove duplicate constants

* Split large shaders into multiple static strings

---------

Co-authored-by: neha-ha <137219201+neha-ha@users.noreply.github.com>
2026-02-18 07:51:02 -07:00
Aleksander Grygier ea003229d3 Pre-MCP UI and architecture cleanup (#19689) 2026-02-18 12:02:02 +01:00
Jeff Bolz d0061be838 vulkan: split mul_mat into multiple dispatches to avoid overflow (#19509)
* vulkan: split mul_mat into multiple dispatches to avoid overflow

The batch dimensions can be greater than the max workgroup count limit,
in which case we need to split into multiple dispatches and pass the base
index through a push constant.

Fall back for the less common p021 and nc variants.

* address feedback
2026-02-18 10:47:10 +01:00
Adrien Gallouët a569bda445 common : make small string helpers as inline functions (#19693)
Also use string_view when it make sense and fix some corner cases.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-18 08:03:01 +01:00
shaofeiqi e2f19b320f opencl: refactor expm1 and softplus (#19404)
* opencl: refactor expm1

* opencl: refactor softplus

* opencl: use h for half literals

---------

Co-authored-by: Li He <lih@qti.qualcomm.com>
2026-02-17 14:47:18 -08:00
shaofeiqi 983559d24b opencl: optimize mean and sum_row kernels (#19614)
* opencl: optimize mean and sum_row kernels

* opencl: add comment for max subgroups

* opencl: format

---------

Co-authored-by: Li He <lih@qti.qualcomm.com>
2026-02-17 13:56:09 -08:00
Daniel Bevenius 2b089c7758 model-conversion : add option to print tensor values (#19692)
This commit updates the tensor-info.py script to support the option to
print the first N values of a tensor when displaying its information.

The motivation for this is that it can be useful to inspect some actual
values in addition to the shapes of the tensors.
2026-02-17 20:43:22 +01:00
Aleksander Grygier afa6bfe4f7 Pre-MCP UI and architecture cleanup (#19685)
* webui: extract non-MCP changes from mcp-mvp review split

* webui: extract additional pre-MCP UI and architecture cleanup

* chore: update webui build output
2026-02-17 13:47:45 +01:00
Talha Can Havadar ae2d3f28a8 ggml: ggml-cpu: force-no-lto-for-cpu-feats (#19609)
When LTO enabled in build environments it forces all builds to have LTO
in place. But feature detection logic is fragile, and causing Illegal
instruction errors with lto. This disables LTO for the feature
detection code to prevent cross-module optimization from inlining
architecture-specific instructions into the score function. Without this,
LTO can cause SIGILL when loading backends on older CPUs (e.g., loading
power10 backend on power9 crashes before feature check runs).
2026-02-17 13:22:46 +02:00
Georgi Gerganov ad8207af77 cuda : enable CUDA graphs for MMID 1 <= BS <= 4 (#19645)
* cuda : enable CUDA graphs for MMID BS <= 4

* cont : add stream capture check

Co-authored-by: Oliver Simons <osimons@nvidia.com>

* cont : add MMVQ_MMID_MAX_BATCH_SIZE

---------

Co-authored-by: Oliver Simons <osimons@nvidia.com>
2026-02-17 12:31:49 +02:00
Daniel Bevenius 667b694278 model-conversion : make printing of config values optional (#19681)
* model-conversion : make printing of config values optional

This commit updates run-org-model.py to make the printing of model
configuration values optional.

The motivation for this change is that not all models have these
configuration values defined and those that do not will error when
running this script. With these changes we only print the values if they
exist or a default value.

We could optionally just remove them but it can be useful to see these
values when running the original model.
2026-02-17 10:46:53 +01:00
Sigbjørn Skjæret e48349a49d ci : bump komac version (#19682) 2026-02-17 09:30:31 +01:00
Adrien Gallouët ae46a61e41 build : link ws2_32 as PUBLIC on Windows (#19666)
Signed-off-by: Adrien Gallouët <adrien@gallouet.fr>
2026-02-17 08:37:07 +01:00
Adrien Gallouët 65cede7c70 build : cleanup library linking logic (#19665)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-17 08:36:45 +01:00
DAN™ 05fa625eac convert : add JoyAI-LLM-Flash (#19651)
* convert_hf_to_gguf: add JoyAI-LLM-Flash tokenizer hash mapping to deepseek-v3

* llama-vocab: create a new pre-tokenizer name for joyai-llm.

* add missing vocab type section

* Update convert_hf_to_gguf_update.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-16 22:49:57 +01:00
AesSedai d612901116 perplexity: add proper batching (#19661) 2026-02-16 18:44:44 +02:00
Ivan Chikish cceb1b4e33 common : inline functions (#18639) 2026-02-16 17:52:24 +02:00
110 changed files with 8328 additions and 6997 deletions
+1 -1
View File
@@ -17,7 +17,7 @@ jobs:
- name: Install komac
run: |
cargo binstall komac@2.11.2 -y
cargo binstall komac@2.15.0 -y
- name: Find latest release
id: find_latest_release
+11 -23
View File
@@ -5,7 +5,6 @@ find_package(Threads REQUIRED)
llama_add_compile_flags()
# Build info header
#
if(EXISTS "${PROJECT_SOURCE_DIR}/.git")
set(GIT_DIR "${PROJECT_SOURCE_DIR}/.git")
@@ -110,29 +109,16 @@ if (BUILD_SHARED_LIBS)
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
# TODO: use list(APPEND LLAMA_COMMON_EXTRA_LIBS ...)
set(LLAMA_COMMON_EXTRA_LIBS build_info)
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} cpp-httplib)
target_link_libraries(${TARGET} PRIVATE
build_info
cpp-httplib
)
if (LLAMA_LLGUIDANCE)
include(ExternalProject)
set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source)
set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release)
# Set the correct library file extension based on platform
if (WIN32)
set(LLGUIDANCE_LIB_NAME "llguidance.lib")
# Add Windows-specific libraries
set(LLGUIDANCE_PLATFORM_LIBS
ws2_32 # Windows Sockets API
userenv # For GetUserProfileDirectoryW
ntdll # For NT functions
bcrypt # For BCryptGenRandom
)
else()
set(LLGUIDANCE_LIB_NAME "libllguidance.a")
set(LLGUIDANCE_PLATFORM_LIBS "")
endif()
set(LLGUIDANCE_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}llguidance${CMAKE_STATIC_LIBRARY_SUFFIX}")
ExternalProject_Add(llguidance_ext
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
@@ -154,8 +140,10 @@ if (LLAMA_LLGUIDANCE)
add_dependencies(llguidance llguidance_ext)
target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH})
# Add platform libraries to the main target
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance ${LLGUIDANCE_PLATFORM_LIBS})
endif ()
target_link_libraries(${TARGET} PRIVATE llguidance)
if (WIN32)
target_link_libraries(${TARGET} PRIVATE ws2_32 userenv ntdll bcrypt)
endif()
endif()
target_link_libraries(${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
target_link_libraries(${TARGET} PUBLIC llama Threads::Threads)
-28
View File
@@ -452,34 +452,6 @@ void string_replace_all(std::string & s, const std::string & search, const std::
s = std::move(builder);
}
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
}
bool string_remove_suffix(std::string & str, const std::string_view & suffix) {
bool has_suffix = string_ends_with(str, suffix);
if (has_suffix) {
str = str.substr(0, str.size() - suffix.size());
}
return has_suffix;
}
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
if (!str.empty() && !stop.empty()) {
const char text_last_char = str.back();
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
if (stop[char_index] == text_last_char) {
const auto current_partial = stop.substr(0, char_index + 1);
if (string_ends_with(str, current_partial)) {
return str.size() - char_index - 1;
}
}
}
}
return std::string::npos;
}
std::string regex_escape(const std::string & s) {
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
return std::regex_replace(s, special_chars, "\\$&");
+41 -16
View File
@@ -670,30 +670,55 @@ static std::vector<T> string_split(const std::string & str, char delim) {
}
template<>
std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
inline std::vector<std::string> string_split<std::string>(const std::string & str, char delim)
{
std::vector<std::string> parts;
size_t begin_pos = 0;
size_t separator_pos = input.find(separator);
while (separator_pos != std::string::npos) {
std::string part = input.substr(begin_pos, separator_pos - begin_pos);
size_t delim_pos = str.find(delim);
while (delim_pos != std::string::npos) {
std::string part = str.substr(begin_pos, delim_pos - begin_pos);
parts.emplace_back(part);
begin_pos = separator_pos + 1;
separator_pos = input.find(separator, begin_pos);
begin_pos = delim_pos + 1;
delim_pos = str.find(delim, begin_pos);
}
parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
parts.emplace_back(str.substr(begin_pos));
return parts;
}
static bool string_starts_with(const std::string & str,
const std::string & prefix) { // While we wait for C++20's std::string::starts_with...
return str.rfind(prefix, 0) == 0;
// remove when moving to c++20
inline bool string_starts_with(std::string_view str, std::string_view prefix) {
return str.size() >= prefix.size() &&
str.compare(0, prefix.size(), prefix) == 0;
}
// While we wait for C++20's std::string::ends_with...
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
bool string_remove_suffix(std::string & str, const std::string_view & suffix);
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
// remove when moving to c++20
inline bool string_ends_with(std::string_view str, std::string_view suffix) {
return str.size() >= suffix.size() &&
str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0;
}
inline bool string_remove_suffix(std::string & str, std::string_view suffix) {
if (string_ends_with(str, suffix)) {
str.resize(str.size() - suffix.size());
return true;
}
return false;
}
inline size_t string_find_partial_stop(std::string_view str, std::string_view stop) {
if (!str.empty() && !stop.empty()) {
const size_t max_len = std::min(str.size(), stop.size());
const char last_char = str.back();
for (size_t len = max_len; len > 0; --len) {
if (stop[len - 1] == last_char) {
if (string_ends_with(str, stop.substr(0, len))) {
return str.size() - len;
}
}
}
}
return std::string::npos;
}
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
void string_process_escapes(std::string & input);
@@ -870,11 +895,11 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_(ch|)exps";
static std::string llm_ffn_exps_block_regex(int idx) {
inline std::string llm_ffn_exps_block_regex(int idx) {
return string_format("blk\\.%d%s", idx, LLM_FFN_EXPS_REGEX);
}
static llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
inline llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
return { LLM_FFN_EXPS_REGEX, ggml_backend_cpu_buffer_type() };
}
+29 -5
View File
@@ -1049,6 +1049,9 @@ class TextModel(ModelBase):
if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902":
# ref: https://huggingface.co/zai-org/GLM-4.5-Air
res = "glm4"
if chkhsh == "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267":
# ref: https://huggingface.co/zai-org/GLM-4.7-Flash
res = "glm4"
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
res = "minerva-7b"
@@ -1082,9 +1085,6 @@ class TextModel(ModelBase):
if chkhsh == "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df":
# ref: https://huggingface.co/aari1995/German_Semantic_V3
res = "jina-v2-de"
if chkhsh == "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267":
# ref: https://huggingface.co/zai-org/GLM-4.7-Flash
res = "glm4"
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
res = "llama-bpe"
@@ -1268,6 +1268,9 @@ class TextModel(ModelBase):
if chkhsh == "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4":
# ref: https://huggingface.co/Qwen/Qwen3.5-9B-Instruct
res = "qwen35"
if chkhsh == "b4b8ca1f9769494fbd956ebc4c249de6131fb277a4a3345a7a92c7dd7a55808d":
# ref: https://huggingface.co/jdopensource/JoyAI-LLM-Flash
res = "joyai-llm"
if res is None:
logger.warning("\n")
@@ -4581,7 +4584,7 @@ class Qwen3VLVisionModel(MmprojModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration")
@ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration", "GlmOcrForConditionalGeneration")
class Glm4VVisionModel(Qwen3VLVisionModel):
def set_gguf_parameters(self):
MmprojModel.set_gguf_parameters(self) # skip Qwen3VLVisionModel parameters
@@ -8773,7 +8776,7 @@ class Glm4Model(TextModel):
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams["num_key_value_heads"]
n_embd = self.hparams["hidden_size"]
head_dim = n_embd // n_head
head_dim = self.hparams.get("head_dim", n_embd // n_head)
# because llama.cpp M-RoPE kernel only supports Neox ordering, we have to permute the weights here
if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = Glm4Model.normal_to_neox(data_torch, n_head, n_head, head_dim, self.partial_rotary_factor)
@@ -8782,6 +8785,27 @@ class Glm4Model(TextModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("GlmOcrForConditionalGeneration")
class GlmOCRModel(Glm4Model):
model_arch = gguf.MODEL_ARCH.GLM4
use_mrope = False
partial_rotary_factor = 0.5
# Note: GLM-OCR is the same as GLM4, but with an extra NextN/MTP prediction layer
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# GLM-OCR has num_hidden_layers + 1 actual layers (including NextN layer)
self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0)
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
def set_gguf_parameters(self):
super().set_gguf_parameters()
# NextN/MTP prediction layers
if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers)
@ModelBase.register("Glm4MoeForCausalLM", "Glm4vMoeForConditionalGeneration")
class Glm4MoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.GLM4_MOE
+3 -2
View File
@@ -149,7 +149,8 @@ models = [
{"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", },
{"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", },
{"name": "exaone-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B", },
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", }
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", },
{"name": "joyai-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jdopensource/JoyAI-LLM-Flash", },
]
# some models are known to be broken upstream, so we will skip them as exceptions
@@ -159,6 +160,7 @@ pre_computed_hashes = [
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
{"name": "hunyuan-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B-Instruct", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"},
@@ -172,7 +174,6 @@ pre_computed_hashes = [
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
# jina-v2-de variants
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
]
+1 -1
View File
@@ -246,7 +246,7 @@ cmake --build build --config release
1. **Retrieve and prepare model**
You can refer to the general [*Prepare and Quantize*](../../README.md#prepare-and-quantize) guide for model prepration.
You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model prepration.
**Notes**:
+2 -2
View File
@@ -281,7 +281,7 @@ as `-cl-fp32-correctly-rounded-divide-sqrt`
#### Retrieve and prepare model
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
##### Check device
@@ -569,7 +569,7 @@ Once it is completed, final results will be in **build/Release/bin**
#### Retrieve and prepare model
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
##### Check device
@@ -42,11 +42,15 @@ def load_model_and_tokenizer(model_path, device="auto"):
config = config.text_config
multimodal = True
print("Vocab size: ", config.vocab_size)
print("Hidden size: ", config.hidden_size)
print("Number of layers: ", config.num_hidden_layers)
print("BOS token id: ", config.bos_token_id)
print("EOS token id: ", config.eos_token_id)
def print_if_exists(label, obj, attr, default="N/A"):
val = getattr(obj, attr) if hasattr(obj, attr) else default
print(f"{label}", val)
print_if_exists("Vocab size: ", config, "vocab_size")
print_if_exists("Hidden size: ", config, "hidden_size")
print_if_exists("Number of layers: ", config, "num_hidden_layers")
print_if_exists("BOS token id: ", config, "bos_token_id")
print_if_exists("EOS token id: ", config, "eos_token_id")
unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
if unreleased_model_name:
@@ -78,7 +78,7 @@ def list_all_tensors(model_path: Path, unique: bool = False):
print(tensor_name)
def print_tensor_info(model_path: Path, tensor_name: str):
def print_tensor_info(model_path: Path, tensor_name: str, num_values: Optional[int] = None):
tensor_file = find_tensor_file(model_path, tensor_name)
if tensor_file is None:
@@ -96,6 +96,12 @@ def print_tensor_info(model_path: Path, tensor_name: str):
print(f"Tensor: {tensor_name}")
print(f"File: {tensor_file}")
print(f"Shape: {shape}")
if num_values is not None:
tensor = f.get_tensor(tensor_name)
print(f"Dtype: {tensor.dtype}")
flat = tensor.flatten()
n = min(num_values, flat.numel())
print(f"Values: {flat[:n].tolist()}")
else:
print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}")
sys.exit(1)
@@ -127,6 +133,15 @@ def main():
action="store_true",
help="List unique tensor patterns in the model (layer numbers replaced with #)"
)
parser.add_argument(
"-n", "--num-values",
nargs="?",
const=10,
default=None,
type=int,
metavar="N",
help="Print the first N values of the tensor flattened (default: 10 if flag is given without a number)"
)
args = parser.parse_args()
@@ -152,7 +167,7 @@ def main():
if args.tensor_name is None:
print("Error: tensor_name is required when not using --list")
sys.exit(1)
print_tensor_info(model_path, args.tensor_name)
print_tensor_info(model_path, args.tensor_name, args.num_values)
if __name__ == "__main__":
+5
View File
@@ -9,6 +9,11 @@ function(ggml_add_cpu_backend_features cpu_name arch)
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARGN})
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED)
set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
# Disable LTO for the feature detection code to prevent cross-module optimization
# from inlining architecture-specific instructions into the score function.
# Without this, LTO can cause SIGILL when loading backends on older CPUs
# (e.g., loading power10 backend on power9 crashes before feature check runs).
target_compile_options(${GGML_CPU_FEATS_NAME} PRIVATE -fno-lto)
target_link_libraries(${cpu_name} PRIVATE ${GGML_CPU_FEATS_NAME})
endfunction()
+10 -33
View File
@@ -2278,11 +2278,12 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
if (ggml_is_quantized(src0->type)) {
if (ne2 <= 4) {
if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
return;
}
@@ -2305,6 +2306,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
}
}
// note: this path should not be reached when recording CUDA graphs, because it requires stream synchronization
// TODO: add asserts to verify this. should work with CUDA, HIP, etc.
cudaStream_t stream = ctx.stream();
GGML_ASSERT(nb12 % nb11 == 0);
@@ -2865,15 +2868,6 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
bool use_cuda_graph = true;
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
const std::string delta_net_prefix = "dnet_add";
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
@@ -2888,31 +2882,14 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
#endif
}
if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
#endif
}
if (node->op == GGML_OP_ADD &&
node->src[1] && node->src[1]->ne[1] > 1 &&
(node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0 &&
strncmp(node->name, delta_net_prefix.c_str(), delta_net_prefix.size()) != 0) {
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
// by means of matching node names. See
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) {
// under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
// TODO: figure out a way to enable for larger batch sizes, without hurting performance
// ref: https://github.com/ggml-org/llama.cpp/pull/18958
use_cuda_graph = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
#endif
}
+1
View File
@@ -1,6 +1,7 @@
#include "common.cuh"
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
#define MMVQ_MMID_MAX_BATCH_SIZE 4 // Max. batch size for which to use MMVQ kernels for MUL_MAT_ID
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
+179 -187
View File
@@ -484,7 +484,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_scale_f32, kernel_scale_f32_4;
cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4;
cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4;
cl_kernel kernel_mean_f32;
cl_kernel kernel_mean_f32, kernel_mean_f32_4;
cl_kernel kernel_silu, kernel_silu_4;
cl_kernel kernel_gelu, kernel_gelu_4;
cl_kernel kernel_gelu_erf, kernel_gelu_erf_4;
@@ -543,15 +543,15 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_solve_tri_f32;
cl_kernel kernel_im2col_f32, kernel_im2col_f16;
cl_kernel kernel_argsort_f32_i32;
cl_kernel kernel_sum_rows_f32;
cl_kernel kernel_sum_rows_f32, kernel_sum_rows_f32_4;
cl_kernel kernel_repeat_f32;
cl_kernel kernel_pad;
cl_kernel kernel_tanh_f32, kernel_tanh_f32_4, kernel_tanh_f32_nc;
cl_kernel kernel_tanh_f16, kernel_tanh_f16_4, kernel_tanh_f16_nc;
cl_kernel kernel_expm1_f32_nd;
cl_kernel kernel_expm1_f16_nd;
cl_kernel kernel_softplus_f32_nd;
cl_kernel kernel_softplus_f16_nd;
cl_kernel kernel_expm1_f32, kernel_expm1_f32_4, kernel_expm1_f32_nc;
cl_kernel kernel_expm1_f16, kernel_expm1_f16_4, kernel_expm1_f16_nc;
cl_kernel kernel_softplus_f32, kernel_softplus_f32_4, kernel_softplus_f32_nc;
cl_kernel kernel_softplus_f16, kernel_softplus_f16_4, kernel_softplus_f16_nc;
cl_kernel kernel_upscale;
cl_kernel kernel_upscale_bilinear;
cl_kernel kernel_concat_f32;
@@ -1837,6 +1837,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mean_f32 = clCreateKernel(prog, "kernel_mean_f32", &err), err));
CL_CHECK((backend_ctx->kernel_mean_f32_4 = clCreateKernel(prog, "kernel_mean_f32_4", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
@@ -1874,6 +1875,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_sum_rows_f32 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32", &err), err));
CL_CHECK((backend_ctx->kernel_sum_rows_f32_4 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32_4", &err), err));
GGML_LOG_CONT(".");
}
@@ -1978,20 +1980,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
#else
const std::string kernel_src = read_file("expm1.cl");
#endif
cl_program prog;
if (!kernel_src.empty()) {
prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_expm1_f32_nd = clCreateKernel(prog, "kernel_expm1_f32_nd", &err), err));
CL_CHECK((backend_ctx->kernel_expm1_f16_nd = clCreateKernel(prog, "kernel_expm1_f16_nd", &err), err));
GGML_LOG_CONT(".");
} else {
GGML_LOG_WARN("ggml_opencl: expm1 kernel source not found or empty. Expm1 operation will not be available.\n");
prog = nullptr;
backend_ctx->kernel_expm1_f32_nd = nullptr;
backend_ctx->kernel_expm1_f16_nd = nullptr;
}
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_expm1_f32 = clCreateKernel(prog, "kernel_expm1_f32", &err), err));
CL_CHECK((backend_ctx->kernel_expm1_f32_4 = clCreateKernel(prog, "kernel_expm1_f32_4", &err), err));
CL_CHECK((backend_ctx->kernel_expm1_f32_nc = clCreateKernel(prog, "kernel_expm1_f32_nc", &err), err));
CL_CHECK((backend_ctx->kernel_expm1_f16 = clCreateKernel(prog, "kernel_expm1_f16", &err), err));
CL_CHECK((backend_ctx->kernel_expm1_f16_4 = clCreateKernel(prog, "kernel_expm1_f16_4", &err), err));
CL_CHECK((backend_ctx->kernel_expm1_f16_nc = clCreateKernel(prog, "kernel_expm1_f16_nc", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// softplus
@@ -2003,20 +2001,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
#else
const std::string kernel_src = read_file("softplus.cl");
#endif
cl_program prog;
if (!kernel_src.empty()) {
prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_softplus_f32_nd = clCreateKernel(prog, "kernel_softplus_f32_nd", &err), err));
CL_CHECK((backend_ctx->kernel_softplus_f16_nd = clCreateKernel(prog, "kernel_softplus_f16_nd", &err), err));
GGML_LOG_CONT(".");
} else {
GGML_LOG_WARN("ggml_opencl: softplus kernel source not found or empty. Softplus operation will not be available.\n");
prog = nullptr;
backend_ctx->kernel_softplus_f32_nd = nullptr;
backend_ctx->kernel_softplus_f16_nd = nullptr;
}
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_softplus_f32 = clCreateKernel(prog, "kernel_softplus_f32", &err), err));
CL_CHECK((backend_ctx->kernel_softplus_f32_4 = clCreateKernel(prog, "kernel_softplus_f32_4", &err), err));
CL_CHECK((backend_ctx->kernel_softplus_f32_nc = clCreateKernel(prog, "kernel_softplus_f32_nc", &err), err));
CL_CHECK((backend_ctx->kernel_softplus_f16 = clCreateKernel(prog, "kernel_softplus_f16", &err), err));
CL_CHECK((backend_ctx->kernel_softplus_f16_4 = clCreateKernel(prog, "kernel_softplus_f16_4", &err), err));
CL_CHECK((backend_ctx->kernel_softplus_f16_nc = clCreateKernel(prog, "kernel_softplus_f16_nc", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// upscale
@@ -3463,11 +3457,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
case GGML_UNARY_OP_TANH:
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
case GGML_UNARY_OP_EXPM1:
return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
(op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
case GGML_UNARY_OP_SOFTPLUS:
return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
(op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
default:
return false;
}
@@ -3587,7 +3579,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
}
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_FLASH_ATTN_EXT:
{
const ggml_tensor * q = op->src[0];
@@ -6400,7 +6392,6 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const
GGML_UNUSED(src1);
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
GGML_ASSERT(ggml_is_contiguous(src0));
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
@@ -6423,7 +6414,14 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];
cl_kernel kernel = backend_ctx->kernel_mean_f32;
cl_kernel kernel;
const bool is_c4 = ne00 % 4 == 0;
if (is_c4) {
kernel = backend_ctx->kernel_mean_f32_4;
} else {
kernel = backend_ctx->kernel_mean_f32;
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -6440,7 +6438,7 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3));
size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03};
size_t global_work_size[] = {64 * (size_t)ne01, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)64, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
@@ -7388,18 +7386,8 @@ static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, cons
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0_abs = extra0->offset + src0->view_offs;
cl_ulong offsetd_abs = extrad->offset + dst->view_offs;
cl_kernel kernel;
if (dst->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_expm1_f32_nd;
} else if (dst->type == GGML_TYPE_F16) {
kernel = backend_ctx->kernel_expm1_f16_nd;
} else {
GGML_ASSERT(false && "Unsupported type for ggml_cl_expm1");
}
GGML_ASSERT(kernel != nullptr);
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
@@ -7411,70 +7399,74 @@ static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, cons
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];
const int ne10 = dst->ne[0];
const int ne11 = dst->ne[1];
const int ne12 = dst->ne[2];
const int ne13 = dst->ne[3];
const cl_ulong nb0 = dst->nb[0];
const cl_ulong nb1 = dst->nb[1];
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];
const cl_ulong nb10 = dst->nb[0];
const cl_ulong nb11 = dst->nb[1];
const cl_ulong nb12 = dst->nb[2];
const cl_ulong nb13 = dst->nb[3];
cl_kernel kernel;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13));
size_t global_work_size[3];
if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements
return;
}
global_work_size[0] = (size_t)ne10;
global_work_size[1] = (size_t)ne11;
global_work_size[2] = (size_t)ne12;
size_t lws0 = 16, lws1 = 4, lws2 = 1;
if (ne10 < 16) lws0 = ne10;
if (ne11 < 4) lws1 = ne11;
if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1;
while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2;
while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2;
while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2;
size_t local_work_size[] = {lws0, lws1, lws2};
size_t* local_work_size_ptr = local_work_size;
if (!backend_ctx->non_uniform_workgroups) {
if (global_work_size[0] % local_work_size[0] != 0 ||
global_work_size[1] % local_work_size[1] != 0 ||
global_work_size[2] % local_work_size[2] != 0) {
local_work_size_ptr = NULL;
if (ggml_is_contiguous(src0)) {
// Handle contiguous input
int n = ggml_nelements(dst);
if (n % 4 == 0) {
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_expm1_f32_4;
} else {
kernel = backend_ctx->kernel_expm1_f16_4;
}
n /= 4;
} else {
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_expm1_f32;
} else {
kernel = backend_ctx->kernel_expm1_f16;
}
}
}
if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return;
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr;
}
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
} else {
// Handle non-contiguous input
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_expm1_f32_nc;
} else {
kernel = backend_ctx->kernel_expm1_f16_nc;
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3));
int nth = 64;
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)nth, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
}
static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -7490,18 +7482,8 @@ static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, c
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0_abs = extra0->offset + src0->view_offs;
cl_ulong offsetd_abs = extrad->offset + dst->view_offs;
cl_kernel kernel;
if (dst->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_softplus_f32_nd;
} else if (dst->type == GGML_TYPE_F16) {
kernel = backend_ctx->kernel_softplus_f16_nd;
} else {
GGML_ASSERT(false && "Unsupported type for ggml_cl_softplus");
}
GGML_ASSERT(kernel != nullptr);
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
@@ -7513,70 +7495,74 @@ static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, c
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];
const int ne10 = dst->ne[0];
const int ne11 = dst->ne[1];
const int ne12 = dst->ne[2];
const int ne13 = dst->ne[3];
const cl_ulong nb0 = dst->nb[0];
const cl_ulong nb1 = dst->nb[1];
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];
const cl_ulong nb10 = dst->nb[0];
const cl_ulong nb11 = dst->nb[1];
const cl_ulong nb12 = dst->nb[2];
const cl_ulong nb13 = dst->nb[3];
cl_kernel kernel;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13));
size_t global_work_size[3];
if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements
return;
}
global_work_size[0] = (size_t)ne10;
global_work_size[1] = (size_t)ne11;
global_work_size[2] = (size_t)ne12;
size_t lws0 = 16, lws1 = 4, lws2 = 1;
if (ne10 < 16) lws0 = ne10;
if (ne11 < 4) lws1 = ne11;
if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1;
while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2;
while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2;
while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2;
size_t local_work_size[] = {lws0, lws1, lws2};
size_t* local_work_size_ptr = local_work_size;
if (!backend_ctx->non_uniform_workgroups) {
if (global_work_size[0] % local_work_size[0] != 0 ||
global_work_size[1] % local_work_size[1] != 0 ||
global_work_size[2] % local_work_size[2] != 0) {
local_work_size_ptr = NULL;
if (ggml_is_contiguous(src0)) {
// Handle contiguous input
int n = ggml_nelements(dst);
if (n % 4 == 0) {
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_softplus_f32_4;
} else {
kernel = backend_ctx->kernel_softplus_f16_4;
}
n /= 4;
} else {
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_softplus_f32;
} else {
kernel = backend_ctx->kernel_softplus_f16;
}
}
}
if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return;
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr;
}
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
} else {
// Handle non-contiguous input
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_softplus_f32_nc;
} else {
kernel = backend_ctx->kernel_softplus_f16_nc;
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3));
int nth = 64;
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)nth, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
}
static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1_shape_def, ggml_tensor * dst) {
@@ -11088,7 +11074,6 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c
GGML_UNUSED(src1);
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
GGML_ASSERT(ggml_is_contiguous(src0));
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
@@ -11111,7 +11096,14 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];
cl_kernel kernel = backend_ctx->kernel_sum_rows_f32;
cl_kernel kernel;
const bool is_c4 = ne00 % 4 == 0;
if (is_c4) {
kernel = backend_ctx->kernel_sum_rows_f32_4;
} else {
kernel = backend_ctx->kernel_sum_rows_f32;
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -11128,7 +11120,7 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3));
size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03};
size_t global_work_size[] = {64 * (size_t)ne01, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)64, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+87 -56
View File
@@ -3,80 +3,111 @@
//------------------------------------------------------------------------------
// expm1
//------------------------------------------------------------------------------
kernel void kernel_expm1_f32_nd(
global void * p_src0_base,
ulong off_src0_abs,
global void * p_dst_base,
ulong off_dst_abs,
int ne00,
int ne01,
int ne02,
int ne03,
kernel void kernel_expm1_f32(
global const float * src0,
ulong offset0,
global float * dst,
ulong offsetd
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f;
}
kernel void kernel_expm1_f32_4(
global const float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f;
}
kernel void kernel_expm1_f16(
global const half * src0,
ulong offset0,
global half * dst,
ulong offsetd
) {
src0 = (global half*)((global char*)src0 + offset0);
dst = (global half*)((global char*)dst + offsetd);
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h;
}
kernel void kernel_expm1_f16_4(
global const half4 * src0,
ulong offset0,
global half4 * dst,
ulong offsetd
) {
src0 = (global half4*)((global char*)src0 + offset0);
dst = (global half4*)((global char*)dst + offsetd);
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h;
}
kernel void kernel_expm1_f32_nc(
global const char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
int i0 = get_global_id(0);
int i1 = get_global_id(1);
int i2 = get_global_id(2);
src0 = src0 + offset0;
dst = dst + offsetd;
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
for (int i3 = 0; i3 < ne13; ++i3) {
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*dst_val_ptr = exp(*src_val_ptr) - 1;
}
*y = exp(*x) - 1.0f;
}
}
kernel void kernel_expm1_f16_nd(
global void * p_src0_base,
ulong off_src0_abs,
global void * p_dst_base,
ulong off_dst_abs,
int ne00,
int ne01,
int ne02,
int ne03,
kernel void kernel_expm1_f16_nc(
global const char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
int i0 = get_global_id(0);
int i1 = get_global_id(1);
int i2 = get_global_id(2);
src0 = src0 + offset0;
dst = dst + offsetd;
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
for (int i3 = 0; i3 < ne13; ++i3) {
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*dst_val_ptr = exp(*src_val_ptr) - 1;
}
*y = exp(*x) - 1.0f;
}
}
+116 -15
View File
@@ -1,8 +1,13 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
// Most devices have max workgroup size of 1024, so this is enough for subgroup
// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes
#define MAX_SUBGROUPS 64
kernel void kernel_mean_f32(
global float * src0,
global char * src0,
ulong offset0,
global float * dst,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
@@ -15,25 +20,121 @@ kernel void kernel_mean_f32(
ulong nb2,
ulong nb3
) {
src0 = (global float *)((global char *)src0 + offset0);
dst = (global float *)((global char *)dst + offsetd);
src0 = src0 + offset0;
dst = dst + offsetd;
int i3 = get_global_id(2);
int i2 = get_global_id(1);
int i1 = get_global_id(0);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
const int lid = get_local_id(0);
const int lsize = get_local_size(0);
const uint sg_size = get_sub_group_size();
const uint sg_id = get_sub_group_id();
const uint sg_lid = get_sub_group_local_id();
__local float lmem[MAX_SUBGROUPS];
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
float row_sum = 0;
for (int i0 = 0; i0 < ne00; i0++) {
row_sum += src_row[i0];
if(sg_id == 0){
lmem[sg_lid] = 0.0f;
}
dst_row[0] = row_sum / ne00;
global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
float sumf = 0.0f;
for (int i0 = lid; i0 < ne00; i0 += lsize) {
sumf += src_row[i0];
}
sumf = sub_group_reduce_add(sumf);
barrier(CLK_LOCAL_MEM_FENCE);
if(sg_lid == 0){
lmem[sg_id] = sumf;
}
barrier(CLK_LOCAL_MEM_FENCE);
sumf = lmem[sg_lid];
sumf = sub_group_reduce_add(sumf);
if (lid == 0) {
dst_row[0] = sumf / ne00;
}
}
kernel void kernel_mean_f32_4(
global char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb01,
ulong nb02,
ulong nb03,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
dst = dst + offsetd;
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
const int lid = get_local_id(0);
const int lsize = get_local_size(0);
const uint sg_size = get_sub_group_size();
const uint sg_id = get_sub_group_id();
const uint sg_lid = get_sub_group_local_id();
__local float lmem[MAX_SUBGROUPS];
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
if(sg_id == 0){
lmem[sg_lid] = 0.0f;
}
global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
float4 sum_vec = (float4)0.0f;
for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) {
sum_vec += src_row[i0];
}
float sumf = dot(sum_vec, (float4)(1.0f));
sumf = sub_group_reduce_add(sumf);
barrier(CLK_LOCAL_MEM_FENCE);
if(sg_lid == 0){
lmem[sg_id] = sumf;
}
barrier(CLK_LOCAL_MEM_FENCE);
sumf = lmem[sg_lid];
sumf = sub_group_reduce_add(sumf);
if (lid == 0) {
dst_row[0] = sumf / ne00;
}
}
+88 -60
View File
@@ -3,86 +3,114 @@
//------------------------------------------------------------------------------
// softplus
//------------------------------------------------------------------------------
inline float softplus_f32(float x){
float ax = fabs(x);
float m = fmax(x, 0.0f);
return log1p(exp(-ax)) + m;
kernel void kernel_softplus_f32(
global const float * src0,
ulong offset0,
global float * dst,
ulong offsetd
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)]));
}
kernel void kernel_softplus_f32_nd(
global void * p_src0_base,
ulong off_src0_abs,
global void * p_dst_base,
ulong off_dst_abs,
int ne00,
int ne01,
int ne02,
int ne03,
kernel void kernel_softplus_f32_4(
global const float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)]));
}
kernel void kernel_softplus_f16(
global const half * src0,
ulong offset0,
global half * dst,
ulong offsetd
) {
src0 = (global half*)((global char*)src0 + offset0);
dst = (global half*)((global char*)dst + offsetd);
const float x = convert_float(src0[get_global_id(0)]);
dst[get_global_id(0)] = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
}
kernel void kernel_softplus_f16_4(
global const half4 * src0,
ulong offset0,
global half4 * dst,
ulong offsetd
) {
src0 = (global half4*)((global char*)src0 + offset0);
dst = (global half4*)((global char*)dst + offsetd);
const float4 x = convert_float4(src0[get_global_id(0)]);
dst[get_global_id(0)] = convert_half4_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
}
kernel void kernel_softplus_f32_nc(
global const char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
int i0 = get_global_id(0);
int i1 = get_global_id(1);
int i2 = get_global_id(2);
src0 = src0 + offset0;
dst = dst + offsetd;
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
for (int i3 = 0; i3 < ne13; ++i3) {
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*dst_val_ptr = softplus_f32(*src_val_ptr);
}
*y = (*x > 20.0f) ? *x : log(1.0f + exp(*x));
}
}
kernel void kernel_softplus_f16_nd(
global void * p_src0_base,
ulong off_src0_abs,
global void * p_dst_base,
ulong off_dst_abs,
int ne00,
int ne01,
int ne02,
int ne03,
kernel void kernel_softplus_f16_nc(
global const char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
int i0 = get_global_id(0);
int i1 = get_global_id(1);
int i2 = get_global_id(2);
src0 = src0 + offset0;
dst = dst + offsetd;
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
for (int i3 = 0; i3 < ne13; ++i3) {
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
global const half * hx = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global half * hy = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*dst_val_ptr = (half)(softplus_f32((float)(*src_val_ptr)));
}
const float x = convert_float(*hx);
*hy = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
}
}
+116 -15
View File
@@ -1,8 +1,13 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
// Most devices have max workgroup size of 1024, so this is enough for subgroup
// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes
#define MAX_SUBGROUPS 64
kernel void kernel_sum_rows_f32(
global float * src0,
global char * src0,
ulong offset0,
global float * dst,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
@@ -15,25 +20,121 @@ kernel void kernel_sum_rows_f32(
ulong nb2,
ulong nb3
) {
src0 = (global float *)((global char *)src0 + offset0);
dst = (global float *)((global char *)dst + offsetd);
src0 = src0 + offset0;
dst = dst + offsetd;
int i3 = get_global_id(2);
int i2 = get_global_id(1);
int i1 = get_global_id(0);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
const int lid = get_local_id(0);
const int lsize = get_local_size(0);
const uint sg_size = get_sub_group_size();
const uint sg_id = get_sub_group_id();
const uint sg_lid = get_sub_group_local_id();
__local float lmem[MAX_SUBGROUPS];
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
float row_sum = 0;
for (int i0 = 0; i0 < ne00; i0++) {
row_sum += src_row[i0];
if(sg_id == 0){
lmem[sg_lid] = 0.0f;
}
dst_row[0] = row_sum;
global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
float sumf = 0.0f;
for (int i0 = lid; i0 < ne00; i0 += lsize) {
sumf += src_row[i0];
}
sumf = sub_group_reduce_add(sumf);
barrier(CLK_LOCAL_MEM_FENCE);
if(sg_lid == 0){
lmem[sg_id] = sumf;
}
barrier(CLK_LOCAL_MEM_FENCE);
sumf = lmem[sg_lid];
sumf = sub_group_reduce_add(sumf);
if (lid == 0) {
dst_row[0] = sumf;
}
}
kernel void kernel_sum_rows_f32_4(
global char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb01,
ulong nb02,
ulong nb03,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
dst = dst + offsetd;
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
const int lid = get_local_id(0);
const int lsize = get_local_size(0);
const uint sg_size = get_sub_group_size();
const uint sg_id = get_sub_group_id();
const uint sg_lid = get_sub_group_local_id();
__local float lmem[MAX_SUBGROUPS];
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
if(sg_id == 0){
lmem[sg_lid] = 0.0f;
}
global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
float4 sum_vec = (float4)0.0f;
for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) {
sum_vec += src_row[i0];
}
float sumf = dot(sum_vec, (float4)(1.0f));
sumf = sub_group_reduce_add(sumf);
barrier(CLK_LOCAL_MEM_FENCE);
if(sg_lid == 0){
lmem[sg_id] = sumf;
}
barrier(CLK_LOCAL_MEM_FENCE);
sumf = lmem[sg_lid];
sumf = sub_group_reduce_add(sumf);
if (lid == 0) {
dst_row[0] = sumf;
}
}
+53 -27
View File
@@ -944,6 +944,7 @@ struct vk_mat_mat_push_constants {
uint32_t M; uint32_t N; uint32_t K;
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
uint32_t base_work_group_z; uint32_t num_batches;
uint32_t k_split;
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
uint32_t padded_N;
@@ -963,6 +964,7 @@ struct vk_mat_vec_push_constants {
uint32_t batch_stride_b;
uint32_t batch_stride_d;
uint32_t fusion_flags;
uint32_t base_work_group_y;
uint32_t ne02;
uint32_t ne12;
uint32_t broadcast2;
@@ -6773,8 +6775,16 @@ static void ggml_vk_matmul(
uint32_t padded_n) {
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
if (split_k == 1) {
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
uint32_t base_work_group_z = 0;
while (base_work_group_z < batch) {
uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z });
base_work_group_z += groups_z;
}
return;
}
@@ -6788,9 +6798,17 @@ static void ggml_vk_matmul(
uint32_t k_split = CEIL_DIV(k, split_k);
k_split = ROUNDUP_POW2(k_split, 256);
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
// Make sure enough workgroups get assigned for split k to work
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
uint32_t base_work_group_z = 0;
while (base_work_group_z < batch) {
uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
// Make sure enough workgroups get assigned for split k to work
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z });
base_work_group_z += groups_z;
}
ggml_vk_sync_buffers(ctx, subctx);
const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
@@ -7186,7 +7204,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
}
// Request descriptor sets
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
if (qx_needs_dequant) {
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
}
@@ -7484,7 +7501,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (quantize_y) {
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
}
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
}
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
@@ -7579,22 +7595,29 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
}
// compute
const vk_mat_vec_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
stride_batch_x, stride_batch_y, stride_batch_d,
fusion_flags,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{
d_X,
d_Y,
d_D,
d_F0,
d_F1,
},
pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1]));
uint32_t base_work_group_y = 0;
while (base_work_group_y < ne12 * ne13) {
uint32_t groups_y = std::min((uint32_t)(ne12 * ne13) - base_work_group_y, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
const vk_mat_vec_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
stride_batch_x, stride_batch_y, stride_batch_d,
fusion_flags, base_work_group_y,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{
d_X,
d_Y,
d_D,
d_F0,
d_F1,
},
pc, { groups_x, groups_y, groups_z });
base_work_group_y += groups_y;
}
if (x_non_contig) {
ctx->prealloc_x_need_sync = true;
@@ -7832,10 +7855,15 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
src1->nb[2] <= src1->nb[1] &&
src1->nb[1] <= src1->nb[3] &&
src0->ne[3] == 1 &&
src1->ne[3] == 1) {
src1->ne[3] == 1 &&
src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx);
} else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
!ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
!ggml_is_permuted(src0) && !ggml_is_permuted(src1) &&
src0->ne[3] <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx);
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
// when ne12 and ne13 are one.
@@ -11560,7 +11588,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
}
}
ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
if (split_k > 1) {
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
@@ -12069,7 +12096,6 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
// y[i] = i % k;
}
ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
if (split_k > 1) {
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
@@ -32,6 +32,7 @@ layout (push_constant) uniform parameter
uint expert_i1;
uint nbi1;
#else
uint base_work_group_y;
uint ne02;
uint ne12;
uint broadcast2;
@@ -45,9 +46,9 @@ uint expert_id;
void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
#ifdef MUL_MAT_ID
const uint expert_i0 = gl_GlobalInvocationID.y;
const uint expert_i0 = gl_WorkGroupID.y;
#else
const uint batch_idx = gl_GlobalInvocationID.y;
const uint batch_idx = gl_WorkGroupID.y + p.base_work_group_y;
#endif
#ifndef MUL_MAT_ID
@@ -90,6 +90,8 @@ layout (push_constant) uniform parameter
uint nbi1;
uint ne11;
#else
uint base_work_group_z;
uint num_batches;
uint k_split;
uint ne02;
uint ne12;
@@ -139,7 +141,7 @@ void main() {
const uint ic = gl_WorkGroupID.y;
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
const uint expert_idx = gl_WorkGroupID.z;
if (ic * BN >= data_expert_count[expert_idx]) {
return;
}
@@ -149,7 +151,7 @@ void main() {
#endif
#ifndef MUL_MAT_ID
const uint batch_idx = gl_GlobalInvocationID.z;
const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12;
@@ -366,7 +368,7 @@ void main() {
const uint dc = ic * BN + warp_c * WN;
#ifndef MUL_MAT_ID
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
#endif
#ifdef COOPMAT
@@ -53,6 +53,8 @@ layout (push_constant) uniform parameter
uint nbi1;
uint ne11;
#else
uint base_work_group_z;
uint num_batches;
uint k_split;
uint ne02;
uint ne12;
@@ -197,7 +199,7 @@ void main() {
const uint ic = gl_WorkGroupID.y;
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
const uint expert_idx = gl_WorkGroupID.z;
if (ic * BN >= data_expert_count[expert_idx]) {
return;
}
@@ -215,7 +217,7 @@ void main() {
#endif
#ifndef MUL_MAT_ID
const uint batch_idx = gl_GlobalInvocationID.z;
const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12;
@@ -255,7 +257,7 @@ void main() {
#else
uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K);
uint pos_b = batch_idx * p.batch_stride_b;
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
#endif
uint stride_a = p.stride_a / QUANT_K;
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -1,5 +1,4 @@
#decl(BYTE_HELPERS)
#ifdef BYTE_HELPERS
fn get_byte(value: u32, index: u32) -> u32 {
return (value >> (index * 8)) & 0xFF;
}
@@ -7,76 +6,74 @@ fn get_byte(value: u32, index: u32) -> u32 {
fn get_byte_i32(value: u32, index: u32) -> i32 {
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
}
#endif
#enddecl(BYTE_HELPERS)
#decl(Q4_0_T)
#ifdef Q4_0_T
struct q4_0 {
d: f16,
qs: array<f16, 8>
};
#enddecl(Q4_0_T)
#endif
#decl(Q4_1_T)
#ifdef Q4_1_T
struct q4_1 {
d: f16,
m: f16,
qs: array<u32, 4>
};
#enddecl(Q4_1_T)
#endif
#decl(Q5_0_T)
#ifdef Q5_0_T
struct q5_0 {
d: f16,
qh: array<f16, 2>,
qs: array<f16, 8>
};
#enddecl(Q5_0_T)
#endif
#decl(Q5_1_T)
#ifdef Q5_1_T
struct q5_1 {
d: f16,
m: f16,
qh: u32,
qs: array<u32, 4>
};
#enddecl(Q5_1_T)
#endif
#decl(Q8_0_T)
#ifdef Q8_0_T
struct q8_0 {
d: f16,
qs: array<f16, 16>
};
#enddecl(Q8_0_T)
#endif
#decl(Q8_1_T)
#ifdef Q8_1_T
struct q8_1 {
d: f16,
m: f16,
qs: array<u32, 8>
};
#enddecl(Q8_1_T)
#endif
#decl(Q2_K_T)
struct q2_k {
#ifdef Q2_K_T
struct q2_K {
scales: array<u32, 4>,
qs: array<u32, 16>,
d: f16,
dmin: f16
};
#enddecl(Q2_K_T)
#endif
#decl(Q3_K_T)
struct q3_k {
#ifdef Q3_K_T
struct q3_K {
hmask: array<f16, 16>,
qs: array<f16, 32>,
scales: array<f16, 6>,
d: f16
};
#enddecl(Q3_K_T)
#decl(Q45_K_SCALE_MIN)
#endif
#if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN)
fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
if (is < 4) {
let sc_byte = get_byte(scales[is / 4], is % 4);
@@ -91,69 +88,67 @@ fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
return vec2(f32(sc), f32(m));
}
}
#enddecl(Q45_K_SCALE_MIN)
#decl(Q4_K_T)
struct q4_k {
#endif
#ifdef Q4_K_T
struct q4_K {
d: f16,
dmin: f16,
scales: array<u32, 3>,
qs: array<u32, 32>
};
#enddecl(Q4_K_T)
#endif
#decl(Q5_K_T)
struct q5_k {
#ifdef Q5_K_T
struct q5_K {
d: f16,
dmin: f16,
scales: array<u32, 3>,
qh: array<u32, 8>,
qs: array<u32, 32>
};
#enddecl(Q5_K_T)
#endif
#decl(Q6_K_T)
struct q6_k {
#ifdef Q6_K_T
struct q6_K {
ql: array<f16, 64>,
qh: array<f16, 32>,
scales: array<f16, 8>,
d: f16
};
#enddecl(Q6_K_T)
#endif
#decl(IQ2_XXS_T)
#ifdef IQ2_XXS_T
struct iq2_xxs {
d: f16,
qs: array<f16, 32>
};
#enddecl(IQ2_XXS_T)
#endif
#decl(IQ2_XS_T)
#ifdef IQ2_XS_T
struct iq2_xs {
d: f16,
qs: array<f16, 32>,
scales: array<f16, 4>
};
#enddecl(IQ2_XS_T)
#endif
#decl(IQ2_S_T)
#ifdef IQ2_S_T
struct iq2_s {
d: f16,
qs: array<f16, 32>,
qh: array<f16, 4>,
scales: array<f16, 4>
};
#enddecl(IQ2_S_T)
#endif
#decl(IQ3_XSS_T)
#ifdef IQ3_XXS_T
struct iq3_xxs {
d: f16,
qs: array<f16, 48>
};
#enddecl(IQ3_XSS_T)
#endif
#decl(IQ3_S_T)
#ifdef IQ3_S_T
struct iq3_s {
d: f16,
qs: array<f16, 32>,
@@ -161,41 +156,41 @@ struct iq3_s {
signs: array<f16, 16>,
scales: array<f16, 2>
};
#enddecl(IQ3_S_T)
#endif
#decl(IQ1_S_T)
#ifdef IQ1_S_T
struct iq1_s {
d: f16,
qs: array<f16, 16>,
qh: array<f16, 8>
};
#enddecl(IQ1_S_T)
#endif
#decl(IQ1_M_T)
#ifdef IQ1_M_T
struct iq1_m {
qs: array<u32, 8>,
qh: array<u32, 4>,
scales: array<u32, 2>
};
#enddecl(IQ1_M_T)
#endif
#decl(IQ4_NL_T)
#ifdef IQ4_NL_T
struct iq4_nl {
d: f16,
qs: array<f16, 8>,
};
#enddecl(IQ4_NL_T)
#endif
#decl(IQ4_XS_T)
#ifdef IQ4_XS_T
struct iq4_xs {
d: f16,
scales_h: f16,
scales_l: u32,
qs: array<u32, 32>
};
#enddecl(IQ4_XS_T)
#endif
#decl(IQ23_TABLES)
#if defined(IQ2_XXS_TABLES) || defined(IQ2_XS_TABLES) || defined(IQ2_S_TABLES) || defined(IQ3_XXS_TABLES) || defined(IQ3_S_TABLES)
const kmask_iq2xs : array<u32, 2> = array<u32, 2>(
0x08040201u, // 1, 2, 4, 8
0x80402010u // 16, 32, 64, 128
@@ -211,9 +206,9 @@ const ksigns_iq2xs: array<u32, 32> = array<u32, 32>(
0x63e2e160,0xe76665e4,0xeb6a69e8,0x6feeed6c,
0xf37271f0,0x77f6f574,0x7bfaf978,0xff7e7dfc
);
#enddecl(IQ23_TABLES)
#endif
#decl(IQ2_XXS_GRID)
#ifdef IQ2_XXS_GRID
const iq2xxs_grid = array<u32, 512>(
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x082b0808, 0x08080808,
@@ -280,9 +275,9 @@ const iq2xxs_grid = array<u32, 512>(
0x0808082b, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b081919, 0x2b2b0808, 0x08082b19, 0x2b2b0819,
0x08080808, 0x2b2b082b, 0x08192b08, 0x2b2b1908, 0x19190808, 0x2b2b2b08, 0x08081908, 0x2b2b2b19
);
#enddecl(IQ2_XXS_GRID)
#endif
#decl(IQ2_XS_GRID)
#ifdef IQ2_XS_GRID
const iq2xs_grid = array<u32, 1024>(
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
@@ -413,9 +408,9 @@ const iq2xs_grid = array<u32, 1024>(
0x2b2b2b08, 0x2b2b2b08, 0x08081908, 0x2b2b2b19, 0x2b081908, 0x2b2b2b19, 0x2b08192b, 0x2b2b2b19,
0x082b2b08, 0x2b2b2b2b, 0x082b2b2b, 0x2b2b2b2b, 0x2b190819, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
);
#enddecl(IQ2_XS_GRID)
#endif
#decl(IQ2_S_GRID)
#ifdef IQ2_S_GRID
const iq2s_grid = array<u32, 2048>(
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
@@ -674,10 +669,9 @@ const iq2s_grid = array<u32, 2048>(
0x2b08192b, 0x2b2b2b19, 0x08082b08, 0x2b2b2b2b, 0x08082b2b, 0x2b2b2b2b, 0x082b0808, 0x2b2b2b2b,
0x082b082b, 0x2b2b2b2b, 0x082b2b08, 0x2b2b2b2b, 0x2b082b08, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
);
#enddecl(IQ2_S_GRID)
#decl(IQ3_XSS_GRID)
#endif
#ifdef IQ3_XXS_GRID
const iq3xxs_grid = array<u32, 256>(
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
@@ -712,10 +706,9 @@ const iq3xxs_grid = array<u32, 256>(
0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04
);
#enddecl(IQ3_XSS_GRID)
#decl(IQ3_S_GRID)
#endif
#ifdef IQ3_S_GRID
const iq3s_grid = array<u32, 512>(
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
@@ -782,9 +775,9 @@ const iq3s_grid = array<u32, 512>(
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101
);
#enddecl(IQ3_S_GRID)
#endif
#decl(IQ1_GRID)
#if defined(IQ1_S_GRID) || defined(IQ1_M_GRID)
const IQ1_DELTA: f32 = 0.125;
@@ -919,12 +912,12 @@ const iq1_grid = array<u32, 1024>(
0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557
);
#enddecl(IQ1_GRID)
#endif
#decl(IQ4_GRID)
#if defined(IQ4_NL_GRID) || defined(IQ4_XS_GRID)
const kvalues_iq4nl = array<i32, 16>(
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113
);
#enddecl(IQ4_GRID)
#endif
@@ -56,12 +56,46 @@ def expand_includes(shader, input_dir):
return include_pattern.sub(replacer, shader)
def write_shader(shader_name, shader_code, output_dir, outfile):
def chunk_shader(shader_code, max_chunk_len=60000):
"""Split shader_code into safe raw-string sized chunks."""
return [shader_code[i : i + max_chunk_len] for i in range(0, len(shader_code), max_chunk_len)]
def raw_delim(shader_code):
"""Pick a raw-string delimiter that does not appear in the shader."""
delim = "wgsl"
while f"){delim}\"" in shader_code:
delim += "_x"
return delim
def write_shader(shader_name, shader_code, output_dir, outfile, input_dir):
shader_code = expand_includes(shader_code, input_dir)
if output_dir:
wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl")
with open(wgsl_filename, "w", encoding="utf-8") as f_out:
f_out.write(shader_code)
outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n')
delim = raw_delim(shader_code)
chunks = chunk_shader(shader_code)
if len(chunks) == 1:
outfile.write(f'const char* wgsl_{shader_name} = R"{delim}({shader_code}){delim}";\n\n')
else:
for idx, chunk in enumerate(chunks):
outfile.write(f'static const char wgsl_{shader_name}_part{idx}[] = R"{delim}({chunk}){delim}";\n\n')
outfile.write(f'static const std::string& wgsl_{shader_name}_str() {{\n')
outfile.write(' static const std::string s = []{\n')
outfile.write(' std::string tmp;\n')
outfile.write(f' tmp.reserve({len(shader_code)});\n')
for idx in range(len(chunks)):
outfile.write(f' tmp.append(wgsl_{shader_name}_part{idx});\n')
outfile.write(' return tmp;\n')
outfile.write(' }();\n')
outfile.write(' return s;\n')
outfile.write('}\n')
outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\n\n')
def generate_variants(fname, input_dir, output_dir, outfile):
@@ -74,7 +108,7 @@ def generate_variants(fname, input_dir, output_dir, outfile):
try:
variants = ast.literal_eval(extract_block(text, "VARIANTS"))
except ValueError:
write_shader(shader_base_name, text, output_dir, outfile)
write_shader(shader_base_name, text, output_dir, outfile, input_dir)
else:
try:
decls_map = parse_decls(extract_block(text, "DECLS"))
@@ -123,7 +157,7 @@ def generate_variants(fname, input_dir, output_dir, outfile):
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
else:
output_name = shader_base_name
write_shader(output_name, final_shader, output_dir, outfile)
write_shader(output_name, final_shader, output_dir, outfile, input_dir)
def main():
@@ -137,7 +171,8 @@ def main():
os.makedirs(args.output_dir, exist_ok=True)
with open(args.output_file, "w", encoding="utf-8") as out:
out.write("// Auto-generated shader embedding\n\n")
out.write("// Auto-generated shader embedding\n")
out.write("#include <string>\n\n")
for fname in sorted(os.listdir(args.input_dir)):
if fname.endswith(".wgsl"):
generate_variants(fname, args.input_dir, args.output_dir, out)
@@ -1,222 +1,31 @@
#define(VARIANTS)
enable f16;
#include "common_decls.tmpl"
[
{
"SHADER_SUFFIX": "f32_vec",
"REPLS": {
"TYPE" : "vec4<f32>",
"DST_TYPE": "vec4<f32>",
"BLOCK_SIZE": 4
},
"DECLS": ["F32_VEC"]
},
{
"REPLS": {
"TYPE" : "f32",
"DST_TYPE": "f32",
"BLOCK_SIZE": 1
},
"DECLS": ["F32"]
},
{
"REPLS": {
"TYPE" : "f16",
"DST_TYPE": "f32",
"BLOCK_SIZE": 1
},
"DECLS": ["F16"]
},
{
"REPLS": {
"TYPE" : "i32",
"DST_TYPE": "i32",
"BLOCK_SIZE": 1
},
"DECLS": ["I32"]
},
{
"REPLS": {
"TYPE" : "q4_0",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
},
{
"REPLS": {
"TYPE" : "q4_1",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
},
{
"REPLS": {
"TYPE" : "q5_0",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
},
{
"REPLS": {
"TYPE" : "q5_1",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
},
{
"REPLS": {
"TYPE" : "q8_0",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
},
{
"REPLS": {
"TYPE" : "q2_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
},
{
"REPLS": {
"TYPE" : "q3_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
},
{
"REPLS": {
"TYPE" : "q4_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
},
{
"REPLS": {
"TYPE" : "q5_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
},
{
"REPLS": {
"TYPE" : "q6_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
},
{
"REPLS": {
"TYPE" : "iq2_xxs",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
},
{
"REPLS": {
"TYPE" : "iq2_xs",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
},
{
"REPLS": {
"TYPE": "iq2_s",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
},
{
"REPLS": {
"TYPE": "iq3_xxs",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
},
{
"REPLS": {
"TYPE": "iq3_s",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
},
{
"REPLS": {
"TYPE": "iq1_s",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
},
{
"REPLS": {
"TYPE": "iq1_m",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
},
{
"REPLS": {
"TYPE": "iq4_nl",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32,
},
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
},
{
"REPLS": {
"TYPE": "iq4_xs",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256,
},
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
}
]
#end(VARIANTS)
#define(DECLS)
#decl(F32_VEC)
#ifdef F32_VEC
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset];
}
#enddecl(F32_VEC)
#endif
#decl(F32)
#ifdef F32
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[dst_base + offset] = src[src_base + offset];
}
#enddecl(F32)
#endif
#decl(F16)
#ifdef F16
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[dst_base + offset] = f32(src[src_base + offset]);
}
#enddecl(F16)
#endif
#decl(I32)
#ifdef I32
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[dst_base + offset] = src[src_base + offset];
}
#enddecl(I32)
#endif
#decl(Q4_0)
#ifdef Q4_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q4_0 = src[src_base + offset];
let d = f32(block_q4_0.d);
@@ -232,9 +41,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q4_0)
#endif
#decl(Q4_1)
#ifdef Q4_1
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q4_1 = src[src_base + offset];
let d = f32(block_q4_1.d);
@@ -251,9 +60,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q4_1)
#endif
#decl(Q5_0)
#ifdef Q5_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q5_0 = src[src_base + offset];
let d = f32(block_q5_0.d);
@@ -272,10 +81,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#endif
#enddecl(Q5_0)
#decl(Q5_1)
#ifdef Q5_1
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q5_1 = src[src_base + offset];
let d = f32(block_q5_1.d);
@@ -294,9 +102,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q5_1)
#endif
#decl(Q8_0)
#ifdef Q8_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q8_0 = src[src_base + offset];
let d = f32(block_q8_0.d);
@@ -310,9 +118,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q8_0)
#endif
#decl(Q2_K)
#ifdef Q2_K
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -340,9 +148,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q2_K)
#endif
#decl(Q3_K)
#ifdef Q3_K
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -398,9 +206,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q3_K)
#endif
#decl(Q4_K)
#ifdef Q4_K
// 8 blocks of 32 elements each
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
@@ -425,9 +233,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q4_K)
#endif
#decl(Q5_K)
#ifdef Q5_K
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -455,9 +263,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q5_K)
#endif
#decl(Q6_K)
#ifdef Q6_K
// 16 blocks of 16 elements each
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
@@ -511,10 +319,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
sc_b_idx += 8;
}
}
#endif
#enddecl(Q6_K)
#decl(IQ2_XXS)
#ifdef IQ2_XXS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -536,9 +343,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(IQ2_XXS)
#endif
#decl(IQ2_XS)
#ifdef IQ2_XS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -568,9 +375,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(IQ2_XS)
#endif
#decl(IQ2_S)
#ifdef IQ2_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -608,10 +415,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#endif
#enddecl(IQ2_S)
#decl(IQ3_XSS)
#ifdef IQ3_XXS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -638,9 +444,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(IQ3_XSS)
#endif
#decl(IQ3_S)
#ifdef IQ3_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -683,9 +489,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(IQ3_S)
#endif
#decl(IQ1_S)
#ifdef IQ1_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -707,10 +513,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#endif
#enddecl(IQ1_S)
#decl(IQ1_M)
#ifdef IQ1_M
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
@@ -751,10 +556,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#endif
#enddecl(IQ1_M)
#decl(IQ4_NL)
#ifdef IQ4_NL
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -770,9 +574,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst_i++;
}
}
#enddecl(IQ4_NL)
#endif
#decl(IQ4_XS)
#ifdef IQ4_XS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -791,24 +595,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst_i += 16;
}
}
#enddecl(IQ4_XS)
#end(DECLS)
#define(SHADER)
enable f16;
DECLS
#endif
@group(0) @binding(0)
var<storage, read_write> src: array<{{TYPE}}>;
var<storage, read_write> src: array<SRC_TYPE>;
@group(0) @binding(1)
var<storage, read_write> idx: array<i32>;
@group(0) @binding(2)
var<storage, read_write> dst: array<{{DST_TYPE}}>;
var<storage, read_write> dst: array<DST_TYPE>;
struct Params {
offset_src: u32, // in elements
@@ -842,8 +638,7 @@ struct Params {
@group(0) @binding(3)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
return;
@@ -866,9 +661,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3;
let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3;
for (var i: u32 = 0; i < params.ne0/{{BLOCK_SIZE}}; i++) {
for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) {
copy_elements(i_src_row, i_dst_row, i);
}
}
#end(SHADER)
@@ -1,195 +1,24 @@
#define(VARIANTS)
enable f16;
[
{
"REPLS": {
"SRC0_TYPE" : "f32",
"SRC1_TYPE" : "f32",
"BLOCK_SIZE" : 1
},
"DECLS" : ["FLOAT"]
},
{
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f16",
"BLOCK_SIZE" : 1
},
"DECLS" : ["FLOAT"]
},
{
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"BLOCK_SIZE" : 1
},
"DECLS" : ["FLOAT"]
},
{
"REPLS": {
"SRC0_TYPE": "q4_0",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
},
{
"REPLS": {
"SRC0_TYPE": "q4_1",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
},
{
"REPLS": {
"SRC0_TYPE": "q5_0",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
},
{
"REPLS": {
"SRC0_TYPE": "q5_1",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
},
{
"REPLS": {
"SRC0_TYPE": "q8_0",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
},
{
"REPLS": {
"SRC0_TYPE": "q2_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
},
{
"REPLS": {
"SRC0_TYPE": "q3_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
},
{
"REPLS": {
"SRC0_TYPE": "q4_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
},
{
"REPLS": {
"SRC0_TYPE": "q5_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
},
{
"REPLS": {
"SRC0_TYPE": "q6_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
},
{
"REPLS": {
"SRC0_TYPE": "iq2_xxs",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
},
{
"REPLS": {
"SRC0_TYPE": "iq2_xs",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
},
{
"REPLS": {
"SRC0_TYPE": "iq2_s",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
},
{
"REPLS": {
"SRC0_TYPE": "iq3_xxs",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
},
{
"REPLS": {
"SRC0_TYPE": "iq3_s",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
},
{
"REPLS": {
"SRC0_TYPE": "iq1_s",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
},
{
"REPLS": {
"SRC0_TYPE": "iq1_m",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
},
{
"REPLS": {
"SRC0_TYPE": "iq4_nl",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32,
},
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
},
{
"REPLS": {
"SRC0_TYPE": "iq4_xs",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256,
},
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
}
]
#include "common_decls.tmpl"
#end(VARIANTS)
#ifdef FLOAT
const BLOCK_SIZE = 1u;
#define(DECLS)
#elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL)
const BLOCK_SIZE = 32u;
#decl(FLOAT)
#elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS)
const BLOCK_SIZE = 256u;
#endif
#ifdef FLOAT
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]);
}
#enddecl(FLOAT)
#endif
#decl(Q4_0)
#ifdef Q4_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q4_0 = src0[src0_idx_base + offset];
let d = f32(block_q4_0.d);
@@ -207,9 +36,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q4_0)
#endif
#decl(Q4_1)
#ifdef Q4_1
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q4_1 = src0[src0_idx_base + offset];
let d = f32(block_q4_1.d);
@@ -228,9 +57,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q4_1)
#endif
#decl(Q5_0)
#ifdef Q5_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q5_0 = src0[src0_idx_base + offset];
let d = f32(block_q5_0.d);
@@ -251,9 +80,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q5_0)
#endif
#decl(Q5_1)
#ifdef Q5_1
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q5_1 = src0[src0_idx_base + offset];
let d = f32(block_q5_1.d);
@@ -274,9 +103,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q5_1)
#endif
#decl(Q8_0)
#ifdef Q8_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q8_0 = src0[src0_idx_base + offset];
let d = f32(block_q8_0.d);
@@ -292,9 +121,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q8_0)
#endif
#decl(Q8_1)
#ifdef Q8_1
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q8_1 = src0[src0_idx_base + offset];
let d = f32(block_q8_1.d);
@@ -311,9 +140,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q8_1)
#endif
#decl(Q2_K)
#ifdef Q2_K
// 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@@ -344,10 +173,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(Q2_K)
#decl(Q3_K)
#ifdef Q3_K
// 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@@ -406,10 +234,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(Q3_K)
#decl(Q4_K)
#ifdef Q4_K
// 8 blocks of 32 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@@ -436,10 +263,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(Q4_K)
#decl(Q5_K)
#ifdef Q5_K
// 8 blocks of 32 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@@ -470,10 +296,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(Q5_K)
#decl(Q6_K)
#ifdef Q6_K
// 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@@ -529,10 +354,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(Q6_K)
#decl(IQ2_XXS)
#ifdef IQ2_XXS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -556,10 +380,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ2_XXS)
#decl(IQ2_XS)
#ifdef IQ2_XS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -591,10 +414,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ2_XS)
#decl(IQ2_S)
#ifdef IQ2_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -634,11 +456,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ2_S)
#decl(IQ3_XSS)
#ifdef IQ3_XXS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -667,10 +487,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ3_XSS)
#decl(IQ3_S)
#ifdef IQ3_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -715,9 +534,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(IQ3_S)
#endif
#decl(IQ1_S)
#ifdef IQ1_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -741,10 +560,10 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ1_S)
#decl(IQ1_M)
#ifdef IQ1_M
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@@ -787,10 +606,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ1_M)
#decl(IQ4_NL)
#ifdef IQ4_NL
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -808,10 +626,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ4_NL)
#decl(IQ4_XS)
#ifdef IQ4_XS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -832,16 +649,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(IQ4_XS)
#end(DECLS)
#define(SHADER)
enable f16;
DECLS
#endif
struct MulMatParams {
offset_src0: u32, // in elements/blocks
@@ -864,8 +672,8 @@ struct MulMatParams {
broadcast3: u32
};
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
@group(0) @binding(3) var<uniform> params: MulMatParams;
@@ -898,10 +706,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
var sum = 0.0;
for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) {
for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) {
sum += multiply_add(src0_idx_base, src1_idx_base, i);
}
dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
}
#end(SHADER)
@@ -1,58 +1,65 @@
#decl(SHMEM_VEC)
#ifdef VEC
#define VEC_SIZE 4
#define SHMEM_TYPE vec4<f16>
#define DST_TYPE vec4<f32>
#define SRC0_TYPE vec4<SRC0_INNER_TYPE>
#define SRC1_TYPE vec4<SRC1_INNER_TYPE>
fn store_shmem(val: vec4<f16>, idx: u32) {
shmem[idx] = val.x;
shmem[idx + 1] = val.y;
shmem[idx + 2] = val.z;
shmem[idx + 3] = val.w;
}
#enddecl(SHMEM_VEC)
#endif
#ifdef SCALAR
#define VEC_SIZE 1
#define SHMEM_TYPE f16
#define DST_TYPE f32
#define SRC0_TYPE SRC0_INNER_TYPE
#define SRC1_TYPE SRC1_INNER_TYPE
#decl(SHMEM_SCALAR)
fn store_shmem(val: f16, idx: u32) {
shmem[idx] = val;
}
#enddecl(SHMEM_SCALAR)
#decl(INIT_SRC0_SHMEM_FLOAT)
#endif
#ifdef INIT_SRC0_SHMEM_FLOAT
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let src0_val = select( // taking a slight performance hit to avoid oob
{{SRC0_TYPE}}(0.0),
src0[src0_idx/{{VEC_SIZE}}],
SRC0_TYPE(0.0),
src0[src0_idx/VEC_SIZE],
global_m < params.m && global_k < params.k);
store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx);
store_shmem(SHMEM_TYPE(src0_val), elem_idx);
}
}
#endif
#enddecl(INIT_SRC0_SHMEM_FLOAT)
#decl(INIT_SRC1_SHMEM)
#ifdef INIT_SRC1_SHMEM_FLOAT
fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
let tile_n = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_n = offset_n + tile_n;
let global_k = k_outer + tile_k;
let src1_idx = batch_offset + global_n * params.stride_11 + global_k;
let src1_val = select(
{{SRC1_TYPE}}(0.0),
src1[src1_idx/{{VEC_SIZE}}],
SRC1_TYPE(0.0),
src1[src1_idx/VEC_SIZE],
global_n < params.n && global_k < params.k);
store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx);
store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx);
}
}
#endif
#enddecl(INIT_SRC1_SHMEM)
#decl(INIT_SRC0_SHMEM_Q4_0)
#ifdef INIT_SRC0_SHMEM_Q4_0
const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
@@ -93,5 +100,4 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
}
}
}
#enddecl(INIT_SRC0_SHMEM_Q4_0)
#endif
@@ -1,115 +1,19 @@
#define(VARIANTS)
[
{
"SHADER_SUFFIX": "f32_f32_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f32>",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f32_f32",
"REPLS": {
"SRC0_TYPE" : "f32",
"SRC1_TYPE" : "f32",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f32_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f16>",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f32",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f16_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f16>",
"SRC1_TYPE" : "vec4<f16>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f16",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f16",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "q4_0_f32_vec",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "q4_0_f32",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
}
]
enable f16;
#end(VARIANTS)
#include "common_decls.tmpl"
#include "mul_mat_decls.tmpl"
#define(DECLS)
#decl(VEC)
#ifdef VEC
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
return vec4<f32>(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn]));
}
#enddecl(VEC)
#endif
#decl(SCALAR)
#ifdef SCALAR
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
return f32(acc[tm][tn]);
}
#enddecl(SCALAR)
#end(DECLS)
#define(SHADER)
enable f16;
#endif
struct MulMatParams {
offset_src0: u32,
@@ -130,14 +34,12 @@ struct MulMatParams {
broadcast3: u32
};
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
@group(0) @binding(3) var<uniform> params: MulMatParams;
DECLS
fn get_local_n(thread_id: u32) -> u32 {
return thread_id / WORKGROUP_SIZE_M;
}
@@ -145,18 +47,9 @@ fn get_local_m(thread_id: u32) -> u32 {
return thread_id % WORKGROUP_SIZE_M;
}
// TILE_M must be multiple of 4 for vec4 loads
const TILE_M = {{WEBGPU_TILE_M}}u;
const TILE_N = {{WEBGPU_TILE_N}}u;
override WORKGROUP_SIZE_M: u32;
override WORKGROUP_SIZE_N: u32;
override TILE_K: u32;
override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
@@ -233,15 +126,13 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
for (var tn = 0u; tn < TILE_N; tn++) {
let global_col = output_col_base + tn;
if (global_col < params.n) {
for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) {
for (var tm = 0u; tm < TILE_M; tm += VEC_SIZE) {
let global_row = output_row_base + tm;
if (global_row < params.m) {
let dst_idx = dst_batch_offset + global_col * params.m + global_row;
dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm);
dst[dst_idx/VEC_SIZE] = store_val(acc, tn, tm);
}
}
}
}
}
#end(SHADER)
@@ -1,100 +1,12 @@
#define(VARIANTS)
[
{
"SHADER_SUFFIX": "f32_f32_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f32>",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f32_f32",
"REPLS": {
"SRC0_TYPE" : "f32",
"SRC1_TYPE" : "f32",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f32_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f16>",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f32",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f16_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f16>",
"SRC1_TYPE" : "vec4<f16>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f16",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f16",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "q4_0_f32_vec",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "q4_0_f32",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
}
]
diagnostic(off, chromium.subgroup_matrix_uniformity);
enable f16;
enable subgroups;
enable chromium_experimental_subgroup_matrix;
#end(VARIANTS)
#include "common_decls.tmpl"
#include "mul_mat_decls.tmpl"
#define(DECLS)
#decl(VEC)
#ifdef VEC
fn store_dst(shmem_idx: u32, dst_idx: u32) {
dst[dst_idx] = vec4<f32>(
f32(shmem[shmem_idx]),
@@ -103,21 +15,13 @@ fn store_dst(shmem_idx: u32, dst_idx: u32) {
f32(shmem[shmem_idx + 3])
);
}
#enddecl(VEC)
#endif
#decl(SCALAR)
#ifdef SCALAR
fn store_dst(shmem_idx: u32, dst_idx: u32) {
dst[dst_idx] = f32(shmem[shmem_idx]);
}
#enddecl(SCALAR)
#end(DECLS)
#define(SHADER)
diagnostic(off, chromium.subgroup_matrix_uniformity);
enable f16;
enable subgroups;
enable chromium_experimental_subgroup_matrix;
#endif
struct MulMatParams {
offset_src0: u32,
@@ -138,36 +42,19 @@ struct MulMatParams {
broadcast3: u32
};
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
@group(0) @binding(3) var<uniform> params: MulMatParams;
DECLS
// Note: These are string interpolated at build time, cannot use override constants due to limitations in
// current Dawn version type definitions/matrix load requirements for constant memory sizes.
const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u;
const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u;
// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
// runtime subgroup size is smaller.
const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u;
const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;
const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u;
const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u;
const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u;
const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u;
const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u;
const TILE_K = {{WEBGPU_TILE_K}}u;
const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
// runtime subgroup size is smaller.
const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;
const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE;
const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
@@ -285,7 +172,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
for (var idx = thread_id * VEC_SIZE; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
let local_row = idx % WG_TILE_STRIDE;
let local_col = idx / WG_TILE_STRIDE;
@@ -294,9 +181,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
if (global_col < params.n && global_row < params.m) {
let dst_idx = dst_batch_offset + global_col * params.m + global_row;
store_dst(idx, dst_idx/{{VEC_SIZE}});
store_dst(idx, dst_idx/VEC_SIZE);
}
}
}
#end(SHADER)
@@ -1,84 +1,17 @@
#define(VARIANTS)
[
{
"SHADER_SUFFIX": "f32_f32_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f32>",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE": "vec4<f32>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "MUL_ACC_FLOAT"]
},
{
"SHADER_SUFFIX": "f32_f32",
"REPLS": {
"SRC0_TYPE" : "f32",
"SRC1_TYPE" : "f32",
"DST_TYPE": "f32",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
},
{
"SHADER_SUFFIX": "f16_f32_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f16>",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE": "vec4<f32>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "MUL_ACC_FLOAT"]
},
{
"SHADER_SUFFIX": "f16_f32",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"DST_TYPE": "f32",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
},
{
"SHADER_SUFFIX": "f16_f16_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f16>",
"SRC1_TYPE" : "vec4<f16>",
"DST_TYPE": "vec4<f32>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "MUL_ACC_FLOAT"]
},
{
"SHADER_SUFFIX": "f16_f16",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f16",
"DST_TYPE": "f32",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
},
{
"SHADER_SUFFIX": "q4_0_f32",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"DST_TYPE": "f32",
"VEC_SIZE" : 1,
},
"DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"]
}
]
#end(VARIANTS)
enable f16;
#define(DECLS)
#include "common_decls.tmpl"
#decl(VEC)
fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
return f32(dot({{SRC1_TYPE}}(src0_val), src1_val));
#ifdef VEC
#define VEC_SIZE 4
#define DST_TYPE vec4<f32>
#define SRC0_TYPE vec4<SRC0_INNER_TYPE>
#define SRC1_TYPE vec4<SRC1_INNER_TYPE>
fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
return f32(dot(SRC1_TYPE(src0_val), src1_val));
}
fn store_val(group_base: u32) -> vec4<f32> {
@@ -87,33 +20,37 @@ fn store_val(group_base: u32) -> vec4<f32> {
partial_sums[group_base + THREADS_PER_OUTPUT * 2],
partial_sums[group_base + THREADS_PER_OUTPUT * 3]);
}
#enddecl(VEC)
#endif
#decl(SCALAR)
fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
#ifdef SCALAR
#define VEC_SIZE 1
#define DST_TYPE f32
#define SRC0_TYPE SRC0_INNER_TYPE
#define SRC1_TYPE SRC1_INNER_TYPE
fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
return f32(src0_val) * f32(src1_val);
}
fn store_val(group_base: u32) -> f32 {
return partial_sums[group_base];
}
#enddecl(SCALAR)
#decl(MUL_ACC_FLOAT)
#endif
#ifdef MUL_ACC_FLOAT
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
var local_sum = 0.0;
for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) {
let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}];
let b = shared_vector[i / {{VEC_SIZE}}];
for (var i = tig * VEC_SIZE; i < tile_size; i += THREADS_PER_OUTPUT * VEC_SIZE) {
let a = src0[(idx_base + k_outer + i) / VEC_SIZE];
let b = shared_vector[i / VEC_SIZE];
local_sum += inner_dot(a, b);
}
return local_sum;
}
#endif
#enddecl(MUL_ACC_FLOAT)
#decl(MUL_ACC_Q4_0)
#ifdef MUL_ACC_Q4_0
const BLOCK_SIZE = 32;
const NQ = 16u; // number of weights per thread
@@ -145,15 +82,7 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
}
return local_sum;
}
#enddecl(MUL_ACC_Q4_0)
#end(DECLS)
#define(SHADER)
enable f16;
DECLS
#endif
struct MulMatParams {
offset_src0: u32,
@@ -174,22 +103,20 @@ struct MulMatParams {
broadcast3: u32
};
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // Matrix (M x K)
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // Result vector (transposed)
// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
@group(0) @binding(3) var<uniform> params: MulMatParams;
override WORKGROUP_SIZE: u32;
override TILE_K: u32;
override OUTPUTS_PER_WG: u32;
override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG;
const THREADS_PER_OUTPUT = WG_SIZE / OUTPUTS_PER_WG;
// Shared memory for collaborative loading and reduction
var<workgroup> shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>; // Cache vector tile
var<workgroup> partial_sums: array<f32, WORKGROUP_SIZE>; // For reduction
var<workgroup> shared_vector: array<SRC1_TYPE, TILE_K/VEC_SIZE>; // Cache vector tile
var<workgroup> partial_sums: array<f32, WG_SIZE>; // For reduction
@compute @workgroup_size(WORKGROUP_SIZE)
@compute @workgroup_size(WG_SIZE)
fn main(
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) wg_id: vec3<u32>,
@@ -232,8 +159,8 @@ fn main(
let tile_size = min(TILE_K, params.k - k_tile);
// Cooperatively load vector tile into shared memory (all threads)
for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) {
shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}];
for (var i = thread_id * VEC_SIZE; i < tile_size; i += WG_SIZE * VEC_SIZE) {
shared_vector[i / VEC_SIZE] = src1[(src1_idx_base + k_tile + i) / VEC_SIZE];
}
workgroupBarrier();
@@ -250,7 +177,7 @@ fn main(
workgroupBarrier();
let group_base = thread_group * THREADS_PER_OUTPUT;
let thread_base = group_base + thread_in_group;
var offset = THREADS_PER_OUTPUT / 2;
var offset: u32 = THREADS_PER_OUTPUT / 2;
while (offset > 0) {
if (thread_in_group < offset) {
partial_sums[thread_base] += partial_sums[thread_base + offset];
@@ -260,8 +187,8 @@ fn main(
}
// Store back to global memory
if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) {
dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base);
if (output_row < params.m && thread_group % VEC_SIZE == 0 && thread_in_group == 0) {
dst[dst_idx / VEC_SIZE] = store_val(group_base);
}
}
#end(SHADER)
@@ -1,21 +1,11 @@
#define(VARIANTS)
#ifdef INPLACE
@group(0) @binding(1)
var<uniform> params: Params;
[
{
"SHADER_NAME": "scale_f32",
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "scale_f32_inplace",
"DECLS": ["INPLACE"]
}
]
#end(VARIANTS)
#define(DECLS)
#decl(NOT_INPLACE)
fn store_scale(val: f32, offset: u32) {
src[offset] = val;
}
#else
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;
@@ -25,20 +15,7 @@ var<uniform> params: Params;
fn store_scale(val: f32, offset: u32) {
dst[offset] = val;
}
#enddecl(NOT_INPLACE)
#decl(INPLACE)
@group(0) @binding(1)
var<uniform> params: Params;
fn store_scale(val: f32, offset: u32) {
src[offset] = val;
}
#enddecl(INPLACE)
#end(DECLS)
#define(SHADER)
#endif
struct Params {
offset_src: u32,
@@ -65,10 +42,7 @@ struct Params {
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
DECLS
override wg_size: u32;
@compute @workgroup_size(wg_size)
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
@@ -87,4 +61,3 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
store_scale(src[i_src] * params.scale + params.bias, i_dst);
}
#end(SHADER)
+7
View File
@@ -2660,6 +2660,13 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_POST_NORM,
# NextN/MTP tensors - preserved but unused
MODEL_TENSOR.NEXTN_EH_PROJ,
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
MODEL_TENSOR.NEXTN_ENORM,
MODEL_TENSOR.NEXTN_HNORM,
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.GLM4_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
+2
View File
@@ -1404,6 +1404,7 @@ class TensorNameMap:
MODEL_TENSOR.V_ENC_ATTN_Q_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.attn.q_norm", # InternVL
"model.vision_tower.encoder.layer.{bid}.attention.q_norm", # Intern-S1
"visual.blocks.{bid}.attn.q_norm", # GLM-OCR
),
MODEL_TENSOR.V_ENC_ATTN_K: (
@@ -1422,6 +1423,7 @@ class TensorNameMap:
MODEL_TENSOR.V_ENC_ATTN_K_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.attn.k_norm", # InternVL
"model.vision_tower.encoder.layer.{bid}.attention.k_norm", # Intern-S1
"visual.blocks.{bid}.attn.k_norm", # GLM-OCR
),
MODEL_TENSOR.V_ENC_ATTN_V: (
+6
View File
@@ -1633,6 +1633,12 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_ATTN_POST_NORM,
LLM_TENSOR_FFN_POST_NORM,
LLM_TENSOR_NEXTN_EH_PROJ,
LLM_TENSOR_NEXTN_EMBED_TOKENS,
LLM_TENSOR_NEXTN_ENORM,
LLM_TENSOR_NEXTN_HNORM,
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
};
case LLM_ARCH_GLM4_MOE:
return {
+45 -19
View File
@@ -1784,7 +1784,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false);
// NextN/MTP parameters (GLM-OCR)
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
// TODO: when MTP is implemented, this should probably be updated if needed
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
switch (hparams.n_layer) {
case 17: type = LLM_TYPE_1B; break; // GLM-OCR
case 40: type = LLM_TYPE_9B; break;
case 61: type = LLM_TYPE_32B; break;
default: type = LLM_TYPE_UNKNOWN;
@@ -5410,30 +5418,48 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
if (layer.wqkv == nullptr) {
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
int flags = 0;
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
// skip all tensors in the NextN layers
flags |= TENSOR_SKIP;
}
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
auto & layer = layers[i];
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags);
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags | TENSOR_NOT_REQUIRED);
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, flags | TENSOR_NOT_REQUIRED);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0);
if (layer.wqkv == nullptr) {
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, flags);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, flags);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, flags);
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, flags | TENSOR_NOT_REQUIRED);
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, flags | TENSOR_NOT_REQUIRED);
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, flags | TENSOR_NOT_REQUIRED);
}
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, flags);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, flags);
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags);
// NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags);
// Optional tensors
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED);
}
}
} break;
case LLM_ARCH_GLM4_MOE:
+5
View File
@@ -308,6 +308,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
break;
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM:
case LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE:
case LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM:
regex_exprs = {
"\\p{N}{1,3}",
"[一-龥぀-ゟ゠-ヿ]+",
@@ -2051,6 +2052,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "hunyuan-dense") {
pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE;
clean_spaces = false;
} else if (
tokenizer_pre == "joyai-llm") {
pre_type = LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM;
clean_spaces = false;
} else if (
tokenizer_pre == "kimi-k2") {
pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2;
+1
View File
@@ -56,6 +56,7 @@ enum llama_vocab_pre_type {
LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45,
LLAMA_VOCAB_PRE_TYPE_QWEN35 = 46,
LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47,
LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48,
};
struct LLM_KV;
+12 -5
View File
@@ -29,7 +29,10 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
// Only process up to last layer (skip final NextN layer)
// Final layer tensors are loaded but not processed in forward pass
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
for (int il = 0; il < n_transformer_layers; ++il) {
ggml_tensor * inpSA = inpL;
// Pre-attention norm
@@ -100,7 +103,7 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
}
if (il == n_layer - 1 && inp_out_ids) {
if (il == n_transformer_layers - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
@@ -130,9 +133,13 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params
cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "post_mlp_norm", il);
}
// Add residual connection after post-MLP norm
inpL = ggml_add(ctx0, cur, ffn_inp);
cb(inpL, "l_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
// Final norm
cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1);
+11 -3
View File
@@ -342,9 +342,17 @@ ggml_tensor * clip_graph::build_vit(
/* nb2 */ cur->nb[1],
/* offset */ ggml_row_size(cur->type, 2 * n_embd));
// TODO: q/k norm requires row size == n_embd, while here it's d_head
// we can add support in the future if needed
GGML_ASSERT(layer.q_norm == nullptr && layer.k_norm == nullptr);
if (layer.q_norm) {
GGML_ASSERT(layer.q_norm->ne[0] == Qcur->ne[0]);
Qcur = build_norm(Qcur, layer.q_norm, NULL, norm_t, eps, il);
cb(Qcur, "Qcur_norm", il);
}
if (layer.k_norm) {
GGML_ASSERT(layer.k_norm->ne[0] == Kcur->ne[0]);
Kcur = build_norm(Kcur, layer.k_norm, NULL, norm_t, eps, il);
cb(Kcur, "Kcur_norm", il);
}
} else {
// separate q, k, v
+16 -14
View File
@@ -2,7 +2,6 @@
ggml_cgraph * clip_graph_glm4v::build() {
GGML_ASSERT(model.patch_bias != nullptr);
GGML_ASSERT(model.position_embeddings != nullptr);
GGML_ASSERT(model.class_embedding == nullptr);
const int batch_size = 1;
@@ -45,19 +44,22 @@ ggml_cgraph * clip_graph_glm4v::build() {
// pos-conv norm
inp = build_norm(inp, model.norm_embd_w, model.norm_embd_b, norm_t, eps, -1);
// calculate absolute position embedding and apply
ggml_tensor * learned_pos_embd = resize_position_embeddings(GGML_SCALE_MODE_BICUBIC);
learned_pos_embd = ggml_cont_4d(
ctx0, learned_pos_embd,
n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
learned_pos_embd = ggml_reshape_4d(
ctx0, learned_pos_embd,
n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
learned_pos_embd = ggml_permute(ctx0, learned_pos_embd, 0, 2, 1, 3);
learned_pos_embd = ggml_cont_3d(
ctx0, learned_pos_embd,
n_embd, n_patches_x * n_patches_y, batch_size);
cb(learned_pos_embd, "learned_pos_embd", -1);
ggml_tensor * learned_pos_embd = nullptr;
// Note: GLM-OCR does not have learned position embeddings
if (model.position_embeddings != nullptr) {
learned_pos_embd = resize_position_embeddings(GGML_SCALE_MODE_BICUBIC);
learned_pos_embd = ggml_cont_4d(
ctx0, learned_pos_embd,
n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
learned_pos_embd = ggml_reshape_4d(
ctx0, learned_pos_embd,
n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
learned_pos_embd = ggml_permute(ctx0, learned_pos_embd, 0, 2, 1, 3);
learned_pos_embd = ggml_cont_3d(
ctx0, learned_pos_embd,
n_embd, n_patches_x * n_patches_y, batch_size);
cb(learned_pos_embd, "learned_pos_embd", -1);
}
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
return ggml_rope_multi(
+89 -65
View File
@@ -347,7 +347,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
int count = 0;
double nll = 0.0;
LOG_INF("%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
const int n_seq = std::max(1, n_batch / n_ctx);
LOG_INF("%s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
for (int i = 0; i < n_chunk; ++i) {
const int start = i * params.ppl_stride;
@@ -1737,11 +1738,21 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
}
const int n_batch = params.n_batch;
const int num_batches = (n_ctx + n_batch - 1)/n_batch;
const int num_batches = (static_cast<int>(n_ctx) + n_batch - 1) / n_batch;
// Calculate n_seq based on the logits file's n_ctx, but cap it at what the context supports
const int n_seq_max = llama_n_seq_max(ctx);
int n_seq = std::max(1, n_batch / static_cast<int>(n_ctx));
if (n_seq > n_seq_max) {
LOG_WRN("%s: calculated n_seq=%d exceeds context's n_seq_max=%d, capping at %d\n",
__func__, n_seq, n_seq_max, n_seq_max);
n_seq = n_seq_max;
}
const int nv = 2*((n_vocab + 1)/2) + 4;
const bool add_bos = llama_vocab_get_add_bos(vocab);
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
llama_batch batch = llama_batch_init(std::min(n_batch, static_cast<int>(n_ctx)*n_seq), 0, 1);
std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
std::vector<float> p_diff_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
@@ -1750,6 +1761,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
logits.reserve(size_t(n_ctx) * n_vocab);
}
LOG_INF("%s: computing over %d chunks, n_ctx=%u, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
auto mean_and_uncertainty = [] (double sum, double sum2, size_t count) {
@@ -1774,107 +1787,122 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
auto kld_ptr = kld_values.data();
auto p_diff_ptr = p_diff_values.data();
for (int i = 0; i < n_chunk; ++i) {
const int first = n_ctx/2;
for (int i = 0; i < n_chunk; i += n_seq) {
const int start = i * n_ctx;
const int end = start + n_ctx;
const auto t_start = std::chrono::high_resolution_clock::now();
const int n_seq_batch = std::min(n_seq, n_chunk - i);
if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) {
LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i);
return;
}
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_memory_clear(llama_get_memory(ctx), true);
llama_batch batch = llama_batch_init(n_batch, 0, 1);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
// save original token and restore it after eval
const auto token_org = tokens[batch_start];
// add BOS token for the first batch of each chunk
if (add_bos && j == 0) {
tokens[batch_start] = llama_vocab_bos(vocab);
}
int n_outputs = 0;
common_batch_clear(batch);
for (int i = 0; i < batch_size; i++) {
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
for (int seq = 0; seq < n_seq_batch; seq++) {
int seq_start = batch_start + seq*n_ctx;
// save original token and restore it after eval
const auto token_org = tokens[seq_start];
// add BOS token for the first batch of each chunk
if (add_bos && j == 0) {
tokens[seq_start] = llama_vocab_bos(vocab);
}
for (int k = 0; k < batch_size; ++k) {
const int pos = j*n_batch + k;
const bool need_logits = pos >= first;
common_batch_add(batch, tokens[seq_start + k], pos, { seq }, need_logits);
n_outputs += need_logits;
}
// restore the original token in case it was set to BOS
tokens[seq_start] = token_org;
}
if (llama_decode(ctx, batch)) {
LOG_ERR("%s : failed to eval\n", __func__);
LOG_ERR("%s : failed to decode\n", __func__);
llama_batch_free(batch);
return;
}
// restore the original token in case it was set to BOS
tokens[batch_start] = token_org;
if (num_batches > 1) {
if (num_batches > 1 && n_outputs > 0) {
const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + size_t(batch_size) * n_vocab);
logits.insert(logits.end(), batch_logits, batch_logits + size_t(n_outputs) * n_vocab);
}
}
llama_batch_free(batch);
const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) {
llama_synchronize(ctx);
const auto t_end = std::chrono::high_resolution_clock::now();
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total * n_chunk);
int total_seconds = (int)(t_total * n_chunk / n_seq);
if (total_seconds >= 60*60) {
LOG("%d hours ", total_seconds / (60*60));
total_seconds = total_seconds % (60*60);
}
LOG("%.2f minutes\n", total_seconds / 60.0);
LOG("\n");
LOG("chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n");
}
LOG("\n");
LOG("chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n");
const int first = n_ctx/2;
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
process_logits(n_vocab, all_logits + size_t(first)*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr);
p_diff_ptr += n_ctx - 1 - first;
kld_ptr += n_ctx - 1 - first;
// Read log probs for each sequence in the batch
for (int seq = 0; seq < n_seq_batch; seq++) {
if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) {
LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i + seq);
llama_batch_free(batch);
return;
}
LOG("%4d", i+1);
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
const double ppl_val = exp(log_ppl.first);
const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 )
LOG(" %9.4lf ± %9.4lf", ppl_val, ppl_unc);
process_logits(n_vocab, all_logits, tokens.data() + start + seq*n_ctx + first, n_ctx - 1 - first,
workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr);
p_diff_ptr += n_ctx - 1 - first;
kld_ptr += n_ctx - 1 - first;
auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count);
const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
LOG(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc);
LOG("%4d", i + seq + 1);
auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
LOG(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second);
auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
const double ppl_val = exp(log_ppl.first);
const double ppl_unc = ppl_val * log_ppl.second;
LOG(" %9.4lf ± %9.4lf", ppl_val, ppl_unc);
auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count);
const double p_diff_rms_val = sqrt(p_diff_mse.first);
const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second;
LOG(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count);
const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
LOG(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc);
double p_top_val = 1.*kld.n_same_top/kld.count;
double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1));
LOG(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc);
auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
LOG(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second);
LOG("\n");
auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count);
const double p_diff_rms_val = sqrt(p_diff_mse.first);
const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second;
LOG(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
double p_top_val = 1.*kld.n_same_top/kld.count;
double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1));
LOG(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc);
LOG("\n");
}
logits.clear();
}
llama_batch_free(batch);
LOG("\n");
if (kld.count < 100) return; // we do not wish to do statistics on so few values
@@ -1996,7 +2024,7 @@ int main(int argc, char ** argv) {
const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence;
if (ppl) {
if (ppl || params.kl_divergence) {
const int32_t n_seq = std::max(1, params.n_batch / n_ctx);
const int32_t n_kv = n_seq * n_ctx;
@@ -2006,12 +2034,8 @@ int main(int argc, char ** argv) {
params.n_batch = std::min(params.n_batch, n_kv);
} else {
params.n_batch = std::min(params.n_batch, params.n_ctx);
if (params.kl_divergence) {
params.n_parallel = 1;
} else {
// ensure there's at least enough seq_ids for HellaSwag
params.n_parallel = std::max(4, params.n_parallel);
}
// ensure there's at least enough seq_ids for HellaSwag
params.n_parallel = std::max(4, params.n_parallel);
}
if (params.ppl_stride > 0) {
-4
View File
@@ -59,8 +59,4 @@ target_include_directories(${TARGET} PRIVATE ../mtmd)
target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
target_link_libraries(${TARGET} PRIVATE server-context PUBLIC common cpp-httplib ${CMAKE_THREAD_LIBS_INIT})
if (WIN32)
TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
endif()
target_compile_features(${TARGET} PRIVATE cxx_std_17)
Binary file not shown.
+8 -1
View File
@@ -77,6 +77,7 @@ struct server_slot {
size_t last_nl_pos = 0;
std::string generated_text;
std::string debug_generated_text;
llama_tokens generated_tokens;
// idx of draft tokens in the main batch
@@ -425,7 +426,7 @@ struct server_slot {
if (!only_metrics) {
res["prompt"] = ptask->tokens.detokenize(ctx, true);
res["generated"] = generated_text;
res["generated"] = generated_text.empty() ? debug_generated_text : generated_text;
}
}
@@ -1442,6 +1443,12 @@ private:
res->id_slot = slot.id;
res->index = slot.task->index;
// keep copy of last generated text for debugging purposes
if (slots_debug) {
slot.debug_generated_text = slot.generated_text;
}
// in stream mode, content and tokens are already in last partial chunk
if (slot.task->params.stream) {
res->content = "";
+3 -1
View File
@@ -27,7 +27,9 @@ export default ts.config(
// typescript-eslint strongly recommend that you do not use the no-undef lint rule on TypeScript projects.
// see: https://typescript-eslint.io/troubleshooting/faqs/eslint/#i-get-errors-from-the-no-undef-rule-about-global-variables-not-being-defined-even-though-there-are-no-typescript-errors
'no-undef': 'off',
'svelte/no-at-html-tags': 'off'
'svelte/no-at-html-tags': 'off',
// This app uses hash-based routing (#/) where resolve() from $app/paths does not apply
'svelte/no-navigation-without-resolve': 'off'
}
},
{
+1053 -487
View File
File diff suppressed because it is too large Load Diff
+9 -8
View File
@@ -23,31 +23,32 @@
"cleanup": "rm -rf .svelte-kit build node_modules test-results"
},
"devDependencies": {
"@chromatic-com/storybook": "^4.1.2",
"@chromatic-com/storybook": "^5.0.0",
"@eslint/compat": "^1.2.5",
"@eslint/js": "^9.18.0",
"@internationalized/date": "^3.10.1",
"@lucide/svelte": "^0.515.0",
"@playwright/test": "^1.49.1",
"@storybook/addon-a11y": "^10.0.7",
"@storybook/addon-docs": "^10.0.7",
"@storybook/addon-a11y": "^10.2.4",
"@storybook/addon-docs": "^10.2.4",
"@storybook/addon-svelte-csf": "^5.0.10",
"@storybook/addon-vitest": "^10.0.7",
"@storybook/sveltekit": "^10.0.7",
"@storybook/addon-vitest": "^10.2.4",
"@storybook/sveltekit": "^10.2.4",
"@sveltejs/adapter-static": "^3.0.10",
"@sveltejs/kit": "^2.48.4",
"@sveltejs/vite-plugin-svelte": "^6.2.1",
"@tailwindcss/forms": "^0.5.9",
"@tailwindcss/typography": "^0.5.15",
"@tailwindcss/vite": "^4.0.0",
"@types/node": "^22",
"@types/node": "^24",
"@vitest/browser": "^3.2.3",
"@vitest/coverage-v8": "^3.2.3",
"bits-ui": "^2.14.4",
"clsx": "^2.1.1",
"dexie": "^4.0.11",
"eslint": "^9.18.0",
"eslint-config-prettier": "^10.0.1",
"eslint-plugin-storybook": "^10.0.7",
"eslint-plugin-storybook": "^10.2.4",
"eslint-plugin-svelte": "^3.0.0",
"fflate": "^0.8.2",
"globals": "^16.0.0",
@@ -61,7 +62,7 @@
"rehype-katex": "^7.0.1",
"remark-math": "^6.0.0",
"sass": "^1.93.3",
"storybook": "^10.0.7",
"storybook": "^10.2.4",
"svelte": "^5.38.2",
"svelte-check": "^4.0.0",
"tailwind-merge": "^3.3.1",
@@ -8,7 +8,8 @@
isImageFile,
isPdfFile,
isAudioFile,
getLanguageFromFilename
getLanguageFromFilename,
createBase64DataUrl
} from '$lib/utils';
import { convertPDFToImage } from '$lib/utils/browser-only';
import { modelsStore } from '$lib/stores/models.svelte';
@@ -255,7 +256,7 @@
<audio
controls
class="mb-4 w-full"
src={`data:${attachment.mimeType};base64,${attachment.base64Data}`}
src={createBase64DataUrl(attachment.mimeType, attachment.base64Data)}
>
Your browser does not support the audio element.
</audio>
@@ -1,5 +1,5 @@
<script lang="ts">
import { RemoveButton } from '$lib/components/app';
import { ActionIconRemove } from '$lib/components/app';
import { formatFileSize, getFileTypeLabel, getPreviewText, isTextFile } from '$lib/utils';
import { AttachmentType } from '$lib/enums';
@@ -104,7 +104,7 @@
onclick={onClick}
>
<div class="absolute top-2 right-2 opacity-0 transition-opacity group-hover:opacity-100">
<RemoveButton {id} {onRemove} />
<ActionIconRemove {id} {onRemove} />
</div>
<div class="pr-8">
@@ -158,7 +158,7 @@
{#if !readonly}
<div class="absolute top-2 right-2 opacity-0 transition-opacity group-hover:opacity-100">
<RemoveButton {id} {onRemove} />
<ActionIconRemove {id} {onRemove} />
</div>
{/if}
</button>
@@ -1,5 +1,5 @@
<script lang="ts">
import { RemoveButton } from '$lib/components/app';
import { ActionIconRemove } from '$lib/components/app';
interface Props {
id: string;
@@ -58,7 +58,7 @@
<div
class="absolute top-1 right-1 flex items-center justify-center opacity-0 transition-opacity group-hover:opacity-100"
>
<RemoveButton {id} {onRemove} class="text-white" />
<ActionIconRemove {id} {onRemove} class="text-white" />
</div>
{/if}
</div>
@@ -1,8 +1,12 @@
<script lang="ts">
import { ChatAttachmentThumbnailImage, ChatAttachmentThumbnailFile } from '$lib/components/app';
import {
ChatAttachmentThumbnailImage,
ChatAttachmentThumbnailFile,
HorizontalScrollCarousel,
DialogChatAttachmentPreview,
DialogChatAttachmentsViewAll
} from '$lib/components/app';
import { Button } from '$lib/components/ui/button';
import { ChevronLeft, ChevronRight } from '@lucide/svelte';
import { DialogChatAttachmentPreview, DialogChatAttachmentsViewAll } from '$lib/components/app';
import { getAttachmentDisplayItems } from '$lib/utils';
interface Props {
@@ -41,12 +45,10 @@
let displayItems = $derived(getAttachmentDisplayItems({ uploadedFiles, attachments }));
let canScrollLeft = $state(false);
let canScrollRight = $state(false);
let carouselRef: HorizontalScrollCarousel | undefined = $state();
let isScrollable = $state(false);
let previewDialogOpen = $state(false);
let previewItem = $state<ChatAttachmentPreviewItem | null>(null);
let scrollContainer: HTMLDivElement | undefined = $state();
let showViewAll = $derived(limitToSingleRow && displayItems.length > 0 && isScrollable);
let viewAllDialogOpen = $state(false);
@@ -65,41 +67,9 @@
previewDialogOpen = true;
}
function scrollLeft(event?: MouseEvent) {
event?.stopPropagation();
event?.preventDefault();
if (!scrollContainer) return;
scrollContainer.scrollBy({ left: scrollContainer.clientWidth * -0.67, behavior: 'smooth' });
}
function scrollRight(event?: MouseEvent) {
event?.stopPropagation();
event?.preventDefault();
if (!scrollContainer) return;
scrollContainer.scrollBy({ left: scrollContainer.clientWidth * 0.67, behavior: 'smooth' });
}
function updateScrollButtons() {
if (!scrollContainer) return;
const { scrollLeft, scrollWidth, clientWidth } = scrollContainer;
canScrollLeft = scrollLeft > 0;
canScrollRight = scrollLeft < scrollWidth - clientWidth - 1;
isScrollable = scrollWidth > clientWidth;
}
$effect(() => {
if (scrollContainer && displayItems.length) {
scrollContainer.scrollLeft = 0;
setTimeout(() => {
updateScrollButtons();
}, 0);
if (carouselRef && displayItems.length) {
carouselRef.resetScroll();
}
});
</script>
@@ -107,67 +77,40 @@
{#if displayItems.length > 0}
<div class={className} {style}>
{#if limitToSingleRow}
<div class="relative">
<button
class="absolute top-1/2 left-4 z-10 flex h-6 w-6 -translate-y-1/2 items-center justify-center rounded-full bg-foreground/15 shadow-md backdrop-blur-xs transition-opacity hover:bg-foreground/35 {canScrollLeft
? 'opacity-100'
: 'pointer-events-none opacity-0'}"
onclick={scrollLeft}
aria-label="Scroll left"
>
<ChevronLeft class="h-4 w-4" />
</button>
<div
class="scrollbar-hide flex items-start gap-3 overflow-x-auto"
bind:this={scrollContainer}
onscroll={updateScrollButtons}
>
{#each displayItems as item (item.id)}
{#if item.isImage && item.preview}
<ChatAttachmentThumbnailImage
class="flex-shrink-0 cursor-pointer {limitToSingleRow
? 'first:ml-4 last:mr-4'
: ''}"
id={item.id}
name={item.name}
preview={item.preview}
{readonly}
onRemove={onFileRemove}
height={imageHeight}
width={imageWidth}
{imageClass}
onClick={(event) => openPreview(item, event)}
/>
{:else}
<ChatAttachmentThumbnailFile
class="flex-shrink-0 cursor-pointer {limitToSingleRow
? 'first:ml-4 last:mr-4'
: ''}"
id={item.id}
name={item.name}
size={item.size}
{readonly}
onRemove={onFileRemove}
textContent={item.textContent}
attachment={item.attachment}
uploadedFile={item.uploadedFile}
onClick={(event) => openPreview(item, event)}
/>
{/if}
{/each}
</div>
<button
class="absolute top-1/2 right-4 z-10 flex h-6 w-6 -translate-y-1/2 items-center justify-center rounded-full bg-foreground/15 shadow-md backdrop-blur-xs transition-opacity hover:bg-foreground/35 {canScrollRight
? 'opacity-100'
: 'pointer-events-none opacity-0'}"
onclick={scrollRight}
aria-label="Scroll right"
>
<ChevronRight class="h-4 w-4" />
</button>
</div>
<HorizontalScrollCarousel
bind:this={carouselRef}
onScrollableChange={(scrollable) => (isScrollable = scrollable)}
>
{#each displayItems as item (item.id)}
{#if item.isImage && item.preview}
<ChatAttachmentThumbnailImage
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
id={item.id}
name={item.name}
preview={item.preview}
{readonly}
onRemove={onFileRemove}
height={imageHeight}
width={imageWidth}
{imageClass}
onClick={(event) => openPreview(item, event)}
/>
{:else}
<ChatAttachmentThumbnailFile
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
id={item.id}
name={item.name}
size={item.size}
{readonly}
onRemove={onFileRemove}
textContent={item.textContent}
attachment={item.attachment}
uploadedFile={item.uploadedFile}
onClick={(event) => openPreview(item, event)}
/>
{/if}
{/each}
</HorizontalScrollCarousel>
{#if showViewAll}
<div class="mt-2 -mr-2 flex justify-end px-4">
@@ -1,20 +1,19 @@
<script lang="ts">
import { afterNavigate } from '$app/navigation';
import {
ChatAttachmentsList,
ChatFormActions,
ChatFormFileInputInvisible,
ChatFormHelperText,
ChatFormTextarea
} from '$lib/components/app';
import { INPUT_CLASSES } from '$lib/constants/input-classes';
import { INPUT_CLASSES } from '$lib/constants/css-classes';
import { SETTING_CONFIG_DEFAULT } from '$lib/constants/settings-config';
import { CLIPBOARD_CONTENT_QUOTE_PREFIX } from '$lib/constants/chat-form';
import { KeyboardKey, MimeTypeText } from '$lib/enums';
import { config } from '$lib/stores/settings.svelte';
import { modelOptions, selectedModelId } from '$lib/stores/models.svelte';
import { isRouterMode } from '$lib/stores/server.svelte';
import { chatStore } from '$lib/stores/chat.svelte';
import { activeMessages } from '$lib/stores/conversations.svelte';
import { MimeTypeText } from '$lib/enums';
import { isIMEComposing, parseClipboardContent } from '$lib/utils';
import {
AudioRecorder,
@@ -25,68 +24,82 @@
import { onMount } from 'svelte';
interface Props {
// Data
attachments?: DatabaseMessageExtra[];
uploadedFiles?: ChatUploadedFile[];
value?: string;
// UI State
class?: string;
disabled?: boolean;
initialMessage?: string;
isLoading?: boolean;
onFileRemove?: (fileId: string) => void;
onFileUpload?: (files: File[]) => void;
onSend?: (message: string, files?: ChatUploadedFile[]) => Promise<boolean>;
placeholder?: string;
// Event Handlers
onAttachmentRemove?: (index: number) => void;
onFilesAdd?: (files: File[]) => void;
onStop?: () => void;
onSystemPromptAdd?: (draft: { message: string; files: ChatUploadedFile[] }) => void;
showHelperText?: boolean;
uploadedFiles?: ChatUploadedFile[];
onSubmit?: () => void;
onSystemPromptClick?: (draft: { message: string; files: ChatUploadedFile[] }) => void;
onUploadedFileRemove?: (fileId: string) => void;
onValueChange?: (value: string) => void;
}
let {
class: className,
attachments = [],
class: className = '',
disabled = false,
initialMessage = '',
isLoading = false,
onFileRemove,
onFileUpload,
onSend,
placeholder = 'Type a message...',
uploadedFiles = $bindable([]),
value = $bindable(''),
onAttachmentRemove,
onFilesAdd,
onStop,
onSystemPromptAdd,
showHelperText = true,
uploadedFiles = $bindable([])
onSubmit,
onSystemPromptClick,
onUploadedFileRemove,
onValueChange
}: Props = $props();
/**
*
*
* STATE
*
*
*/
// Component References
let audioRecorder: AudioRecorder | undefined;
let chatFormActionsRef: ChatFormActions | undefined = $state(undefined);
let currentConfig = $derived(config());
let fileInputRef: ChatFormFileInputInvisible | undefined = $state(undefined);
let textareaRef: ChatFormTextarea | undefined = $state(undefined);
// Audio Recording State
let isRecording = $state(false);
let message = $derived(initialMessage);
let recordingSupported = $state(false);
/**
*
*
* DERIVED STATE
*
*
*/
// Configuration
let currentConfig = $derived(config());
let pasteLongTextToFileLength = $derived.by(() => {
const n = Number(currentConfig.pasteLongTextToFileLen);
return Number.isNaN(n) ? Number(SETTING_CONFIG_DEFAULT.pasteLongTextToFileLen) : n;
});
let previousIsLoading = $derived(isLoading);
let previousInitialMessage = $derived(initialMessage);
let recordingSupported = $state(false);
let textareaRef: ChatFormTextarea | undefined = $state(undefined);
// Sync message when initialMessage prop changes (e.g., after draft restoration)
$effect(() => {
if (initialMessage !== previousInitialMessage) {
message = initialMessage;
previousInitialMessage = initialMessage;
}
});
function handleSystemPromptClick() {
onSystemPromptAdd?.({ message, files: uploadedFiles });
}
// Check if model is selected (in ROUTER mode)
// Model Selection Logic
let isRouter = $derived(isRouterMode());
let conversationModel = $derived(
chatStore.getConversationModel(activeMessages() as DatabaseMessage[])
);
let isRouter = $derived(isRouterMode());
let hasModelSelected = $derived(!isRouter || !!conversationModel || !!selectedModelId());
// Get active model ID for capability detection
let activeModelId = $derived.by(() => {
const options = modelOptions();
@@ -94,14 +107,12 @@
return options.length > 0 ? options[0].model : null;
}
// First try user-selected model
const selectedId = selectedModelId();
if (selectedId) {
const model = options.find((m) => m.id === selectedId);
if (model) return model.model;
}
// Fallback to conversation model
if (conversationModel) {
const model = options.find((m) => m.model === conversationModel);
if (model) return model.model;
@@ -110,46 +121,101 @@
return null;
});
function checkModelSelected(): boolean {
// Form Validation State
let hasModelSelected = $derived(!isRouter || !!conversationModel || !!selectedModelId());
let hasLoadingAttachments = $derived(uploadedFiles.some((f) => f.isLoading));
let hasAttachments = $derived(
(attachments && attachments.length > 0) || (uploadedFiles && uploadedFiles.length > 0)
);
let canSubmit = $derived(value.trim().length > 0 || hasAttachments);
/**
*
*
* LIFECYCLE
*
*
*/
onMount(() => {
recordingSupported = isAudioRecordingSupported();
audioRecorder = new AudioRecorder();
});
/**
*
*
* PUBLIC API
*
*
*/
export function focus() {
textareaRef?.focus();
}
export function resetTextareaHeight() {
textareaRef?.resetHeight();
}
export function openModelSelector() {
chatFormActionsRef?.openModelSelector();
}
/**
* Check if a model is selected, open selector if not
* @returns true if model is selected, false otherwise
*/
export function checkModelSelected(): boolean {
if (!hasModelSelected) {
// Open the model selector
chatFormActionsRef?.openModelSelector();
return false;
}
return true;
}
/**
*
*
* EVENT HANDLERS - File Management
*
*
*/
function handleFileSelect(files: File[]) {
onFileUpload?.(files);
onFilesAdd?.(files);
}
function handleFileUpload() {
fileInputRef?.click();
}
async function handleKeydown(event: KeyboardEvent) {
if (event.key === 'Enter' && !event.shiftKey && !isIMEComposing(event)) {
function handleFileRemove(fileId: string) {
if (fileId.startsWith('attachment-')) {
const index = parseInt(fileId.replace('attachment-', ''), 10);
if (!isNaN(index) && index >= 0 && index < attachments.length) {
onAttachmentRemove?.(index);
}
} else {
onUploadedFileRemove?.(fileId);
}
}
/**
*
*
* EVENT HANDLERS - Input & Keyboard
*
*
*/
function handleKeydown(event: KeyboardEvent) {
if (event.key === KeyboardKey.ENTER && !event.shiftKey && !isIMEComposing(event)) {
event.preventDefault();
if ((!message.trim() && uploadedFiles.length === 0) || disabled || isLoading) return;
if (!canSubmit || disabled || isLoading || hasLoadingAttachments) return;
if (!checkModelSelected()) return;
const messageToSend = message.trim();
const filesToSend = [...uploadedFiles];
message = '';
uploadedFiles = [];
textareaRef?.resetHeight();
const success = await onSend?.(messageToSend, filesToSend);
if (!success) {
message = messageToSend;
uploadedFiles = filesToSend;
}
onSubmit?.();
}
}
@@ -163,29 +229,30 @@
if (files.length > 0) {
event.preventDefault();
onFileUpload?.(files);
onFilesAdd?.(files);
return;
}
const text = event.clipboardData.getData(MimeTypeText.PLAIN);
if (text.startsWith('"')) {
if (text.startsWith(CLIPBOARD_CONTENT_QUOTE_PREFIX)) {
const parsed = parseClipboardContent(text);
if (parsed.textAttachments.length > 0) {
event.preventDefault();
value = parsed.message;
onValueChange?.(parsed.message);
message = parsed.message;
const attachmentFiles = parsed.textAttachments.map(
(att) =>
new File([att.content], att.name, {
type: MimeTypeText.PLAIN
})
);
onFileUpload?.(attachmentFiles);
// Handle text attachments as files
if (parsed.textAttachments.length > 0) {
const attachmentFiles = parsed.textAttachments.map(
(att) =>
new File([att.content], att.name, {
type: MimeTypeText.PLAIN
})
);
onFilesAdd?.(attachmentFiles);
}
setTimeout(() => {
textareaRef?.focus();
@@ -206,14 +273,21 @@
type: MimeTypeText.PLAIN
});
onFileUpload?.([textFile]);
onFilesAdd?.([textFile]);
}
}
/**
*
*
* EVENT HANDLERS - Audio Recording
*
*
*/
async function handleMicClick() {
if (!audioRecorder || !recordingSupported) {
console.warn('Audio recording not supported');
return;
}
@@ -223,7 +297,7 @@
const wavBlob = await convertToWav(audioBlob);
const audioFile = createAudioFile(wavBlob);
onFileUpload?.([audioFile]);
onFilesAdd?.([audioFile]);
isRecording = false;
} catch (error) {
console.error('Failed to stop recording:', error);
@@ -238,98 +312,64 @@
}
}
}
function handleStop() {
onStop?.();
}
async function handleSubmit(event: SubmitEvent) {
event.preventDefault();
if ((!message.trim() && uploadedFiles.length === 0) || disabled || isLoading) return;
// Check if model is selected first
if (!checkModelSelected()) return;
const messageToSend = message.trim();
const filesToSend = [...uploadedFiles];
message = '';
uploadedFiles = [];
textareaRef?.resetHeight();
const success = await onSend?.(messageToSend, filesToSend);
if (!success) {
message = messageToSend;
uploadedFiles = filesToSend;
}
}
onMount(() => {
setTimeout(() => textareaRef?.focus(), 10);
recordingSupported = isAudioRecordingSupported();
audioRecorder = new AudioRecorder();
});
afterNavigate(() => {
setTimeout(() => textareaRef?.focus(), 10);
});
$effect(() => {
if (previousIsLoading && !isLoading) {
setTimeout(() => textareaRef?.focus(), 10);
}
previousIsLoading = isLoading;
});
</script>
<ChatFormFileInputInvisible bind:this={fileInputRef} onFileSelect={handleFileSelect} />
<form
onsubmit={handleSubmit}
class="relative {INPUT_CLASSES} border-radius-bottom-none mx-auto max-w-[48rem] overflow-hidden rounded-3xl backdrop-blur-md {disabled
? 'cursor-not-allowed opacity-60'
: ''} {className}"
data-slot="chat-form"
class="relative {className}"
onsubmit={(e) => {
e.preventDefault();
if (!canSubmit || disabled || isLoading || hasLoadingAttachments) return;
onSubmit?.();
}}
>
<ChatAttachmentsList
bind:uploadedFiles
{onFileRemove}
limitToSingleRow
class="py-5"
style="scroll-padding: 1rem;"
activeModelId={activeModelId ?? undefined}
/>
<div
class="flex-column relative min-h-[48px] items-center rounded-3xl py-2 pb-2.25 shadow-sm transition-all focus-within:shadow-md md:!py-3"
onpaste={handlePaste}
class="{INPUT_CLASSES} overflow-hidden rounded-3xl backdrop-blur-md {disabled
? 'cursor-not-allowed opacity-60'
: ''}"
data-slot="input-area"
>
<ChatFormTextarea
class="px-5 py-1.5 md:pt-0"
bind:this={textareaRef}
bind:value={message}
onKeydown={handleKeydown}
{disabled}
<ChatAttachmentsList
{attachments}
bind:uploadedFiles
onFileRemove={handleFileRemove}
limitToSingleRow
class="py-5"
style="scroll-padding: 1rem;"
activeModelId={activeModelId ?? undefined}
/>
<ChatFormActions
class="px-3"
bind:this={chatFormActionsRef}
canSend={message.trim().length > 0 || uploadedFiles.length > 0}
hasText={message.trim().length > 0}
{disabled}
{isLoading}
{isRecording}
{uploadedFiles}
onFileUpload={handleFileUpload}
onMicClick={handleMicClick}
onStop={handleStop}
onSystemPromptClick={handleSystemPromptClick}
/>
<div
class="flex-column relative min-h-[48px] items-center rounded-3xl py-2 pb-2.25 shadow-sm transition-all focus-within:shadow-md md:!py-3"
onpaste={handlePaste}
>
<ChatFormTextarea
class="px-5 py-1.5 md:pt-0"
bind:this={textareaRef}
bind:value
onKeydown={handleKeydown}
onInput={() => {
onValueChange?.(value);
}}
{disabled}
{placeholder}
/>
<ChatFormActions
class="px-3"
bind:this={chatFormActionsRef}
canSend={canSubmit}
hasText={value.trim().length > 0}
{disabled}
{isLoading}
{isRecording}
{uploadedFiles}
onFileUpload={handleFileUpload}
onMicClick={handleMicClick}
{onStop}
onSystemPromptClick={() => onSystemPromptClick?.({ message: value, files: uploadedFiles })}
/>
</div>
</div>
</form>
<ChatFormHelperText show={showHelperText} />
@@ -1,6 +1,6 @@
<script lang="ts">
import { page } from '$app/state';
import { MessageSquare, Plus } from '@lucide/svelte';
import { Plus, MessageSquare } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import * as Tooltip from '$lib/components/ui/tooltip';
@@ -16,16 +16,6 @@
onSystemPromptClick?: () => void;
}
type AttachmentActionId = 'images' | 'audio' | 'text' | 'pdf' | 'system';
interface AttachmentAction {
id: AttachmentActionId;
label: string;
disabled?: boolean;
disabledReason?: string;
tooltip?: string;
}
let {
class: className = '',
disabled = false,
@@ -36,62 +26,20 @@
}: Props = $props();
let isNewChat = $derived(!page.params.id);
let systemMessageTooltip = $derived(
isNewChat
? 'Add custom system message for a new conversation'
: 'Inject custom system message at the beginning of the conversation'
);
let actions = $derived.by<AttachmentAction[]>(() => [
{
id: 'images',
label: 'Images',
disabled: !hasVisionModality,
disabledReason: !hasVisionModality
? 'Images require vision models to be processed'
: undefined
},
{
id: 'audio',
label: 'Audio Files',
disabled: !hasAudioModality,
disabledReason: !hasAudioModality
? 'Audio files require audio models to be processed'
: undefined
},
{
id: 'text',
label: 'Text Files'
},
{
id: 'pdf',
label: 'PDF Files',
tooltip: !hasVisionModality
? 'PDFs will be converted to text. Image-based PDFs may not work properly.'
: undefined
},
{
id: 'system',
label: 'System Message',
tooltip: systemMessageTooltip
}
]);
let dropdownOpen = $state(false);
function handleActionClick(id: AttachmentActionId) {
if (id === 'system') {
onSystemPromptClick?.();
return;
}
onFileUpload?.();
}
const triggerTooltipText = 'Add files or system message';
const itemClass = 'flex cursor-pointer items-center gap-2';
const fileUploadTooltipText = 'Add files, system prompt or MCP Servers';
</script>
<div class="flex items-center gap-1 {className}">
<DropdownMenu.Root>
<DropdownMenu.Root bind:open={dropdownOpen}>
<DropdownMenu.Trigger name="Attach files" {disabled}>
<Tooltip.Root>
<Tooltip.Trigger class="w-full">
@@ -101,89 +49,125 @@
variant="secondary"
type="button"
>
<span class="sr-only">{triggerTooltipText}</span>
<span class="sr-only">{fileUploadTooltipText}</span>
<Plus class="h-4 w-4" />
</Button>
</Tooltip.Trigger>
<Tooltip.Content>
<p>{triggerTooltipText}</p>
<p>{fileUploadTooltipText}</p>
</Tooltip.Content>
</Tooltip.Root>
</DropdownMenu.Trigger>
<DropdownMenu.Content align="start" class="w-56">
{#each actions as item (item.id)}
{@const hasDisabledTooltip = !!item.disabled && !!item.disabledReason}
{@const hasEnabledTooltip = !item.disabled && !!item.tooltip}
<DropdownMenu.Content align="start" class="w-48">
{#if hasVisionModality}
<DropdownMenu.Item
class="images-button flex cursor-pointer items-center gap-2"
onclick={() => onFileUpload?.()}
>
<FILE_TYPE_ICONS.image class="h-4 w-4" />
{#if hasDisabledTooltip}
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
<Tooltip.Trigger class="w-full">
<DropdownMenu.Item class={itemClass} disabled>
{#if item.id === 'images'}
<FILE_TYPE_ICONS.image class="h-4 w-4" />
{:else if item.id === 'audio'}
<FILE_TYPE_ICONS.audio class="h-4 w-4" />
{:else if item.id === 'text'}
<FILE_TYPE_ICONS.text class="h-4 w-4" />
{:else if item.id === 'pdf'}
<FILE_TYPE_ICONS.pdf class="h-4 w-4" />
{:else}
<MessageSquare class="h-4 w-4" />
{/if}
<span>{item.label}</span>
</DropdownMenu.Item>
</Tooltip.Trigger>
<Tooltip.Content side="right">
<p>{item.disabledReason}</p>
</Tooltip.Content>
</Tooltip.Root>
{:else if hasEnabledTooltip}
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
<Tooltip.Trigger class="w-full">
<DropdownMenu.Item class={itemClass} onclick={() => handleActionClick(item.id)}>
{#if item.id === 'images'}
<FILE_TYPE_ICONS.image class="h-4 w-4" />
{:else if item.id === 'audio'}
<FILE_TYPE_ICONS.audio class="h-4 w-4" />
{:else if item.id === 'text'}
<FILE_TYPE_ICONS.text class="h-4 w-4" />
{:else if item.id === 'pdf'}
<FILE_TYPE_ICONS.pdf class="h-4 w-4" />
{:else}
<MessageSquare class="h-4 w-4" />
{/if}
<span>{item.label}</span>
</DropdownMenu.Item>
</Tooltip.Trigger>
<Tooltip.Content side="right">
<p>{item.tooltip}</p>
</Tooltip.Content>
</Tooltip.Root>
{:else}
<DropdownMenu.Item class={itemClass} onclick={() => handleActionClick(item.id)}>
{#if item.id === 'images'}
<span>Images</span>
</DropdownMenu.Item>
{:else}
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
<Tooltip.Trigger class="w-full">
<DropdownMenu.Item
class="images-button flex cursor-pointer items-center gap-2"
disabled
>
<FILE_TYPE_ICONS.image class="h-4 w-4" />
{:else if item.id === 'audio'}
<FILE_TYPE_ICONS.audio class="h-4 w-4" />
{:else if item.id === 'text'}
<FILE_TYPE_ICONS.text class="h-4 w-4" />
{:else if item.id === 'pdf'}
<FILE_TYPE_ICONS.pdf class="h-4 w-4" />
{:else}
<MessageSquare class="h-4 w-4" />
{/if}
<span>{item.label}</span>
<span>Images</span>
</DropdownMenu.Item>
</Tooltip.Trigger>
<Tooltip.Content side="right">
<p>Images require vision models to be processed</p>
</Tooltip.Content>
</Tooltip.Root>
{/if}
{#if hasAudioModality}
<DropdownMenu.Item
class="audio-button flex cursor-pointer items-center gap-2"
onclick={() => onFileUpload?.()}
>
<FILE_TYPE_ICONS.audio class="h-4 w-4" />
<span>Audio Files</span>
</DropdownMenu.Item>
{:else}
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
<Tooltip.Trigger class="w-full">
<DropdownMenu.Item class="audio-button flex cursor-pointer items-center gap-2" disabled>
<FILE_TYPE_ICONS.audio class="h-4 w-4" />
<span>Audio Files</span>
</DropdownMenu.Item>
</Tooltip.Trigger>
<Tooltip.Content side="right">
<p>Audio files require audio models to be processed</p>
</Tooltip.Content>
</Tooltip.Root>
{/if}
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={() => onFileUpload?.()}
>
<FILE_TYPE_ICONS.text class="h-4 w-4" />
<span>Text Files</span>
</DropdownMenu.Item>
{#if hasVisionModality}
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={() => onFileUpload?.()}
>
<FILE_TYPE_ICONS.pdf class="h-4 w-4" />
<span>PDF Files</span>
</DropdownMenu.Item>
{:else}
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
<Tooltip.Trigger class="w-full">
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={() => onFileUpload?.()}
>
<FILE_TYPE_ICONS.pdf class="h-4 w-4" />
<span>PDF Files</span>
</DropdownMenu.Item>
</Tooltip.Trigger>
<Tooltip.Content side="right">
<p>PDFs will be converted to text. Image-based PDFs may not work properly.</p>
</Tooltip.Content>
</Tooltip.Root>
{/if}
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
<Tooltip.Trigger class="w-full">
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={() => onSystemPromptClick?.()}
>
<MessageSquare class="h-4 w-4" />
<span>System Message</span>
</DropdownMenu.Item>
{/if}
{/each}
</Tooltip.Trigger>
<Tooltip.Content side="right">
<p>{systemMessageTooltip}</p>
</Tooltip.Content>
</Tooltip.Root>
</DropdownMenu.Content>
</DropdownMenu.Root>
</div>
@@ -1,143 +0,0 @@
<script lang="ts">
import { Paperclip } from '@lucide/svelte';
import { MessageSquare } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import * as Tooltip from '$lib/components/ui/tooltip';
import { FILE_TYPE_ICONS } from '$lib/constants/icons';
interface Props {
class?: string;
disabled?: boolean;
hasAudioModality?: boolean;
hasVisionModality?: boolean;
onFileUpload?: () => void;
onSystemPromptClick?: () => void;
}
let {
class: className = '',
disabled = false,
hasAudioModality = false,
hasVisionModality = false,
onFileUpload,
onSystemPromptClick
}: Props = $props();
const fileUploadTooltipText = $derived.by(() => {
return !hasVisionModality
? 'Text files and PDFs supported. Images, audio, and video require vision models.'
: 'Attach files';
});
</script>
<div class="flex items-center gap-1 {className}">
<DropdownMenu.Root>
<DropdownMenu.Trigger name="Attach files" {disabled}>
<Tooltip.Root>
<Tooltip.Trigger>
<Button
class="file-upload-button h-8 w-8 rounded-full bg-transparent p-0 text-muted-foreground hover:bg-foreground/10 hover:text-foreground"
{disabled}
type="button"
>
<span class="sr-only">Attach files</span>
<Paperclip class="h-4 w-4" />
</Button>
</Tooltip.Trigger>
<Tooltip.Content>
<p>{fileUploadTooltipText}</p>
</Tooltip.Content>
</Tooltip.Root>
</DropdownMenu.Trigger>
<DropdownMenu.Content align="start" class="w-48">
<Tooltip.Root>
<Tooltip.Trigger class="w-full">
<DropdownMenu.Item
class="images-button flex cursor-pointer items-center gap-2"
disabled={!hasVisionModality}
onclick={() => onFileUpload?.()}
>
<FILE_TYPE_ICONS.image class="h-4 w-4" />
<span>Images</span>
</DropdownMenu.Item>
</Tooltip.Trigger>
{#if !hasVisionModality}
<Tooltip.Content>
<p>Images require vision models to be processed</p>
</Tooltip.Content>
{/if}
</Tooltip.Root>
<Tooltip.Root>
<Tooltip.Trigger class="w-full">
<DropdownMenu.Item
class="audio-button flex cursor-pointer items-center gap-2"
disabled={!hasAudioModality}
onclick={() => onFileUpload?.()}
>
<FILE_TYPE_ICONS.audio class="h-4 w-4" />
<span>Audio Files</span>
</DropdownMenu.Item>
</Tooltip.Trigger>
{#if !hasAudioModality}
<Tooltip.Content>
<p>Audio files require audio models to be processed</p>
</Tooltip.Content>
{/if}
</Tooltip.Root>
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={() => onFileUpload?.()}
>
<FILE_TYPE_ICONS.text class="h-4 w-4" />
<span>Text Files</span>
</DropdownMenu.Item>
<Tooltip.Root>
<Tooltip.Trigger class="w-full">
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={() => onFileUpload?.()}
>
<FILE_TYPE_ICONS.pdf class="h-4 w-4" />
<span>PDF Files</span>
</DropdownMenu.Item>
</Tooltip.Trigger>
{#if !hasVisionModality}
<Tooltip.Content>
<p>PDFs will be converted to text. Image-based PDFs may not work properly.</p>
</Tooltip.Content>
{/if}
</Tooltip.Root>
<DropdownMenu.Separator />
<Tooltip.Root>
<Tooltip.Trigger class="w-full">
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={() => onSystemPromptClick?.()}
>
<MessageSquare class="h-4 w-4" />
<span>System Prompt</span>
</DropdownMenu.Item>
</Tooltip.Trigger>
<Tooltip.Content>
<p>Add a custom system message for this conversation</p>
</Tooltip.Content>
</Tooltip.Root>
</DropdownMenu.Content>
</DropdownMenu.Root>
</div>
@@ -13,8 +13,7 @@
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
import { isRouterMode } from '$lib/stores/server.svelte';
import { chatStore } from '$lib/stores/chat.svelte';
import { activeMessages, usedModalities } from '$lib/stores/conversations.svelte';
import { useModelChangeValidation } from '$lib/hooks/use-model-change-validation.svelte';
import { activeMessages } from '$lib/stores/conversations.svelte';
interface Props {
canSend?: boolean;
@@ -154,15 +153,6 @@
export function openModelSelector() {
selectorModelRef?.open();
}
const { handleModelChange } = useModelChangeValidation({
getRequiredModalities: () => usedModalities(),
onValidationFailure: async (previousModelId: string | null) => {
if (previousModelId) {
await modelsStore.selectModelById(previousModelId);
}
}
});
</script>
<div class="flex w-full items-center gap-3 {className}" style="container-type: inline-size">
@@ -183,7 +173,6 @@
currentModel={conversationModel}
forceForegroundText={true}
useGlobalSelection={true}
onModelChange={handleModelChange}
/>
</div>
@@ -5,6 +5,7 @@
interface Props {
class?: string;
disabled?: boolean;
onInput?: () => void;
onKeydown?: (event: KeyboardEvent) => void;
onPaste?: (event: ClipboardEvent) => void;
placeholder?: string;
@@ -14,6 +15,7 @@
let {
class: className = '',
disabled = false,
onInput,
onKeydown,
onPaste,
placeholder = 'Ask anything...',
@@ -52,7 +54,10 @@
class:cursor-not-allowed={disabled}
{disabled}
onkeydown={onKeydown}
oninput={(event) => autoResizeTextarea(event.currentTarget)}
oninput={(event) => {
autoResizeTextarea(event.currentTarget);
onInput?.();
}}
onpaste={onPaste}
{placeholder}
></textarea>
@@ -1,61 +1,35 @@
<script lang="ts">
import { goto } from '$app/navigation';
import { base } from '$app/paths';
import {
chatStore,
pendingEditMessageId,
clearPendingEditMessageId,
removeSystemPromptPlaceholder
} from '$lib/stores/chat.svelte';
import { getChatActionsContext, setMessageEditContext } from '$lib/contexts';
import { chatStore, pendingEditMessageId } from '$lib/stores/chat.svelte';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { DatabaseService } from '$lib/services';
import { config } from '$lib/stores/settings.svelte';
import { SYSTEM_MESSAGE_PLACEHOLDER } from '$lib/constants/ui';
import { copyToClipboard, isIMEComposing, formatMessageForClipboard } from '$lib/utils';
import ChatMessageAssistant from './ChatMessageAssistant.svelte';
import ChatMessageUser from './ChatMessageUser.svelte';
import ChatMessageSystem from './ChatMessageSystem.svelte';
import { MessageRole } from '$lib/enums';
import {
ChatMessageAssistant,
ChatMessageUser,
ChatMessageSystem
} from '$lib/components/app/chat';
import { parseFilesToMessageExtras } from '$lib/utils/browser-only';
interface Props {
class?: string;
message: DatabaseMessage;
onCopy?: (message: DatabaseMessage) => void;
onContinueAssistantMessage?: (message: DatabaseMessage) => void;
onDelete?: (message: DatabaseMessage) => void;
onEditWithBranching?: (
message: DatabaseMessage,
newContent: string,
newExtras?: DatabaseMessageExtra[]
) => void;
onEditWithReplacement?: (
message: DatabaseMessage,
newContent: string,
shouldBranch: boolean
) => void;
onEditUserMessagePreserveResponses?: (
message: DatabaseMessage,
newContent: string,
newExtras?: DatabaseMessageExtra[]
) => void;
onNavigateToSibling?: (siblingId: string) => void;
onRegenerateWithBranching?: (message: DatabaseMessage, modelOverride?: string) => void;
isLastAssistantMessage?: boolean;
siblingInfo?: ChatMessageSiblingInfo | null;
}
let {
class: className = '',
message,
onCopy,
onContinueAssistantMessage,
onDelete,
onEditWithBranching,
onEditWithReplacement,
onEditUserMessagePreserveResponses,
onNavigateToSibling,
onRegenerateWithBranching,
isLastAssistantMessage = false,
siblingInfo = null
}: Props = $props();
const chatActions = getChatActionsContext();
let deletionInfo = $state<{
totalCount: number;
userMessages: number;
@@ -70,45 +44,51 @@
let shouldBranchAfterEdit = $state(false);
let textareaElement: HTMLTextAreaElement | undefined = $state();
let thinkingContent = $derived.by(() => {
if (message.role === 'assistant') {
const trimmedThinking = message.thinking?.trim();
let showSaveOnlyOption = $derived(message.role === MessageRole.USER);
return trimmedThinking ? trimmedThinking : null;
}
return null;
setMessageEditContext({
get isEditing() {
return isEditing;
},
get editedContent() {
return editedContent;
},
get editedExtras() {
return editedExtras;
},
get editedUploadedFiles() {
return editedUploadedFiles;
},
get originalContent() {
return message.content;
},
get originalExtras() {
return message.extra || [];
},
get showSaveOnlyOption() {
return showSaveOnlyOption;
},
setContent: (content: string) => {
editedContent = content;
},
setExtras: (extras: DatabaseMessageExtra[]) => {
editedExtras = extras;
},
setUploadedFiles: (files: ChatUploadedFile[]) => {
editedUploadedFiles = files;
},
save: handleSaveEdit,
saveOnly: handleSaveEditOnly,
cancel: handleCancelEdit,
startEdit: handleEdit
});
let toolCallContent = $derived.by((): ApiChatCompletionToolCall[] | string | null => {
if (message.role === 'assistant') {
const trimmedToolCalls = message.toolCalls?.trim();
if (!trimmedToolCalls) {
return null;
}
try {
const parsed = JSON.parse(trimmedToolCalls);
if (Array.isArray(parsed)) {
return parsed as ApiChatCompletionToolCall[];
}
} catch {
// Harmony-only path: fall back to the raw string so issues surface visibly.
}
return trimmedToolCalls;
}
return null;
});
// Auto-start edit mode if this message is the pending edit target
$effect(() => {
const pendingId = pendingEditMessageId();
if (pendingId && pendingId === message.id && !isEditing) {
handleEdit();
clearPendingEditMessageId();
chatStore.clearPendingEditMessageId();
}
});
@@ -116,8 +96,8 @@
isEditing = false;
// If canceling a new system message with placeholder content, remove it without deleting children
if (message.role === 'system') {
const conversationDeleted = await removeSystemPromptPlaceholder(message.id);
if (message.role === MessageRole.SYSTEM) {
const conversationDeleted = await chatStore.removeSystemPromptPlaceholder(message.id);
if (conversationDeleted) {
goto(`${base}/`);
@@ -131,30 +111,19 @@
editedUploadedFiles = [];
}
function handleEditedExtrasChange(extras: DatabaseMessageExtra[]) {
editedExtras = extras;
}
function handleEditedUploadedFilesChange(files: ChatUploadedFile[]) {
editedUploadedFiles = files;
}
async function handleCopy() {
const asPlainText = Boolean(config().copyTextAttachmentsAsPlainText);
const clipboardContent = formatMessageForClipboard(message.content, message.extra, asPlainText);
await copyToClipboard(clipboardContent, 'Message copied to clipboard');
onCopy?.(message);
function handleCopy() {
chatActions.copy(message);
}
async function handleConfirmDelete() {
if (message.role === 'system') {
const conversationDeleted = await removeSystemPromptPlaceholder(message.id);
if (message.role === MessageRole.SYSTEM) {
const conversationDeleted = await chatStore.removeSystemPromptPlaceholder(message.id);
if (conversationDeleted) {
goto('/');
goto(`${base}/`);
}
} else {
onDelete?.(message);
chatActions.delete(message);
}
showDeleteDialog = false;
@@ -167,9 +136,9 @@
function handleEdit() {
isEditing = true;
// Clear placeholder content for system messages
// Clear temporary placeholder content for system messages
editedContent =
message.role === 'system' && message.content === SYSTEM_MESSAGE_PLACEHOLDER
message.role === MessageRole.SYSTEM && message.content === SYSTEM_MESSAGE_PLACEHOLDER
? ''
: message.content;
textareaElement?.focus();
@@ -187,38 +156,26 @@
}, 0);
}
function handleEditedContentChange(content: string) {
editedContent = content;
}
function handleEditKeydown(event: KeyboardEvent) {
// Check for IME composition using isComposing property and keyCode 229 (specifically for IME composition on Safari)
// This prevents saving edit when confirming IME word selection (e.g., Japanese/Chinese input)
if (event.key === 'Enter' && !event.shiftKey && !isIMEComposing(event)) {
event.preventDefault();
handleSaveEdit();
} else if (event.key === 'Escape') {
event.preventDefault();
handleCancelEdit();
}
}
function handleRegenerate(modelOverride?: string) {
onRegenerateWithBranching?.(message, modelOverride);
chatActions.regenerateWithBranching(message, modelOverride);
}
function handleContinue() {
onContinueAssistantMessage?.(message);
chatActions.continueAssistantMessage(message);
}
function handleNavigateToSibling(siblingId: string) {
chatActions.navigateToSibling(siblingId);
}
async function handleSaveEdit() {
if (message.role === 'system') {
if (message.role === MessageRole.SYSTEM) {
// System messages: update in place without branching
const newContent = editedContent.trim();
// If content is empty or still the placeholder, remove without deleting children
// If content is empty, remove without deleting children
if (!newContent) {
const conversationDeleted = await removeSystemPromptPlaceholder(message.id);
const conversationDeleted = await chatStore.removeSystemPromptPlaceholder(message.id);
isEditing = false;
if (conversationDeleted) {
goto(`${base}/`);
@@ -231,13 +188,13 @@
if (index !== -1) {
conversationsStore.updateMessageAtIndex(index, { content: newContent });
}
} else if (message.role === 'user') {
} else if (message.role === MessageRole.USER) {
const finalExtras = await getMergedExtras();
onEditWithBranching?.(message, editedContent.trim(), finalExtras);
chatActions.editWithBranching(message, editedContent.trim(), finalExtras);
} else {
// For assistant messages, preserve exact content including trailing whitespace
// This is important for the Continue feature to work properly
onEditWithReplacement?.(message, editedContent, shouldBranchAfterEdit);
chatActions.editWithReplacement(message, editedContent, shouldBranchAfterEdit);
}
isEditing = false;
@@ -246,10 +203,10 @@
}
async function handleSaveEditOnly() {
if (message.role === 'user') {
if (message.role === MessageRole.USER) {
// For user messages, trim to avoid accidental whitespace
const finalExtras = await getMergedExtras();
onEditUserMessagePreserveResponses?.(message, editedContent.trim(), finalExtras);
chatActions.editUserMessagePreserveResponses(message, editedContent.trim(), finalExtras);
}
isEditing = false;
@@ -261,8 +218,8 @@
return editedExtras;
}
const { parseFilesToMessageExtras } = await import('$lib/utils/browser-only');
const result = await parseFilesToMessageExtras(editedUploadedFiles);
const plainFiles = $state.snapshot(editedUploadedFiles);
const result = await parseFilesToMessageExtras(plainFiles);
const newExtras = result?.extras || [];
return [...editedExtras, ...newExtras];
@@ -273,49 +230,31 @@
}
</script>
{#if message.role === 'system'}
{#if message.role === MessageRole.SYSTEM}
<ChatMessageSystem
bind:textareaElement
class={className}
{deletionInfo}
{editedContent}
{isEditing}
{message}
onCancelEdit={handleCancelEdit}
onConfirmDelete={handleConfirmDelete}
onCopy={handleCopy}
onDelete={handleDelete}
onEdit={handleEdit}
onEditKeydown={handleEditKeydown}
onEditedContentChange={handleEditedContentChange}
{onNavigateToSibling}
onSaveEdit={handleSaveEdit}
onNavigateToSibling={handleNavigateToSibling}
onShowDeleteDialogChange={handleShowDeleteDialogChange}
{showDeleteDialog}
{siblingInfo}
/>
{:else if message.role === 'user'}
{:else if message.role === MessageRole.USER}
<ChatMessageUser
bind:textareaElement
class={className}
{deletionInfo}
{editedContent}
{editedExtras}
{editedUploadedFiles}
{isEditing}
{message}
onCancelEdit={handleCancelEdit}
onConfirmDelete={handleConfirmDelete}
onCopy={handleCopy}
onDelete={handleDelete}
onEdit={handleEdit}
onEditKeydown={handleEditKeydown}
onEditedContentChange={handleEditedContentChange}
onEditedExtrasChange={handleEditedExtrasChange}
onEditedUploadedFilesChange={handleEditedUploadedFilesChange}
{onNavigateToSibling}
onSaveEdit={handleSaveEdit}
onSaveEditOnly={handleSaveEditOnly}
onNavigateToSibling={handleNavigateToSibling}
onShowDeleteDialogChange={handleShowDeleteDialogChange}
{showDeleteDialog}
{siblingInfo}
@@ -325,27 +264,18 @@
bind:textareaElement
class={className}
{deletionInfo}
{editedContent}
{isEditing}
{isLastAssistantMessage}
{message}
messageContent={message.content}
onCancelEdit={handleCancelEdit}
onConfirmDelete={handleConfirmDelete}
onContinue={handleContinue}
onCopy={handleCopy}
onDelete={handleDelete}
onEdit={handleEdit}
onEditKeydown={handleEditKeydown}
onEditedContentChange={handleEditedContentChange}
{onNavigateToSibling}
onNavigateToSibling={handleNavigateToSibling}
onRegenerate={handleRegenerate}
onSaveEdit={handleSaveEdit}
onShowDeleteDialogChange={handleShowDeleteDialogChange}
{shouldBranchAfterEdit}
onShouldBranchAfterEditChange={(value) => (shouldBranchAfterEdit = value)}
{showDeleteDialog}
{siblingInfo}
{thinkingContent}
{toolCallContent}
/>
{/if}
@@ -1,14 +1,15 @@
<script lang="ts">
import { Edit, Copy, RefreshCw, Trash2, ArrowRight } from '@lucide/svelte';
import {
ActionButton,
ActionIcon,
ChatMessageBranchingControls,
DialogConfirmation
} from '$lib/components/app';
import { Switch } from '$lib/components/ui/switch';
import { MessageRole } from '$lib/enums';
interface Props {
role: 'user' | 'assistant';
role: MessageRole.USER | MessageRole.ASSISTANT;
justify: 'start' | 'end';
actionsPosition: 'left' | 'right';
siblingInfo?: ChatMessageSiblingInfo | null;
@@ -71,21 +72,21 @@
<div
class="pointer-events-auto inset-0 flex items-center gap-1 opacity-100 transition-all duration-150"
>
<ActionButton icon={Copy} tooltip="Copy" onclick={onCopy} />
<ActionIcon icon={Copy} tooltip="Copy" onclick={onCopy} />
{#if onEdit}
<ActionButton icon={Edit} tooltip="Edit" onclick={onEdit} />
<ActionIcon icon={Edit} tooltip="Edit" onclick={onEdit} />
{/if}
{#if role === 'assistant' && onRegenerate}
<ActionButton icon={RefreshCw} tooltip="Regenerate" onclick={() => onRegenerate()} />
{#if role === MessageRole.ASSISTANT && onRegenerate}
<ActionIcon icon={RefreshCw} tooltip="Regenerate" onclick={() => onRegenerate()} />
{/if}
{#if role === 'assistant' && onContinue}
<ActionButton icon={ArrowRight} tooltip="Continue" onclick={onContinue} />
{#if role === MessageRole.ASSISTANT && onContinue}
<ActionIcon icon={ArrowRight} tooltip="Continue" onclick={onContinue} />
{/if}
<ActionButton icon={Trash2} tooltip="Delete" onclick={onDelete} />
<ActionIcon icon={Trash2} tooltip="Delete" onclick={onDelete} />
</div>
</div>
@@ -1,26 +1,29 @@
<script lang="ts">
import {
ModelBadge,
ChatMessageActions,
ChatMessageStatistics,
ChatMessageThinkingBlock,
CopyToClipboardIcon,
MarkdownContent,
ModelBadge,
ModelsSelector
} from '$lib/components/app';
import ChatMessageThinkingBlock from './ChatMessageThinkingBlock.svelte';
import { getMessageEditContext } from '$lib/contexts';
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
import { useModelChangeValidation } from '$lib/hooks/use-model-change-validation.svelte';
import { isLoading } from '$lib/stores/chat.svelte';
import { autoResizeTextarea, copyToClipboard } from '$lib/utils';
import { isLoading, isChatStreaming } from '$lib/stores/chat.svelte';
import { autoResizeTextarea, copyToClipboard, isIMEComposing } from '$lib/utils';
import { tick } from 'svelte';
import { fade } from 'svelte/transition';
import { Check, X, Wrench } from '@lucide/svelte';
import { Check, X } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import { Checkbox } from '$lib/components/ui/checkbox';
import { INPUT_CLASSES } from '$lib/constants/input-classes';
import { INPUT_CLASSES } from '$lib/constants/css-classes';
import { MessageRole, KeyboardKey } from '$lib/enums';
import Label from '$lib/components/ui/label/label.svelte';
import { config } from '$lib/stores/settings.svelte';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { isRouterMode } from '$lib/stores/server.svelte';
import { modelsStore } from '$lib/stores/models.svelte';
import { ServerModelStatus } from '$lib/enums';
import { REASONING_TAGS } from '$lib/constants/agentic';
interface Props {
class?: string;
@@ -30,153 +33,198 @@
assistantMessages: number;
messageTypes: string[];
} | null;
editedContent?: string;
isEditing?: boolean;
isLastAssistantMessage?: boolean;
message: DatabaseMessage;
messageContent: string | undefined;
onCancelEdit?: () => void;
onCopy: () => void;
onConfirmDelete: () => void;
onContinue?: () => void;
onDelete: () => void;
onEdit?: () => void;
onEditKeydown?: (event: KeyboardEvent) => void;
onEditedContentChange?: (content: string) => void;
onNavigateToSibling?: (siblingId: string) => void;
onRegenerate: (modelOverride?: string) => void;
onSaveEdit?: () => void;
onShowDeleteDialogChange: (show: boolean) => void;
onShouldBranchAfterEditChange?: (value: boolean) => void;
showDeleteDialog: boolean;
shouldBranchAfterEdit?: boolean;
siblingInfo?: ChatMessageSiblingInfo | null;
textareaElement?: HTMLTextAreaElement;
thinkingContent: string | null;
toolCallContent: ApiChatCompletionToolCall[] | string | null;
}
interface ParsedReasoningContent {
content: string;
reasoningContent: string | null;
hasReasoningMarkers: boolean;
}
function parseReasoningContent(content: string | undefined): ParsedReasoningContent {
if (!content) {
return {
content: '',
reasoningContent: null,
hasReasoningMarkers: false
};
}
const plainParts: string[] = [];
const reasoningParts: string[] = [];
const { START, END } = REASONING_TAGS;
let cursor = 0;
let hasReasoningMarkers = false;
while (cursor < content.length) {
const startIndex = content.indexOf(START, cursor);
if (startIndex === -1) {
plainParts.push(content.slice(cursor));
break;
}
hasReasoningMarkers = true;
plainParts.push(content.slice(cursor, startIndex));
const reasoningStart = startIndex + START.length;
const endIndex = content.indexOf(END, reasoningStart);
if (endIndex === -1) {
reasoningParts.push(content.slice(reasoningStart));
cursor = content.length;
break;
}
reasoningParts.push(content.slice(reasoningStart, endIndex));
cursor = endIndex + END.length;
}
return {
content: plainParts.join(''),
reasoningContent: reasoningParts.length > 0 ? reasoningParts.join('\n\n') : null,
hasReasoningMarkers
};
}
let {
class: className = '',
deletionInfo,
editedContent = '',
isEditing = false,
isLastAssistantMessage = false,
message,
messageContent,
onCancelEdit,
onConfirmDelete,
onContinue,
onCopy,
onDelete,
onEdit,
onEditKeydown,
onEditedContentChange,
onNavigateToSibling,
onRegenerate,
onSaveEdit,
onShowDeleteDialogChange,
onShouldBranchAfterEditChange,
showDeleteDialog,
shouldBranchAfterEdit = false,
siblingInfo = null,
textareaElement = $bindable(),
thinkingContent,
toolCallContent = null
textareaElement = $bindable()
}: Props = $props();
const toolCalls = $derived(
Array.isArray(toolCallContent) ? (toolCallContent as ApiChatCompletionToolCall[]) : null
);
const fallbackToolCalls = $derived(typeof toolCallContent === 'string' ? toolCallContent : null);
// Get edit context
const editCtx = getMessageEditContext();
// Local state for assistant-specific editing
let shouldBranchAfterEdit = $state(false);
function handleEditKeydown(event: KeyboardEvent) {
if (event.key === KeyboardKey.ENTER && !event.shiftKey && !isIMEComposing(event)) {
event.preventDefault();
editCtx.save();
} else if (event.key === KeyboardKey.ESCAPE) {
event.preventDefault();
editCtx.cancel();
}
}
const parsedMessageContent = $derived.by(() => parseReasoningContent(messageContent));
const visibleMessageContent = $derived(parsedMessageContent.content);
const thinkingContent = $derived(parsedMessageContent.reasoningContent);
const hasReasoningMarkers = $derived(parsedMessageContent.hasReasoningMarkers);
const processingState = useProcessingState();
// Local state for raw output toggle (per message)
let showRawOutput = $state(false);
let currentConfig = $derived(config());
let isRouter = $derived(isRouterMode());
let displayedModel = $derived((): string | null => {
if (message.model) {
return message.model;
let showRawOutput = $state(false);
let statsContainerEl: HTMLDivElement | undefined = $state();
function getScrollParent(el: HTMLElement): HTMLElement | null {
let parent = el.parentElement;
while (parent) {
const style = getComputedStyle(parent);
if (/(auto|scroll)/.test(style.overflowY)) {
return parent;
}
parent = parent.parentElement;
}
return null;
}
async function handleStatsViewChange() {
const el = statsContainerEl;
if (!el) {
return;
}
return null;
});
const scrollParent = getScrollParent(el);
if (!scrollParent) {
return;
}
const { handleModelChange } = useModelChangeValidation({
getRequiredModalities: () => conversationsStore.getModalitiesUpToMessage(message.id),
onSuccess: (modelName: string) => onRegenerate(modelName)
});
const yBefore = el.getBoundingClientRect().top;
await tick();
const delta = el.getBoundingClientRect().top - yBefore;
if (delta !== 0) {
scrollParent.scrollTop += delta;
}
// Correct any drift after browser paint
requestAnimationFrame(() => {
const drift = el.getBoundingClientRect().top - yBefore;
if (Math.abs(drift) > 1) {
scrollParent.scrollTop += drift;
}
});
}
let displayedModel = $derived(message.model ?? null);
let isCurrentlyLoading = $derived(isLoading());
let isStreaming = $derived(isChatStreaming());
let hasNoContent = $derived(!visibleMessageContent?.trim());
let isActivelyProcessing = $derived(isCurrentlyLoading || isStreaming);
let showProcessingInfoTop = $derived(
message?.role === MessageRole.ASSISTANT &&
isActivelyProcessing &&
hasNoContent &&
isLastAssistantMessage
);
let showProcessingInfoBottom = $derived(
message?.role === MessageRole.ASSISTANT &&
isActivelyProcessing &&
!hasNoContent &&
isLastAssistantMessage
);
function handleCopyModel() {
const model = displayedModel();
void copyToClipboard(model ?? '');
void copyToClipboard(displayedModel ?? '');
}
$effect(() => {
if (isEditing && textareaElement) {
if (editCtx.isEditing && textareaElement) {
autoResizeTextarea(textareaElement);
}
});
$effect(() => {
if (isLoading() && !message?.content?.trim()) {
if (showProcessingInfoTop || showProcessingInfoBottom) {
processingState.startMonitoring();
}
});
function formatToolCallBadge(toolCall: ApiChatCompletionToolCall, index: number) {
const callNumber = index + 1;
const functionName = toolCall.function?.name?.trim();
const label = functionName || `Call #${callNumber}`;
const payload: Record<string, unknown> = {};
const id = toolCall.id?.trim();
if (id) {
payload.id = id;
}
const type = toolCall.type?.trim();
if (type) {
payload.type = type;
}
if (toolCall.function) {
const fnPayload: Record<string, unknown> = {};
const name = toolCall.function.name?.trim();
if (name) {
fnPayload.name = name;
}
const rawArguments = toolCall.function.arguments?.trim();
if (rawArguments) {
try {
fnPayload.arguments = JSON.parse(rawArguments);
} catch {
fnPayload.arguments = rawArguments;
}
}
if (Object.keys(fnPayload).length > 0) {
payload.function = fnPayload;
}
}
const formattedPayload = JSON.stringify(payload, null, 2);
return {
label,
tooltip: formattedPayload,
copyValue: formattedPayload
};
}
function handleCopyToolCall(payload: string) {
void copyToClipboard(payload, 'Tool call copied to clipboard');
}
</script>
<div
@@ -184,34 +232,36 @@
role="group"
aria-label="Assistant message with actions"
>
{#if thinkingContent}
{#if !editCtx.isEditing && thinkingContent}
<ChatMessageThinkingBlock
reasoningContent={thinkingContent}
isStreaming={!message.timestamp}
hasRegularContent={!!messageContent?.trim()}
hasRegularContent={!!visibleMessageContent?.trim()}
/>
{/if}
{#if message?.role === 'assistant' && isLoading() && !message?.content?.trim()}
{#if showProcessingInfoTop}
<div class="mt-6 w-full max-w-[48rem]" in:fade>
<div class="processing-container">
<span class="processing-text">
{processingState.getPromptProgressText() ?? processingState.getProcessingMessage()}
{processingState.getPromptProgressText() ??
processingState.getProcessingMessage() ??
'Processing...'}
</span>
</div>
</div>
{/if}
{#if isEditing}
{#if editCtx.isEditing}
<div class="w-full">
<textarea
bind:this={textareaElement}
bind:value={editedContent}
value={editCtx.editedContent}
class="min-h-[50vh] w-full resize-y rounded-2xl px-3 py-2 text-sm {INPUT_CLASSES}"
onkeydown={onEditKeydown}
onkeydown={handleEditKeydown}
oninput={(e) => {
autoResizeTextarea(e.currentTarget);
onEditedContentChange?.(e.currentTarget.value);
editCtx.setContent(e.currentTarget.value);
}}
placeholder="Edit assistant message..."
></textarea>
@@ -221,30 +271,35 @@
<Checkbox
id="branch-after-edit"
bind:checked={shouldBranchAfterEdit}
onCheckedChange={(checked) => onShouldBranchAfterEditChange?.(checked === true)}
onCheckedChange={(checked) => (shouldBranchAfterEdit = checked === true)}
/>
<Label for="branch-after-edit" class="cursor-pointer text-sm text-muted-foreground">
Branch conversation after edit
</Label>
</div>
<div class="flex gap-2">
<Button class="h-8 px-3" onclick={onCancelEdit} size="sm" variant="outline">
<Button class="h-8 px-3" onclick={editCtx.cancel} size="sm" variant="outline">
<X class="mr-1 h-3 w-3" />
Cancel
</Button>
<Button class="h-8 px-3" onclick={onSaveEdit} disabled={!editedContent?.trim()} size="sm">
<Button
class="h-8 px-3"
onclick={editCtx.save}
disabled={!editCtx.editedContent?.trim()}
size="sm"
>
<Check class="mr-1 h-3 w-3" />
Save
</Button>
</div>
</div>
</div>
{:else if message.role === 'assistant'}
{:else if message.role === MessageRole.ASSISTANT}
{#if showRawOutput}
<pre class="raw-output">{messageContent || ''}</pre>
{:else}
<MarkdownContent content={messageContent || ''} />
<MarkdownContent content={visibleMessageContent || ''} attachments={message.extra} />
{/if}
{:else}
<div class="text-sm whitespace-pre-wrap">
@@ -252,18 +307,41 @@
</div>
{/if}
{#if showProcessingInfoBottom}
<div class="mt-4 w-full max-w-[48rem]" in:fade>
<div class="processing-container">
<span class="processing-text">
{processingState.getPromptProgressText() ??
processingState.getProcessingMessage() ??
'Processing...'}
</span>
</div>
</div>
{/if}
<div class="info my-6 grid gap-4 tabular-nums">
{#if displayedModel()}
<div class="inline-flex flex-wrap items-start gap-2 text-xs text-muted-foreground">
{#if displayedModel}
<div
bind:this={statsContainerEl}
class="inline-flex flex-wrap items-start gap-2 text-xs text-muted-foreground"
>
{#if isRouter}
<ModelsSelector
currentModel={displayedModel()}
onModelChange={handleModelChange}
currentModel={displayedModel}
disabled={isLoading()}
upToMessageId={message.id}
onModelChange={async (modelId, modelName) => {
const status = modelsStore.getModelStatus(modelId);
if (status !== ServerModelStatus.LOADED) {
await modelsStore.loadModel(modelId);
}
onRegenerate(modelName);
return true;
}}
/>
{:else}
<ModelBadge model={displayedModel() || undefined} onclick={handleCopyModel} />
<ModelBadge model={displayedModel || undefined} onclick={handleCopyModel} />
{/if}
{#if currentConfig.showMessageStats && message.timings && message.timings.predicted_n && message.timings.predicted_ms}
@@ -272,6 +350,7 @@
promptMs={message.timings.prompt_ms}
predictedTokens={message.timings.predicted_n}
predictedMs={message.timings.predicted_ms}
onActiveViewChange={handleStatsViewChange}
/>
{:else if isLoading() && currentConfig.showMessageStats}
{@const liveStats = processingState.getLiveProcessingStats()}
@@ -293,53 +372,11 @@
{/if}
</div>
{/if}
{#if config().showToolCalls}
{#if (toolCalls && toolCalls.length > 0) || fallbackToolCalls}
<span class="inline-flex flex-wrap items-center gap-2 text-xs text-muted-foreground">
<span class="inline-flex items-center gap-1">
<Wrench class="h-3.5 w-3.5" />
<span>Tool calls:</span>
</span>
{#if toolCalls && toolCalls.length > 0}
{#each toolCalls as toolCall, index (toolCall.id ?? `${index}`)}
{@const badge = formatToolCallBadge(toolCall, index)}
<button
type="button"
class="tool-call-badge inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
title={badge.tooltip}
aria-label={`Copy tool call ${badge.label}`}
onclick={() => handleCopyToolCall(badge.copyValue)}
>
{badge.label}
<CopyToClipboardIcon
text={badge.copyValue}
ariaLabel={`Copy tool call ${badge.label}`}
/>
</button>
{/each}
{:else if fallbackToolCalls}
<button
type="button"
class="tool-call-badge tool-call-badge--fallback inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
title={fallbackToolCalls}
aria-label="Copy tool call payload"
onclick={() => handleCopyToolCall(fallbackToolCalls)}
>
{fallbackToolCalls}
<CopyToClipboardIcon text={fallbackToolCalls} ariaLabel="Copy tool call payload" />
</button>
{/if}
</span>
{/if}
{/if}
</div>
{#if message.timestamp && !isEditing}
{#if message.timestamp && !editCtx.isEditing}
<ChatMessageActions
role="assistant"
role={MessageRole.ASSISTANT}
justify="start"
actionsPosition="left"
{siblingInfo}
@@ -348,7 +385,7 @@
{onCopy}
{onEdit}
{onRegenerate}
onContinue={currentConfig.enableContinueGeneration && !thinkingContent
onContinue={currentConfig.enableContinueGeneration && !hasReasoningMarkers
? onContinue
: undefined}
{onDelete}
@@ -408,17 +445,4 @@
white-space: pre-wrap;
word-break: break-word;
}
.tool-call-badge {
max-width: 12rem;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
.tool-call-badge--fallback {
max-width: 20rem;
white-space: normal;
word-break: break-word;
}
</style>
@@ -1,79 +1,26 @@
<script lang="ts">
import { X, ArrowUp, Paperclip, AlertTriangle } from '@lucide/svelte';
import { X, AlertTriangle } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import { Switch } from '$lib/components/ui/switch';
import { ChatAttachmentsList, DialogConfirmation, ModelsSelector } from '$lib/components/app';
import { INPUT_CLASSES } from '$lib/constants/input-classes';
import { SETTING_CONFIG_DEFAULT } from '$lib/constants/settings-config';
import { AttachmentType, FileTypeCategory, MimeTypeText } from '$lib/enums';
import { config } from '$lib/stores/settings.svelte';
import { useModelChangeValidation } from '$lib/hooks/use-model-change-validation.svelte';
import { setEditModeActive, clearEditMode } from '$lib/stores/chat.svelte';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { modelsStore } from '$lib/stores/models.svelte';
import { isRouterMode } from '$lib/stores/server.svelte';
import {
autoResizeTextarea,
getFileTypeCategory,
getFileTypeCategoryByExtension,
parseClipboardContent
} from '$lib/utils';
import { ChatForm, DialogConfirmation } from '$lib/components/app';
import { getMessageEditContext } from '$lib/contexts';
import { KeyboardKey } from '$lib/enums';
import { chatStore } from '$lib/stores/chat.svelte';
import { processFilesToChatUploaded } from '$lib/utils/browser-only';
interface Props {
messageId: string;
editedContent: string;
editedExtras?: DatabaseMessageExtra[];
editedUploadedFiles?: ChatUploadedFile[];
originalContent: string;
originalExtras?: DatabaseMessageExtra[];
showSaveOnlyOption?: boolean;
onCancelEdit: () => void;
onSaveEdit: () => void;
onSaveEditOnly?: () => void;
onEditKeydown: (event: KeyboardEvent) => void;
onEditedContentChange: (content: string) => void;
onEditedExtrasChange?: (extras: DatabaseMessageExtra[]) => void;
onEditedUploadedFilesChange?: (files: ChatUploadedFile[]) => void;
textareaElement?: HTMLTextAreaElement;
}
const editCtx = getMessageEditContext();
let {
messageId,
editedContent,
editedExtras = [],
editedUploadedFiles = [],
originalContent,
originalExtras = [],
showSaveOnlyOption = false,
onCancelEdit,
onSaveEdit,
onSaveEditOnly,
onEditKeydown,
onEditedContentChange,
onEditedExtrasChange,
onEditedUploadedFilesChange,
textareaElement = $bindable()
}: Props = $props();
let fileInputElement: HTMLInputElement | undefined = $state();
let inputAreaRef: ChatForm | undefined = $state(undefined);
let saveWithoutRegenerate = $state(false);
let showDiscardDialog = $state(false);
let isRouter = $derived(isRouterMode());
let currentConfig = $derived(config());
let pasteLongTextToFileLength = $derived.by(() => {
const n = Number(currentConfig.pasteLongTextToFileLen);
return Number.isNaN(n) ? Number(SETTING_CONFIG_DEFAULT.pasteLongTextToFileLen) : n;
});
let hasUnsavedChanges = $derived.by(() => {
if (editedContent !== originalContent) return true;
if (editedUploadedFiles.length > 0) return true;
if (editCtx.editedContent !== editCtx.originalContent) return true;
if (editCtx.editedUploadedFiles.length > 0) return true;
const extrasChanged =
editedExtras.length !== originalExtras.length ||
editedExtras.some((extra, i) => extra !== originalExtras[i]);
editCtx.editedExtras.length !== editCtx.originalExtras.length ||
editCtx.editedExtras.some((extra, i) => extra !== editCtx.originalExtras[i]);
if (extrasChanged) return true;
@@ -81,77 +28,14 @@
});
let hasAttachments = $derived(
(editedExtras && editedExtras.length > 0) ||
(editedUploadedFiles && editedUploadedFiles.length > 0)
(editCtx.editedExtras && editCtx.editedExtras.length > 0) ||
(editCtx.editedUploadedFiles && editCtx.editedUploadedFiles.length > 0)
);
let canSubmit = $derived(editedContent.trim().length > 0 || hasAttachments);
function getEditedAttachmentsModalities(): ModelModalities {
const modalities: ModelModalities = { vision: false, audio: false };
for (const extra of editedExtras) {
if (extra.type === AttachmentType.IMAGE) {
modalities.vision = true;
}
if (
extra.type === AttachmentType.PDF &&
'processedAsImages' in extra &&
extra.processedAsImages
) {
modalities.vision = true;
}
if (extra.type === AttachmentType.AUDIO) {
modalities.audio = true;
}
}
for (const file of editedUploadedFiles) {
const category = getFileTypeCategory(file.type) || getFileTypeCategoryByExtension(file.name);
if (category === FileTypeCategory.IMAGE) {
modalities.vision = true;
}
if (category === FileTypeCategory.AUDIO) {
modalities.audio = true;
}
}
return modalities;
}
function getRequiredModalities(): ModelModalities {
const beforeModalities = conversationsStore.getModalitiesUpToMessage(messageId);
const editedModalities = getEditedAttachmentsModalities();
return {
vision: beforeModalities.vision || editedModalities.vision,
audio: beforeModalities.audio || editedModalities.audio
};
}
const { handleModelChange } = useModelChangeValidation({
getRequiredModalities,
onValidationFailure: async (previousModelId: string | null) => {
if (previousModelId) {
await modelsStore.selectModelById(previousModelId);
}
}
});
function handleFileInputChange(event: Event) {
const input = event.target as HTMLInputElement;
if (!input.files || input.files.length === 0) return;
const files = Array.from(input.files);
processNewFiles(files);
input.value = '';
}
let canSubmit = $derived(editCtx.editedContent.trim().length > 0 || hasAttachments);
function handleGlobalKeydown(event: KeyboardEvent) {
if (event.key === 'Escape') {
if (event.key === KeyboardKey.ESCAPE) {
event.preventDefault();
attemptCancel();
}
@@ -161,205 +45,66 @@
if (hasUnsavedChanges) {
showDiscardDialog = true;
} else {
onCancelEdit();
editCtx.cancel();
}
}
function handleRemoveExistingAttachment(index: number) {
if (!onEditedExtrasChange) return;
const newExtras = [...editedExtras];
newExtras.splice(index, 1);
onEditedExtrasChange(newExtras);
}
function handleRemoveUploadedFile(fileId: string) {
if (!onEditedUploadedFilesChange) return;
const newFiles = editedUploadedFiles.filter((f) => f.id !== fileId);
onEditedUploadedFilesChange(newFiles);
}
function handleSubmit() {
if (!canSubmit) return;
if (saveWithoutRegenerate && onSaveEditOnly) {
onSaveEditOnly();
if (saveWithoutRegenerate && editCtx.showSaveOnlyOption) {
editCtx.saveOnly();
} else {
onSaveEdit();
editCtx.save();
}
saveWithoutRegenerate = false;
}
async function processNewFiles(files: File[]) {
if (!onEditedUploadedFilesChange) return;
function handleAttachmentRemove(index: number) {
const newExtras = [...editCtx.editedExtras];
newExtras.splice(index, 1);
editCtx.setExtras(newExtras);
}
const { processFilesToChatUploaded } = await import('$lib/utils/browser-only');
function handleUploadedFileRemove(fileId: string) {
const newFiles = editCtx.editedUploadedFiles.filter((f) => f.id !== fileId);
editCtx.setUploadedFiles(newFiles);
}
async function handleFilesAdd(files: File[]) {
const processed = await processFilesToChatUploaded(files);
onEditedUploadedFilesChange([...editedUploadedFiles, ...processed]);
}
function handlePaste(event: ClipboardEvent) {
if (!event.clipboardData) return;
const files = Array.from(event.clipboardData.items)
.filter((item) => item.kind === 'file')
.map((item) => item.getAsFile())
.filter((file): file is File => file !== null);
if (files.length > 0) {
event.preventDefault();
processNewFiles(files);
return;
}
const text = event.clipboardData.getData(MimeTypeText.PLAIN);
if (text.startsWith('"')) {
const parsed = parseClipboardContent(text);
if (parsed.textAttachments.length > 0) {
event.preventDefault();
onEditedContentChange(parsed.message);
const attachmentFiles = parsed.textAttachments.map(
(att) =>
new File([att.content], att.name, {
type: MimeTypeText.PLAIN
})
);
processNewFiles(attachmentFiles);
setTimeout(() => {
textareaElement?.focus();
}, 10);
return;
}
}
if (
text.length > 0 &&
pasteLongTextToFileLength > 0 &&
text.length > pasteLongTextToFileLength
) {
event.preventDefault();
const textFile = new File([text], 'Pasted', {
type: MimeTypeText.PLAIN
});
processNewFiles([textFile]);
}
editCtx.setUploadedFiles([...editCtx.editedUploadedFiles, ...processed]);
}
$effect(() => {
if (textareaElement) {
autoResizeTextarea(textareaElement);
}
});
$effect(() => {
setEditModeActive(processNewFiles);
chatStore.setEditModeActive(handleFilesAdd);
return () => {
clearEditMode();
chatStore.clearEditMode();
};
});
</script>
<svelte:window onkeydown={handleGlobalKeydown} />
<input
bind:this={fileInputElement}
type="file"
multiple
class="hidden"
onchange={handleFileInputChange}
/>
<div
class="{INPUT_CLASSES} w-full max-w-[80%] overflow-hidden rounded-3xl backdrop-blur-md"
data-slot="edit-form"
>
<ChatAttachmentsList
attachments={editedExtras}
uploadedFiles={editedUploadedFiles}
readonly={false}
onFileRemove={(fileId) => {
if (fileId.startsWith('attachment-')) {
const index = parseInt(fileId.replace('attachment-', ''), 10);
if (!isNaN(index) && index >= 0 && index < editedExtras.length) {
handleRemoveExistingAttachment(index);
}
} else {
handleRemoveUploadedFile(fileId);
}
}}
limitToSingleRow
class="py-5"
style="scroll-padding: 1rem;"
<div class="relative w-full max-w-[80%]">
<ChatForm
bind:this={inputAreaRef}
value={editCtx.editedContent}
attachments={editCtx.editedExtras}
uploadedFiles={editCtx.editedUploadedFiles}
placeholder="Edit your message..."
onValueChange={editCtx.setContent}
onAttachmentRemove={handleAttachmentRemove}
onUploadedFileRemove={handleUploadedFileRemove}
onFilesAdd={handleFilesAdd}
onSubmit={handleSubmit}
/>
<div class="relative min-h-[48px] px-5 py-3">
<textarea
bind:this={textareaElement}
bind:value={editedContent}
class="field-sizing-content max-h-80 min-h-10 w-full resize-none bg-transparent text-sm outline-none"
onkeydown={onEditKeydown}
oninput={(e) => {
autoResizeTextarea(e.currentTarget);
onEditedContentChange(e.currentTarget.value);
}}
onpaste={handlePaste}
placeholder="Edit your message..."
></textarea>
<div class="flex w-full items-center gap-3" style="container-type: inline-size">
<Button
class="h-8 w-8 shrink-0 rounded-full bg-transparent p-0 text-muted-foreground hover:bg-foreground/10 hover:text-foreground"
onclick={() => fileInputElement?.click()}
type="button"
title="Add attachment"
>
<span class="sr-only">Attach files</span>
<Paperclip class="h-4 w-4" />
</Button>
<div class="flex-1"></div>
{#if isRouter}
<ModelsSelector
forceForegroundText={true}
useGlobalSelection={true}
onModelChange={handleModelChange}
/>
{/if}
<Button
class="h-8 w-8 shrink-0 rounded-full p-0"
onclick={handleSubmit}
disabled={!canSubmit}
type="button"
title={saveWithoutRegenerate ? 'Save changes' : 'Send and regenerate'}
>
<span class="sr-only">{saveWithoutRegenerate ? 'Save' : 'Send'}</span>
<ArrowUp class="h-5 w-5" />
</Button>
</div>
</div>
</div>
<div class="mt-2 flex w-full max-w-[80%] items-center justify-between">
{#if showSaveOnlyOption && onSaveEditOnly}
{#if editCtx.showSaveOnlyOption}
<div class="flex items-center gap-2">
<Switch id="save-only-switch" bind:checked={saveWithoutRegenerate} class="scale-75" />
@@ -386,6 +131,6 @@
cancelText="Keep editing"
variant="destructive"
icon={AlertTriangle}
onConfirm={onCancelEdit}
onConfirm={editCtx.cancel}
onCancel={() => (showDiscardDialog = false)}
/>
@@ -3,19 +3,18 @@
import { BadgeChatStatistic } from '$lib/components/app';
import * as Tooltip from '$lib/components/ui/tooltip';
import { ChatMessageStatsView } from '$lib/enums';
import { formatPerformanceTime } from '$lib/utils/formatters';
import { formatPerformanceTime } from '$lib/utils';
import { MS_PER_SECOND, DEFAULT_PERFORMANCE_TIME } from '$lib/constants/formatters';
interface Props {
predictedTokens?: number;
predictedMs?: number;
promptTokens?: number;
promptMs?: number;
// Live mode: when true, shows stats during streaming
isLive?: boolean;
// Whether prompt processing is still in progress
isProcessingPrompt?: boolean;
// Initial view to show (defaults to READING in live mode)
initialView?: ChatMessageStatsView;
onActiveViewChange?: (view: ChatMessageStatsView) => void;
}
let {
@@ -25,12 +24,17 @@
promptMs,
isLive = false,
isProcessingPrompt = false,
initialView = ChatMessageStatsView.GENERATION
initialView = ChatMessageStatsView.GENERATION,
onActiveViewChange
}: Props = $props();
let activeView: ChatMessageStatsView = $derived(initialView);
let hasAutoSwitchedToGeneration = $state(false);
$effect(() => {
onActiveViewChange?.(activeView);
});
// In live mode: auto-switch to GENERATION tab when prompt processing completes
$effect(() => {
if (isLive) {
@@ -57,14 +61,16 @@
predictedMs > 0
);
let tokensPerSecond = $derived(hasGenerationStats ? (predictedTokens! / predictedMs!) * 1000 : 0);
let tokensPerSecond = $derived(
hasGenerationStats ? (predictedTokens! / predictedMs!) * MS_PER_SECOND : 0
);
let formattedTime = $derived(
predictedMs !== undefined ? formatPerformanceTime(predictedMs) : '0s'
predictedMs !== undefined ? formatPerformanceTime(predictedMs) : DEFAULT_PERFORMANCE_TIME
);
let promptTokensPerSecond = $derived(
promptTokens !== undefined && promptMs !== undefined && promptMs > 0
? (promptTokens / promptMs) * 1000
? (promptTokens / promptMs) * MS_PER_SECOND
: undefined
);
@@ -97,9 +103,11 @@
onclick={() => (activeView = ChatMessageStatsView.READING)}
>
<BookOpenText class="h-3 w-3" />
<span class="sr-only">Reading</span>
</button>
</Tooltip.Trigger>
<Tooltip.Content>
<p>Reading (prompt processing)</p>
</Tooltip.Content>
@@ -119,9 +127,11 @@
disabled={isGenerationDisabled}
>
<Sparkles class="h-3 w-3" />
<span class="sr-only">Generation</span>
</button>
</Tooltip.Trigger>
<Tooltip.Content>
<p>
{isGenerationDisabled
@@ -140,16 +150,18 @@
value="{predictedTokens?.toLocaleString()} tokens"
tooltipLabel="Generated tokens"
/>
<BadgeChatStatistic
class="bg-transparent"
icon={Clock}
value={formattedTime}
tooltipLabel="Generation time"
/>
<BadgeChatStatistic
class="bg-transparent"
icon={Gauge}
value="{tokensPerSecond.toFixed(2)} tokens/s"
value="{tokensPerSecond.toFixed(2)} t/s"
tooltipLabel="Generation speed"
/>
{:else if hasPromptStats}
@@ -159,12 +171,14 @@
value="{promptTokens} tokens"
tooltipLabel="Prompt tokens"
/>
<BadgeChatStatistic
class="bg-transparent"
icon={Clock}
value={formattedPromptTime ?? '0s'}
tooltipLabel="Prompt processing time"
/>
<BadgeChatStatistic
class="bg-transparent"
icon={Gauge}
@@ -3,15 +3,16 @@
import { Card } from '$lib/components/ui/card';
import { Button } from '$lib/components/ui/button';
import { MarkdownContent } from '$lib/components/app';
import { INPUT_CLASSES } from '$lib/constants/input-classes';
import { getMessageEditContext } from '$lib/contexts';
import { INPUT_CLASSES } from '$lib/constants/css-classes';
import { config } from '$lib/stores/settings.svelte';
import { isIMEComposing } from '$lib/utils';
import ChatMessageActions from './ChatMessageActions.svelte';
import { KeyboardKey, MessageRole } from '$lib/enums';
interface Props {
class?: string;
message: DatabaseMessage;
isEditing: boolean;
editedContent: string;
siblingInfo?: ChatMessageSiblingInfo | null;
showDeleteDialog: boolean;
deletionInfo: {
@@ -20,10 +21,6 @@
assistantMessages: number;
messageTypes: string[];
} | null;
onCancelEdit: () => void;
onSaveEdit: () => void;
onEditKeydown: (event: KeyboardEvent) => void;
onEditedContentChange: (content: string) => void;
onCopy: () => void;
onEdit: () => void;
onDelete: () => void;
@@ -36,15 +33,9 @@
let {
class: className = '',
message,
isEditing,
editedContent,
siblingInfo = null,
showDeleteDialog,
deletionInfo,
onCancelEdit,
onSaveEdit,
onEditKeydown,
onEditedContentChange,
onCopy,
onEdit,
onDelete,
@@ -54,10 +45,25 @@
textareaElement = $bindable()
}: Props = $props();
const editCtx = getMessageEditContext();
function handleEditKeydown(event: KeyboardEvent) {
if (event.key === KeyboardKey.ENTER && !event.shiftKey && !isIMEComposing(event)) {
event.preventDefault();
editCtx.save();
} else if (event.key === KeyboardKey.ESCAPE) {
event.preventDefault();
editCtx.cancel();
}
}
let isMultiline = $state(false);
let messageElement: HTMLElement | undefined = $state();
let isExpanded = $state(false);
let contentHeight = $state(0);
const MAX_HEIGHT = 200; // pixels
const currentConfig = config();
@@ -97,25 +103,32 @@
class="group flex flex-col items-end gap-3 md:gap-2 {className}"
role="group"
>
{#if isEditing}
{#if editCtx.isEditing}
<div class="w-full max-w-[80%]">
<textarea
bind:this={textareaElement}
bind:value={editedContent}
value={editCtx.editedContent}
class="min-h-[60px] w-full resize-none rounded-2xl px-3 py-2 text-sm {INPUT_CLASSES}"
onkeydown={onEditKeydown}
oninput={(e) => onEditedContentChange(e.currentTarget.value)}
onkeydown={handleEditKeydown}
oninput={(e) => editCtx.setContent(e.currentTarget.value)}
placeholder="Edit system message..."
></textarea>
<div class="mt-2 flex justify-end gap-2">
<Button class="h-8 px-3" onclick={onCancelEdit} size="sm" variant="outline">
<Button class="h-8 px-3" onclick={editCtx.cancel} size="sm" variant="outline">
<X class="mr-1 h-3 w-3" />
Cancel
</Button>
<Button class="h-8 px-3" onclick={onSaveEdit} disabled={!editedContent.trim()} size="sm">
<Button
class="h-8 px-3"
onclick={editCtx.save}
disabled={!editCtx.editedContent.trim()}
size="sm"
>
<Check class="mr-1 h-3 w-3" />
Save
</Button>
</div>
@@ -131,12 +144,12 @@
type="button"
>
<Card
class="rounded-[1.125rem] !border-2 !border-dashed !border-border/50 bg-muted px-3.75 py-1.5 data-[multiline]:py-2.5"
class="overflow-y-auto rounded-[1.125rem] !border-2 !border-dashed !border-border/50 bg-muted px-3.75 py-1.5 data-[multiline]:py-2.5"
data-multiline={isMultiline ? '' : undefined}
style="border: 2px dashed hsl(var(--border));"
style="border: 2px dashed hsl(var(--border)); max-height: var(--max-message-height);"
>
<div
class="relative overflow-hidden transition-all duration-300 {isExpanded
class="relative transition-all duration-300 {isExpanded
? 'cursor-text select-text'
: 'select-none'}"
style={!isExpanded && showExpandButton
@@ -145,7 +158,10 @@
>
{#if currentConfig.renderUserContentAsMarkdown}
<div bind:this={messageElement} class="text-md {isExpanded ? 'cursor-text' : ''}">
<MarkdownContent class="markdown-system-content" content={message.content} />
<MarkdownContent
class="markdown-system-content overflow-auto"
content={message.content}
/>
</div>
{:else}
<span
@@ -160,6 +176,7 @@
<div
class="pointer-events-none absolute right-0 bottom-0 left-0 h-48 bg-gradient-to-t from-muted to-transparent"
></div>
<div
class="pointer-events-none absolute right-0 bottom-4 left-0 flex justify-center opacity-0 transition-opacity group-hover/expand:opacity-100"
>
@@ -208,7 +225,7 @@
{onShowDeleteDialogChange}
{siblingInfo}
{showDeleteDialog}
role="user"
role={MessageRole.USER}
/>
</div>
{/if}
@@ -1,67 +1,48 @@
<script lang="ts">
import { Card } from '$lib/components/ui/card';
import { ChatAttachmentsList, MarkdownContent } from '$lib/components/app';
import { getMessageEditContext } from '$lib/contexts';
import { config } from '$lib/stores/settings.svelte';
import ChatMessageActions from './ChatMessageActions.svelte';
import ChatMessageEditForm from './ChatMessageEditForm.svelte';
import { MessageRole } from '$lib/enums';
interface Props {
class?: string;
message: DatabaseMessage;
isEditing: boolean;
editedContent: string;
editedExtras?: DatabaseMessageExtra[];
editedUploadedFiles?: ChatUploadedFile[];
siblingInfo?: ChatMessageSiblingInfo | null;
showDeleteDialog: boolean;
deletionInfo: {
totalCount: number;
userMessages: number;
assistantMessages: number;
messageTypes: string[];
} | null;
onCancelEdit: () => void;
onSaveEdit: () => void;
onSaveEditOnly?: () => void;
onEditKeydown: (event: KeyboardEvent) => void;
onEditedContentChange: (content: string) => void;
onEditedExtrasChange?: (extras: DatabaseMessageExtra[]) => void;
onEditedUploadedFilesChange?: (files: ChatUploadedFile[]) => void;
onCopy: () => void;
showDeleteDialog: boolean;
onEdit: () => void;
onDelete: () => void;
onConfirmDelete: () => void;
onNavigateToSibling?: (siblingId: string) => void;
onShowDeleteDialogChange: (show: boolean) => void;
textareaElement?: HTMLTextAreaElement;
onNavigateToSibling?: (siblingId: string) => void;
onCopy: () => void;
}
let {
class: className = '',
message,
isEditing,
editedContent,
editedExtras = [],
editedUploadedFiles = [],
siblingInfo = null,
showDeleteDialog,
deletionInfo,
onCancelEdit,
onSaveEdit,
onSaveEditOnly,
onEditKeydown,
onEditedContentChange,
onEditedExtrasChange,
onEditedUploadedFilesChange,
onCopy,
showDeleteDialog,
onEdit,
onDelete,
onConfirmDelete,
onNavigateToSibling,
onShowDeleteDialogChange,
textareaElement = $bindable()
onNavigateToSibling,
onCopy
}: Props = $props();
// Get contexts
const editCtx = getMessageEditContext();
let isMultiline = $state(false);
let messageElement: HTMLElement | undefined = $state();
const currentConfig = config();
@@ -96,24 +77,8 @@
class="group flex flex-col items-end gap-3 md:gap-2 {className}"
role="group"
>
{#if isEditing}
<ChatMessageEditForm
bind:textareaElement
messageId={message.id}
{editedContent}
{editedExtras}
{editedUploadedFiles}
originalContent={message.content}
originalExtras={message.extra}
showSaveOnlyOption={!!onSaveEditOnly}
{onCancelEdit}
{onSaveEdit}
{onSaveEditOnly}
{onEditKeydown}
{onEditedContentChange}
{onEditedExtrasChange}
{onEditedUploadedFilesChange}
/>
{#if editCtx.isEditing}
<ChatMessageEditForm />
{:else}
{#if message.extra && message.extra.length > 0}
<div class="mb-2 max-w-[80%]">
@@ -123,15 +88,13 @@
{#if message.content.trim()}
<Card
class="max-w-[80%] rounded-[1.125rem] border-none bg-primary px-3.75 py-1.5 text-primary-foreground data-[multiline]:py-2.5"
class="max-w-[80%] overflow-y-auto rounded-[1.125rem] border-none bg-primary/5 px-3.75 py-1.5 text-foreground backdrop-blur-md data-[multiline]:py-2.5 dark:bg-primary/15"
data-multiline={isMultiline ? '' : undefined}
style="max-height: var(--max-message-height); overflow-wrap: anywhere; word-break: break-word;"
>
{#if currentConfig.renderUserContentAsMarkdown}
<div bind:this={messageElement} class="text-md">
<MarkdownContent
class="markdown-user-content text-primary-foreground"
content={message.content}
/>
<div bind:this={messageElement}>
<MarkdownContent class="markdown-user-content -my-4" content={message.content} />
</div>
{:else}
<span bind:this={messageElement} class="text-md whitespace-pre-wrap">
@@ -155,7 +118,7 @@
{onShowDeleteDialogChange}
{siblingInfo}
{showDeleteDialog}
role="user"
role={MessageRole.USER}
/>
</div>
{/if}
@@ -1,9 +1,11 @@
<script lang="ts">
import { ChatMessage } from '$lib/components/app';
import { setChatActionsContext } from '$lib/contexts';
import { MessageRole } from '$lib/enums';
import { chatStore } from '$lib/stores/chat.svelte';
import { conversationsStore, activeConversation } from '$lib/stores/conversations.svelte';
import { config } from '$lib/stores/settings.svelte';
import { getMessageSiblings } from '$lib/utils';
import { copyToClipboard, formatMessageForClipboard, getMessageSiblings } from '$lib/utils';
interface Props {
class?: string;
@@ -16,6 +18,69 @@
let allConversationMessages = $state<DatabaseMessage[]>([]);
const currentConfig = config();
setChatActionsContext({
copy: async (message: DatabaseMessage) => {
const asPlainText = Boolean(currentConfig.copyTextAttachmentsAsPlainText);
const clipboardContent = formatMessageForClipboard(
message.content,
message.extra,
asPlainText
);
await copyToClipboard(clipboardContent, 'Message copied to clipboard');
},
delete: async (message: DatabaseMessage) => {
await chatStore.deleteMessage(message.id);
refreshAllMessages();
},
navigateToSibling: async (siblingId: string) => {
await conversationsStore.navigateToSibling(siblingId);
},
editWithBranching: async (
message: DatabaseMessage,
newContent: string,
newExtras?: DatabaseMessageExtra[]
) => {
onUserAction?.();
await chatStore.editMessageWithBranching(message.id, newContent, newExtras);
refreshAllMessages();
},
editWithReplacement: async (
message: DatabaseMessage,
newContent: string,
shouldBranch: boolean
) => {
onUserAction?.();
await chatStore.editAssistantMessage(message.id, newContent, shouldBranch);
refreshAllMessages();
},
editUserMessagePreserveResponses: async (
message: DatabaseMessage,
newContent: string,
newExtras?: DatabaseMessageExtra[]
) => {
onUserAction?.();
await chatStore.editUserMessagePreserveResponses(message.id, newContent, newExtras);
refreshAllMessages();
},
regenerateWithBranching: async (message: DatabaseMessage, modelOverride?: string) => {
onUserAction?.();
await chatStore.regenerateMessageWithBranching(message.id, modelOverride);
refreshAllMessages();
},
continueAssistantMessage: async (message: DatabaseMessage) => {
onUserAction?.();
await chatStore.continueAssistantMessage(message.id);
refreshAllMessages();
}
});
function refreshAllMessages() {
const conversation = activeConversation();
@@ -42,16 +107,28 @@
return [];
}
// Filter out system messages if showSystemMessage is false
const filteredMessages = currentConfig.showSystemMessage
? messages
: messages.filter((msg) => msg.type !== 'system');
: messages.filter((msg) => msg.type !== MessageRole.SYSTEM);
return filteredMessages.map((message) => {
let lastAssistantIndex = -1;
for (let i = filteredMessages.length - 1; i >= 0; i--) {
if (filteredMessages[i].role === MessageRole.ASSISTANT) {
lastAssistantIndex = i;
break;
}
}
return filteredMessages.map((message, index) => {
const siblingInfo = getMessageSiblings(allConversationMessages, message.id);
const isLastAssistantMessage =
message.role === MessageRole.ASSISTANT && index === lastAssistantIndex;
return {
message,
isLastAssistantMessage,
siblingInfo: siblingInfo || {
message,
siblingIds: [message.id],
@@ -61,83 +138,15 @@
};
});
});
async function handleNavigateToSibling(siblingId: string) {
await conversationsStore.navigateToSibling(siblingId);
}
async function handleEditWithBranching(
message: DatabaseMessage,
newContent: string,
newExtras?: DatabaseMessageExtra[]
) {
onUserAction?.();
await chatStore.editMessageWithBranching(message.id, newContent, newExtras);
refreshAllMessages();
}
async function handleEditWithReplacement(
message: DatabaseMessage,
newContent: string,
shouldBranch: boolean
) {
onUserAction?.();
await chatStore.editAssistantMessage(message.id, newContent, shouldBranch);
refreshAllMessages();
}
async function handleRegenerateWithBranching(message: DatabaseMessage, modelOverride?: string) {
onUserAction?.();
await chatStore.regenerateMessageWithBranching(message.id, modelOverride);
refreshAllMessages();
}
async function handleContinueAssistantMessage(message: DatabaseMessage) {
onUserAction?.();
await chatStore.continueAssistantMessage(message.id);
refreshAllMessages();
}
async function handleEditUserMessagePreserveResponses(
message: DatabaseMessage,
newContent: string,
newExtras?: DatabaseMessageExtra[]
) {
onUserAction?.();
await chatStore.editUserMessagePreserveResponses(message.id, newContent, newExtras);
refreshAllMessages();
}
async function handleDeleteMessage(message: DatabaseMessage) {
await chatStore.deleteMessage(message.id);
refreshAllMessages();
}
</script>
<div class="flex h-full flex-col space-y-10 pt-16 md:pt-24 {className}" style="height: auto; ">
{#each displayMessages as { message, siblingInfo } (message.id)}
<div class="flex h-full flex-col space-y-10 pt-24 {className}" style="height: auto; ">
{#each displayMessages as { message, isLastAssistantMessage, siblingInfo } (message.id)}
<ChatMessage
class="mx-auto w-full max-w-[48rem]"
{message}
{isLastAssistantMessage}
{siblingInfo}
onDelete={handleDeleteMessage}
onNavigateToSibling={handleNavigateToSibling}
onEditWithBranching={handleEditWithBranching}
onEditWithReplacement={handleEditWithReplacement}
onEditUserMessagePreserveResponses={handleEditUserMessagePreserveResponses}
onRegenerateWithBranching={handleRegenerateWithBranching}
onContinueAssistantMessage={handleContinueAssistantMessage}
/>
{/each}
</div>
@@ -1,7 +1,7 @@
<script lang="ts">
import { afterNavigate } from '$app/navigation';
import {
ChatForm,
ChatScreenForm,
ChatScreenHeader,
ChatMessages,
ChatScreenProcessingInfo,
@@ -12,11 +12,9 @@
} from '$lib/components/app';
import * as Alert from '$lib/components/ui/alert';
import * as AlertDialog from '$lib/components/ui/alert-dialog';
import {
AUTO_SCROLL_AT_BOTTOM_THRESHOLD,
AUTO_SCROLL_INTERVAL,
INITIAL_SCROLL_DELAY
} from '$lib/constants/auto-scroll';
import { INITIAL_SCROLL_DELAY } from '$lib/constants/auto-scroll';
import { KeyboardKey } from '$lib/enums';
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
import {
chatStore,
errorDialog,
@@ -44,16 +42,13 @@
let { showCenteredEmpty = false } = $props();
let disableAutoScroll = $derived(Boolean(config().disableAutoScroll));
let autoScrollEnabled = $state(true);
let chatScrollContainer: HTMLDivElement | undefined = $state();
let dragCounter = $state(0);
let isDragOver = $state(false);
let lastScrollTop = $state(0);
let scrollInterval: ReturnType<typeof setInterval> | undefined;
let scrollTimeout: ReturnType<typeof setTimeout> | undefined;
let showFileErrorDialog = $state(false);
let uploadedFiles = $state<ChatUploadedFile[]>([]);
let userScrolledUp = $state(false);
const autoScroll = createAutoScrollController();
let fileErrorData = $state<{
generallyUnsupported: File[];
@@ -217,7 +212,11 @@
function handleKeydown(event: KeyboardEvent) {
const isCtrlOrCmd = event.ctrlKey || event.metaKey;
if (isCtrlOrCmd && event.shiftKey && (event.key === 'd' || event.key === 'D')) {
if (
isCtrlOrCmd &&
event.shiftKey &&
(event.key === KeyboardKey.D_LOWER || event.key === KeyboardKey.D_UPPER)
) {
event.preventDefault();
if (activeConversation()) {
showDeleteDialog = true;
@@ -234,37 +233,13 @@
}
function handleScroll() {
if (disableAutoScroll || !chatScrollContainer) return;
const { scrollTop, scrollHeight, clientHeight } = chatScrollContainer;
const distanceFromBottom = scrollHeight - scrollTop - clientHeight;
const isAtBottom = distanceFromBottom < AUTO_SCROLL_AT_BOTTOM_THRESHOLD;
if (scrollTop < lastScrollTop && !isAtBottom) {
userScrolledUp = true;
autoScrollEnabled = false;
} else if (isAtBottom && userScrolledUp) {
userScrolledUp = false;
autoScrollEnabled = true;
}
if (scrollTimeout) {
clearTimeout(scrollTimeout);
}
scrollTimeout = setTimeout(() => {
if (isAtBottom) {
userScrolledUp = false;
autoScrollEnabled = true;
}
}, AUTO_SCROLL_INTERVAL);
lastScrollTop = scrollTop;
autoScroll.handleScroll();
}
async function handleSendMessage(message: string, files?: ChatUploadedFile[]): Promise<boolean> {
const result = files
? await parseFilesToMessageExtras(files, activeModelId ?? undefined)
const plainFiles = files ? $state.snapshot(files) : undefined;
const result = plainFiles
? await parseFilesToMessageExtras(plainFiles, activeModelId ?? undefined)
: undefined;
if (result?.emptyFiles && result.emptyFiles.length > 0) {
@@ -281,12 +256,9 @@
const extras = result?.extras;
// Enable autoscroll for user-initiated message sending
if (!disableAutoScroll) {
userScrolledUp = false;
autoScrollEnabled = true;
}
autoScroll.enable();
await chatStore.sendMessage(message, extras);
scrollChatToBottom();
autoScroll.scrollToBottom();
return true;
}
@@ -336,24 +308,15 @@
}
}
function scrollChatToBottom(behavior: ScrollBehavior = 'smooth') {
if (disableAutoScroll) return;
chatScrollContainer?.scrollTo({
top: chatScrollContainer?.scrollHeight,
behavior
});
}
afterNavigate(() => {
if (!disableAutoScroll) {
setTimeout(() => scrollChatToBottom('instant'), INITIAL_SCROLL_DELAY);
setTimeout(() => autoScroll.scrollToBottom('instant'), INITIAL_SCROLL_DELAY);
}
});
onMount(() => {
if (!disableAutoScroll) {
setTimeout(() => scrollChatToBottom('instant'), INITIAL_SCROLL_DELAY);
setTimeout(() => autoScroll.scrollToBottom('instant'), INITIAL_SCROLL_DELAY);
}
const pendingDraft = chatStore.consumePendingDraft();
@@ -364,21 +327,15 @@
});
$effect(() => {
if (disableAutoScroll) {
autoScrollEnabled = false;
if (scrollInterval) {
clearInterval(scrollInterval);
scrollInterval = undefined;
}
return;
}
autoScroll.setContainer(chatScrollContainer);
});
if (isCurrentConversationLoading && autoScrollEnabled) {
scrollInterval = setInterval(scrollChatToBottom, AUTO_SCROLL_INTERVAL);
} else if (scrollInterval) {
clearInterval(scrollInterval);
scrollInterval = undefined;
}
$effect(() => {
autoScroll.setDisabled(disableAutoScroll);
});
$effect(() => {
autoScroll.updateInterval(isCurrentConversationLoading);
});
</script>
@@ -406,11 +363,8 @@
class="mb-16 md:mb-24"
messages={activeMessages()}
onUserAction={() => {
if (!disableAutoScroll) {
userScrolledUp = false;
autoScrollEnabled = true;
scrollChatToBottom();
}
autoScroll.enable();
autoScroll.scrollToBottom();
}}
/>
@@ -444,7 +398,7 @@
{/if}
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl pb-4">
<ChatForm
<ChatScreenForm
disabled={hasPropsError || isEditing()}
{initialMessage}
isLoading={isCurrentConversationLoading}
@@ -474,7 +428,7 @@
>
<div class="w-full max-w-[48rem] px-4">
<div class="mb-10 text-center" in:fade={{ duration: 300 }}>
<h1 class="mb-4 text-3xl font-semibold tracking-tight">llama.cpp</h1>
<h1 class="mb-2 text-3xl font-semibold tracking-tight">llama.cpp</h1>
<p class="text-lg text-muted-foreground">
{serverStore.props?.modalities?.audio
@@ -504,7 +458,7 @@
{/if}
<div in:fly={{ y: 10, duration: 250, delay: hasPropsError ? 0 : 300 }}>
<ChatForm
<ChatScreenForm
disabled={hasPropsError}
{initialMessage}
isLoading={isCurrentConversationLoading}
@@ -617,7 +571,7 @@
contextInfo={activeErrorDialog?.contextInfo}
onOpenChange={handleErrorDialogOpenChange}
open={Boolean(activeErrorDialog)}
type={(activeErrorDialog?.type as ErrorDialogType) ?? ErrorDialogType.SERVER}
type={activeErrorDialog?.type ?? ErrorDialogType.SERVER}
/>
<style>
@@ -1,5 +1,7 @@
<script lang="ts">
import ChatForm from '$lib/components/app/chat/ChatForm/ChatForm.svelte';
import { afterNavigate } from '$app/navigation';
import { ChatFormHelperText, ChatForm } from '$lib/components/app';
import { onMount } from 'svelte';
interface Props {
class?: string;
@@ -28,20 +30,92 @@
showHelperText = true,
uploadedFiles = $bindable([])
}: Props = $props();
let chatFormRef: ChatForm | undefined = $state(undefined);
let message = $derived(initialMessage);
let previousIsLoading = $derived(isLoading);
let previousInitialMessage = $derived(initialMessage);
// Sync message when initialMessage prop changes (e.g., after draft restoration)
$effect(() => {
if (initialMessage !== previousInitialMessage) {
message = initialMessage;
previousInitialMessage = initialMessage;
}
});
function handleSystemPromptClick() {
onSystemPromptAdd?.({ message, files: uploadedFiles });
}
let hasLoadingAttachments = $derived(uploadedFiles.some((f) => f.isLoading));
async function handleSubmit() {
if (
(!message.trim() && uploadedFiles.length === 0) ||
disabled ||
isLoading ||
hasLoadingAttachments
)
return;
if (!chatFormRef?.checkModelSelected()) return;
const messageToSend = message.trim();
const filesToSend = [...uploadedFiles];
message = '';
uploadedFiles = [];
chatFormRef?.resetTextareaHeight();
const success = await onSend?.(messageToSend, filesToSend);
if (!success) {
message = messageToSend;
uploadedFiles = filesToSend;
}
}
function handleFilesAdd(files: File[]) {
onFileUpload?.(files);
}
function handleUploadedFileRemove(fileId: string) {
onFileRemove?.(fileId);
}
onMount(() => {
setTimeout(() => chatFormRef?.focus(), 10);
});
afterNavigate(() => {
setTimeout(() => chatFormRef?.focus(), 10);
});
$effect(() => {
if (previousIsLoading && !isLoading) {
setTimeout(() => chatFormRef?.focus(), 10);
}
previousIsLoading = isLoading;
});
</script>
<div class="relative mx-auto max-w-[48rem]">
<ChatForm
bind:this={chatFormRef}
bind:value={message}
bind:uploadedFiles
class={className}
{disabled}
{initialMessage}
{isLoading}
{onFileRemove}
{onFileUpload}
{onSend}
onFilesAdd={handleFilesAdd}
{onStop}
{onSystemPromptAdd}
{showHelperText}
bind:uploadedFiles
onSubmit={handleSubmit}
onSystemPromptClick={handleSystemPromptClick}
onUploadedFileRemove={handleUploadedFileRemove}
/>
</div>
<ChatFormHelperText show={showHelperText} />
@@ -14,12 +14,17 @@
</script>
<header
class="md:background-transparent pointer-events-none fixed top-0 right-0 left-0 z-50 flex items-center justify-end bg-background/40 p-4 backdrop-blur-xl duration-200 ease-linear {sidebar.open
class="pointer-events-none fixed top-0 right-0 left-0 z-50 flex items-center justify-end p-4 duration-200 ease-linear {sidebar.open
? 'md:left-[var(--sidebar-width)]'
: ''}"
>
<div class="pointer-events-auto flex items-center space-x-2">
<Button variant="ghost" size="sm" onclick={toggleSettings}>
<Button
variant="ghost"
size="icon"
onclick={toggleSettings}
class="rounded-full backdrop-blur-lg"
>
<Settings class="h-4 w-4" />
</Button>
</div>
@@ -11,7 +11,7 @@
let isCurrentConversationLoading = $derived(isLoading());
let isStreaming = $derived(isChatStreaming());
let hasProcessingData = $derived(processingState.processingState !== null);
let processingDetails = $derived(processingState.getProcessingDetails());
let processingDetails = $derived(processingState.getTechnicalDetails());
let showProcessingInfo = $derived(
isCurrentConversationLoading || isStreaming || config().keepStatsVisible || hasProcessingData
@@ -63,7 +63,7 @@
<div class="chat-processing-info-container pointer-events-none" class:visible={showProcessingInfo}>
<div class="chat-processing-info-content">
{#each processingDetails as detail (detail)}
<span class="chat-processing-info-detail pointer-events-auto">{detail}</span>
<span class="chat-processing-info-detail pointer-events-auto backdrop-blur-sm">{detail}</span>
{/each}
</div>
</div>
@@ -73,7 +73,7 @@
position: sticky;
top: 0;
z-index: 10;
padding: 1.5rem 1rem;
padding: 0 1rem 0.75rem;
opacity: 0;
transform: translateY(50%);
transition:
@@ -100,7 +100,6 @@
color: var(--muted-foreground);
font-size: 0.75rem;
padding: 0.25rem 0.75rem;
background: var(--muted);
border-radius: 0.375rem;
font-family:
ui-monospace, SFMono-Regular, 'SF Mono', Consolas, 'Liberation Mono', Menlo, monospace;
@@ -5,8 +5,6 @@
AlertTriangle,
Code,
Monitor,
Sun,
Moon,
ChevronLeft,
ChevronRight,
Database
@@ -23,7 +21,12 @@
type SettingsSectionTitle
} from '$lib/constants/settings-sections';
import { setMode } from 'mode-watcher';
import { ColorMode } from '$lib/enums/ui';
import { SettingsFieldType } from '$lib/enums/settings';
import type { Component } from 'svelte';
import { NUMERIC_FIELDS, POSITIVE_INTEGER_FIELDS } from '$lib/constants/settings-fields';
import { SETTINGS_COLOR_MODES_CONFIG } from '$lib/constants/settings-config';
import { SETTINGS_KEYS } from '$lib/constants/settings-keys';
interface Props {
onSave?: () => void;
@@ -38,240 +41,231 @@
title: SettingsSectionTitle;
}> = [
{
title: 'General',
title: SETTINGS_SECTION_TITLES.GENERAL,
icon: Settings,
fields: [
{
key: 'theme',
key: SETTINGS_KEYS.THEME,
label: 'Theme',
type: 'select',
options: [
{ value: 'system', label: 'System', icon: Monitor },
{ value: 'light', label: 'Light', icon: Sun },
{ value: 'dark', label: 'Dark', icon: Moon }
]
type: SettingsFieldType.SELECT,
options: SETTINGS_COLOR_MODES_CONFIG
},
{ key: 'apiKey', label: 'API Key', type: 'input' },
{ key: SETTINGS_KEYS.API_KEY, label: 'API Key', type: SettingsFieldType.INPUT },
{
key: 'systemMessage',
key: SETTINGS_KEYS.SYSTEM_MESSAGE,
label: 'System Message',
type: 'textarea'
type: SettingsFieldType.TEXTAREA
},
{
key: 'pasteLongTextToFileLen',
key: SETTINGS_KEYS.PASTE_LONG_TEXT_TO_FILE_LEN,
label: 'Paste long text to file length',
type: 'input'
type: SettingsFieldType.INPUT
},
{
key: 'copyTextAttachmentsAsPlainText',
key: SETTINGS_KEYS.COPY_TEXT_ATTACHMENTS_AS_PLAIN_TEXT,
label: 'Copy text attachments as plain text',
type: 'checkbox'
type: SettingsFieldType.CHECKBOX
},
{
key: 'enableContinueGeneration',
key: SETTINGS_KEYS.ENABLE_CONTINUE_GENERATION,
label: 'Enable "Continue" button',
type: 'checkbox',
type: SettingsFieldType.CHECKBOX,
isExperimental: true
},
{
key: 'pdfAsImage',
key: SETTINGS_KEYS.PDF_AS_IMAGE,
label: 'Parse PDF as image',
type: 'checkbox'
type: SettingsFieldType.CHECKBOX
},
{
key: 'askForTitleConfirmation',
key: SETTINGS_KEYS.ASK_FOR_TITLE_CONFIRMATION,
label: 'Ask for confirmation before changing conversation title',
type: 'checkbox'
type: SettingsFieldType.CHECKBOX
}
]
},
{
title: 'Display',
title: SETTINGS_SECTION_TITLES.DISPLAY,
icon: Monitor,
fields: [
{
key: 'showMessageStats',
key: SETTINGS_KEYS.SHOW_MESSAGE_STATS,
label: 'Show message generation statistics',
type: 'checkbox'
type: SettingsFieldType.CHECKBOX
},
{
key: 'showThoughtInProgress',
key: SETTINGS_KEYS.SHOW_THOUGHT_IN_PROGRESS,
label: 'Show thought in progress',
type: 'checkbox'
type: SettingsFieldType.CHECKBOX
},
{
key: 'keepStatsVisible',
key: SETTINGS_KEYS.KEEP_STATS_VISIBLE,
label: 'Keep stats visible after generation',
type: 'checkbox'
type: SettingsFieldType.CHECKBOX
},
{
key: 'autoMicOnEmpty',
key: SETTINGS_KEYS.AUTO_MIC_ON_EMPTY,
label: 'Show microphone on empty input',
type: 'checkbox',
type: SettingsFieldType.CHECKBOX,
isExperimental: true
},
{
key: 'renderUserContentAsMarkdown',
key: SETTINGS_KEYS.RENDER_USER_CONTENT_AS_MARKDOWN,
label: 'Render user content as Markdown',
type: 'checkbox'
type: SettingsFieldType.CHECKBOX
},
{
key: 'disableAutoScroll',
key: SETTINGS_KEYS.DISABLE_AUTO_SCROLL,
label: 'Disable automatic scroll',
type: 'checkbox'
type: SettingsFieldType.CHECKBOX
},
{
key: 'alwaysShowSidebarOnDesktop',
key: SETTINGS_KEYS.ALWAYS_SHOW_SIDEBAR_ON_DESKTOP,
label: 'Always show sidebar on desktop',
type: 'checkbox'
type: SettingsFieldType.CHECKBOX
},
{
key: 'autoShowSidebarOnNewChat',
key: SETTINGS_KEYS.AUTO_SHOW_SIDEBAR_ON_NEW_CHAT,
label: 'Auto-show sidebar on new chat',
type: 'checkbox'
type: SettingsFieldType.CHECKBOX
}
]
},
{
title: 'Sampling',
title: SETTINGS_SECTION_TITLES.SAMPLING,
icon: Funnel,
fields: [
{
key: 'temperature',
key: SETTINGS_KEYS.TEMPERATURE,
label: 'Temperature',
type: 'input'
type: SettingsFieldType.INPUT
},
{
key: 'dynatemp_range',
key: SETTINGS_KEYS.DYNATEMP_RANGE,
label: 'Dynamic temperature range',
type: 'input'
type: SettingsFieldType.INPUT
},
{
key: 'dynatemp_exponent',
key: SETTINGS_KEYS.DYNATEMP_EXPONENT,
label: 'Dynamic temperature exponent',
type: 'input'
type: SettingsFieldType.INPUT
},
{
key: 'top_k',
key: SETTINGS_KEYS.TOP_K,
label: 'Top K',
type: 'input'
type: SettingsFieldType.INPUT
},
{
key: 'top_p',
key: SETTINGS_KEYS.TOP_P,
label: 'Top P',
type: 'input'
type: SettingsFieldType.INPUT
},
{
key: 'min_p',
key: SETTINGS_KEYS.MIN_P,
label: 'Min P',
type: 'input'
type: SettingsFieldType.INPUT
},
{
key: 'xtc_probability',
key: SETTINGS_KEYS.XTC_PROBABILITY,
label: 'XTC probability',
type: 'input'
type: SettingsFieldType.INPUT
},
{
key: 'xtc_threshold',
key: SETTINGS_KEYS.XTC_THRESHOLD,
label: 'XTC threshold',
type: 'input'
type: SettingsFieldType.INPUT
},
{
key: 'typ_p',
key: SETTINGS_KEYS.TYP_P,
label: 'Typical P',
type: 'input'
type: SettingsFieldType.INPUT
},
{
key: 'max_tokens',
key: SETTINGS_KEYS.MAX_TOKENS,
label: 'Max tokens',
type: 'input'
type: SettingsFieldType.INPUT
},
{
key: 'samplers',
key: SETTINGS_KEYS.SAMPLERS,
label: 'Samplers',
type: 'input'
type: SettingsFieldType.INPUT
},
{
key: 'backend_sampling',
key: SETTINGS_KEYS.BACKEND_SAMPLING,
label: 'Backend sampling',
type: 'checkbox'
type: SettingsFieldType.CHECKBOX
}
]
},
{
title: 'Penalties',
title: SETTINGS_SECTION_TITLES.PENALTIES,
icon: AlertTriangle,
fields: [
{
key: 'repeat_last_n',
key: SETTINGS_KEYS.REPEAT_LAST_N,
label: 'Repeat last N',
type: 'input'
type: SettingsFieldType.INPUT
},
{
key: 'repeat_penalty',
key: SETTINGS_KEYS.REPEAT_PENALTY,
label: 'Repeat penalty',
type: 'input'
type: SettingsFieldType.INPUT
},
{
key: 'presence_penalty',
key: SETTINGS_KEYS.PRESENCE_PENALTY,
label: 'Presence penalty',
type: 'input'
type: SettingsFieldType.INPUT
},
{
key: 'frequency_penalty',
key: SETTINGS_KEYS.FREQUENCY_PENALTY,
label: 'Frequency penalty',
type: 'input'
type: SettingsFieldType.INPUT
},
{
key: 'dry_multiplier',
key: SETTINGS_KEYS.DRY_MULTIPLIER,
label: 'DRY multiplier',
type: 'input'
type: SettingsFieldType.INPUT
},
{
key: 'dry_base',
key: SETTINGS_KEYS.DRY_BASE,
label: 'DRY base',
type: 'input'
type: SettingsFieldType.INPUT
},
{
key: 'dry_allowed_length',
key: SETTINGS_KEYS.DRY_ALLOWED_LENGTH,
label: 'DRY allowed length',
type: 'input'
type: SettingsFieldType.INPUT
},
{
key: 'dry_penalty_last_n',
key: SETTINGS_KEYS.DRY_PENALTY_LAST_N,
label: 'DRY penalty last N',
type: 'input'
type: SettingsFieldType.INPUT
}
]
},
{
title: 'Import/Export',
title: SETTINGS_SECTION_TITLES.IMPORT_EXPORT,
icon: Database,
fields: []
},
{
title: 'Developer',
title: SETTINGS_SECTION_TITLES.DEVELOPER,
icon: Code,
fields: [
{
key: 'showToolCalls',
label: 'Show tool call labels',
type: 'checkbox'
},
{
key: 'disableReasoningParsing',
key: SETTINGS_KEYS.DISABLE_REASONING_PARSING,
label: 'Disable reasoning content parsing',
type: 'checkbox'
type: SettingsFieldType.CHECKBOX
},
{
key: 'showRawOutputSwitch',
key: SETTINGS_KEYS.SHOW_RAW_OUTPUT_SWITCH,
label: 'Enable raw output toggle',
type: 'checkbox'
type: SettingsFieldType.CHECKBOX
},
{
key: 'custom',
key: SETTINGS_KEYS.CUSTOM,
label: 'Custom JSON',
type: 'textarea'
type: SettingsFieldType.TEXTAREA
}
]
}
@@ -303,11 +297,7 @@
let scrollContainer: HTMLDivElement | undefined = $state();
$effect(() => {
if (!initialSection) {
return;
}
if (settingSections.some((section) => section.title === initialSection)) {
if (initialSection) {
activeSection = initialSection;
}
});
@@ -315,7 +305,7 @@
function handleThemeChange(newTheme: string) {
localConfig.theme = newTheme;
setMode(newTheme as 'light' | 'dark' | 'system');
setMode(newTheme as ColorMode);
}
function handleConfigChange(key: string, value: string | boolean) {
@@ -325,7 +315,7 @@
function handleReset() {
localConfig = { ...config() };
setMode(localConfig.theme as 'light' | 'dark' | 'system');
setMode(localConfig.theme as ColorMode);
}
function handleSave() {
@@ -341,33 +331,16 @@
// Convert numeric strings to numbers for numeric fields
const processedConfig = { ...localConfig };
const numericFields = [
'temperature',
'top_k',
'top_p',
'min_p',
'max_tokens',
'pasteLongTextToFileLen',
'dynatemp_range',
'dynatemp_exponent',
'typ_p',
'xtc_probability',
'xtc_threshold',
'repeat_last_n',
'repeat_penalty',
'presence_penalty',
'frequency_penalty',
'dry_multiplier',
'dry_base',
'dry_allowed_length',
'dry_penalty_last_n'
];
for (const field of numericFields) {
for (const field of NUMERIC_FIELDS) {
if (processedConfig[field] !== undefined && processedConfig[field] !== '') {
const numValue = Number(processedConfig[field]);
if (!isNaN(numValue)) {
processedConfig[field] = numValue;
if ((POSITIVE_INTEGER_FIELDS as readonly string[]).includes(field)) {
processedConfig[field] = Math.max(1, Math.round(numValue));
} else {
processedConfig[field] = numValue;
}
} else {
alert(`Invalid numeric value for ${field}. Please enter a valid number.`);
return;
@@ -506,7 +479,7 @@
<h3 class="text-lg font-semibold">{currentSection.title}</h3>
</div>
{#if currentSection.title === 'Import/Export'}
{#if currentSection.title === SETTINGS_SECTION_TITLES.IMPORT_EXPORT}
<ChatSettingsImportExportTab />
{:else}
<div class="space-y-6">
@@ -6,6 +6,8 @@
import * as Select from '$lib/components/ui/select';
import { Textarea } from '$lib/components/ui/textarea';
import { SETTING_CONFIG_DEFAULT, SETTING_CONFIG_INFO } from '$lib/constants/settings-config';
import { SETTINGS_KEYS } from '$lib/constants/settings-keys';
import { SettingsFieldType } from '$lib/enums/settings';
import { settingsStore } from '$lib/stores/settings.svelte';
import { ChatSettingsParameterSourceIndicator } from '$lib/components/app';
import type { Component } from 'svelte';
@@ -31,7 +33,7 @@
{#each fields as field (field.key)}
<div class="space-y-2">
{#if field.type === 'input'}
{#if field.type === SettingsFieldType.INPUT}
{@const paramInfo = getParameterSourceInfo(field.key)}
{@const currentValue = String(localConfig[field.key] ?? '')}
{@const propsDefault = paramInfo?.serverDefault}
@@ -98,7 +100,7 @@
{@html field.help || SETTING_CONFIG_INFO[field.key]}
</p>
{/if}
{:else if field.type === 'textarea'}
{:else if field.type === SettingsFieldType.TEXTAREA}
<Label for={field.key} class="block flex items-center gap-1.5 text-sm font-medium">
{field.label}
@@ -121,7 +123,7 @@
</p>
{/if}
{#if field.key === 'systemMessage'}
{#if field.key === SETTINGS_KEYS.SYSTEM_MESSAGE}
<div class="mt-3 flex items-center gap-2">
<Checkbox
id="showSystemMessage"
@@ -134,7 +136,7 @@
</Label>
</div>
{/if}
{:else if field.type === 'select'}
{:else if field.type === SettingsFieldType.SELECT}
{@const selectedOption = field.options?.find(
(opt: { value: string; label: string; icon?: Component }) =>
opt.value === localConfig[field.key]
@@ -166,7 +168,7 @@
type="single"
value={currentValue}
onValueChange={(value) => {
if (field.key === 'theme' && value && onThemeChange) {
if (field.key === SETTINGS_KEYS.THEME && value && onThemeChange) {
onThemeChange(value);
} else {
onConfigChange(field.key, value);
@@ -222,7 +224,7 @@
{field.help || SETTING_CONFIG_INFO[field.key]}
</p>
{/if}
{:else if field.type === 'checkbox'}
{:else if field.type === SettingsFieldType.CHECKBOX}
<div class="flex items-start space-x-3">
<Checkbox
id={field.key}
@@ -1,11 +1,10 @@
<script lang="ts">
import { Download, Upload, Trash2 } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import { DialogConversationSelection } from '$lib/components/app';
import { DialogConversationSelection, DialogConfirmation } from '$lib/components/app';
import { createMessageCountMap } from '$lib/utils';
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
import { toast } from 'svelte-sonner';
import DialogConfirmation from '$lib/components/app/dialogs/DialogConfirmation.svelte';
let exportedConversations = $state<DatabaseConversation[]>([]);
let importedConversations = $state<DatabaseConversation[]>([]);
@@ -9,7 +9,7 @@
import Input from '$lib/components/ui/input/input.svelte';
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
import { chatStore } from '$lib/stores/chat.svelte';
import { getPreviewText } from '$lib/utils/text';
import { getPreviewText } from '$lib/utils';
import ChatSidebarActions from './ChatSidebarActions.svelte';
const sidebar = Sidebar.useSidebar();
@@ -1,6 +1,6 @@
<script lang="ts">
import { Trash2, Pencil, MoreHorizontal, Download, Loader2, Square } from '@lucide/svelte';
import { ActionDropdown } from '$lib/components/app';
import { DropdownMenuActions } from '$lib/components/app';
import * as Tooltip from '$lib/components/ui/tooltip';
import { getAllLoadingChats } from '$lib/stores/chat.svelte';
import { conversationsStore } from '$lib/stores/conversations.svelte';
@@ -128,7 +128,7 @@
{#if renderActionsDropdown}
<div class="actions flex items-center">
<ActionDropdown
<DropdownMenuActions
triggerIcon={MoreHorizontal}
triggerTooltip="More actions"
bind:open={dropdownOpen}
@@ -0,0 +1,597 @@
/**
*
* ATTACHMENTS
*
* Components for displaying and managing different attachment types in chat messages.
* Supports two operational modes:
* - **Readonly mode**: For displaying stored attachments in sent messages (DatabaseMessageExtra[])
* - **Editable mode**: For managing pending uploads in the input form (ChatUploadedFile[])
*
* The attachment system uses `getAttachmentDisplayItems()` utility to normalize both
* data sources into a unified display format, enabling consistent rendering regardless
* of the attachment origin.
*
*/
/**
* **ChatAttachmentsList** - Unified display for file attachments in chat
*
* Central component for rendering file attachments in both ChatMessage (readonly)
* and ChatForm (editable) contexts.
*
* **Architecture:**
* - Delegates rendering to specialized thumbnail components based on attachment type
* - Manages scroll state and navigation arrows for horizontal overflow
* - Integrates with DialogChatAttachmentPreview for full-size viewing
* - Validates vision modality support via `activeModelId` prop
*
* **Features:**
* - Horizontal scroll with smooth navigation arrows
* - Image thumbnails with lazy loading and error fallback
* - File type icons for non-image files (PDF, text, audio, etc.)
* - Click-to-preview with full-size dialog and download option
* - "View All" button when `limitToSingleRow` is enabled and content overflows
* - Vision modality validation to warn about unsupported image uploads
* - Customizable thumbnail dimensions via `imageHeight`/`imageWidth` props
*
* @example
* ```svelte
* <!-- Readonly mode (in ChatMessage) -->
* <ChatAttachmentsList attachments={message.extra} readonly />
*
* <!-- Editable mode (in ChatForm) -->
* <ChatAttachmentsList
* bind:uploadedFiles
* onFileRemove={(id) => removeFile(id)}
* limitToSingleRow
* activeModelId={selectedModel}
* />
* ```
*/
export { default as ChatAttachmentsList } from './ChatAttachments/ChatAttachmentsList.svelte';
/**
* Thumbnail for non-image file attachments. Displays file type icon based on extension,
* file name (truncated), and file size.
* Handles text files, PDFs, audio, and other document types.
*/
export { default as ChatAttachmentThumbnailFile } from './ChatAttachments/ChatAttachmentThumbnailFile.svelte';
/**
* Thumbnail for image attachments with lazy loading and error fallback.
* Displays image preview with configurable dimensions. Falls back to placeholder
* on load error.
*/
export { default as ChatAttachmentThumbnailImage } from './ChatAttachments/ChatAttachmentThumbnailImage.svelte';
/**
* Grid view of all attachments for "View All" dialog. Displays all attachments
* in a responsive grid layout when there are too many to show inline.
* Triggered by "+X more" button in ChatAttachmentsList.
*/
export { default as ChatAttachmentsViewAll } from './ChatAttachments/ChatAttachmentsViewAll.svelte';
/**
* Full-size preview dialog for attachments. Opens when clicking on any attachment
* thumbnail. Shows the attachment in full size with options to download or close.
* Handles both image and non-image attachments with appropriate rendering.
*/
export { default as ChatAttachmentPreview } from './ChatAttachments/ChatAttachmentPreview.svelte';
/**
*
* FORM
*
* Components for the chat input area. The form handles user input, file attachments,
* audio recording. It integrates with multiple stores:
* - `chatStore` for message submission and generation control
* - `modelsStore` for model selection and validation
*
* The form exposes a public API for programmatic control from parent components
* (focus, height reset, model selector, validation).
*
*/
/**
* **ChatForm** - Main chat input component with rich features
*
* The primary input interface for composing and sending chat messages.
* Orchestrates text input, file attachments, audio recording.
* Used by ChatScreenForm and ChatMessageEditForm for both new conversations and message editing.
*
* **Architecture:**
* - Composes ChatFormTextarea, ChatFormActions, and ChatFormPromptPicker
* - Manages file upload state via `uploadedFiles` bindable prop
* - Integrates with ModelsSelector for model selection in router mode
* - Communicates with parent via callbacks (onSubmit, onFilesAdd, onStop, etc.)
*
* **Input Handling:**
* - IME-safe Enter key handling (waits for composition end)
* - Shift+Enter for newline, Enter for submit
* - Paste handler for files and long text (> {pasteLongTextToFileLen} chars file conversion)
*
* **Features:**
* - Auto-resizing textarea with placeholder
* - File upload via button dropdown (images/text/PDF), drag-drop, or paste
* - Audio recording with WAV conversion (when model supports audio)
* - Model selector integration (router mode)
* - Loading state with stop button, disabled state for errors
*
* **Exported API:**
* - `focus()` - Focus the textarea programmatically
* - `resetTextareaHeight()` - Reset textarea to default height after submit
* - `openModelSelector()` - Open model selection dropdown
* - `checkModelSelected(): boolean` - Validate model selection, show error if none
*
* @example
* ```svelte
* <ChatForm
* bind:this={chatFormRef}
* bind:value={message}
* bind:uploadedFiles
* {isLoading}
* onSubmit={handleSubmit}
* onFilesAdd={processFiles}
* onStop={handleStop}
* />
* ```
*/
export { default as ChatForm } from './ChatForm/ChatForm.svelte';
/**
* Dropdown button for file attachment selection. Opens a menu with options for
* Images, Text Files, and PDF Files. Each option filters the file picker to
* appropriate types. Images option is disabled when model lacks vision modality.
*/
export { default as ChatFormActionAttachmentsDropdown } from './ChatForm/ChatFormActions/ChatFormActionAttachmentsDropdown.svelte';
/**
* Audio recording button with real-time recording indicator. Records audio
* and converts to WAV format for upload. Only visible when the active model
* supports audio modality and setting for automatic audio input is enabled. Shows recording duration while active.
*/
export { default as ChatFormActionRecord } from './ChatForm/ChatFormActions/ChatFormActionRecord.svelte';
/**
* Container for chat form action buttons. Arranges file attachment, audio record,
* and submit/stop buttons in a horizontal layout. Handles conditional visibility
* based on model capabilities and loading state.
*/
export { default as ChatFormActions } from './ChatForm/ChatFormActions/ChatFormActions.svelte';
/**
* Submit/stop button with loading state. Shows send icon normally, transforms
* to stop icon during generation. Disabled when input is empty or form is disabled.
* Triggers onSubmit or onStop callbacks based on current state.
*/
export { default as ChatFormActionSubmit } from './ChatForm/ChatFormActions/ChatFormActionSubmit.svelte';
/**
* Hidden file input element for programmatic file selection.
*/
export { default as ChatFormFileInputInvisible } from './ChatForm/ChatFormFileInputInvisible.svelte';
/**
* Helper text display below chat.
*/
export { default as ChatFormHelperText } from './ChatForm/ChatFormHelperText.svelte';
/**
* Auto-resizing textarea with IME composition support. Automatically adjusts
* height based on content. Handles IME input correctly (waits for composition
* end before processing Enter key). Exposes focus() and resetHeight() methods.
*/
export { default as ChatFormTextarea } from './ChatForm/ChatFormTextarea.svelte';
/**
*
* MESSAGES
*
* Components for displaying chat messages. The message system supports:
* - **Conversation branching**: Messages can have siblings (alternative versions)
* created by editing or regenerating. Users can navigate between branches.
* - **Role-based rendering**: Different layouts for user, assistant, and system messages
* - **Streaming support**: Real-time display of assistant responses as they generate
* - **Agentic workflows**: Special rendering for tool calls and reasoning blocks
*
* The branching system uses `getMessageSiblings()` utility to compute sibling info
* for each message based on the full conversation tree stored in the database.
*
*/
/**
* **ChatMessages** - Message list container with branching support
*
* Container component that renders the list of messages in a conversation.
* Computes sibling information for each message to enable branch navigation.
* Integrates with conversationsStore for message operations.
*
* **Architecture:**
* - Fetches all conversation messages to compute sibling relationships
* - Filters system messages based on user config (`showSystemMessage`)
* - Delegates rendering to ChatMessage for each message
* - Propagates all message operations to chatStore via callbacks
*
* **Branching Logic:**
* - Uses `getMessageSiblings()` to find all messages with same parent
* - Computes `siblingInfo: { currentIndex, totalSiblings, siblingIds }`
* - Enables navigation between alternative message versions
*
* **Message Operations (delegated to chatStore):**
* - Edit with branching: Creates new message branch, preserves original
* - Edit with replacement: Modifies message in place
* - Regenerate: Creates new assistant response as sibling
* - Delete: Removes message and all descendants (cascade)
* - Continue: Appends to incomplete assistant message
*
* @example
* ```svelte
* <ChatMessages
* messages={activeMessages()}
* onUserAction={resetAutoScroll}
* />
* ```
*/
export { default as ChatMessages } from './ChatMessages/ChatMessages.svelte';
/**
* **ChatMessage** - Single message display with actions
*
* Renders a single chat message with role-specific styling and full action
* support. Delegates to specialized components based on message role:
* ChatMessageUser, ChatMessageAssistant, or ChatMessageSystem.
*
* **Architecture:**
* - Routes to role-specific component based on `message.type`
* - Manages edit mode state and inline editing UI
* - Handles action callbacks (copy, edit, delete, regenerate)
* - Displays branching controls when message has siblings
*
* **User Messages:**
* - Shows attachments via ChatAttachmentsList
* - Edit creates new branch or preserves responses
*
* **Assistant Messages:**
* - Renders content via MarkdownContent or ChatMessageAgenticContent
* - Shows model info badge (when enabled)
* - Regenerate creates sibling with optional model override
* - Continue action for incomplete responses
*
* **Features:**
* - Inline editing with file attachments support
* - Copy formatted content to clipboard
* - Delete with confirmation (shows cascade delete count)
* - Branching controls for sibling navigation
* - Statistics display (tokens, timing)
*
* @example
* ```svelte
* <ChatMessage
* {message}
* {siblingInfo}
* onEditWithBranching={handleEdit}
* onRegenerateWithBranching={handleRegenerate}
* onNavigateToSibling={handleNavigate}
* />
* ```
*/
export { default as ChatMessage } from './ChatMessages/ChatMessage.svelte';
/**
* Action buttons toolbar for messages. Displays copy, edit, delete, and regenerate
* buttons based on message role. Includes branching controls when message has siblings.
* Shows delete confirmation dialog with cascade delete count. Handles raw output toggle
* for assistant messages.
*/
export { default as ChatMessageActions } from './ChatMessages/ChatMessageActions.svelte';
/**
* Navigation controls for message siblings (conversation branches). Displays
* prev/next arrows with current position counter (e.g., "2/5"). Enables users
* to navigate between alternative versions of a message created by editing
* or regenerating. Uses `conversationsStore.navigateToSibling()` for navigation.
*/
export { default as ChatMessageBranchingControls } from './ChatMessages/ChatMessageBranchingControls.svelte';
/**
* Statistics display for assistant messages. Shows token counts (prompt/completion),
* generation timing, tokens per second, and model name (when enabled in settings).
* Data sourced from message.timings stored during generation.
*/
export { default as ChatMessageStatistics } from './ChatMessages/ChatMessageStatistics.svelte';
/**
* System message display component. Renders system messages with distinct styling.
* Visibility controlled by `showSystemMessage` config setting.
*/
export { default as ChatMessageSystem } from './ChatMessages/ChatMessageSystem.svelte';
/**
* User message display component. Renders user messages with right-aligned bubble styling.
* Shows message content, attachments via ChatAttachmentsList.
* Supports inline editing mode with ChatMessageEditForm integration.
*/
export { default as ChatMessageUser } from './ChatMessages/ChatMessageUser.svelte';
/**
* Assistant message display component. Renders assistant responses with left-aligned styling.
* Supports both plain markdown content (via MarkdownContent) and agentic content with tool calls
* (via ChatMessageAgenticContent). Shows model info badge, statistics, and action buttons.
* Handles streaming state with real-time content updates.
*/
export { default as ChatMessageAssistant } from './ChatMessages/ChatMessageAssistant.svelte';
/**
* Inline message editing form. Provides textarea for editing message content with
* attachment management. Shows save/cancel buttons and optional "Save only" button
* for editing without regenerating responses. Used within ChatMessage components
* when user enters edit mode.
*/
export { default as ChatMessageEditForm } from './ChatMessages/ChatMessageEditForm.svelte';
/**
*
* SCREEN
*
* Top-level chat interface components. ChatScreen is the main container that
* orchestrates all chat functionality. It integrates with multiple stores:
* - `chatStore` for message operations and generation control
* - `conversationsStore` for conversation management
* - `serverStore` for server connection state
* - `modelsStore` for model capabilities (vision, audio modalities)
*
* The screen handles the complete chat lifecycle from empty state to active
* conversation with streaming responses.
*
*/
/**
* **ChatScreen** - Main chat interface container
*
* Top-level component that orchestrates the entire chat interface. Manages
* messages display, input form, file handling, auto-scroll, error dialogs,
* and server state. Used as the main content area in chat routes.
*
* **Architecture:**
* - Composes ChatMessages, ChatScreenForm, ChatScreenHeader, and dialogs
* - Manages auto-scroll via `createAutoScrollController()` hook
* - Handles file upload pipeline (validation processing state update)
* - Integrates with serverStore for loading/error/warning states
* - Tracks active model for modality validation (vision, audio)
*
* **File Upload Pipeline:**
* 1. Files received via drag-drop, paste, or file picker
* 2. Validated against supported types (`isFileTypeSupported()`)
* 3. Filtered by model modalities (`filterFilesByModalities()`)
* 4. Empty files detected and reported via DialogEmptyFileAlert
* 5. Valid files processed to ChatUploadedFile[] format
* 6. Unsupported files shown in error dialog with reasons
*
* **State Management:**
* - `isEmpty`: Shows centered welcome UI when no conversation active
* - `isCurrentConversationLoading`: Tracks generation state for current chat
* - `activeModelId`: Determines available modalities for file validation
* - `uploadedFiles`: Pending file attachments for next message
*
* **Features:**
* - Messages display with smart auto-scroll (pauses on user scroll up)
* - File drag-drop with visual overlay indicator
* - File validation with detailed error messages
* - Error dialog management (chat errors, model unavailable)
* - Server loading/error/warning states with appropriate UI
* - Conversation deletion with confirmation dialog
* - Processing info display (tokens/sec, timing) during generation
* - Keyboard shortcuts (Ctrl+Shift+Backspace to delete conversation)
*
* @example
* ```svelte
* <!-- In chat route -->
* <ChatScreen showCenteredEmpty={true} />
*
* <!-- In conversation route -->
* <ChatScreen showCenteredEmpty={false} />
* ```
*/
export { default as ChatScreen } from './ChatScreen/ChatScreen.svelte';
/**
* Visual overlay displayed when user drags files over the chat screen.
* Shows drop zone indicator to guide users where to release files.
* Integrated with ChatScreen's drag-drop file upload handling.
*/
export { default as ChatScreenDragOverlay } from './ChatScreen/ChatScreenDragOverlay.svelte';
/**
* Chat form wrapper within ChatScreen. Positions the ChatForm component at the
* bottom of the screen with proper padding and max-width constraints. Handles
* the visual container styling for the input area.
*/
export { default as ChatScreenForm } from './ChatScreen/ChatScreenForm.svelte';
/**
* Header bar for chat screen. Displays conversation title (or "New Chat"),
* model selector (in router mode), and action buttons (delete conversation).
* Sticky positioned at the top of the chat area.
*/
export { default as ChatScreenHeader } from './ChatScreen/ChatScreenHeader.svelte';
/**
* Processing info display during generation. Shows real-time statistics:
* tokens per second, prompt/completion token counts, and elapsed time.
* Data sourced from slotsService polling during active generation.
* Only visible when `isCurrentConversationLoading` is true.
*/
export { default as ChatScreenProcessingInfo } from './ChatScreen/ChatScreenProcessingInfo.svelte';
/**
*
* SETTINGS
*
* Application settings components. Settings are persisted to localStorage via
* the config store and synchronized with server `/props` endpoint for sampling
* parameters. The settings panel uses a tabbed interface with mobile-responsive
* horizontal scrolling tabs.
*
* **Parameter Sync System:**
* Sampling parameters (temperature, top_p, etc.) can come from three sources:
* 1. **Server Props**: Default values from `/props` endpoint
* 2. **User Custom**: Values explicitly set by user (overrides server)
* 3. **App Default**: Fallback when server props unavailable
*
* The `ChatSettingsParameterSourceIndicator` badge shows which source is active.
*
*/
/**
* **ChatSettings** - Application settings panel
*
* Comprehensive settings interface with categorized sections. Manages all
* user preferences and sampling parameters. Integrates with config store
* for persistence and ParameterSyncService for server synchronization.
*
* **Architecture:**
* - Uses tabbed navigation with category sections
* - Maintains local form state, commits on save
* - Tracks user overrides vs server defaults for sampling params
* - Exposes reset() method for dialog close without save
*
* **Categories:**
* - **General**: API key, system message, show system messages toggle
* - **Display**: Theme selection, message actions visibility, model info badge
* - **Sampling**: Temperature, top_p, top_k, min_p, repeat_penalty, etc.
* - **Penalties**: Frequency penalty, presence penalty, repeat last N
* - **Import/Export**: Conversation backup and restore
* - **Developer**: Debug options, disable auto-scroll
*
* **Parameter Sync:**
* - Fetches defaults from server `/props` endpoint
* - Shows source indicator badge (Custom/Server Props/Default)
* - Real-time badge updates as user types
* - Tracks which parameters user has explicitly overridden
*
* **Features:**
* - Mobile-responsive layout with horizontal scrolling tabs
* - Form validation with error messages
* - Secure API key storage (masked input)
* - Import/export conversations as JSON
* - Reset to defaults option per parameter
*
* **Exported API:**
* - `reset()` - Reset form fields to currently saved values (for cancel action)
*
* @example
* ```svelte
* <ChatSettings
* bind:this={settingsRef}
* onSave={() => dialogOpen = false}
* onCancel={() => { settingsRef.reset(); dialogOpen = false; }}
* />
* ```
*/
export { default as ChatSettings } from './ChatSettings/ChatSettings.svelte';
/**
* Footer with save/cancel buttons for settings panel. Positioned at bottom
* of settings dialog. Save button commits form state to config store,
* cancel button triggers reset and close.
*/
export { default as ChatSettingsFooter } from './ChatSettings/ChatSettingsFooter.svelte';
/**
* Form fields renderer for individual settings. Generates appropriate input
* components based on field type (text, number, select, checkbox, textarea).
* Handles validation, help text display, and parameter source indicators.
*/
export { default as ChatSettingsFields } from './ChatSettings/ChatSettingsFields.svelte';
/**
* Import/export tab content for conversation data management. Provides buttons
* to export all conversations as JSON file and import from JSON file.
* Handles file download/upload and data validation.
*/
export { default as ChatSettingsImportExportTab } from './ChatSettings/ChatSettingsImportExportTab.svelte';
/**
* Badge indicating parameter source for sampling settings. Shows one of:
* - **Custom**: User has explicitly set this value (orange badge)
* - **Server Props**: Using default from `/props` endpoint (blue badge)
* - **Default**: Using app default, server props unavailable (gray badge)
* Updates in real-time as user types to show immediate feedback.
*/
export { default as ChatSettingsParameterSourceIndicator } from './ChatSettings/ChatSettingsParameterSourceIndicator.svelte';
/**
*
* SIDEBAR
*
* The sidebar integrates with ShadCN's sidebar component system
* for consistent styling and mobile responsiveness.
* Conversations are loaded from conversationsStore and displayed in reverse
* chronological order (most recent first).
*
*/
/**
* **ChatSidebar** - Chat Sidebar with actions menu and conversation list
*
* Collapsible sidebar displaying conversation history with search and
* management actions. Integrates with ShadCN sidebar component for
* consistent styling and mobile responsiveness.
*
* **Architecture:**
* - Uses ShadCN Sidebar.* components for structure
* - Fetches conversations from conversationsStore
* - Manages search state and filtered results locally
* - Handles conversation CRUD operations via conversationsStore
*
* **Navigation:**
* - Click conversation to navigate to `/chat/[id]`
* - New chat button navigates to `/` (root)
* - Active conversation highlighted based on route params
*
* **Conversation Management:**
* - Right-click or menu button for context menu
* - Rename: Opens inline edit dialog
* - Delete: Shows confirmation with conversation preview
* - Delete All: Removes all conversations with confirmation
*
* **Features:**
* - Search/filter conversations by title
* - Conversation list with message previews (first message truncated)
* - Active conversation highlighting
* - Mobile-responsive collapse/expand via ShadCN sidebar
* - New chat button in header
* - Settings button opens DialogChatSettings
*
* **Exported API:**
* - `handleMobileSidebarItemClick()` - Close sidebar on mobile after item selection
* - `activateSearchMode()` - Focus search input programmatically
* - `editActiveConversation()` - Open rename dialog for current conversation
*
* @example
* ```svelte
* <ChatSidebar bind:this={sidebarRef} />
* ```
*/
export { default as ChatSidebar } from './ChatSidebar/ChatSidebar.svelte';
/**
* Action buttons for sidebar header. Contains new chat button, settings button,
* and delete all conversations button. Manages dialog states for settings and
* delete confirmation.
*/
export { default as ChatSidebarActions } from './ChatSidebar/ChatSidebarActions.svelte';
/**
* Single conversation item in sidebar. Displays conversation title (truncated),
* last message preview, and timestamp. Shows context menu on right-click with
* rename and delete options. Highlights when active (matches current route).
* Handles click to navigate and keyboard accessibility.
*/
export { default as ChatSidebarConversationItem } from './ChatSidebar/ChatSidebarConversationItem.svelte';
/**
* Search input for filtering conversations in sidebar. Filters conversation
* list by title as user types. Shows clear button when query is not empty.
* Integrated into sidebar header with proper styling.
*/
export { default as ChatSidebarSearch } from './ChatSidebar/ChatSidebarSearch.svelte';
@@ -616,7 +616,7 @@
code={incompleteCodeBlock.code}
language={incompleteCodeBlock.language || 'text'}
disabled={true}
onPreview={(code: string, lang: string) => {
onPreview={(code, lang) => {
previewCode = code;
previewLanguage = lang;
previewDialogOpen = true;
@@ -0,0 +1,416 @@
/**
*
* DIALOGS
*
* Modal dialog components for the chat application.
*
* All dialogs use ShadCN Dialog or AlertDialog components for consistent
* styling, accessibility, and animation. They integrate with application
* stores for state management and data access.
*
*/
/**
*
* SETTINGS DIALOGS
*
* Dialogs for application and server configuration.
*
*/
/**
* **DialogChatSettings** - Settings dialog wrapper
*
* Modal dialog containing ChatSettings component with proper
* open/close state management and automatic form reset on open.
*
* **Architecture:**
* - Wraps ChatSettings component in ShadCN Dialog
* - Manages open/close state via bindable `open` prop
* - Resets form state when dialog opens to discard unsaved changes
*
* @example
* ```svelte
* <DialogChatSettings bind:open={showSettings} />
* ```
*/
export { default as DialogChatSettings } from './DialogChatSettings.svelte';
/**
*
* CONFIRMATION DIALOGS
*
* Dialogs for user action confirmations. Use AlertDialog for blocking
* confirmations that require explicit user decision before proceeding.
*
*/
/**
* **DialogConfirmation** - Generic confirmation dialog
*
* Reusable confirmation dialog with customizable title, description,
* and action buttons. Supports destructive action styling and custom icons.
* Used for delete confirmations, irreversible actions, and important decisions.
*
* **Architecture:**
* - Uses ShadCN AlertDialog
* - Supports variant styling (default, destructive)
* - Customizable button labels and callbacks
*
* **Features:**
* - Customizable title and description text
* - Destructive variant with red styling for dangerous actions
* - Custom icon support in header
* - Cancel and confirm button callbacks
* - Keyboard accessible (Escape to cancel, Enter to confirm)
*
* @example
* ```svelte
* <DialogConfirmation
* bind:open={showDelete}
* title="Delete conversation?"
* description="This action cannot be undone."
* variant="destructive"
* onConfirm={handleDelete}
* onCancel={() => showDelete = false}
* />
* ```
*/
export { default as DialogConfirmation } from './DialogConfirmation.svelte';
/**
* **DialogConversationTitleUpdate** - Conversation rename confirmation
*
* Confirmation dialog shown when editing the first user message in a conversation.
* Asks user whether to update the conversation title to match the new message content.
*
* **Architecture:**
* - Uses ShadCN AlertDialog
* - Shows current vs proposed title comparison
* - Triggered by ChatMessages when first message is edited
*
* **Features:**
* - Side-by-side display of current and new title
* - "Keep Current Title" and "Update Title" action buttons
* - Styled title previews in muted background boxes
*
* @example
* ```svelte
* <DialogConversationTitleUpdate
* bind:open={showTitleUpdate}
* currentTitle={conversation.name}
* newTitle={truncatedMessageContent}
* onConfirm={updateTitle}
* onCancel={() => showTitleUpdate = false}
* />
* ```
*/
export { default as DialogConversationTitleUpdate } from './DialogConversationTitleUpdate.svelte';
/**
*
* CONTENT PREVIEW DIALOGS
*
* Dialogs for previewing and displaying content in full-screen or modal views.
*
*/
/**
* **DialogCodePreview** - Full-screen code/HTML preview
*
* Full-screen dialog for previewing HTML or code in an isolated iframe.
* Used by MarkdownContent component for previewing rendered HTML blocks
* from code blocks in chat messages.
*
* **Architecture:**
* - Uses ShadCN Dialog with full viewport layout
* - Sandboxed iframe execution (allow-scripts only)
* - Clears content when closed for security
*
* **Features:**
* - Full viewport iframe preview
* - Sandboxed execution environment
* - Close button with mix-blend-difference for visibility over any content
* - Automatic content cleanup on close
* - Supports HTML preview with proper isolation
*
* @example
* ```svelte
* <DialogCodePreview
* bind:open={showPreview}
* code={htmlContent}
* language="html"
* />
* ```
*/
export { default as DialogCodePreview } from './DialogCodePreview.svelte';
/**
*
* ATTACHMENT DIALOGS
*
* Dialogs for viewing and managing file attachments. Support both
* uploaded files (pending) and stored attachments (in messages).
*
*/
/**
* **DialogChatAttachmentPreview** - Full-size attachment preview
*
* Modal dialog for viewing file attachments at full size. Supports different
* file types with appropriate preview modes: images, text files, PDFs, and audio.
*
* **Architecture:**
* - Wraps ChatAttachmentPreview component in ShadCN Dialog
* - Accepts either uploaded file or stored attachment as data source
* - Resets preview state when dialog opens
*
* **Features:**
* - Full-size image display with proper scaling
* - Text file content with syntax highlighting
* - PDF preview with text/image view toggle
* - Audio file placeholder with download option
* - File name and size display in header
* - Download button for all file types
* - Vision modality check for image attachments
*
* @example
* ```svelte
* <!-- Preview uploaded file -->
* <DialogChatAttachmentPreview
* bind:open={showPreview}
* uploadedFile={selectedFile}
* activeModelId={currentModel}
* />
*
* <!-- Preview stored attachment -->
* <DialogChatAttachmentPreview
* bind:open={showPreview}
* attachment={selectedAttachment}
* />
* ```
*/
export { default as DialogChatAttachmentPreview } from './DialogChatAttachmentPreview.svelte';
/**
* **DialogChatAttachmentsViewAll** - Grid view of all attachments
*
* Dialog showing all attachments in a responsive grid layout. Triggered by
* "+X more" button in ChatAttachmentsList when there are too many attachments
* to display inline.
*
* **Architecture:**
* - Wraps ChatAttachmentsViewAll component in ShadCN Dialog
* - Supports both readonly (message view) and editable (form) modes
* - Displays total attachment count in header
*
* **Features:**
* - Responsive grid layout for all attachments
* - Thumbnail previews with click-to-expand
* - Remove button in editable mode
* - Configurable thumbnail dimensions
* - Vision modality validation for images
*
* @example
* ```svelte
* <DialogChatAttachmentsViewAll
* bind:open={showAllAttachments}
* attachments={message.extra}
* readonly
* />
* ```
*/
export { default as DialogChatAttachmentsViewAll } from './DialogChatAttachmentsViewAll.svelte';
/**
*
* ERROR & ALERT DIALOGS
*
* Dialogs for displaying errors, warnings, and alerts to users.
* Provide context about what went wrong and recovery options.
*
*/
/**
* **DialogChatError** - Chat/generation error display
*
* Alert dialog for displaying chat and generation errors with context
* information. Supports different error types with appropriate styling
* and messaging.
*
* **Architecture:**
* - Uses ShadCN AlertDialog for modal display
* - Differentiates between timeout and server errors
* - Shows context info when available (token counts)
*
* **Error Types:**
* - **timeout**: TCP timeout with timer icon, red destructive styling
* - **server**: Server error with warning icon, amber warning styling
*
* **Features:**
* - Type-specific icons (TimerOff for timeout, AlertTriangle for server)
* - Error message display in styled badge
* - Context info showing prompt tokens and context size
* - Close button to dismiss
*
* @example
* ```svelte
* <DialogChatError
* bind:open={showError}
* type="server"
* message={errorMessage}
* contextInfo={{ n_prompt_tokens: 1024, n_ctx: 4096 }}
* />
* ```
*/
export { default as DialogChatError } from './DialogChatError.svelte';
/**
* **DialogEmptyFileAlert** - Empty file upload warning
*
* Alert dialog shown when user attempts to upload empty files. Lists the
* empty files that were detected and removed from attachments, with
* explanation of why empty files cannot be processed.
*
* **Architecture:**
* - Uses ShadCN AlertDialog for modal display
* - Receives list of empty file names from ChatScreen
* - Triggered during file upload validation
*
* **Features:**
* - FileX icon indicating file error
* - List of empty file names in monospace font
* - Explanation of what happened and why
* - Single "Got it" dismiss button
*
* @example
* ```svelte
* <DialogEmptyFileAlert
* bind:open={showEmptyAlert}
* emptyFiles={['empty.txt', 'blank.md']}
* />
* ```
*/
export { default as DialogEmptyFileAlert } from './DialogEmptyFileAlert.svelte';
/**
* **DialogModelNotAvailable** - Model unavailable error
*
* Alert dialog shown when the requested model (from URL params or selection)
* is not available on the server. Displays the requested model name and
* offers selection from available models.
*
* **Architecture:**
* - Uses ShadCN AlertDialog for modal display
* - Integrates with SvelteKit navigation for model switching
* - Receives available models list from modelsStore
*
* **Features:**
* - Warning icon with amber styling
* - Requested model name display in styled badge
* - Scrollable list of available models
* - Click model to navigate with updated URL params
* - Cancel button to dismiss without selection
*
* @example
* ```svelte
* <DialogModelNotAvailable
* bind:open={showModelError}
* modelName={requestedModel}
* availableModels={modelsList}
* />
* ```
*/
export { default as DialogModelNotAvailable } from './DialogModelNotAvailable.svelte';
/**
*
* DATA MANAGEMENT DIALOGS
*
* Dialogs for managing conversation data, including import/export
* and selection operations.
*
*/
/**
* **DialogConversationSelection** - Conversation picker for import/export
*
* Dialog for selecting conversations during import or export operations.
* Displays list of conversations with checkboxes for multi-selection.
* Used by ChatSettingsImportExportTab for data management.
*
* **Architecture:**
* - Wraps ConversationSelection component in ShadCN Dialog
* - Supports export mode (select from local) and import mode (select from file)
* - Resets selection state when dialog opens
* - High z-index to appear above settings dialog
*
* **Features:**
* - Multi-select with checkboxes
* - Conversation title and message count display
* - Select all / deselect all controls
* - Mode-specific descriptions (export vs import)
* - Cancel and confirm callbacks with selected conversations
*
* @example
* ```svelte
* <DialogConversationSelection
* bind:open={showExportSelection}
* conversations={allConversations}
* messageCountMap={messageCounts}
* mode="export"
* onConfirm={handleExport}
* onCancel={() => showExportSelection = false}
* />
* ```
*/
export { default as DialogConversationSelection } from './DialogConversationSelection.svelte';
/**
*
* MODEL INFORMATION DIALOGS
*
* Dialogs for displaying model and server information.
*
*/
/**
* **DialogModelInformation** - Model details display
*
* Dialog showing comprehensive information about the currently loaded model
* and server configuration. Displays model metadata, capabilities, and
* server settings in a structured table format.
*
* **Architecture:**
* - Uses ShadCN Dialog with wide layout for table display
* - Fetches data from serverStore (props) and modelsStore (metadata)
* - Auto-fetches models when dialog opens if not loaded
*
* **Information Displayed:**
* - **Model**: Name with copy button
* - **File Path**: Full path to model file with copy button
* - **Context Size**: Current context window size
* - **Training Context**: Original training context (if available)
* - **Model Size**: File size in human-readable format
* - **Parameters**: Parameter count (e.g., "7B", "70B")
* - **Embedding Size**: Embedding dimension
* - **Vocabulary Size**: Token vocabulary size
* - **Vocabulary Type**: Tokenizer type (BPE, etc.)
* - **Parallel Slots**: Number of concurrent request slots
* - **Modalities**: Supported input types (text, vision, audio)
* - **Build Info**: Server build information
* - **Chat Template**: Full Jinja template in scrollable code block
*
* **Features:**
* - Copy buttons for model name and path
* - Modality badges with icons
* - Responsive table layout with container queries
* - Loading state while fetching model info
* - Scrollable chat template display
*
* @example
* ```svelte
* <DialogModelInformation bind:open={showModelInfo} />
* ```
*/
export { default as DialogModelInformation } from './DialogModelInformation.svelte';
@@ -1,68 +1,10 @@
export * from './actions';
export * from './badges';
export * from './chat';
export * from './content';
export * from './dialogs';
export * from './forms';
export * from './misc';
export * from './models';
export * from './navigation';
export * from './server';
// Chat
export { default as ChatAttachmentPreview } from './chat/ChatAttachments/ChatAttachmentPreview.svelte';
export { default as ChatAttachmentThumbnailFile } from './chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte';
export { default as ChatAttachmentThumbnailImage } from './chat/ChatAttachments/ChatAttachmentThumbnailImage.svelte';
export { default as ChatAttachmentsList } from './chat/ChatAttachments/ChatAttachmentsList.svelte';
export { default as ChatAttachmentsViewAll } from './chat/ChatAttachments/ChatAttachmentsViewAll.svelte';
export { default as ChatForm } from './chat/ChatForm/ChatForm.svelte';
export { default as ChatFormActionAttachmentsDropdown } from './chat/ChatForm/ChatFormActions/ChatFormActionAttachmentsDropdown.svelte';
export { default as ChatFormActionFileAttachments } from './chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte';
export { default as ChatFormActionRecord } from './chat/ChatForm/ChatFormActions/ChatFormActionRecord.svelte';
export { default as ChatFormActions } from './chat/ChatForm/ChatFormActions/ChatFormActions.svelte';
export { default as ChatFormActionSubmit } from './chat/ChatForm/ChatFormActions/ChatFormActionSubmit.svelte';
export { default as ChatFormFileInputInvisible } from './chat/ChatForm/ChatFormFileInputInvisible.svelte';
export { default as ChatFormHelperText } from './chat/ChatForm/ChatFormHelperText.svelte';
export { default as ChatFormTextarea } from './chat/ChatForm/ChatFormTextarea.svelte';
export { default as ChatMessage } from './chat/ChatMessages/ChatMessage.svelte';
export { default as ChatMessageActions } from './chat/ChatMessages/ChatMessageActions.svelte';
export { default as ChatMessageAssistant } from './chat/ChatMessages/ChatMessageAssistant.svelte';
export { default as ChatMessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte';
export { default as ChatMessageEditForm } from './chat/ChatMessages/ChatMessageEditForm.svelte';
export { default as ChatMessageStatistics } from './chat/ChatMessages/ChatMessageStatistics.svelte';
export { default as ChatMessageSystem } from './chat/ChatMessages/ChatMessageSystem.svelte';
export { default as ChatMessageThinkingBlock } from './chat/ChatMessages/ChatMessageThinkingBlock.svelte';
export { default as ChatMessageUser } from './chat/ChatMessages/ChatMessageUser.svelte';
export { default as ChatMessages } from './chat/ChatMessages/ChatMessages.svelte';
export { default as MessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte';
export { default as ChatScreen } from './chat/ChatScreen/ChatScreen.svelte';
export { default as ChatScreenDragOverlay } from './chat/ChatScreen/ChatScreenDragOverlay.svelte';
export { default as ChatScreenForm } from './chat/ChatScreen/ChatScreenForm.svelte';
export { default as ChatScreenHeader } from './chat/ChatScreen/ChatScreenHeader.svelte';
export { default as ChatScreenProcessingInfo } from './chat/ChatScreen/ChatScreenProcessingInfo.svelte';
export { default as ChatSettings } from './chat/ChatSettings/ChatSettings.svelte';
export { default as ChatSettingsFooter } from './chat/ChatSettings/ChatSettingsFooter.svelte';
export { default as ChatSettingsFields } from './chat/ChatSettings/ChatSettingsFields.svelte';
export { default as ChatSettingsImportExportTab } from './chat/ChatSettings/ChatSettingsImportExportTab.svelte';
export { default as ChatSettingsParameterSourceIndicator } from './chat/ChatSettings/ChatSettingsParameterSourceIndicator.svelte';
export { default as ChatSidebar } from './chat/ChatSidebar/ChatSidebar.svelte';
export { default as ChatSidebarActions } from './chat/ChatSidebar/ChatSidebarActions.svelte';
export { default as ChatSidebarConversationItem } from './chat/ChatSidebar/ChatSidebarConversationItem.svelte';
export { default as ChatSidebarSearch } from './chat/ChatSidebar/ChatSidebarSearch.svelte';
// Dialogs
export { default as DialogChatAttachmentPreview } from './dialogs/DialogChatAttachmentPreview.svelte';
export { default as DialogChatAttachmentsViewAll } from './dialogs/DialogChatAttachmentsViewAll.svelte';
export { default as DialogChatError } from './dialogs/DialogChatError.svelte';
export { default as DialogChatSettings } from './dialogs/DialogChatSettings.svelte';
export { default as DialogCodePreview } from './dialogs/DialogCodePreview.svelte';
export { default as DialogConfirmation } from './dialogs/DialogConfirmation.svelte';
export { default as DialogConversationSelection } from './dialogs/DialogConversationSelection.svelte';
export { default as DialogConversationTitleUpdate } from './dialogs/DialogConversationTitleUpdate.svelte';
export { default as DialogEmptyFileAlert } from './dialogs/DialogEmptyFileAlert.svelte';
export { default as DialogModelInformation } from './dialogs/DialogModelInformation.svelte';
export { default as DialogModelNotAvailable } from './dialogs/DialogModelNotAvailable.svelte';
// Compatibility aliases
export { default as ActionButton } from './actions/ActionIcon.svelte';
export { default as ActionDropdown } from './navigation/DropdownMenuActions.svelte';
export { default as CopyToClipboardIcon } from './actions/ActionIconCopyToClipboard.svelte';
export { default as RemoveButton } from './actions/ActionIconRemove.svelte';
@@ -31,8 +31,6 @@
forceForegroundText?: boolean;
/** When true, user's global selection takes priority over currentModel (for form selector) */
useGlobalSelection?: boolean;
/** Optional compatibility prop for context-aware selectors. */
upToMessageId?: string;
}
let {
@@ -41,9 +39,7 @@
onModelChange,
disabled = false,
forceForegroundText = false,
useGlobalSelection = false,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
upToMessageId: _upToMessageId = undefined
useGlobalSelection = false
}: Props = $props();
let options = $derived(modelOptions());
@@ -0,0 +1,37 @@
// Agentic tool call tag markers
export const AGENTIC_TAGS = {
TOOL_CALL_START: '<<<AGENTIC_TOOL_CALL_START>>>',
TOOL_CALL_END: '<<<AGENTIC_TOOL_CALL_END>>>',
TOOL_NAME_PREFIX: '<<<TOOL_NAME:',
TOOL_ARGS_START: '<<<TOOL_ARGS_START>>>',
TOOL_ARGS_END: '<<<TOOL_ARGS_END>>>',
TAG_SUFFIX: '>>>'
} as const;
export const REASONING_TAGS = {
START: '<<<reasoning_content_start>>>',
END: '<<<reasoning_content_end>>>'
} as const;
// Regex patterns for parsing agentic content
export const AGENTIC_REGEX = {
// Matches completed tool calls (with END marker)
COMPLETED_TOOL_CALL:
/<<<AGENTIC_TOOL_CALL_START>>>\n<<<TOOL_NAME:(.+?)>>>\n<<<TOOL_ARGS_START>>>([\s\S]*?)<<<TOOL_ARGS_END>>>([\s\S]*?)<<<AGENTIC_TOOL_CALL_END>>>/g,
// Matches pending tool call (has NAME and ARGS but no END)
PENDING_TOOL_CALL:
/<<<AGENTIC_TOOL_CALL_START>>>\n<<<TOOL_NAME:(.+?)>>>\n<<<TOOL_ARGS_START>>>([\s\S]*?)<<<TOOL_ARGS_END>>>([\s\S]*)$/,
// Matches partial tool call (has START and NAME, ARGS still streaming)
PARTIAL_WITH_NAME:
/<<<AGENTIC_TOOL_CALL_START>>>\n<<<TOOL_NAME:(.+?)>>>\n<<<TOOL_ARGS_START>>>([\s\S]*)$/,
// Matches early tool call (just START marker)
EARLY_MATCH: /<<<AGENTIC_TOOL_CALL_START>>>([\s\S]*)$/,
// Matches partial marker at end of content
PARTIAL_MARKER: /<<<[A-Za-z_]*$/,
// Matches reasoning content blocks (including tags)
REASONING_BLOCK: /<<<reasoning_content_start>>>[\s\S]*?<<<reasoning_content_end>>>/g,
// Matches an opening reasoning tag and any remaining content (unterminated)
REASONING_OPEN: /<<<reasoning_content_start>>>[\s\S]*$/,
// Matches tool name inside content
TOOL_NAME_EXTRACT: /<<<TOOL_NAME:([^>]+)>>>/
} as const;
@@ -0,0 +1,2 @@
export const ATTACHMENT_LABEL_FILE = 'File';
export const ATTACHMENT_LABEL_PDF_FILE = 'PDF File';
+15 -6
View File
@@ -3,31 +3,40 @@
*/
/**
* Default TTL (Time-To-Live) for cache entries in milliseconds.
* Default TTL (Time-To-Live) for cache entries in milliseconds
* @default 5 minutes
*/
export const DEFAULT_CACHE_TTL_MS = 5 * 60 * 1000;
/**
* Default maximum number of entries in a cache.
* Default maximum number of entries in a cache
* @default 100
*/
export const DEFAULT_CACHE_MAX_ENTRIES = 100;
/**
* TTL for model props cache in milliseconds.
* TTL for model props cache in milliseconds
* Props don't change frequently, so we can cache them longer
* @default 10 minutes
*/
export const MODEL_PROPS_CACHE_TTL_MS = 10 * 60 * 1000;
/**
* Maximum number of model props to cache.
* Maximum number of model props to cache
* @default 50
*/
export const MODEL_PROPS_CACHE_MAX_ENTRIES = 50;
/**
* Maximum number of inactive conversation states to keep in memory.
* Maximum number of inactive conversation states to keep in memory
* States for conversations beyond this limit will be cleaned up
* @default 10
*/
export const MAX_INACTIVE_CONVERSATION_STATES = 10;
/**
* Maximum age (in ms) for inactive conversation states before cleanup.
* Maximum age (in ms) for inactive conversation states before cleanup
* States older than this will be removed during cleanup
* @default 30 minutes
*/
export const INACTIVE_CONVERSATION_STATE_MAX_AGE_MS = 30 * 60 * 1000;
@@ -1 +0,0 @@
export const DEFAULT_CONTEXT = 4096;
@@ -1 +0,0 @@
export { INPUT_CLASSES } from './css-classes';
@@ -1,12 +1,14 @@
import { ColorMode } from '$lib/enums/ui';
import { Monitor, Moon, Sun } from '@lucide/svelte';
export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> = {
// Note: in order not to introduce breaking changes, please keep the same data type (number, string, etc) if you want to change the default value. Do not use null or undefined for default value.
// Do not use nested objects, keep it single level. Prefix the key if you need to group them.
apiKey: '',
systemMessage: '',
showSystemMessage: true,
theme: 'system',
theme: ColorMode.SYSTEM,
showThoughtInProgress: false,
showToolCalls: false,
disableReasoningParsing: false,
showRawOutputSwitch: false,
keepStatsVisible: false,
@@ -91,8 +93,6 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
max_tokens: 'The maximum number of token per output. Use -1 for infinite (no limit).',
custom: 'Custom JSON parameters to send to the API. Must be valid JSON format.',
showThoughtInProgress: 'Expand thought process by default when generating messages.',
showToolCalls:
'Display tool call labels and payloads from Harmony-compatible delta.tool_calls data below assistant messages.',
disableReasoningParsing:
'Send reasoning_format=none to prevent server-side extraction of reasoning tokens into separate field',
showRawOutputSwitch:
@@ -118,3 +118,9 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
enableContinueGeneration:
'Enable "Continue" button for assistant messages. Currently works only with non-reasoning models.'
};
export const SETTINGS_COLOR_MODES_CONFIG = [
{ value: ColorMode.SYSTEM, label: 'System', icon: Monitor },
{ value: ColorMode.LIGHT, label: 'Light', icon: Sun },
{ value: ColorMode.DARK, label: 'Dark', icon: Moon }
];
@@ -0,0 +1,52 @@
/**
* Settings key constants for ChatSettings configuration.
*
* These keys correspond to properties in SettingsConfigType and are used
* in settings field configurations to ensure consistency.
*/
export const SETTINGS_KEYS = {
// General
THEME: 'theme',
API_KEY: 'apiKey',
SYSTEM_MESSAGE: 'systemMessage',
PASTE_LONG_TEXT_TO_FILE_LEN: 'pasteLongTextToFileLen',
COPY_TEXT_ATTACHMENTS_AS_PLAIN_TEXT: 'copyTextAttachmentsAsPlainText',
ENABLE_CONTINUE_GENERATION: 'enableContinueGeneration',
PDF_AS_IMAGE: 'pdfAsImage',
ASK_FOR_TITLE_CONFIRMATION: 'askForTitleConfirmation',
// Display
SHOW_MESSAGE_STATS: 'showMessageStats',
SHOW_THOUGHT_IN_PROGRESS: 'showThoughtInProgress',
KEEP_STATS_VISIBLE: 'keepStatsVisible',
AUTO_MIC_ON_EMPTY: 'autoMicOnEmpty',
RENDER_USER_CONTENT_AS_MARKDOWN: 'renderUserContentAsMarkdown',
DISABLE_AUTO_SCROLL: 'disableAutoScroll',
ALWAYS_SHOW_SIDEBAR_ON_DESKTOP: 'alwaysShowSidebarOnDesktop',
AUTO_SHOW_SIDEBAR_ON_NEW_CHAT: 'autoShowSidebarOnNewChat',
// Sampling
TEMPERATURE: 'temperature',
DYNATEMP_RANGE: 'dynatemp_range',
DYNATEMP_EXPONENT: 'dynatemp_exponent',
TOP_K: 'top_k',
TOP_P: 'top_p',
MIN_P: 'min_p',
XTC_PROBABILITY: 'xtc_probability',
XTC_THRESHOLD: 'xtc_threshold',
TYP_P: 'typ_p',
MAX_TOKENS: 'max_tokens',
SAMPLERS: 'samplers',
BACKEND_SAMPLING: 'backend_sampling',
// Penalties
REPEAT_LAST_N: 'repeat_last_n',
REPEAT_PENALTY: 'repeat_penalty',
PRESENCE_PENALTY: 'presence_penalty',
FREQUENCY_PENALTY: 'frequency_penalty',
DRY_MULTIPLIER: 'dry_multiplier',
DRY_BASE: 'dry_base',
DRY_ALLOWED_LENGTH: 'dry_allowed_length',
DRY_PENALTY_LAST_N: 'dry_penalty_last_n',
// Developer
DISABLE_REASONING_PARSING: 'disableReasoningParsing',
SHOW_RAW_OUTPUT_SWITCH: 'showRawOutputSwitch',
CUSTOM: 'custom'
} as const;
+21 -1
View File
@@ -136,9 +136,28 @@ export enum FileExtensionText {
CS = '.cs'
}
// MIME type prefixes and includes for content detection
export enum MimeTypePrefix {
IMAGE = 'image/',
TEXT = 'text'
}
export enum MimeTypeIncludes {
JSON = 'json',
JAVASCRIPT = 'javascript',
TYPESCRIPT = 'typescript'
}
// URI patterns for content detection
export enum UriPattern {
DATABASE_KEYWORD = 'database',
DATABASE_SCHEME = 'db://'
}
// MIME type enums
export enum MimeTypeApplication {
PDF = 'application/pdf'
PDF = 'application/pdf',
OCTET_STREAM = 'application/octet-stream'
}
export enum MimeTypeAudio {
@@ -152,6 +171,7 @@ export enum MimeTypeAudio {
export enum MimeTypeImage {
JPEG = 'image/jpeg',
JPG = 'image/jpg',
PNG = 'image/png',
GIF = 'image/gif',
WEBP = 'image/webp',
+8 -5
View File
@@ -2,11 +2,11 @@ export { AttachmentType } from './attachment';
export {
ChatMessageStatsView,
ReasoningFormat,
ContentPartType,
ErrorDialogType,
MessageRole,
MessageType,
ContentPartType,
ErrorDialogType
ReasoningFormat
} from './chat';
export {
@@ -19,6 +19,9 @@ export {
FileExtensionAudio,
FileExtensionPdf,
FileExtensionText,
MimeTypePrefix,
MimeTypeIncludes,
UriPattern,
MimeTypeApplication,
MimeTypeAudio,
MimeTypeImage,
@@ -31,6 +34,6 @@ export { ServerRole, ServerModelStatus } from './server';
export { ParameterSource, SyncableParameterType, SettingsFieldType } from './settings';
export { KeyboardKey } from './keyboard';
export { ColorMode, UrlPrefix } from './ui';
export { UrlPrefix } from './ui';
export { KeyboardKey } from './keyboard';
+7 -1
View File
@@ -1,5 +1,11 @@
export enum ColorMode {
LIGHT = 'light',
DARK = 'dark',
SYSTEM = 'system'
}
/**
* URL prefixes for protocol detection.
* URL prefixes for protocol detection
*/
export enum UrlPrefix {
DATA = 'data:',
@@ -1,104 +0,0 @@
import { modelsStore } from '$lib/stores/models.svelte';
import { isRouterMode } from '$lib/stores/server.svelte';
import { toast } from 'svelte-sonner';
import type { ModelModalities } from '$lib/types';
interface UseModelChangeValidationOptions {
/**
* Function to get required modalities for validation.
*/
getRequiredModalities: () => ModelModalities;
/**
* Optional callback to execute after successful validation.
*/
onSuccess?: (modelName: string) => void;
/**
* Optional callback for rollback on validation failure.
*/
onValidationFailure?: (previousModelId: string | null) => Promise<void>;
}
export function useModelChangeValidation(options: UseModelChangeValidationOptions) {
const { getRequiredModalities, onSuccess, onValidationFailure } = options;
let previousSelectedModelId: string | null = null;
const isRouter = $derived(isRouterMode());
async function handleModelChange(modelId: string, modelName: string): Promise<boolean> {
try {
if (onValidationFailure) {
previousSelectedModelId = modelsStore.selectedModelId;
}
let hasLoadedModel = false;
const isModelLoadedBefore = modelsStore.isModelLoaded(modelName);
if (isRouter && !isModelLoadedBefore) {
try {
await modelsStore.loadModel(modelName);
hasLoadedModel = true;
} catch {
toast.error(`Failed to load model "${modelName}"`);
return false;
}
}
const props = await modelsStore.fetchModelProps(modelName);
if (props?.modalities) {
const requiredModalities = getRequiredModalities();
const missingModalities: string[] = [];
if (requiredModalities.vision && !props.modalities.vision) {
missingModalities.push('vision');
}
if (requiredModalities.audio && !props.modalities.audio) {
missingModalities.push('audio');
}
if (missingModalities.length > 0) {
toast.error(
`Model "${modelName}" doesn't support required modalities: ${missingModalities.join(', ')}. Please select a different model.`
);
if (isRouter && hasLoadedModel) {
try {
await modelsStore.unloadModel(modelName);
} catch (error) {
console.error('Failed to unload incompatible model:', error);
}
}
if (onValidationFailure && previousSelectedModelId) {
await onValidationFailure(previousSelectedModelId);
}
return false;
}
}
await modelsStore.selectModelById(modelId);
if (onSuccess) {
onSuccess(modelName);
}
return true;
} catch (error) {
console.error('Failed to change model:', error);
toast.error('Failed to validate model capabilities');
if (onValidationFailure && previousSelectedModelId) {
await onValidationFailure(previousSelectedModelId);
}
return false;
}
}
return {
handleModelChange
};
}
@@ -1,42 +1,52 @@
import { getJsonHeaders } from '$lib/utils';
import { AttachmentType } from '$lib/enums';
import { getJsonHeaders, formatAttachmentText, isAbortError } from '$lib/utils';
import { ATTACHMENT_LABEL_PDF_FILE } from '$lib/constants/attachment-labels';
import {
AttachmentType,
ContentPartType,
MessageRole,
ReasoningFormat,
UrlPrefix
} from '$lib/enums';
import type { ApiChatMessageContentPart, ApiChatCompletionToolCall } from '$lib/types/api';
import { modelsStore } from '$lib/stores/models.svelte';
import { AGENTIC_REGEX } from '$lib/constants/agentic';
/**
* ChatService - Low-level API communication layer for Chat Completions
*
* **Terminology - Chat vs Conversation:**
* - **Chat**: The active interaction space with the Chat Completions API. This service
* handles the real-time communication with the AI backend - sending messages, receiving
* streaming responses, and managing request lifecycles. "Chat" is ephemeral and runtime-focused.
* - **Conversation**: The persistent database entity storing all messages and metadata.
* Managed by ConversationsService/Store, conversations persist across sessions.
*
* This service handles direct communication with the llama-server's Chat Completions API.
* It provides the network layer abstraction for AI model interactions while remaining
* stateless and focused purely on API communication.
*
* **Architecture & Relationships:**
* - **ChatService** (this class): Stateless API communication layer
* - Handles HTTP requests/responses with the llama-server
* - Manages streaming and non-streaming response parsing
* - Provides per-conversation request abortion capabilities
* - Converts database messages to API format
* - Handles error translation for server responses
*
* - **chatStore**: Uses ChatService for all AI model communication
* - **conversationsStore**: Provides message context for API requests
*
* **Key Responsibilities:**
* - Message format conversion (DatabaseMessage API format)
* - Streaming response handling with real-time callbacks
* - Reasoning content extraction and processing
* - File attachment processing (images, PDFs, audio, text)
* - Request lifecycle management (abort via AbortSignal)
*/
export class ChatService {
// ─────────────────────────────────────────────────────────────────────────────
// Messaging
// ─────────────────────────────────────────────────────────────────────────────
private static stripReasoningContent(
content: ApiChatMessageData['content'] | null | undefined
): ApiChatMessageData['content'] | null | undefined {
if (!content) {
return content;
}
if (typeof content === 'string') {
return content
.replace(AGENTIC_REGEX.REASONING_BLOCK, '')
.replace(AGENTIC_REGEX.REASONING_OPEN, '');
}
if (!Array.isArray(content)) {
return content;
}
return content.map((part: ApiChatMessageContentPart) => {
if (part.type !== ContentPartType.TEXT || !part.text) return part;
return {
...part,
text: part.text
.replace(AGENTIC_REGEX.REASONING_BLOCK, '')
.replace(AGENTIC_REGEX.REASONING_OPEN, '')
};
});
}
/**
*
*
* Messaging
*
*
*/
/**
* Sends a chat completion request to the llama.cpp server.
@@ -63,6 +73,8 @@ export class ChatService {
onToolCallChunk,
onModel,
onTimings,
// Tools for function calling
tools,
// Generation parameters
temperature,
max_tokens,
@@ -97,6 +109,7 @@ export class ChatService {
.map((msg) => {
if ('id' in msg && 'convId' in msg && 'timestamp' in msg) {
const dbMsg = msg as DatabaseMessage & { extra?: DatabaseMessageExtra[] };
return ChatService.convertDbMessageToApiChatMessageData(dbMsg);
} else {
return msg as ApiChatMessageData;
@@ -104,7 +117,7 @@ export class ChatService {
})
.filter((msg) => {
// Filter out empty system messages
if (msg.role === 'system') {
if (msg.role === MessageRole.SYSTEM) {
const content = typeof msg.content === 'string' ? msg.content : '';
return content.trim().length > 0;
@@ -113,13 +126,41 @@ export class ChatService {
return true;
});
// Filter out image attachments if the model doesn't support vision
if (options.model && !modelsStore.modelSupportsVision(options.model)) {
normalizedMessages.forEach((msg) => {
if (Array.isArray(msg.content)) {
msg.content = msg.content.filter((part: ApiChatMessageContentPart) => {
if (part.type === ContentPartType.IMAGE_URL) {
console.info(
`[ChatService] Skipping image attachment in message history (model "${options.model}" does not support vision)`
);
return false;
}
return true;
});
// If only text remains and it's a single part, simplify to string
if (msg.content.length === 1 && msg.content[0].type === ContentPartType.TEXT) {
msg.content = msg.content[0].text;
}
}
});
}
const requestBody: ApiChatCompletionRequest = {
messages: normalizedMessages.map((msg: ApiChatMessageData) => ({
role: msg.role,
content: msg.content
// Strip reasoning tags/content from the prompt to avoid polluting KV cache.
// TODO: investigate backend expectations for reasoning tags and add a toggle if needed.
content: ChatService.stripReasoningContent(msg.content),
tool_calls: msg.tool_calls,
tool_call_id: msg.tool_call_id
})),
stream,
return_progress: stream ? true : undefined
return_progress: stream ? true : undefined,
tools: tools && tools.length > 0 ? tools : undefined
};
// Include model in request if provided (required in ROUTER mode)
@@ -127,7 +168,9 @@ export class ChatService {
requestBody.model = options.model;
}
requestBody.reasoning_format = disableReasoningParsing ? 'none' : 'auto';
requestBody.reasoning_format = disableReasoningParsing
? ReasoningFormat.NONE
: ReasoningFormat.AUTO;
if (temperature !== undefined) requestBody.temperature = temperature;
if (max_tokens !== undefined) {
@@ -183,9 +226,11 @@ export class ChatService {
if (!response.ok) {
const error = await ChatService.parseErrorResponse(response);
if (onError) {
onError(error);
}
throw error;
}
@@ -202,6 +247,7 @@ export class ChatService {
conversationId,
signal
);
return;
} else {
return ChatService.handleNonStreamResponse(
@@ -213,7 +259,7 @@ export class ChatService {
);
}
} catch (error) {
if (error instanceof Error && error.name === 'AbortError') {
if (isAbortError(error)) {
console.log('Chat completion request was aborted');
return;
}
@@ -240,16 +286,22 @@ export class ChatService {
}
console.error('Error in sendMessage:', error);
if (onError) {
onError(userFriendlyError);
}
throw userFriendlyError;
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Streaming
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* Streaming
*
*
*/
/**
* Handles streaming response from the chat completion API
@@ -323,6 +375,10 @@ export class ChatService {
const serializedToolCalls = JSON.stringify(aggregatedToolCalls);
if (import.meta.env.DEV) {
console.log('[ChatService] Aggregated tool calls:', serializedToolCalls);
}
if (!serializedToolCalls) {
return;
}
@@ -349,10 +405,11 @@ export class ChatService {
for (const line of lines) {
if (abortSignal?.aborted) break;
if (line.startsWith('data: ')) {
if (line.startsWith(UrlPrefix.DATA)) {
const data = line.slice(6);
if (data === '[DONE]') {
streamFinished = true;
continue;
}
@@ -458,6 +515,7 @@ export class ChatService {
if (!responseText.trim()) {
const noResponseError = new Error('No response received from server. Please try again.');
throw noResponseError;
}
@@ -472,10 +530,6 @@ export class ChatService {
const reasoningContent = data.choices[0]?.message?.reasoning_content;
const toolCalls = data.choices[0]?.message?.tool_calls;
if (reasoningContent) {
console.log('Full reasoning content:', reasoningContent);
}
let serializedToolCalls: string | undefined;
if (toolCalls && toolCalls.length > 0) {
@@ -491,6 +545,7 @@ export class ChatService {
if (!content.trim() && !serializedToolCalls) {
const noResponseError = new Error('No response received from server. Please try again.');
throw noResponseError;
}
@@ -563,9 +618,13 @@ export class ChatService {
return result;
}
// ─────────────────────────────────────────────────────────────────────────────
// Conversion
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* Conversion
*
*
*/
/**
* Converts a database message with attachments to API chat message format.
@@ -582,22 +641,48 @@ export class ChatService {
static convertDbMessageToApiChatMessageData(
message: DatabaseMessage & { extra?: DatabaseMessageExtra[] }
): ApiChatMessageData {
if (!message.extra || message.extra.length === 0) {
// Handle tool result messages (role: 'tool')
if (message.role === MessageRole.TOOL && message.toolCallId) {
return {
role: message.role as 'user' | 'assistant' | 'system',
role: MessageRole.TOOL,
content: message.content,
tool_call_id: message.toolCallId
};
}
// Parse tool calls for assistant messages
let toolCalls: ApiChatCompletionToolCall[] | undefined;
if (message.toolCalls) {
try {
toolCalls = JSON.parse(message.toolCalls);
} catch {
// Ignore parse errors for malformed tool calls
}
}
if (!message.extra || message.extra.length === 0) {
const result: ApiChatMessageData = {
role: message.role as MessageRole,
content: message.content
};
if (toolCalls && toolCalls.length > 0) {
result.tool_calls = toolCalls;
}
return result;
}
const contentParts: ApiChatMessageContentPart[] = [];
if (message.content) {
contentParts.push({
type: 'text',
type: ContentPartType.TEXT,
text: message.content
});
}
// Include images from all messages
const imageFiles = message.extra.filter(
(extra: DatabaseMessageExtra): extra is DatabaseMessageExtraImageFile =>
extra.type === AttachmentType.IMAGE
@@ -605,7 +690,7 @@ export class ChatService {
for (const image of imageFiles) {
contentParts.push({
type: 'image_url',
type: ContentPartType.IMAGE_URL,
image_url: { url: image.base64Url }
});
}
@@ -617,8 +702,8 @@ export class ChatService {
for (const textFile of textFiles) {
contentParts.push({
type: 'text',
text: `\n\n--- File: ${textFile.name} ---\n${textFile.content}`
type: ContentPartType.TEXT,
text: formatAttachmentText('File', textFile.name, textFile.content)
});
}
@@ -630,8 +715,8 @@ export class ChatService {
for (const legacyContextFile of legacyContextFiles) {
contentParts.push({
type: 'text',
text: `\n\n--- File: ${legacyContextFile.name} ---\n${legacyContextFile.content}`
type: ContentPartType.TEXT,
text: formatAttachmentText('File', legacyContextFile.name, legacyContextFile.content)
});
}
@@ -642,7 +727,7 @@ export class ChatService {
for (const audio of audioFiles) {
contentParts.push({
type: 'input_audio',
type: ContentPartType.INPUT_AUDIO,
input_audio: {
data: audio.base64Data,
format: audio.mimeType.includes('wav') ? 'wav' : 'mp3'
@@ -659,27 +744,33 @@ export class ChatService {
if (pdfFile.processedAsImages && pdfFile.images) {
for (let i = 0; i < pdfFile.images.length; i++) {
contentParts.push({
type: 'image_url',
type: ContentPartType.IMAGE_URL,
image_url: { url: pdfFile.images[i] }
});
}
} else {
contentParts.push({
type: 'text',
text: `\n\n--- PDF File: ${pdfFile.name} ---\n${pdfFile.content}`
type: ContentPartType.TEXT,
text: formatAttachmentText(ATTACHMENT_LABEL_PDF_FILE, pdfFile.name, pdfFile.content)
});
}
}
return {
role: message.role as 'user' | 'assistant' | 'system',
const result: ApiChatMessageData = {
role: message.role as MessageRole,
content: contentParts
};
return result;
}
// ─────────────────────────────────────────────────────────────────────────────
// Utilities
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* Utilities
*
*
*/
/**
* Parses error response and creates appropriate error with context information
@@ -714,6 +805,7 @@ export class ChatService {
contextInfo?: { n_prompt_tokens: number; n_ctx: number };
};
fallback.name = 'HttpError';
return fallback;
}
}
@@ -745,18 +837,26 @@ export class ChatService {
// 1) root (some implementations provide `model` at the top level)
const rootModel = getTrimmedString(root.model);
if (rootModel) return rootModel;
if (rootModel) {
return rootModel;
}
// 2) streaming choice (delta) or final response (message)
const firstChoice = Array.isArray(root.choices) ? asRecord(root.choices[0]) : undefined;
if (!firstChoice) return undefined;
if (!firstChoice) {
return undefined;
}
// priority: delta.model (first chunk) else message.model (final response)
const deltaModel = getTrimmedString(asRecord(firstChoice.delta)?.model);
if (deltaModel) return deltaModel;
if (deltaModel) {
return deltaModel;
}
const messageModel = getTrimmedString(asRecord(firstChoice.message)?.model);
if (messageModel) return messageModel;
if (messageModel) {
return messageModel;
}
// avoid guessing from non-standard locations (metadata, etc.)
return undefined;
+211 -2
View File
@@ -1,5 +1,214 @@
export { ChatService } from './chat';
/**
*
* SERVICES
*
* Stateless service layer for API communication and data operations.
* Services handle protocol-level concerns (HTTP, WebSocket, MCP, IndexedDB)
* without managing reactive state that responsibility belongs to stores.
*
* **Design Principles:**
* - All methods are static no instance state
* - Pure I/O operations (network requests, database queries)
* - No Svelte runes or reactive primitives
* - Error handling at the protocol level; business-level error handling in stores
*
* **Architecture (bottom to top):**
* - **Services** (this layer): Stateless protocol communication
* - **Stores**: Reactive state management consuming services
* - **Components**: UI consuming stores
*
*/
/**
* **ChatService** - Chat Completions API communication layer
*
* Handles direct communication with the llama-server's `/v1/chat/completions` endpoint.
* Provides streaming and non-streaming response parsing, message format conversion
* (DatabaseMessage API format), and request lifecycle management.
*
* **Terminology - Chat vs Conversation:**
* - **Chat**: The active interaction space with the Chat Completions API. Ephemeral and
* runtime-focused sending messages, receiving streaming responses, managing request lifecycles.
* - **Conversation**: The persistent database entity storing all messages and metadata.
* Managed by conversationsStore, conversations persist across sessions.
*
* **Architecture & Relationships:**
* - **ChatService** (this class): Stateless API communication layer
* - Handles HTTP requests/responses with the llama-server
* - Manages streaming and non-streaming response parsing
* - Converts database messages to API format (multimodal, tool calls)
* - Handles error translation with user-friendly messages
*
* - **chatStore**: Primary consumer uses ChatService for all AI model communication
* - **agenticStore**: Uses ChatService for multi-turn agentic loop streaming
* - **conversationsStore**: Provides message context for API requests
*
* **Key Responsibilities:**
* - Streaming response handling with real-time content/reasoning/tool-call callbacks
* - Non-streaming response parsing with complete response extraction
* - Database message to API format conversion (attachments, tool calls, multimodal)
* - Tool call delta merging for incremental streaming aggregation
* - Request parameter assembly (sampling, penalties, custom params)
* - File attachment processing (images, PDFs, audio, text, MCP prompts/resources)
* - Reasoning content stripping from prompt history to avoid KV cache pollution
* - Error translation (network, timeout, server errors user-friendly messages)
*
* @see chatStore in stores/chat.svelte.ts primary consumer for chat state management
* @see agenticStore in stores/agentic.svelte.ts uses ChatService for agentic loop streaming
* @see conversationsStore in stores/conversations.svelte.ts provides message context
*/
export { ChatService } from './chat.service';
/**
* **DatabaseService** - IndexedDB persistence layer via Dexie ORM
*
* Provides stateless data access for conversations and messages using IndexedDB.
* Handles all low-level storage operations including branching tree structures,
* cascade deletions, and transaction safety for multi-table operations.
*
* **Architecture & Relationships (bottom to top):**
* - **DatabaseService** (this class): Stateless IndexedDB operations
* - Lowest layer direct Dexie/IndexedDB communication
* - Pure CRUD operations without business logic
* - Handles branching tree structure (parent-child relationships)
* - Provides transaction safety for multi-table operations
*
* - **conversationsStore**: Reactive state management layer
* - Uses DatabaseService for all persistence operations
* - Manages conversation list, active conversation, and messages in memory
*
* - **chatStore**: Active AI interaction management
* - Uses conversationsStore for conversation context
* - Directly uses DatabaseService for message CRUD during streaming
*
* **Key Responsibilities:**
* - Conversation CRUD (create, read, update, delete)
* - Message CRUD with branching support (parent-child relationships)
* - Root message and system prompt creation
* - Cascade deletion of message branches (descendants)
* - Transaction-safe multi-table operations
* - Conversation import with duplicate detection
*
* **Database Schema:**
* - `conversations`: id, lastModified, currNode, name
* - `messages`: id, convId, type, role, timestamp, parent, children
*
* **Branching Model:**
* Messages form a tree structure where each message can have multiple children,
* enabling conversation branching and alternative response paths. The conversation's
* `currNode` tracks the currently active branch endpoint.
*
* @see conversationsStore in stores/conversations.svelte.ts reactive layer on top of DatabaseService
* @see chatStore in stores/chat.svelte.ts uses DatabaseService directly for message CRUD during streaming
*/
export { DatabaseService } from './database.service';
/**
* **ModelsService** - Model management API communication
*
* Handles communication with model-related endpoints for both MODEL (single model)
* and ROUTER (multi-model) server modes. Provides model listing, loading/unloading,
* and status checking without managing any model state.
*
* **Architecture & Relationships:**
* - **ModelsService** (this class): Stateless HTTP communication
* - Sends requests to model endpoints
* - Parses and returns typed API responses
* - Provides model status utility methods
*
* - **modelsStore**: Primary consumer manages reactive model state
* - Calls ModelsService for all model API operations
* - Handles polling, caching, and state updates
*
* **Key Responsibilities:**
* - List available models via OpenAI-compatible `/v1/models` endpoint
* - Load/unload models via `/models/load` and `/models/unload` (ROUTER mode)
* - Model status queries (loaded, loading)
*
* **Server Mode Behavior:**
* - **MODEL mode**: Only `list()` is relevant single model always loaded
* - **ROUTER mode**: Full lifecycle `list()`, `listRouter()`, `load()`, `unload()`
*
* **Endpoints:**
* - `GET /v1/models` OpenAI-compatible model list (both modes)
* - `POST /models/load` Load a model (ROUTER mode only)
* - `POST /models/unload` Unload a model (ROUTER mode only)
*
* @see modelsStore in stores/models.svelte.ts primary consumer for reactive model state
*/
export { ModelsService } from './models.service';
/**
* **PropsService** - Server properties and capabilities retrieval
*
* Fetches server configuration, model information, and capabilities from the `/props`
* endpoint. Supports both global server props and per-model props (ROUTER mode).
*
* **Architecture & Relationships:**
* - **PropsService** (this class): Stateless HTTP communication
* - Fetches server properties from `/props` endpoint
* - Handles authentication and request parameters
* - Returns typed `ApiLlamaCppServerProps` responses
*
* - **serverStore**: Consumes global server properties (role detection, connection state)
* - **modelsStore**: Consumes per-model properties (modalities, context size)
* - **settingsStore**: Syncs default generation parameters from props response
*
* **Key Responsibilities:**
* - Fetch global server properties (default generation settings, modalities)
* - Fetch per-model properties in ROUTER mode via `?model=<id>` parameter
* - Handle autoload control to prevent unintended model loading
*
* **API Behavior:**
* - `GET /props` Global server props (MODEL mode: includes modalities)
* - `GET /props?model=<id>` Per-model props (ROUTER mode: model-specific modalities)
* - `&autoload=false` Prevents model auto-loading when querying props
*
* @see serverStore in stores/server.svelte.ts consumes global server props
* @see modelsStore in stores/models.svelte.ts consumes per-model props for modalities
* @see settingsStore in stores/settings.svelte.ts syncs default generation params from props
*/
export { PropsService } from './props.service';
export { ParameterSyncService, SYNCABLE_PARAMETERS } from './parameter-sync.service';
/**
* **ParameterSyncService** - Server defaults and user settings synchronization
*
* Manages the complex logic of merging server-provided default parameters with
* user-configured overrides. Ensures the UI reflects the actual server state
* while preserving user customizations. Tracks parameter sources (server default
* vs user override) for display in the settings UI.
*
* **Architecture & Relationships:**
* - **ParameterSyncService** (this class): Stateless sync logic
* - Pure functions for parameter extraction, merging, and diffing
* - No side effects receives data in, returns data out
* - Handles floating-point precision normalization
*
* - **settingsStore**: Primary consumer calls sync methods during:
* - Initial load (`syncWithServerDefaults`)
* - Settings reset (`forceSyncWithServerDefaults`)
* - Parameter info queries (`getParameterInfo`)
*
* - **PropsService**: Provides raw server props that feed into extraction
*
* **Key Responsibilities:**
* - Extract syncable parameters from server `/props` response
* - Merge server defaults with user overrides (user wins)
* - Track parameter source (Custom vs Default) for UI badges
* - Validate server parameter values by type (number, string, boolean)
* - Create diffs between current settings and server defaults
* - Floating-point precision normalization for consistent comparisons
*
* **Parameter Source Priority:**
* 1. **User Override** (Custom badge) explicitly set by user in settings
* 2. **Server Default** (Default badge) from `/props` endpoint
* 3. **App Default** hardcoded fallback when server props unavailable
*
* **Exports:**
* - `ParameterSyncService` class static methods for sync logic
* - `SYNCABLE_PARAMETERS` mapping of webui setting keys to server parameter keys
*
* @see settingsStore in stores/settings.svelte.ts primary consumer for settings sync
* @see ChatSettingsParameterSourceIndicator displays parameter source badges in UI
*/
export { ParameterSyncService } from './parameter-sync.service';
File diff suppressed because it is too large Load Diff
@@ -1,54 +1,38 @@
import { browser } from '$app/environment';
/**
* conversationsStore - Reactive State Store for Conversations
*
* Manages conversation lifecycle, persistence, navigation.
*
* **Architecture & Relationships:**
* - **DatabaseService**: Stateless IndexedDB layer
* - **conversationsStore** (this): Reactive state + business logic
* - **chatStore**: Chat-specific state (streaming, loading)
*
* **Key Responsibilities:**
* - Conversation CRUD (create, load, delete)
* - Message management and tree navigation
* - Import/Export functionality
* - Title management with confirmation
*
* @see DatabaseService in services/database.ts for IndexedDB operations
*/
import { goto } from '$app/navigation';
import { browser } from '$app/environment';
import { toast } from 'svelte-sonner';
import { DatabaseService } from '$lib/services/database.service';
import { config } from '$lib/stores/settings.svelte';
import { filterByLeafNodeId, findLeafNode } from '$lib/utils';
import { AttachmentType } from '$lib/enums';
import { MessageRole } from '$lib/enums';
/**
* conversationsStore - Persistent conversation data and lifecycle management
*
* **Terminology - Chat vs Conversation:**
* - **Chat**: The active interaction space with the Chat Completions API. Represents the
* real-time streaming session, loading states, and UI visualization of AI communication.
* Managed by chatStore, a "chat" is ephemeral and exists during active AI interactions.
* - **Conversation**: The persistent database entity storing all messages and metadata.
* A "conversation" survives across sessions, page reloads, and browser restarts.
* It contains the complete message history, branching structure, and conversation metadata.
*
* This store manages all conversation-level data and operations including creation, loading,
* deletion, and navigation. It maintains the list of conversations and the currently active
* conversation with its message history, providing reactive state for UI components.
*
* **Architecture & Relationships:**
* - **conversationsStore** (this class): Persistent conversation data management
* - Manages conversation list and active conversation state
* - Handles conversation CRUD operations via DatabaseService
* - Maintains active message array for current conversation
* - Coordinates branching navigation (currNode tracking)
*
* - **chatStore**: Uses conversation data as context for active AI streaming
* - **DatabaseService**: Low-level IndexedDB storage for conversations and messages
*
* **Key Features:**
* - **Conversation Lifecycle**: Create, load, update, delete conversations
* - **Message Management**: Active message array with branching support
* - **Import/Export**: JSON-based conversation backup and restore
* - **Branch Navigation**: Navigate between message tree branches
* - **Title Management**: Auto-update titles with confirmation dialogs
* - **Reactive State**: Svelte 5 runes for automatic UI updates
*
* **State Properties:**
* - `conversations`: All conversations sorted by last modified
* - `activeConversation`: Currently viewed conversation
* - `activeMessages`: Messages in current conversation path
* - `isInitialized`: Store initialization status
*/
class ConversationsStore {
// ─────────────────────────────────────────────────────────────────────────────
// State
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* State
*
*
*/
/** List of all conversations */
conversations = $state<DatabaseConversation[]>([]);
@@ -65,102 +49,110 @@ class ConversationsStore {
/** Callback for title update confirmation dialog */
titleUpdateConfirmationCallback?: (currentTitle: string, newTitle: string) => Promise<boolean>;
// ─────────────────────────────────────────────────────────────────────────────
// Modalities
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* Lifecycle
*
*
*/
/**
* Modalities used in the active conversation.
* Computed from attachments in activeMessages.
* Used to filter available models - models must support all used modalities.
* Initialize the store by loading conversations from database.
* Must be called once after app startup.
*/
usedModalities: ModelModalities = $derived.by(() => {
return this.calculateModalitiesFromMessages(this.activeMessages);
});
async init(): Promise<void> {
if (!browser) return;
if (this.isInitialized) return;
/**
* Calculate modalities from a list of messages.
* Helper method used by both usedModalities and getModalitiesUpToMessage.
*/
private calculateModalitiesFromMessages(messages: DatabaseMessage[]): ModelModalities {
const modalities: ModelModalities = { vision: false, audio: false };
for (const message of messages) {
if (!message.extra) continue;
for (const extra of message.extra) {
if (extra.type === AttachmentType.IMAGE) {
modalities.vision = true;
}
// PDF only requires vision if processed as images
if (extra.type === AttachmentType.PDF) {
const pdfExtra = extra as DatabaseMessageExtraPdfFile;
if (pdfExtra.processedAsImages) {
modalities.vision = true;
}
}
if (extra.type === AttachmentType.AUDIO) {
modalities.audio = true;
}
}
if (modalities.vision && modalities.audio) break;
}
return modalities;
}
/**
* Get modalities used in messages BEFORE the specified message.
* Used for regeneration - only consider context that was available when generating this message.
*/
getModalitiesUpToMessage(messageId: string): ModelModalities {
const messageIndex = this.activeMessages.findIndex((m) => m.id === messageId);
if (messageIndex === -1) {
return this.usedModalities;
}
const messagesBefore = this.activeMessages.slice(0, messageIndex);
return this.calculateModalitiesFromMessages(messagesBefore);
}
constructor() {
if (browser) {
this.initialize();
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Lifecycle
// ─────────────────────────────────────────────────────────────────────────────
/**
* Initializes the conversations store by loading conversations from the database
*/
async initialize(): Promise<void> {
try {
await this.loadConversations();
this.isInitialized = true;
} catch (error) {
console.error('Failed to initialize conversations store:', error);
console.error('Failed to initialize conversations:', error);
}
}
/**
* Alias for init() for backward compatibility.
*/
async initialize(): Promise<void> {
return this.init();
}
/**
*
*
* Message Array Operations
*
*
*/
/**
* Adds a message to the active messages array
*/
addMessageToActive(message: DatabaseMessage): void {
this.activeMessages.push(message);
}
/**
* Updates a message at a specific index in active messages
*/
updateMessageAtIndex(index: number, updates: Partial<DatabaseMessage>): void {
if (index !== -1 && this.activeMessages[index]) {
this.activeMessages[index] = { ...this.activeMessages[index], ...updates };
}
}
/**
* Finds the index of a message in active messages
*/
findMessageIndex(messageId: string): number {
return this.activeMessages.findIndex((m) => m.id === messageId);
}
/**
* Removes messages from active messages starting at an index
*/
sliceActiveMessages(startIndex: number): void {
this.activeMessages = this.activeMessages.slice(0, startIndex);
}
/**
* Removes a message from active messages by index
*/
removeMessageAtIndex(index: number): DatabaseMessage | undefined {
if (index !== -1) {
return this.activeMessages.splice(index, 1)[0];
}
return undefined;
}
/**
* Sets the callback function for title update confirmations
*/
setTitleUpdateConfirmationCallback(
callback: (currentTitle: string, newTitle: string) => Promise<boolean>
): void {
this.titleUpdateConfirmationCallback = callback;
}
/**
*
*
* Conversation CRUD
*
*
*/
/**
* Loads all conversations from the database
*/
async loadConversations(): Promise<void> {
this.conversations = await DatabaseService.getAllConversations();
const conversations = await DatabaseService.getAllConversations();
this.conversations = conversations;
}
// ─────────────────────────────────────────────────────────────────────────────
// Conversation CRUD
// ─────────────────────────────────────────────────────────────────────────────
/**
* Creates a new conversation and navigates to it
* @param name - Optional name for the conversation
@@ -170,7 +162,7 @@ class ConversationsStore {
const conversationName = name || `Chat ${new Date().toLocaleString()}`;
const conversation = await DatabaseService.createConversation(conversationName);
this.conversations.unshift(conversation);
this.conversations = [conversation, ...this.conversations];
this.activeConversation = conversation;
this.activeMessages = [];
@@ -196,13 +188,15 @@ class ConversationsStore {
if (conversation.currNode) {
const allMessages = await DatabaseService.getConversationMessages(convId);
this.activeMessages = filterByLeafNodeId(
const filteredMessages = filterByLeafNodeId(
allMessages,
conversation.currNode,
false
) as DatabaseMessage[];
this.activeMessages = filteredMessages;
} else {
this.activeMessages = await DatabaseService.getConversationMessages(convId);
const messages = await DatabaseService.getConversationMessages(convId);
this.activeMessages = messages;
}
return true;
@@ -213,169 +207,11 @@ class ConversationsStore {
}
/**
* Clears the active conversation and messages
* Used when navigating away from chat or starting fresh
* Clears the active conversation and messages.
*/
clearActiveConversation(): void {
this.activeConversation = null;
this.activeMessages = [];
// Active processing conversation is now managed by chatStore
}
// ─────────────────────────────────────────────────────────────────────────────
// Message Management
// ─────────────────────────────────────────────────────────────────────────────
/**
* Refreshes active messages based on currNode after branch navigation
*/
async refreshActiveMessages(): Promise<void> {
if (!this.activeConversation) return;
const allMessages = await DatabaseService.getConversationMessages(this.activeConversation.id);
if (allMessages.length === 0) {
this.activeMessages = [];
return;
}
const leafNodeId =
this.activeConversation.currNode ||
allMessages.reduce((latest: DatabaseMessage, msg: DatabaseMessage) =>
msg.timestamp > latest.timestamp ? msg : latest
).id;
const currentPath = filterByLeafNodeId(allMessages, leafNodeId, false) as DatabaseMessage[];
this.activeMessages.length = 0;
this.activeMessages.push(...currentPath);
}
/**
* Updates the name of a conversation
* @param convId - The conversation ID to update
* @param name - The new name for the conversation
*/
async updateConversationName(convId: string, name: string): Promise<void> {
try {
await DatabaseService.updateConversation(convId, { name });
const convIndex = this.conversations.findIndex((c) => c.id === convId);
if (convIndex !== -1) {
this.conversations[convIndex].name = name;
}
if (this.activeConversation?.id === convId) {
this.activeConversation.name = name;
}
} catch (error) {
console.error('Failed to update conversation name:', error);
}
}
/**
* Updates conversation title with optional confirmation dialog based on settings
* @param convId - The conversation ID to update
* @param newTitle - The new title content
* @param onConfirmationNeeded - Callback when user confirmation is needed
* @returns True if title was updated, false if cancelled
*/
async updateConversationTitleWithConfirmation(
convId: string,
newTitle: string,
onConfirmationNeeded?: (currentTitle: string, newTitle: string) => Promise<boolean>
): Promise<boolean> {
try {
const currentConfig = config();
if (currentConfig.askForTitleConfirmation && onConfirmationNeeded) {
const conversation = await DatabaseService.getConversation(convId);
if (!conversation) return false;
const shouldUpdate = await onConfirmationNeeded(conversation.name, newTitle);
if (!shouldUpdate) return false;
}
await this.updateConversationName(convId, newTitle);
return true;
} catch (error) {
console.error('Failed to update conversation title with confirmation:', error);
return false;
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Navigation
// ─────────────────────────────────────────────────────────────────────────────
/**
* Updates the current node of the active conversation
* @param nodeId - The new current node ID
*/
async updateCurrentNode(nodeId: string): Promise<void> {
if (!this.activeConversation) return;
await DatabaseService.updateCurrentNode(this.activeConversation.id, nodeId);
this.activeConversation.currNode = nodeId;
}
/**
* Updates conversation lastModified timestamp and moves it to top of list
*/
updateConversationTimestamp(): void {
if (!this.activeConversation) return;
const chatIndex = this.conversations.findIndex((c) => c.id === this.activeConversation!.id);
if (chatIndex !== -1) {
this.conversations[chatIndex].lastModified = Date.now();
const updatedConv = this.conversations.splice(chatIndex, 1)[0];
this.conversations.unshift(updatedConv);
}
}
/**
* Navigates to a specific sibling branch by updating currNode and refreshing messages
* @param siblingId - The sibling message ID to navigate to
*/
async navigateToSibling(siblingId: string): Promise<void> {
if (!this.activeConversation) return;
const allMessages = await DatabaseService.getConversationMessages(this.activeConversation.id);
const rootMessage = allMessages.find(
(m: DatabaseMessage) => m.type === 'root' && m.parent === null
);
const currentFirstUserMessage = this.activeMessages.find(
(m: DatabaseMessage) => m.role === 'user' && m.parent === rootMessage?.id
);
const currentLeafNodeId = findLeafNode(allMessages, siblingId);
await DatabaseService.updateCurrentNode(this.activeConversation.id, currentLeafNodeId);
this.activeConversation.currNode = currentLeafNodeId;
await this.refreshActiveMessages();
// Only show title dialog if we're navigating between different first user message siblings
if (rootMessage && this.activeMessages.length > 0) {
const newFirstUserMessage = this.activeMessages.find(
(m: DatabaseMessage) => m.role === 'user' && m.parent === rootMessage.id
);
if (
newFirstUserMessage &&
newFirstUserMessage.content.trim() &&
(!currentFirstUserMessage ||
newFirstUserMessage.id !== currentFirstUserMessage.id ||
newFirstUserMessage.content.trim() !== currentFirstUserMessage.content.trim())
) {
await this.updateConversationTitleWithConfirmation(
this.activeConversation.id,
newFirstUserMessage.content.trim(),
this.titleUpdateConfirmationCallback
);
}
}
}
/**
@@ -420,12 +256,192 @@ class ConversationsStore {
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Import/Export
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* Message Management
*
*
*/
/**
* Downloads a conversation as JSON file
* Refreshes active messages based on currNode after branch navigation.
*/
async refreshActiveMessages(): Promise<void> {
if (!this.activeConversation) return;
const allMessages = await DatabaseService.getConversationMessages(this.activeConversation.id);
if (allMessages.length === 0) {
this.activeMessages = [];
return;
}
const leafNodeId =
this.activeConversation.currNode ||
allMessages.reduce((latest, msg) => (msg.timestamp > latest.timestamp ? msg : latest)).id;
const currentPath = filterByLeafNodeId(allMessages, leafNodeId, false) as DatabaseMessage[];
this.activeMessages = currentPath;
}
/**
* Gets all messages for a specific conversation
* @param convId - The conversation ID
* @returns Array of messages
*/
async getConversationMessages(convId: string): Promise<DatabaseMessage[]> {
return await DatabaseService.getConversationMessages(convId);
}
/**
*
*
* Title Management
*
*
*/
/**
* Updates the name of a conversation.
* @param convId - The conversation ID to update
* @param name - The new name for the conversation
*/
async updateConversationName(convId: string, name: string): Promise<void> {
try {
await DatabaseService.updateConversation(convId, { name });
const convIndex = this.conversations.findIndex((c) => c.id === convId);
if (convIndex !== -1) {
this.conversations[convIndex].name = name;
this.conversations = [...this.conversations];
}
if (this.activeConversation?.id === convId) {
this.activeConversation = { ...this.activeConversation, name };
}
} catch (error) {
console.error('Failed to update conversation name:', error);
}
}
/**
* Updates conversation title with optional confirmation dialog based on settings
* @param convId - The conversation ID to update
* @param newTitle - The new title content
* @returns True if title was updated, false if cancelled
*/
async updateConversationTitleWithConfirmation(
convId: string,
newTitle: string
): Promise<boolean> {
try {
const currentConfig = config();
if (currentConfig.askForTitleConfirmation && this.titleUpdateConfirmationCallback) {
const conversation = await DatabaseService.getConversation(convId);
if (!conversation) return false;
const shouldUpdate = await this.titleUpdateConfirmationCallback(
conversation.name,
newTitle
);
if (!shouldUpdate) return false;
}
await this.updateConversationName(convId, newTitle);
return true;
} catch (error) {
console.error('Failed to update conversation title with confirmation:', error);
return false;
}
}
/**
* Updates conversation lastModified timestamp and moves it to top of list
*/
updateConversationTimestamp(): void {
if (!this.activeConversation) return;
const chatIndex = this.conversations.findIndex((c) => c.id === this.activeConversation!.id);
if (chatIndex !== -1) {
this.conversations[chatIndex].lastModified = Date.now();
const updatedConv = this.conversations.splice(chatIndex, 1)[0];
this.conversations = [updatedConv, ...this.conversations];
}
}
/**
* Updates the current node of the active conversation
* @param nodeId - The new current node ID
*/
async updateCurrentNode(nodeId: string): Promise<void> {
if (!this.activeConversation) return;
await DatabaseService.updateCurrentNode(this.activeConversation.id, nodeId);
this.activeConversation = { ...this.activeConversation, currNode: nodeId };
}
/**
*
*
* Branch Navigation
*
*
*/
/**
* Navigates to a specific sibling branch by updating currNode and refreshing messages.
* @param siblingId - The sibling message ID to navigate to
*/
async navigateToSibling(siblingId: string): Promise<void> {
if (!this.activeConversation) return;
const allMessages = await DatabaseService.getConversationMessages(this.activeConversation.id);
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
const currentFirstUserMessage = this.activeMessages.find(
(m) => m.role === MessageRole.USER && m.parent === rootMessage?.id
);
const currentLeafNodeId = findLeafNode(allMessages, siblingId);
await DatabaseService.updateCurrentNode(this.activeConversation.id, currentLeafNodeId);
this.activeConversation = { ...this.activeConversation, currNode: currentLeafNodeId };
await this.refreshActiveMessages();
if (rootMessage && this.activeMessages.length > 0) {
const newFirstUserMessage = this.activeMessages.find(
(m) => m.role === MessageRole.USER && m.parent === rootMessage.id
);
if (
newFirstUserMessage &&
newFirstUserMessage.content.trim() &&
(!currentFirstUserMessage ||
newFirstUserMessage.id !== currentFirstUserMessage.id ||
newFirstUserMessage.content.trim() !== currentFirstUserMessage.content.trim())
) {
await this.updateConversationTitleWithConfirmation(
this.activeConversation.id,
newFirstUserMessage.content.trim()
);
}
}
}
/**
*
*
* Import & Export
*
*
*/
/**
* Downloads a conversation as JSON file.
* @param convId - The conversation ID to download
*/
async downloadConversation(convId: string): Promise<void> {
@@ -456,7 +472,7 @@ class ConversationsStore {
}
const allData = await Promise.all(
allConversations.map(async (conv: DatabaseConversation) => {
allConversations.map(async (conv) => {
const messages = await DatabaseService.getConversationMessages(conv.id);
return { conv, messages };
})
@@ -536,15 +552,6 @@ class ConversationsStore {
});
}
/**
* Gets all messages for a specific conversation
* @param convId - The conversation ID
* @returns Array of messages
*/
async getConversationMessages(convId: string): Promise<DatabaseMessage[]> {
return await DatabaseService.getConversationMessages(convId);
}
/**
* Imports conversations from provided data (without file picker)
* @param data - Array of conversation data with messages
@@ -558,61 +565,8 @@ class ConversationsStore {
return result;
}
/**
* Adds a message to the active messages array
* Used by chatStore when creating new messages
* @param message - The message to add
*/
addMessageToActive(message: DatabaseMessage): void {
this.activeMessages.push(message);
}
/**
* Updates a message at a specific index in active messages
* Creates a new object to trigger Svelte 5 reactivity
* @param index - The index of the message to update
* @param updates - Partial message data to update
*/
updateMessageAtIndex(index: number, updates: Partial<DatabaseMessage>): void {
if (index !== -1 && this.activeMessages[index]) {
// Create new object to trigger Svelte 5 reactivity
this.activeMessages[index] = { ...this.activeMessages[index], ...updates };
}
}
/**
* Finds the index of a message in active messages
* @param messageId - The message ID to find
* @returns The index of the message, or -1 if not found
*/
findMessageIndex(messageId: string): number {
return this.activeMessages.findIndex((m) => m.id === messageId);
}
/**
* Removes messages from active messages starting at an index
* @param startIndex - The index to start removing from
*/
sliceActiveMessages(startIndex: number): void {
this.activeMessages = this.activeMessages.slice(0, startIndex);
}
/**
* Removes a message from active messages by index
* @param index - The index to remove
* @returns The removed message or undefined
*/
removeMessageAtIndex(index: number): DatabaseMessage | undefined {
if (index !== -1) {
return this.activeMessages.splice(index, 1)[0];
}
return undefined;
}
/**
* Triggers file download in browser
* @param data - The data to download
* @param filename - Optional filename for the download
*/
private triggerDownload(data: ExportedConversations, filename?: string): void {
const conversation =
@@ -641,26 +595,16 @@ class ConversationsStore {
document.body.removeChild(a);
URL.revokeObjectURL(url);
}
// ─────────────────────────────────────────────────────────────────────────────
// Utilities
// ─────────────────────────────────────────────────────────────────────────────
/**
* Sets the callback function for title update confirmations
* @param callback - Function to call when confirmation is needed
*/
setTitleUpdateConfirmationCallback(
callback: (currentTitle: string, newTitle: string) => Promise<boolean>
): void {
this.titleUpdateConfirmationCallback = callback;
}
}
export const conversationsStore = new ConversationsStore();
// Auto-initialize in browser
if (browser) {
conversationsStore.init();
}
export const conversations = () => conversationsStore.conversations;
export const activeConversation = () => conversationsStore.activeConversation;
export const activeMessages = () => conversationsStore.activeMessages;
export const isConversationsInitialized = () => conversationsStore.isInitialized;
export const usedModalities = () => conversationsStore.usedModalities;
@@ -1,8 +1,9 @@
import { SvelteSet } from 'svelte/reactivity';
import { ModelsService } from '$lib/services/models.service';
import { PropsService } from '$lib/services/props.service';
import { ServerModelStatus, ModelModality } from '$lib/enums';
import { ModelsService, PropsService } from '$lib/services';
import { serverStore } from '$lib/stores/server.svelte';
import { TTLCache } from '$lib/utils';
import { MODEL_PROPS_CACHE_TTL_MS, MODEL_PROPS_CACHE_MAX_ENTRIES } from '$lib/constants/cache';
/**
* modelsStore - Reactive store for model management in both MODEL and ROUTER modes
@@ -32,9 +33,13 @@ import { serverStore } from '$lib/stores/server.svelte';
* - **Lazy loading**: ensureModelLoaded() loads models on demand
*/
class ModelsStore {
// ─────────────────────────────────────────────────────────────────────────────
// State
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* State
*
*
*/
models = $state<ModelOption[]>([]);
routerModels = $state<ApiModelDataEntry[]>([]);
@@ -48,10 +53,14 @@ class ModelsStore {
private modelLoadingStates = $state<Map<string, boolean>>(new Map());
/**
* Model-specific props cache
* Model-specific props cache with TTL
* Key: modelId, Value: props data including modalities
* TTL: 10 minutes - props don't change frequently
*/
private modelPropsCache = $state<Map<string, ApiLlamaCppServerProps>>(new Map());
private modelPropsCache = new TTLCache<string, ApiLlamaCppServerProps>({
ttlMs: MODEL_PROPS_CACHE_TTL_MS,
maxEntries: MODEL_PROPS_CACHE_MAX_ENTRIES
});
private modelPropsFetching = $state<Set<string>>(new Set());
/**
@@ -59,9 +68,13 @@ class ModelsStore {
*/
propsCacheVersion = $state(0);
// ─────────────────────────────────────────────────────────────────────────────
// Computed Getters
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* Computed Getters
*
*
*/
get selectedModel(): ModelOption | null {
if (!this.selectedModelId) return null;
@@ -95,22 +108,24 @@ class ModelsStore {
return props.model_path.split(/(\\|\/)/).pop() || null;
}
// ─────────────────────────────────────────────────────────────────────────────
// Modalities
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* Modalities
*
*
*/
/**
* Get modalities for a specific model
* Returns cached modalities from model props
*/
getModelModalities(modelId: string): ModelModalities | null {
// First check if modalities are stored in the model option
const model = this.models.find((m) => m.model === modelId || m.id === modelId);
if (model?.modalities) {
return model.modalities;
}
// Fall back to props cache
const props = this.modelPropsCache.get(modelId);
if (props?.modalities) {
return {
@@ -155,15 +170,17 @@ class ModelsStore {
* Get props for a specific model (from cache)
*/
getModelProps(modelId: string): ApiLlamaCppServerProps | null {
return this.modelPropsCache.get(modelId) ?? null;
return this.modelPropsCache.get(modelId);
}
/**
* Get context size (n_ctx) for a specific model from cached props
*/
getModelContextSize(modelId: string): number | null {
const props = this.modelPropsCache.get(modelId);
return props?.default_generation_settings?.n_ctx ?? null;
const props = this.getModelProps(modelId);
const nCtx = props?.default_generation_settings?.n_ctx;
return typeof nCtx === 'number' ? nCtx : null;
}
/**
@@ -181,9 +198,13 @@ class ModelsStore {
return this.modelPropsFetching.has(modelId);
}
// ─────────────────────────────────────────────────────────────────────────────
// Status Queries
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* Status Queries
*
*
*/
isModelLoaded(modelId: string): boolean {
const model = this.routerModels.find((m) => m.id === modelId);
@@ -208,9 +229,13 @@ class ModelsStore {
return usage !== undefined && usage.size > 0;
}
// ─────────────────────────────────────────────────────────────────────────────
// Data Fetching
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* Data Fetching
*
*
*/
/**
* Fetch list of models from server and detect server role
@@ -224,7 +249,6 @@ class ModelsStore {
this.error = null;
try {
// Ensure server props are loaded (for role detection and MODEL mode modalities)
if (!serverStore.props) {
await serverStore.fetch();
}
@@ -251,7 +275,6 @@ class ModelsStore {
this.models = models;
// In MODEL mode, populate modalities from serverStore.props (single model)
// WORKAROUND: In MODEL mode, /props returns modalities for the single model,
// but /v1/models doesn't include modalities. We bridge this gap here.
const serverProps = serverStore.props;
@@ -260,9 +283,7 @@ class ModelsStore {
vision: serverProps.modalities.vision ?? false,
audio: serverProps.modalities.audio ?? false
};
// Cache props for the single model
this.modelPropsCache.set(this.models[0].model, serverProps);
// Update model with modalities
this.models = this.models.map((model, index) =>
index === 0 ? { ...model, modalities } : model
);
@@ -302,7 +323,6 @@ class ModelsStore {
* @returns Props data or null if fetch failed or model not loaded
*/
async fetchModelProps(modelId: string): Promise<ApiLlamaCppServerProps | null> {
// Return cached props if available
const cached = this.modelPropsCache.get(modelId);
if (cached) return cached;
@@ -310,7 +330,6 @@ class ModelsStore {
return null;
}
// Avoid duplicate fetches
if (this.modelPropsFetching.has(modelId)) return null;
this.modelPropsFetching.add(modelId);
@@ -335,7 +354,6 @@ class ModelsStore {
const loadedModelIds = this.loadedModelIds;
if (loadedModelIds.length === 0) return;
// Fetch props for each loaded model in parallel
const propsPromises = loadedModelIds.map((modelId) => this.fetchModelProps(modelId));
try {
@@ -357,7 +375,6 @@ class ModelsStore {
return { ...model, modalities };
});
// Increment version to trigger reactivity
this.propsCacheVersion++;
} catch (error) {
console.warn('Failed to fetch modalities for loaded models:', error);
@@ -382,16 +399,19 @@ class ModelsStore {
model.model === modelId ? { ...model, modalities } : model
);
// Increment version to trigger reactivity
this.propsCacheVersion++;
} catch (error) {
console.warn(`Failed to update modalities for model ${modelId}:`, error);
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Model Selection
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* Model Selection
*
*
*/
/**
* Select a model for new conversations
@@ -443,9 +463,13 @@ class ModelsStore {
return this.models.some((model) => model.model === modelName);
}
// ─────────────────────────────────────────────────────────────────────────────
// Loading/Unloading Models
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* Loading/Unloading Models
*
*
*/
/**
* WORKAROUND: Polling for model status after load/unload operations.
@@ -486,7 +510,6 @@ class ModelsStore {
return;
}
// Wait before next poll
await new Promise((resolve) => setTimeout(resolve, ModelsStore.STATUS_POLL_INTERVAL));
}
@@ -511,8 +534,6 @@ class ModelsStore {
try {
await ModelsService.load(modelId);
// Poll until model is loaded
await this.pollForModelStatus(modelId, ServerModelStatus.LOADED);
await this.updateModelModalities(modelId);
@@ -562,9 +583,13 @@ class ModelsStore {
await this.loadModel(modelId);
}
// ─────────────────────────────────────────────────────────────────────────────
// Utilities
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* Utilities
*
*
*/
private toDisplayName(id: string): string {
const segments = id.split(/\\|\//);
@@ -586,6 +611,14 @@ class ModelsStore {
this.modelPropsCache.clear();
this.modelPropsFetching.clear();
}
/**
* Prune expired entries from caches.
* Call periodically for proactive memory cleanup.
*/
pruneExpiredCache(): number {
return this.modelPropsCache.prune();
}
}
export const modelsStore = new ModelsStore();
@@ -18,9 +18,13 @@ import { ServerRole } from '$lib/enums';
* - **Default Params**: Server-wide generation defaults
*/
class ServerStore {
// ─────────────────────────────────────────────────────────────────────────────
// State
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* State
*
*
*/
props = $state<ApiLlamaCppServerProps | null>(null);
loading = $state(false);
@@ -28,16 +32,22 @@ class ServerStore {
role = $state<ServerRole | null>(null);
private fetchPromise: Promise<void> | null = null;
// ─────────────────────────────────────────────────────────────────────────────
// Getters
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* Getters
*
*
*/
get defaultParams(): ApiLlamaCppServerProps['default_generation_settings']['params'] | null {
return this.props?.default_generation_settings?.params || null;
}
get contextSize(): number | null {
return this.props?.default_generation_settings?.n_ctx ?? null;
const nCtx = this.props?.default_generation_settings?.n_ctx;
return typeof nCtx === 'number' ? nCtx : null;
}
get webuiSettings(): Record<string, string | number | boolean> | undefined {
@@ -52,9 +62,13 @@ class ServerStore {
return this.role === ServerRole.MODEL;
}
// ─────────────────────────────────────────────────────────────────────────────
// Data Handling
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* Data Handling
*
*
*/
async fetch(): Promise<void> {
if (this.fetchPromise) return this.fetchPromise;
@@ -115,9 +129,13 @@ class ServerStore {
this.fetchPromise = null;
}
// ─────────────────────────────────────────────────────────────────────────────
// Utilities
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* Utilities
*
*
*/
private detectRole(props: ApiLlamaCppServerProps): void {
const newRole = props?.role === ServerRole.ROUTER ? ServerRole.ROUTER : ServerRole.MODEL;
@@ -47,18 +47,26 @@ import {
} from '$lib/constants/localstorage-keys';
class SettingsStore {
// ─────────────────────────────────────────────────────────────────────────────
// State
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* State
*
*
*/
config = $state<SettingsConfigType>({ ...SETTING_CONFIG_DEFAULT });
theme = $state<string>('auto');
isInitialized = $state(false);
userOverrides = $state<Set<string>>(new Set());
// ─────────────────────────────────────────────────────────────────────────────
// Utilities (private helpers)
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* Utilities (private helpers)
*
*
*/
/**
* Helper method to get server defaults with null safety
@@ -76,9 +84,13 @@ class SettingsStore {
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Lifecycle
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* Lifecycle
*
*
*/
/**
* Initialize the settings store by loading from localStorage
@@ -130,9 +142,13 @@ class SettingsStore {
this.theme = localStorage.getItem('theme') || 'auto';
}
// ─────────────────────────────────────────────────────────────────────────────
// Config Updates
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* Config Updates
*
*
*/
/**
* Update a specific configuration setting
@@ -234,9 +250,13 @@ class SettingsStore {
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Reset
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* Reset
*
*
*/
/**
* Reset configuration to defaults
@@ -285,9 +305,13 @@ class SettingsStore {
this.saveConfig();
}
// ─────────────────────────────────────────────────────────────────────────────
// Server Sync
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* Server Sync
*
*
*/
/**
* Initialize settings with props defaults when server properties are first loaded
@@ -349,9 +373,13 @@ class SettingsStore {
this.saveConfig();
}
// ─────────────────────────────────────────────────────────────────────────────
// Utilities
// ─────────────────────────────────────────────────────────────────────────────
/**
*
*
* Utilities
*
*
*/
/**
* Get a specific configuration value
+19 -4
View File
@@ -1,8 +1,5 @@
import type { ErrorDialogType } from '$lib/enums';
import type { DatabaseMessage, DatabaseMessageExtra } from './database';
export type ChatMessageType = 'root' | 'text' | 'think' | 'system';
export type ChatRole = 'user' | 'assistant' | 'system';
import type { DatabaseMessageExtra } from './database';
export interface ChatUploadedFile {
id: string;
@@ -61,6 +58,9 @@ export interface ChatMessageTimings {
prompt_n?: number;
}
/**
* Callbacks for streaming chat responses
*/
export interface ChatStreamCallbacks {
onChunk?: (chunk: string) => void;
onReasoningChunk?: (chunk: string) => void;
@@ -77,12 +77,18 @@ export interface ChatStreamCallbacks {
onError?: (error: Error) => void;
}
/**
* Error dialog state for displaying server/timeout errors
*/
export interface ErrorDialogState {
type: ErrorDialogType;
message: string;
contextInfo?: { n_prompt_tokens: number; n_ctx: number };
}
/**
* Live processing stats during prompt evaluation
*/
export interface LiveProcessingStats {
tokensProcessed: number;
totalTokens: number;
@@ -91,17 +97,26 @@ export interface LiveProcessingStats {
etaSecs?: number;
}
/**
* Live generation stats during token generation
*/
export interface LiveGenerationStats {
tokensGenerated: number;
timeMs: number;
tokensPerSecond: number;
}
/**
* Options for getting attachment display items
*/
export interface AttachmentDisplayItemsOptions {
uploadedFiles?: ChatUploadedFile[];
attachments?: DatabaseMessageExtra[];
}
/**
* Result of file processing operation
*/
export interface FileProcessingResult {
extras: DatabaseMessageExtra[];
emptyFiles: string[];
+12 -2
View File
@@ -1,7 +1,12 @@
import type { AttachmentType } from '$lib/enums';
/**
* Common utility types used across the application
*/
/**
* Represents a key-value pair.
* Used for headers, environment variables, query parameters, etc.
*/
export interface KeyValuePair {
key: string;
@@ -9,16 +14,19 @@ export interface KeyValuePair {
}
/**
* Binary detection configuration options.
* Binary detection configuration options
*/
export interface BinaryDetectionOptions {
/** Number of characters to check from the beginning of the file */
prefixLength: number;
/** Maximum ratio of suspicious characters allowed (0.0 to 1.0) */
suspiciousCharThresholdRatio: number;
/** Maximum absolute number of null bytes allowed */
maxAbsoluteNullBytes: number;
}
/**
* Format for text attachments when copied to clipboard.
* Format for text attachments when copied to clipboard
*/
export interface ClipboardTextAttachment {
type: typeof AttachmentType.TEXT;
@@ -33,3 +41,5 @@ export interface ParsedClipboardContent {
message: string;
textAttachments: ClipboardTextAttachment[];
}
export type MimeTypeUnion = MimeTypeAudio | MimeTypeImage | MimeTypeApplication | MimeTypeText;

Some files were not shown because too many files have changed in this diff Show More