Compare commits

...

32 Commits

Author SHA1 Message Date
Aman Gupta 2759ccdb4a CUDA: avoid mul + bias fusion when doing fusion (#16935) 2025-11-04 10:53:48 +08:00
lhez c5023daf60 opencl: support imrope (#16914)
* opencl: support imrope

* opencl: fix whitespace
2025-11-03 11:47:57 -08:00
Aleksander Grygier e7da30b584 fix: Viewing multiple PDF attachments (#16974) 2025-11-03 18:53:26 +01:00
Daniel Bevenius ed8aa63320 model-conversion : pass config to from_pretrained (#16963)
This commit modifies the script `run-org-model.py` to ensure that the
model configuration is explicitly passed to the `from_pretrained` method
when loading the model. It also removes a duplicate configuration
loading which was a mistake.

The motivation for this change is that enables the config object to be
modified and then passed to the model loading function, which can be
useful when testing new models.
2025-11-03 18:01:59 +01:00
Georgi Gerganov 48bd26501b server : add props.model_alias (#16943)
* server : add props.model_alias

* webui : npm run format
2025-11-03 14:38:23 +01:00
theo77186 622cd010ff ggml: CUDA: add head size 72 for flash-attn (#16962) 2025-11-03 14:29:11 +01:00
Xuan-Son Nguyen 070ff4d535 mtmd: add --image-min/max-tokens (#16921) 2025-11-03 11:11:18 +01:00
Xuan-Son Nguyen bf7b0c9725 mtmd: pad mask for qwen2.5vl (#16954)
* mtmd: pad mask for qwen2.5vl

* improve
2025-11-03 10:25:55 +01:00
Jinyang He fcfce040e8 ggml : LoongArch fixes (#16958)
* Fix test-quantize-fns f16 and q4_0 failed when use LSX

* Fix LoongArch set float intrinsic when use LSX/LASX
2025-11-03 08:40:02 +02:00
Olivier Chafik ee3a5a10ad sync: minja (glm 4.6 & minmax m2 templates) (#16949)
* sync: minja

* Sync https://github.com/ochafik/minja/pull/7 (MinMax M2)
2025-11-03 07:33:56 +02:00
shani-f 7e994168b1 SYCL: optimized repeat_back kernel (3× fewer asm instructions, 2× faster)Feature/sycl repeat back opt (#16869)
* SYCL repeat_back v1 — add core op + switch case

* Implement repeat_back SYCL operation and minor fixes

* SYCL: optimize repeat_back kernel

* Remove Hebrew comment from repeat_back.cpp

* Remove comments for code clarity

Removed comments to clean up the code.

* Fix formatting in ggml-sycl.cpp

* Formatted lambda according to legacy style. No logic changes

* Remove blank line in repeat_back.cpp

Remove unnecessary blank line before assigning acc to dst_dd.
2025-11-03 09:35:33 +08:00
Sascha Rogmann bcfa87622a feat(webui): improve LaTeX rendering with currency detection (#16508)
* webui : Revised LaTeX formula recognition

* webui : Further examples containg amounts

* webui : vitest for maskInlineLaTeX

* webui: Moved preprocessLaTeX to lib/utils

* webui: LaTeX in table-cells

* chore: update webui build output (use theirs)

* webui: backslash in LaTeX-preprocessing

* chore: update webui build output

* webui: look-behind backslash-check

* chore: update webui build output

* Apply suggestions from code review

Code maintenance (variable names, code formatting, string handling)

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* webui: Moved constants to lib/constants.

* webui: package woff2 inside base64 data

* webui: LaTeX-line-break in display formula

* chore: update webui build output

* webui: Bugfix (font embedding)

* webui: Bugfix (font embedding)

* webui: vite embeds assets

* webui: don't suppress 404 (fonts)

* refactor: KaTeX integration with SCSS

Moves KaTeX styling to SCSS for better customization and font embedding.

This change includes:
- Adding `sass` as a dev dependency.
- Introducing a custom SCSS file to override KaTeX variables and disable TTF/WOFF fonts, relying solely on WOFF2 for embedding.
- Adjusting the Vite configuration to resolve `katex-fonts` alias and inject SCSS variables.

* fix: LaTeX processing within blockquotes

* webui: update webui build output

---------

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>
2025-11-03 00:41:08 +01:00
Shagun Bera a2054e3a8f test-backend-ops : fix segfault in moe-expert-reduce test in support mode and coverage (#16936)
* tests: fix segfault in moe-expert-reduce test in support mode and --show-coverage

* tests: init gf and filter out fusion tests for support mode

* tests: filter out fusion cases before calling eval_support

* tests: filter out fusion cases from show_test_coverage as well, fix lint
2025-11-03 00:10:30 +01:00
Sigbjørn Skjæret dd52868050 ci : disable failing riscv cross build (#16952) 2025-11-02 23:11:21 +01:00
Zhiyong Wang 6b9a52422b model: add Janus Pro for image understanding (#16906)
* Add support for Janus Pro

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Address reviewer suggestions

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Add JANUS_PRO constant

* Update clip model handling

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>

* Update tools/mtmd/clip.cpp

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>

* Refactor JANUS_PRO handling in clip.cpp

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>

* Update tools/mtmd/clip.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* em whitespace

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2025-11-02 22:08:04 +01:00
Georgi Gerganov 2f966b8ed8 clip : use FA (#16837)
* clip : use FA

* cont : add warning about unsupported ops

* implement "auto" mode for clip flash attn

* clip : print more detailed op support info during warmup

* cont : remove obsolete comment [no ci]

* improve debugging message

* trailing space

* metal : remove stray return

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2025-11-02 21:21:48 +01:00
Georgi Gerganov cd5e3b5754 server : support unified cache across slots (#16736)
* server : support unified context across slots

* cont : fix speculative decoding initialization

* context : fix n_ctx_per_seq computation

* server : purge slots one by one

* tests : add unified cache server tests

* llama : update per-seq context computation

* test-thread-safety : handle tiny training context of the input model

* server : fix server_tokens clear()

* server : use 4 slots + unified KV by default

* llama : add note about context size queries

* cont : update todos [no ci]

* context : do not cap the size of the context

* tests : adjust parameters to be CI friendlier

* context : add warning
2025-11-02 18:14:04 +02:00
Aldehir Rojas 87c9efc3b2 common : move gpt-oss reasoning processing to init params (#16937) 2025-11-02 16:56:28 +02:00
Adrian Lundberg 76af40aaaa docs: remove llama_sampler_accept reference in sampling sample usage (#16920)
commit 5fb5e24811 (llama : minor
sampling refactor (2) (#9386)) moved the llama_sampler_accept call
into llama_sampler_sample, but the sampling sample usage in llama.h
was forgotten to be updated accordingly.
2025-11-02 11:28:37 +02:00
mnehete32 7db35a7958 CUDA: add FLOOR, CEIL, ROUND, TRUNC unary ops (#16917) 2025-11-02 11:12:57 +08:00
Aaron Teo a864132ba5 devops: fix failing s390x docker build (#16918) 2025-11-02 08:48:46 +08:00
Aaron Teo d38d9f0877 ggml: add s390x cpu-feats (#16774) 2025-11-02 08:48:23 +08:00
Georgi Gerganov 7fd205a8e8 scripts : add script to bench models (#16894) 2025-11-02 00:15:31 +02:00
Pascal 2f68ce7cfd webui: auto-refresh /props on inference start to resync model metadata (#16784)
* webui: auto-refresh /props on inference start to resync model metadata

- Add no-cache headers to /props and /slots
- Throttle slot checks to 30s
- Prevent concurrent fetches with promise guard
- Trigger refresh from chat streaming for legacy and ModelSelector
- Show dynamic serverWarning when using cached data

* fix: restore proper legacy behavior in webui by using unified /props refresh

Updated assistant message bubbles to show each message's stored model when available,
falling back to the current server model only when the per-message value is missing

When the model selector is disabled, now fetches /props and prioritizes that model name
over chunk metadata, then persists it with the streamed message so legacy mode properly
reflects the backend configuration

* fix: detect first valid SSE chunk and refresh server props once

* fix: removed the slots availability throttle constant and state

* webui: purge ai-generated cruft

* chore: update webui static build
2025-11-01 19:49:51 +01:00
Pascal e4a71599e5 webui: add HTML/JS preview support to MarkdownContent with sandboxed iframe (#16757)
* webui: add HTML/JS preview support to MarkdownContent with sandboxed iframe dialog

Extended MarkdownContent to flag previewable code languages,
add a preview button alongside copy controls, manage preview
dialog state, and share styling for the new button group

Introduced CodePreviewDialog.svelte, a sandboxed iframe modal
for rendering HTML/JS previews with consistent dialog controls

* webui: fullscreen HTML preview dialog using bits-ui

* Update tools/server/webui/src/lib/components/app/misc/CodePreviewDialog.svelte

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* Update tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* webui: pedantic style tweak for CodePreviewDialog close button

* webui: remove overengineered preview language logic

* chore: update webui static build

---------

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>
2025-11-01 17:14:54 +01:00
Adrien Gallouët dd5e8cab51 vendor : update cpp-httplib to 0.27.0 (#16846)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2025-11-01 16:52:17 +01:00
Xuan-Son Nguyen cf659bbb8e mtmd: refactor preprocessing + support max/min pixels (#16878)
* mtmd: refactor preprocessing + support max/min pixels

* fix mlp type

* implement mix/max pixels

* improve hparams

* better image preproc for qwen

* fix

* fix out of bound composite

* fix (2)

* fix token calculation

* get_merge_kernel_size()

* fix llama4 and lfm2

* gonna fix them all

* use simple resize for qwen

* qwen: increase min tokens

* no resize if dst size == src size

* restore to initial min/max tokens value for qwen
2025-11-01 15:51:36 +01:00
Aleksander Grygier d8b860a219 Add a setting to display message generation statistics (#16901)
* feat: Add setting to display message generation statistics

* chore: build static webui output
2025-11-01 15:35:57 +01:00
Jaromír Hradílek 1ae74882f8 webui: recognize AsciiDoc files as valid text files (#16850)
* webui: recognize AsciiDoc files as valid text files

* webui: add an updated static webui build

* webui: add the updated dependency list

* webui: re-add an updated static webui build

This also reverts commit 742dbb8379.
2025-11-01 15:02:57 +01:00
Sigbjørn Skjæret 961660b8c3 common : allow --system-prompt-file for diffusion-cli (#16903) 2025-11-01 11:01:42 +01:00
Sigbjørn Skjæret 74fef4129f codeowners : update after refactor (#16905) 2025-11-01 09:55:25 +02:00
Jeff Bolz 5d8bb900bc vulkan: Fix multi_add invalid descriptor usage (#16899) 2025-11-01 06:52:14 +01:00
80 changed files with 3437 additions and 800 deletions
+4 -1
View File
@@ -24,8 +24,9 @@ RUN --mount=type=cache,target=/root/.ccache \
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
-DLLAMA_BUILD_TESTS=OFF \
-DGGML_BACKEND_DL=OFF \
-DGGML_NATIVE=OFF \
-DGGML_BACKEND_DL=ON \
-DGGML_CPU_ALL_VARIANTS=ON \
-DGGML_BLAS=ON \
-DGGML_BLAS_VENDOR=OpenBLAS && \
cmake --build build --config Release -j $(nproc) && \
@@ -103,6 +104,7 @@ FROM base AS light
WORKDIR /llama.cpp/bin
# Copy llama.cpp binaries and libraries
COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin
COPY --from=collector /llama.cpp/bin/llama-cli /llama.cpp/bin
ENTRYPOINT [ "/llama.cpp/bin/llama-cli" ]
@@ -116,6 +118,7 @@ ENV LLAMA_ARG_HOST=0.0.0.0
WORKDIR /llama.cpp/bin
# Copy llama.cpp binaries and libraries
COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin
COPY --from=collector /llama.cpp/bin/llama-server /llama.cpp/bin
EXPOSE 8080
+37 -37
View File
@@ -4,49 +4,49 @@ on:
workflow_call:
jobs:
ubuntu-24-riscv64-cpu-cross:
runs-on: ubuntu-24.04
# ubuntu-24-riscv64-cpu-cross:
# runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v4
- name: Setup Riscv
run: |
sudo dpkg --add-architecture riscv64
# steps:
# - uses: actions/checkout@v4
# - name: Setup Riscv
# run: |
# sudo dpkg --add-architecture riscv64
# Add arch-specific repositories for non-amd64 architectures
cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
EOF
# # Add arch-specific repositories for non-amd64 architectures
# cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
# EOF
sudo apt-get update || true ;# Prevent failure due to missing URLs.
# sudo apt-get update || true ;# Prevent failure due to missing URLs.
sudo apt-get install -y --no-install-recommends \
build-essential \
gcc-14-riscv64-linux-gnu \
g++-14-riscv64-linux-gnu
# sudo apt-get install -y --no-install-recommends \
# build-essential \
# gcc-14-riscv64-linux-gnu \
# g++-14-riscv64-linux-gnu
- name: Build
run: |
cmake -B build -DLLAMA_CURL=OFF \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_OPENMP=OFF \
-DLLAMA_BUILD_EXAMPLES=ON \
-DLLAMA_BUILD_TOOLS=ON \
-DLLAMA_BUILD_TESTS=OFF \
-DCMAKE_SYSTEM_NAME=Linux \
-DCMAKE_SYSTEM_PROCESSOR=riscv64 \
-DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
-DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
-DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
# - name: Build
# run: |
# cmake -B build -DLLAMA_CURL=OFF \
# -DCMAKE_BUILD_TYPE=Release \
# -DGGML_OPENMP=OFF \
# -DLLAMA_BUILD_EXAMPLES=ON \
# -DLLAMA_BUILD_TOOLS=ON \
# -DLLAMA_BUILD_TESTS=OFF \
# -DCMAKE_SYSTEM_NAME=Linux \
# -DCMAKE_SYSTEM_PROCESSOR=riscv64 \
# -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
# -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \
# -DCMAKE_POSITION_INDEPENDENT_CODE=ON \
# -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \
# -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
# -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
# -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
cmake --build build --config Release -j $(nproc)
# cmake --build build --config Release -j $(nproc)
# ubuntu-24-riscv64-vulkan-cross:
# runs-on: ubuntu-24.04
+2 -2
View File
@@ -134,8 +134,8 @@ jobs:
include:
- build: 'x64'
os: ubuntu-22.04
- build: 's390x-z15' # z15 because our CI runners are on z15
os: ubuntu-22.04-s390x
- build: 's390x'
os: ubuntu-24.04-s390x
# GGML_BACKEND_DL and GGML_CPU_ALL_VARIANTS are not currently supported on arm
# - build: 'arm64'
# os: ubuntu-22.04-arm
+1
View File
@@ -89,6 +89,7 @@
/src/llama-model-loader.* @slaren
/src/llama-model.* @CISC
/src/llama-vocab.* @CISC
/src/models/ @CISC
/tests/ @ggerganov
/tests/test-backend-ops.cpp @slaren
/tests/test-thread-safety.cpp @slaren
+15 -1
View File
@@ -2030,7 +2030,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.system_prompt.pop_back();
}
}
).set_examples({LLAMA_EXAMPLE_MAIN}));
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_DIFFUSION}));
add_opt(common_arg(
{"--in-file"}, "FNAME",
"an input file (repeat to specify multiple files)",
@@ -2768,6 +2768,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.image.emplace_back(value);
}
).set_examples({LLAMA_EXAMPLE_MTMD}));
add_opt(common_arg(
{"--image-min-tokens"}, "N",
"minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)",
[](common_params & params, int value) {
params.image_min_tokens = value;
}
).set_examples(mmproj_examples).set_env("LLAMA_ARG_IMAGE_MIN_TOKENS"));
add_opt(common_arg(
{"--image-max-tokens"}, "N",
"maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)",
[](common_params & params, int value) {
params.image_max_tokens = value;
}
).set_examples(mmproj_examples).set_env("LLAMA_ARG_IMAGE_MAX_TOKENS"));
if (llama_supports_rpc()) {
add_opt(common_arg(
{"--rpc"}, "SERVERS",
+17 -2
View File
@@ -313,7 +313,6 @@ json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msg
}
if (!msg.reasoning_content.empty()) {
jmsg["reasoning_content"] = msg.reasoning_content;
jmsg["thinking"] = msg.reasoning_content; // gpt-oss
}
if (!msg.tool_name.empty()) {
jmsg["name"] = msg.tool_name;
@@ -1810,7 +1809,23 @@ static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
auto prompt = apply(tmpl, inputs);
// Copy reasoning to the "thinking" field as expected by the gpt-oss template
auto adjusted_messages = json::array();
for (const auto & msg : inputs.messages) {
auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string();
auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array();
if (has_reasoning_content && has_tool_calls) {
auto adjusted_message = msg;
adjusted_message["thinking"] = msg.at("reasoning_content");
adjusted_messages.push_back(adjusted_message);
} else {
adjusted_messages.push_back(msg);
}
}
auto prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
// Check if we need to replace the return token with end token during
// inference and without generation prompt. For more details see:
+2
View File
@@ -406,6 +406,8 @@ struct common_params {
bool mmproj_use_gpu = true; // use GPU for multimodal model
bool no_mmproj = false; // explicitly disable multimodal model
std::vector<std::string> image; // path to image file(s)
int image_min_tokens = -1;
int image_max_tokens = -1;
// finetune
struct lr_opt lr;
+107
View File
@@ -9802,6 +9802,113 @@ class CogVLMModel(LlamaModel):
return [(self.map_tensor_name(name), data_torch)]
@ModelBase.register("JanusForConditionalGeneration")
class JanusProModel(LlamaModel):
model_arch = gguf.MODEL_ARCH.LLAMA # reuse Llama arch
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Skip vision, aligner, and generation tensors
skip_prefixes = (
'model.vision_model.',
'model.aligner.',
'model.vqmodel.',
'model.generation_embeddings.',
'model.generation_aligner.',
'model.generation_head.',
)
if name.startswith(skip_prefixes):
return []
if name.startswith('model.language_model.'):
name = name.replace('model.language_model.', 'model.')
elif name.startswith('language_model.'):
name = name.replace('language_model.', '')
return super().modify_tensors(data_torch, name, bid)
@ModelBase.register("JanusForConditionalGeneration")
class JanusProVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
if "intermediate_size" not in self.hparams_vision:
mlp_ratio = self.hparams_vision.get("mlp_ratio")
hidden_size = self.hparams_vision.get("hidden_size")
if mlp_ratio is not None and hidden_size is not None:
self.hparams_vision["intermediate_size"] = int(round(hidden_size * mlp_ratio))
def set_gguf_parameters(self):
super().set_gguf_parameters()
assert self.hparams_vision is not None
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.JANUS_PRO)
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-6))
hidden_act = str(self.hparams_vision.get("hidden_act", "")).lower()
if hidden_act == "gelu":
self.gguf_writer.add_vision_use_gelu(True)
elif hidden_act == "silu":
self.gguf_writer.add_vision_use_silu(True)
def _map_aligner_tensor(self, data_torch: Tensor, name: str) -> Iterable[tuple[str, Tensor]]:
"""Map aligner tensors to projector format"""
suffix = ".bias" if name.endswith(".bias") else ".weight"
if name.startswith("model.aligner."):
local_name = name[len("model.aligner."):]
elif name.startswith("aligner."):
local_name = name[len("aligner."):]
else:
raise ValueError(f"Unsupported Janus aligner prefix: {name}")
if local_name.startswith("fc1."):
mm_index = 0
elif local_name.startswith("hidden_layers."):
parts = local_name.split(".", 2)
if len(parts) < 3:
raise ValueError(f"Unexpected Janus aligner tensor name: {name}")
mm_index = int(parts[1]) + 1
else:
raise ValueError(f"Unsupported Janus aligner tensor: {name}")
tensor_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_MMPROJ, mm_index, suffix=suffix)
return [(tensor_name, data_torch)]
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
# Skip language model tensors as they will be handled by `JanusProModel`
if name.startswith(('model.language_model.', 'language_model.')):
return []
# Skip generation-related components
skip_generation_prefixes = (
'model.vqmodel.',
'vqmodel.',
'model.generation_embeddings.',
'generation_embeddings.',
'model.generation_aligner.',
'generation_aligner.',
'model.generation_head.',
'generation_head.',
)
if name.startswith(skip_generation_prefixes):
return []
# Handle aligner tensors
if name.startswith(('model.aligner.', 'aligner.')):
return list(self._map_aligner_tensor(data_torch, name))
# Handle vision tensors
if name.startswith(('model.vision_model.', 'vision_model.')):
return [(self.map_tensor_name(name), data_torch)]
return []
###### CONVERSION LOGIC ######
+3 -3
View File
@@ -7,9 +7,9 @@
## Images
We have three Docker images available for this project:
1. `ghcr.io/ggml-org/llama.cpp:full`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`)
2. `ghcr.io/ggml-org/llama.cpp:light`: This image only includes the main executable file. (platforms: `linux/amd64`, `linux/arm64`)
3. `ghcr.io/ggml-org/llama.cpp:server`: This image only includes the server executable file. (platforms: `linux/amd64`, `linux/arm64`)
1. `ghcr.io/ggml-org/llama.cpp:full`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
2. `ghcr.io/ggml-org/llama.cpp:light`: This image only includes the main executable file. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
3. `ghcr.io/ggml-org/llama.cpp:server`: This image only includes the server executable file. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
Additionally, there the following images, similar to the above:
@@ -138,6 +138,9 @@ if model_path is None:
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
)
print("Loading model and tokenizer using AutoTokenizer:", model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
print("Model type: ", config.model_type)
@@ -147,10 +150,6 @@ print("Number of layers: ", config.num_hidden_layers)
print("BOS token id: ", config.bos_token_id)
print("EOS token id: ", config.eos_token_id)
print("Loading model and tokenizer using AutoTokenizer:", model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if unreleased_model_name:
model_name_lower = unreleased_model_name.lower()
unreleased_module_path = (
@@ -171,7 +170,7 @@ if unreleased_model_name:
exit(1)
else:
model = AutoModelForCausalLM.from_pretrained(
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
)
for name, module in model.named_modules():
+6 -3
View File
@@ -308,6 +308,10 @@ function(ggml_add_cpu_backend_variant tag_name)
set(GGML_INTERNAL_${feat} ON)
endforeach()
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
foreach (feat VXE2 NNPA)
set(GGML_INTERNAL_${feat} OFF)
endforeach()
foreach (feat ${ARGN})
set(GGML_INTERNAL_${feat} ON)
endforeach()
@@ -377,9 +381,8 @@ if (GGML_CPU_ALL_VARIANTS)
endif()
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
ggml_add_cpu_backend_variant(s390x_z15 Z15 VXE)
# ggml_add_cpu_backend_variant(s390x_z16 Z16 VXE)
# ggml_add_cpu_backend_variant(s390x_z17 Z17 VXE)
ggml_add_cpu_backend_variant(z15 Z15 VXE2)
ggml_add_cpu_backend_variant(z16 Z16 VXE2 NNPA)
else()
message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}")
endif()
+10 -3
View File
@@ -504,11 +504,18 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
endforeach()
endif()
if (GGML_VXE OR GGML_INTERNAL_VXE)
message(STATUS "VX/VXE/VXE2 enabled")
if (GGML_VXE OR GGML_INTERNAL_VXE2)
message(STATUS "VXE2 enabled")
list(APPEND ARCH_FLAGS -mvx -mzvector)
list(APPEND ARCH_DEFINITIONS GGML_VXE)
list(APPEND ARCH_DEFINITIONS GGML_USE_VXE2)
endif()
if (GGML_INTERNAL_NNPA)
message(STATUS "NNPA enabled")
list(APPEND ARCH_DEFINITIONS GGML_USE_NNPA)
endif()
ggml_add_cpu_backend_features(${GGML_CPU_NAME} s390 ${ARCH_DEFINITIONS})
elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "wasm")
message(STATUS "Wasm detected")
list (APPEND GGML_CPU_SOURCES ggml-cpu/arch/wasm/quants.c)
+4 -5
View File
@@ -700,7 +700,8 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
for (; ib + 1 < nb; ib += 2) {
// Compute combined scale for the block 0 and 1
const __m128 d_0_1 = (__m128)__lsx_vreplgr2vr_w( GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d) );
const float ft0 = GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d);
const __m128 d_0_1 = (__m128)(v4f32){ft0, ft0, ft0, ft0};
const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);
@@ -714,11 +715,9 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
bx_1 = __lsx_vsub_b(bx_1, off);
const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
//_mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
//_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
// Compute combined scale for the block 2 and 3
const __m128 d_2_3 = (__m128)__lsx_vreplgr2vr_w( GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) );
const float ft1 = GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d);
const __m128 d_2_3 = (__m128)(v4f32){ft1, ft1, ft1, ft1};
const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);
+50
View File
@@ -0,0 +1,50 @@
#include "ggml-backend-impl.h"
#if defined(__s390x__)
#include <sys/auxv.h>
// find hwcap bits in asm/elf.h
#ifndef HWCAP_VXRS_EXT2
#define HWCAP_VXRS_EXT2 (1 << 15)
#endif
#ifndef HWCAP_NNPA
#define HWCAP_NNPA (1 << 20)
#endif
struct s390x_features {
bool has_vxe2 = false;
bool has_nnpa = false;
s390x_features() {
uint32_t hwcap = getauxval(AT_HWCAP);
// NOTE: use hwcap2 with DFLT for z17 and later
// uint32_t hwcap2 = getauxval(AT_HWCAP2);
has_vxe2 = !!(hwcap & HWCAP_VXRS_EXT2);
has_nnpa = !!(hwcap & HWCAP_NNPA);
}
};
static int ggml_backend_cpu_s390x_score() {
int score = 1;
s390x_features sf;
// IBM z15 / LinuxONE 3
#ifdef GGML_USE_VXE2
if (!sf.has_vxe2) { return 0; }
score += 1 << 1;
#endif
// IBM z16 / LinuxONE 4 and z17 / LinuxONE 5
#ifdef GGML_USE_NNPA
if (!sf.has_nnpa) { return 0; }
score += 1 << 2;
#endif
return score;
}
GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_s390x_score)
#endif // __s390x__
+3 -1
View File
@@ -500,13 +500,15 @@ inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
#endif
#if defined(__loongarch_asx)
#if defined(__loongarch_sx)
/* float type data load instructions */
static __m128 __lsx_vreplfr2vr_s(const float val) {
v4f32 res = {val, val, val, val};
return (__m128)res;
}
#endif
#if defined(__loongarch_asx)
static __m256 __lasx_xvreplfr2vr_s(const float val) {
v8f32 res = {val, val, val, val, val, val, val, val};
return (__m256)res;
+25 -25
View File
@@ -956,7 +956,7 @@ do { \
#define GGML_F32Cx8 __m256
#define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
__m256i a;
@@ -999,34 +999,34 @@ static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
#define GGML_F32x4 __m128
#define GGML_F32x4_ZERO (__m128)__lsx_vldi(0)
#define GGML_F32x4_SET1(x) (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
#define GGML_F32x4_SET1(x) (__m128)__lsx_vreplfr2vr_s((x))
#define GGML_F32x4_LOAD(x) (__m128)__lsx_vld((x), 0)
#define GGML_F32x4_STORE(x, y) __lsx_vst(y, x, 0)
#define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
#define GGML_F32x4_ADD __lsx_vfadd_s
#define GGML_F32x4_MUL __lsx_vfmul_s
#define GGML_F32x4_REDUCE(res, x) \
{ \
int offset = GGML_F32_ARR >> 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \
} \
__m128i tmp = __lsx_vsrli_d((__m128i) x[0], 32); \
tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, x[0]); \
tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
const __m128 t0 = (__m128)__lsx_vshuf4i_w(tmp, 0x88); \
tmp = __lsx_vsrli_d((__m128i) t0, 32); \
tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, t0); \
tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
#define GGML_F32x4_REDUCE(res, x) \
{ \
int offset = GGML_F32_ARR >> 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
} \
__m128i t0 = __lsx_vpickev_w((__m128i)x[0], (__m128i)x[0]); \
__m128i t1 = __lsx_vpickod_w((__m128i)x[0], (__m128i)x[0]); \
__m128 t2 = __lsx_vfadd_s((__m128)t0, (__m128)t1); \
__m128i t3 = __lsx_vpickev_w((__m128i)t2, (__m128i)t2); \
__m128i t4 = __lsx_vpickod_w((__m128i)t2, (__m128i)t2); \
__m128 t5 = __lsx_vfadd_s((__m128)t3, (__m128)t4); \
res = (ggml_float) ((v4f32)t5)[0]; \
}
#define GGML_F32_VEC GGML_F32x4
@@ -1068,7 +1068,7 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
#define GGML_F32Cx4 __m128
#define GGML_F32Cx4_ZERO (__m128)__lsx_vldi(0)
#define GGML_F32Cx4_SET1(x) (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
#define GGML_F32Cx4_SET1(x) (__m128)__lsx_vreplfr2vr_s((x))
#define GGML_F32Cx4_LOAD(x) (__m128)__lsx_f16x4_load(x)
#define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
#define GGML_F32Cx4_FMA GGML_F32x4_FMA
+4
View File
@@ -14,6 +14,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case< 64, 64>(ctx, dst);
} break;
case 72: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case< 72, 72>(ctx, dst);
} break;
case 80: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst);
+29 -2
View File
@@ -6,7 +6,7 @@
// nbatch_K == number of K columns to load in parallel for KQ calculation
// TODO optimize kernel parameters for FP16 NVIDIA (P100)
// TODO optimize kernel parameters for head sizes 40, 80, 96, 112
// TODO optimize kernel parameters for head sizes 40, 72, 80, 96, 112
// The ROCm compiler cannot handle templating in __launch_bounds__.
// As a workaround, define a macro to package the kernel parameters as uint32_t:
@@ -32,6 +32,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 64, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 64, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 64, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 64, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 64, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 64, 40)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 64, 40)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40)
@@ -80,6 +86,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 3, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
@@ -130,6 +142,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
@@ -185,6 +204,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
@@ -723,7 +749,7 @@ static __global__ void flash_attn_tile(
if (
#ifdef GGML_USE_WMMA_FATTN
(ncols2 != 1 && DV != 40 && DV != 512) ||
(ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) ||
#endif // GGML_USE_WMMA_FATTN
(use_logit_softcap && !(DV == 128 || DV == 256))
) {
@@ -1198,6 +1224,7 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
extern DECL_FATTN_TILE_CASE( 40, 40);
extern DECL_FATTN_TILE_CASE( 64, 64);
extern DECL_FATTN_TILE_CASE( 72, 72);
extern DECL_FATTN_TILE_CASE( 80, 80);
extern DECL_FATTN_TILE_CASE( 96, 96);
extern DECL_FATTN_TILE_CASE(112, 112);
+3 -2
View File
@@ -223,6 +223,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
switch (K->ne[0]) {
case 40:
case 64:
case 72:
case 80:
case 96:
case 128:
@@ -275,7 +276,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
// If Turing tensor cores available, use them:
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72) {
if (can_use_vector_kernel) {
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
@@ -301,7 +302,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
}
// Use the WMMA kernel if possible:
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) {
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {
if (can_use_vector_kernel && Q->ne[1] <= 2) {
return BEST_FATTN_KERNEL_VEC;
}
+33
View File
@@ -2115,6 +2115,14 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2] : src1->ne[1]);
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) ||
ggml_backend_buft_is_cuda_split(src1->buffer->buft);
//TODO: add support for fusion for split buffers
if (split) {
return false;
}
//we only support fusion for ncols_dst = 1
if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
return false;
@@ -2154,6 +2162,15 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
return false;
}
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) ||
ggml_backend_buft_is_cuda_split(src1->buffer->buft);
//TODO: add support for fusion for split buffers
if (split) {
return false;
}
return use_mul_mat_vec_q;
}
@@ -2499,6 +2516,18 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_XIELU:
ggml_cuda_op_xielu(ctx, dst);
break;
case GGML_UNARY_OP_FLOOR:
ggml_cuda_op_floor(ctx, dst);
break;
case GGML_UNARY_OP_CEIL:
ggml_cuda_op_ceil(ctx, dst);
break;
case GGML_UNARY_OP_ROUND:
ggml_cuda_op_round(ctx, dst);
break;
case GGML_UNARY_OP_TRUNC:
ggml_cuda_op_trunc(ctx, dst);
break;
default:
return false;
}
@@ -3769,6 +3798,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_TRUNC:
return ggml_is_contiguous(op->src[0]);
default:
return false;
@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(72, 72);
@@ -3,7 +3,7 @@
from glob import glob
import os
HEAD_SIZES_KQ = [40, 64, 80, 96, 112, 128, 256, 576]
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576]
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"]
@@ -81,6 +81,8 @@ for ncols in [8, 16, 32, 64]:
for head_size_kq in HEAD_SIZES_KQ:
if head_size_kq == 40:
continue
if head_size_kq == 72:
continue
if head_size_kq != 576 and ncols2 == 16:
continue
if head_size_kq == 576 and ncols2 != 16:
+32
View File
@@ -85,6 +85,22 @@ static __device__ __forceinline__ float op_elu(float x) {
return (x > 0.f) ? x : expm1f(x);
}
static __device__ __forceinline__ float op_floor(float x) {
return floorf(x);
}
static __device__ __forceinline__ float op_ceil(float x) {
return ceilf(x);
}
static __device__ __forceinline__ float op_round(float x) {
return round(x);
}
static __device__ __forceinline__ float op_trunc(float x) {
return trunc(x);
}
template <float (*op)(float), typename T>
static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -201,6 +217,22 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_elu>(ctx, dst);
}
void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_floor>(ctx, dst);
}
void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_ceil>(ctx, dst);
}
void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_round>(ctx, dst);
}
void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_trunc>(ctx, dst);
}
/* gated ops */
template <float (*op)(float), typename T>
+8
View File
@@ -63,6 +63,14 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+1
View File
@@ -707,6 +707,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
if (op->src[0]->ne[0] != 32 &&
op->src[0]->ne[0] != 40 &&
op->src[0]->ne[0] != 64 &&
op->src[0]->ne[0] != 72 &&
op->src[0]->ne[0] != 80 &&
op->src[0]->ne[0] != 96 &&
op->src[0]->ne[0] != 112 &&
+8
View File
@@ -5362,6 +5362,7 @@ typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, hal
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 72, 72>;
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
@@ -5374,6 +5375,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 72, 72>;
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
@@ -5387,6 +5389,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 72, 72>;
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
@@ -5400,6 +5403,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 72, 72>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
@@ -5412,6 +5416,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 72, 72>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
@@ -5424,6 +5429,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 72, 72>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
@@ -5436,6 +5442,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 72, 72>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
@@ -5448,6 +5455,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 72, 72>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
+6
View File
@@ -8399,6 +8399,7 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
const bool is_neox = mode & 2;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
const int is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
if (is_mrope) {
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
@@ -8489,9 +8490,14 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float), &attn_factor));
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float), &beta_fast));
CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &beta_slow));
// both mrope and vision kernels have sections
if (is_mrope || is_vision) {
CL_CHECK(clSetKernelArg(kernel, 33, sizeof(int32_t)*4, &sections));
}
// only mrope has is_imrope
if (is_mrope && !is_vision) {
CL_CHECK(clSetKernelArg(kernel, 34, sizeof(int), &is_imrope));
}
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)nth, 1, 1};
+50 -24
View File
@@ -392,7 +392,8 @@ kernel void kernel_rope_multi_f32(
float attn_factor,
float beta_fast,
float beta_slow,
int4 sections
int4 sections,
int is_imrope
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
@@ -419,17 +420,29 @@ kernel void kernel_rope_multi_f32(
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0f;
if (sector < sections.s0) {
theta_base = pos[i2];
}
else if (sector >= sections.s0 && sector < sec_w) {
theta_base = pos[i2 + ne2 * 1];
}
else if (sector >= sec_w && sector < sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 2];
}
else if (sector >= sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 3];
if (is_imrope) {
if (sector % 3 == 1 && sector < 3 * sections.s1) { // h
theta_base = (float) pos[i2 + ne02 * 1];
} else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w
theta_base = (float) pos[i2 + ne02 * 2];
} else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t
theta_base = (float) pos[i2 + ne02 * 0];
} else { // e
theta_base = (float) pos[i2 + ne02 * 3];
}
} else {
if (sector < sections.s0) {
theta_base = pos[i2];
}
else if (sector >= sections.s0 && sector < sec_w) {
theta_base = pos[i2 + ne2 * 1];
}
else if (sector >= sec_w && sector < sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 2];
}
else if (sector >= sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 3];
}
}
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
@@ -490,7 +503,8 @@ kernel void kernel_rope_multi_f16(
float attn_factor,
float beta_fast,
float beta_slow,
int4 sections
int4 sections,
int is_imrope
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
@@ -517,17 +531,29 @@ kernel void kernel_rope_multi_f16(
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0f;
if (sector < sections.s0) {
theta_base = pos[i2];
}
else if (sector >= sections.s0 && sector < sec_w) {
theta_base = pos[i2 + ne2 * 1];
}
else if (sector >= sec_w && sector < sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 2];
}
else if (sector >= sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 3];
if (is_imrope) {
if (sector % 3 == 1 && sector < 3 * sections.s1) { // h
theta_base = (float) pos[i2 + ne02 * 1];
} else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w
theta_base = (float) pos[i2 + ne02 * 2];
} else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t
theta_base = (float) pos[i2 + ne02 * 0];
} else { // e
theta_base = (float) pos[i2 + ne02 * 3];
}
} else {
if (sector < sections.s0) {
theta_base = pos[i2];
}
else if (sector >= sections.s0 && sector < sec_w) {
theta_base = pos[i2 + ne2 * 1];
}
else if (sector >= sec_w && sector < sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 2];
}
else if (sector >= sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 3];
}
}
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
+45 -25
View File
@@ -2,26 +2,43 @@
#include "common.hpp"
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
#define GGML_ASSERT_TENSOR_FITS_INT(t) \
GGML_ASSERT((t)->ne[0] < INT_MAX && (t)->ne[1] < INT_MAX && (t)->ne[2] < INT_MAX && (t)->ne[3] < INT_MAX)
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const float * src0_dd = (const float *) dst->src[0]->data;
float * dst_dd = (float *) dst->data;
const int64_t ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3];
const int64_t ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2],
ne03 = dst->src[0]->ne[3];
GGML_ASSERT_TENSOR_FITS_INT(dst);
GGML_ASSERT_TENSOR_FITS_INT(dst->src[0]);
const int nr0 = (int) (ne00 / ne0);
const int nr1 = (int) (ne01 / ne1);
const int nr2 = (int) (ne02 / ne2);
const int nr3 = (int) (ne03 / ne3);
const int ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3];
const int ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2],
ne03 = dst->src[0]->ne[3];
const size_t total = ne0 * ne1 * ne2 * ne3;
const int BLOCK_SIZE = 256;
const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int nr0 = ne00 / ne0;
const int nr1 = ne01 / ne1;
const int nr2 = ne02 / ne2;
const int nr3 = ne03 / ne3;
const int nb0 = dst->src[0]->nb[0];
const int nb1 = dst->src[0]->nb[1];
const int nb2 = dst->src[0]->nb[2];
const int nb3 = dst->src[0]->nb[3];
const char * base = (const char *) src0_dd;
const size_t total = (size_t) ne0 * ne1 * ne2 * ne3;
constexpr int BLOCK_SIZE = 256;
const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
const float inv_ne0 = 1.0f / ne0;
const float inv_ne_01 = 1.0f / (ne0 * ne1);
const float inv_ne_012 = 1.0f / (ne0 * ne1 * ne2);
const int repeat_count = nr0 * nr1 * nr2 * nr3;
queue_ptr stream = ctx.stream();
@@ -33,24 +50,27 @@ void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst
return;
}
const int i0 = i % ne0;
const int i1 = (i / ne0) % ne1;
const int i2 = (i / (ne0 * ne1)) % ne2;
const int i3 = i / (ne0 * ne1 * ne2);
const int i3 = (int) (i * inv_ne_012);
const int i2 = (int) (i * inv_ne_01) - i3 * ne2;
const int i1 = (int) (i * inv_ne0) - (int) (i * inv_ne_01) * ne1;
const int i0 = i - (int) (i * inv_ne0) * ne0;
int j0 = 0, j1 = 0, j2 = 0, j3 = 0;
float acc = 0.0f;
for (int j3 = 0; j3 < nr3; ++j3) {
for (int j2 = 0; j2 < nr2; ++j2) {
for (int j1 = 0; j1 < nr1; ++j1) {
for (int j0 = 0; j0 < nr0; ++j0) {
acc += src0_dd[(i0 + j0 * ne0) + (i1 + j1 * ne1) * ne00 + (i2 + j2 * ne2) * ne00 * ne01 +
(i3 + j3 * ne3) * ne00 * ne01 * ne02];
}
}
}
}
for (int j = 0; j < repeat_count; ++j) {
const float * ptr = (const float *) (base + (i0 + j0 * ne0) * nb0 + (i1 + j1 * ne1) * nb1 +
(i2 + j2 * ne2) * nb2 + (i3 + j3 * ne3) * nb3);
acc += *ptr;
int carry = (++j0 >= nr0);
j0 -= carry * nr0;
carry = (carry && (++j1 >= nr1));
j1 -= carry * nr1;
carry = (carry && (++j2 >= nr2));
j2 -= carry * nr2;
j3 += carry;
}
dst_dd[i] = acc;
});
}
-2
View File
@@ -4274,8 +4274,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) &&
vk12_features.runtimeDescriptorArray &&
device->vendor_id != VK_VENDOR_ID_INTEL &&
getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
device->shader_int64 = device_features2.features.shaderInt64;
@@ -23,16 +23,100 @@ layout (push_constant) uniform parameter2
uint rms_partials;
} p;
// Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
// layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
// layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
layout (binding = 0) buffer A {A_TYPE data_a[];} a[];
layout (binding = 0) buffer D {D_TYPE data_d[];} d[];
layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[];
// No readonly/writeonly decorations. Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
layout (binding = 0) buffer A0 {A_TYPE data_a[];} a0;
layout (binding = 1) buffer A1 {A_TYPE data_a[];} a1;
layout (binding = 2) buffer A2 {A_TYPE data_a[];} a2;
layout (binding = 3) buffer A3 {A_TYPE data_a[];} a3;
layout (binding = 4) buffer A4 {A_TYPE data_a[];} a4;
layout (binding = 5) buffer A5 {A_TYPE data_a[];} a5;
layout (binding = 6) buffer A6 {A_TYPE data_a[];} a6;
layout (binding = 7) buffer A7 {A_TYPE data_a[];} a7;
layout (binding = 8) buffer A8 {A_TYPE data_a[];} a8;
layout (binding = 9) buffer A9 {A_TYPE data_a[];} a9;
layout (binding = 10) buffer A10 {A_TYPE data_a[];} a10;
layout (binding = 11) buffer A11 {A_TYPE data_a[];} a11;
layout (binding = 0) buffer D0 {D_TYPE data_d[];} d0;
layout (binding = 1) buffer D1 {D_TYPE data_d[];} d1;
layout (binding = 2) buffer D2 {D_TYPE data_d[];} d2;
layout (binding = 3) buffer D3 {D_TYPE data_d[];} d3;
layout (binding = 4) buffer D4 {D_TYPE data_d[];} d4;
layout (binding = 5) buffer D5 {D_TYPE data_d[];} d5;
layout (binding = 6) buffer D6 {D_TYPE data_d[];} d6;
layout (binding = 7) buffer D7 {D_TYPE data_d[];} d7;
layout (binding = 8) buffer D8 {D_TYPE data_d[];} d8;
layout (binding = 9) buffer D9 {D_TYPE data_d[];} d9;
layout (binding = 10) buffer D10 {D_TYPE data_d[];} d10;
layout (binding = 11) buffer D11 {D_TYPE data_d[];} d11;
layout (binding = 0, std430) buffer PartialBuf0 {float partial_sums[];} partials0;
layout (binding = 1, std430) buffer PartialBuf1 {float partial_sums[];} partials1;
layout (binding = 2, std430) buffer PartialBuf2 {float partial_sums[];} partials2;
layout (binding = 3, std430) buffer PartialBuf3 {float partial_sums[];} partials3;
layout (binding = 4, std430) buffer PartialBuf4 {float partial_sums[];} partials4;
layout (binding = 5, std430) buffer PartialBuf5 {float partial_sums[];} partials5;
layout (binding = 6, std430) buffer PartialBuf6 {float partial_sums[];} partials6;
layout (binding = 7, std430) buffer PartialBuf7 {float partial_sums[];} partials7;
layout (binding = 8, std430) buffer PartialBuf8 {float partial_sums[];} partials8;
layout (binding = 9, std430) buffer PartialBuf9 {float partial_sums[];} partials9;
layout (binding = 10, std430) buffer PartialBuf10 {float partial_sums[];} partials10;
layout (binding = 11, std430) buffer PartialBuf11 {float partial_sums[];} partials11;
layout(constant_id = 0) const uint num_srcs = 2;
FLOAT_TYPE load_a(uint b, uint i) {
switch (b) {
case 0: return FLOAT_TYPE(a0.data_a[i]);
case 1: return FLOAT_TYPE(a1.data_a[i]);
case 2: return FLOAT_TYPE(a2.data_a[i]);
case 3: return FLOAT_TYPE(a3.data_a[i]);
case 4: return FLOAT_TYPE(a4.data_a[i]);
case 5: return FLOAT_TYPE(a5.data_a[i]);
case 6: return FLOAT_TYPE(a6.data_a[i]);
case 7: return FLOAT_TYPE(a7.data_a[i]);
case 8: return FLOAT_TYPE(a8.data_a[i]);
case 9: return FLOAT_TYPE(a9.data_a[i]);
case 10: return FLOAT_TYPE(a10.data_a[i]);
case 11: return FLOAT_TYPE(a11.data_a[i]);
default: return FLOAT_TYPE(0);
}
}
void store_d(uint b, uint i, FLOAT_TYPE v) {
switch (b) {
case 0: d0.data_d[i] = D_TYPE(v); break;
case 1: d1.data_d[i] = D_TYPE(v); break;
case 2: d2.data_d[i] = D_TYPE(v); break;
case 3: d3.data_d[i] = D_TYPE(v); break;
case 4: d4.data_d[i] = D_TYPE(v); break;
case 5: d5.data_d[i] = D_TYPE(v); break;
case 6: d6.data_d[i] = D_TYPE(v); break;
case 7: d7.data_d[i] = D_TYPE(v); break;
case 8: d8.data_d[i] = D_TYPE(v); break;
case 9: d9.data_d[i] = D_TYPE(v); break;
case 10: d10.data_d[i] = D_TYPE(v); break;
case 11: d11.data_d[i] = D_TYPE(v); break;
default: break;
}
}
void store_partial(uint b, uint i, float v) {
switch (b) {
case 0: partials0.partial_sums[i] = v; break;
case 1: partials1.partial_sums[i] = v; break;
case 2: partials2.partial_sums[i] = v; break;
case 3: partials3.partial_sums[i] = v; break;
case 4: partials4.partial_sums[i] = v; break;
case 5: partials5.partial_sums[i] = v; break;
case 6: partials6.partial_sums[i] = v; break;
case 7: partials7.partial_sums[i] = v; break;
case 8: partials8.partial_sums[i] = v; break;
case 9: partials9.partial_sums[i] = v; break;
case 10: partials10.partial_sums[i] = v; break;
case 11: partials11.partial_sums[i] = v; break;
default: break;
}
}
uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {
return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0];
}
@@ -78,10 +162,10 @@ void main() {
FLOAT_TYPE sum = FLOAT_TYPE(0);
[[unroll]] for (uint s = 0; s < num_srcs; ++s) {
sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]);
sum += load_a(s, src_idx(s, i00, i01, i02, i03));
}
sum_sq += sum*sum;
d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
store_d(num_srcs, dst_idx(i00, i01, i02, i03), sum);
idx += num_threads;
}
@@ -104,7 +188,7 @@ void main() {
}
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
store_partial(num_srcs + 1, orig_idx / (num_iter * num_threads), sum_sq);
}
}
#endif
+1
View File
@@ -3186,6 +3186,7 @@ class VisionProjectorType:
KIMIVL = "kimivl"
LIGHTONOCR = "lightonocr"
COGVLM = "cogvlm"
JANUS_PRO = "janus_pro"
# Items here are (block size, type size)
+2
View File
@@ -1183,6 +1183,7 @@ class TensorNameMap:
"model.mm_projector.mlp.mlp.{bid}",
"vision_model.vision_adapter.mlp.fc{bid}", # llama 4
"mlp1.{bid}", # InternVL
"model.aligner.fc1.hidden_layers.{bid}", # Janus Pro
),
MODEL_TENSOR.V_MMPROJ_PEG: (
@@ -1291,6 +1292,7 @@ class TensorNameMap:
"model.vision_tower.encoder.layer.{bid}.attention.projection_layer", # Intern-S1
"vpm.encoder.layers.{bid}.self_attn.out_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
"model.vision_model.encoder.layers.{bid}.self_attn.projection_layer", # Janus Pro
"vision_model.model.layers.{bid}.self_attn.o_proj", # llama4
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.attention.wo", # pixtral
+4 -3
View File
@@ -461,7 +461,10 @@ extern "C" {
LLAMA_API bool llama_supports_gpu_offload(void);
LLAMA_API bool llama_supports_rpc (void);
// NOTE: After creating a llama_context, it is recommended to query the actual values using these functions
// In some cases the requested values via llama_context_params may differ from the actual values used by the context
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
@@ -585,7 +588,7 @@ extern "C" {
LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
// Manually free a LoRA adapter
// Note: loaded adapters will be free when the associated model is deleted
// NOTE: loaded adapters will be free when the associated model is deleted
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
// Get the invocation tokens if the current lora is an alora
@@ -1111,8 +1114,6 @@ extern "C" {
// // sample from the logits of the last token in the batch
// const llama_token id = llama_sampler_sample(smpl, ctx, -1);
//
// // accepting the token updates the internal state of certain samplers (e.g. grammar, repetition, etc.)
// llama_sampler_accept(smpl, id);
// ...
// }
//
+74
View File
@@ -0,0 +1,74 @@
#!/usr/bin/env bash
RESULTS="bench-models-results.txt"
: > "$RESULTS"
ARGS_BB="-c 270336 -npp 512,4096,8192 -npl 1,2,4,8,16,32 -ntg 32"
ARGS_B="-d 0,4096,8192,16384,32768 -p 2048 -n 32"
QUICK=0
while (( "$#" )); do
case "$1" in
--quick) QUICK=1; shift ;;
*) shift ;;
esac
done
if (( QUICK )); then
ARGS_BB="-c 20480 -npp 512,4096 -npl 1,2,4 -ntg 32"
ARGS_B="-d 0 -p 2048 -n 32"
fi
run_model() {
local HFR=$1
local HFF=$2
printf "## ${HFR}\n" | tee -a "$RESULTS"
printf "\n" | tee -a "$RESULTS"
printf "Model: https://huggingface.co/${HFR}\n" | tee -a "$RESULTS"
printf "\n" | tee -a "$RESULTS"
printf -- "- \`llama-batched-bench\`\n" | tee -a "$RESULTS"
printf "\n" | tee -a "$RESULTS"
./bin/llama-batched-bench \
-hfr "${HFR}" -hff "${HFF}" \
-m "${HFF}" -fa 1 -ub 2048 --no-mmap \
${ARGS_BB} | tee -a "$RESULTS"
printf "\n" | tee -a "$RESULTS"
printf -- "- \`llama-bench\`\n" | tee -a "$RESULTS"
printf "\n" | tee -a "$RESULTS"
./bin/llama-bench \
-m "${HFF}" -fa 1 -ub 2048 -mmp 0 \
${ARGS_B} | tee -a "$RESULTS"
printf "\n" | tee -a "$RESULTS"
printf "\n"
}
run_model "ggml-org/gpt-oss-20b-GGUF" "gpt-oss-20b-mxfp4.gguf"
run_model "ggml-org/gpt-oss-120b-GGUF" "gpt-oss-120b-mxfp4-00001-of-00003.gguf"
run_model "ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF" "qwen3-coder-30b-a3b-instruct-q8_0.gguf"
run_model "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF" "qwen2.5-coder-7b-q8_0.gguf"
run_model "ggml-org/gemma-3-4b-it-qat-GGUF" "gemma-3-4b-it-qat-Q4_0.gguf"
if [[ -f models-extra.txt ]]; then
while read -r HFR HFF; do
[[ -z "$HFR" ]] && continue
run_model "$HFR" "$HFF"
done < models-extra.txt
fi
printf "\n=====================================\n"
printf "\n"
cat "$RESULTS"
printf "\n"
printf "Done! Results are written to $RESULTS\n"
printf "\n"
+27 -10
View File
@@ -112,11 +112,24 @@ llama_context::llama_context(
}
}
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
if (cparams.kv_unified) {
cparams.n_ctx_seq = cparams.n_ctx;
} else {
cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
if (cparams.n_ctx_seq == 0) {
throw std::runtime_error("n_ctx_seq == 0");
}
if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) {
cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max;
LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx);
}
}
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
@@ -125,14 +138,14 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
if (n_ctx_per_seq < hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
__func__, n_ctx_per_seq, hparams.n_ctx_train);
if (cparams.n_ctx_seq < hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
}
if (n_ctx_per_seq > hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
__func__, n_ctx_per_seq, hparams.n_ctx_train);
if (cparams.n_ctx_seq > hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
}
if (!hparams.vocab_only) {
@@ -453,8 +466,8 @@ uint32_t llama_context::n_ctx() const {
return cparams.n_ctx;
}
uint32_t llama_context::n_ctx_per_seq() const {
return cparams.n_ctx / cparams.n_seq_max;
uint32_t llama_context::n_ctx_seq() const {
return cparams.n_ctx_seq;
}
uint32_t llama_context::n_batch() const {
@@ -2383,6 +2396,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
return ctx->n_ctx();
}
uint32_t llama_n_ctx_seq(const llama_context * ctx) {
return ctx->n_ctx_seq();
}
uint32_t llama_n_batch(const llama_context * ctx) {
return ctx->n_batch();
}
+5 -5
View File
@@ -43,11 +43,11 @@ struct llama_context {
ggml_backend_sched_t get_sched() const;
uint32_t n_ctx() const;
uint32_t n_ctx_per_seq() const;
uint32_t n_batch() const;
uint32_t n_ubatch() const;
uint32_t n_seq_max() const;
uint32_t n_ctx() const;
uint32_t n_ctx_seq() const;
uint32_t n_batch() const;
uint32_t n_ubatch() const;
uint32_t n_seq_max() const;
uint32_t n_threads() const;
uint32_t n_threads_batch() const;
+1
View File
@@ -8,6 +8,7 @@
struct llama_cparams {
uint32_t n_ctx; // context size used during inference
uint32_t n_ctx_seq; // context for a single sequence
uint32_t n_batch;
uint32_t n_ubatch;
uint32_t n_seq_max;
+4 -10
View File
@@ -6712,14 +6712,14 @@ float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) co
}
ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const {
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
const uint32_t n_ctx_seq = cparams.n_ctx_seq;
// choose long/short freq factors based on the context size
if (layers[il].rope_freqs != nullptr) {
return layers[il].rope_freqs;
}
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
if (n_ctx_seq > hparams.n_ctx_orig_yarn) {
return layers[il].rope_long;
}
@@ -6795,12 +6795,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* filter_attn */ std::move(filter_attn),
/* filter_recr */ std::move(filter_recr));
} else {
uint32_t n_ctx_per_stream = cparams.n_ctx;
if (!cparams.kv_unified) {
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
}
llama_memory_i::layer_reuse_cb reuse = nullptr;
if (arch == LLM_ARCH_GEMMA3N) {
@@ -6824,7 +6818,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.offload_kqv,
params.swa_full,
cparams.kv_unified,
n_ctx_per_stream,
cparams.n_ctx_seq,
cparams.n_seq_max,
cparams.n_ubatch,
1,
@@ -6840,7 +6834,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
!cparams.flash_attn,
cparams.offload_kqv,
cparams.kv_unified,
n_ctx_per_stream,
cparams.n_ctx_seq,
cparams.n_seq_max,
1,
hparams.n_swa,
+21 -2
View File
@@ -1454,6 +1454,8 @@ struct test_case {
ggml_context_ptr ctx(ggml_init(params)); // smart ptr
GGML_ASSERT(ctx);
gf = ggml_new_graph_custom(ctx.get(), graph_nodes, false);
ggml_tensor * out = build_graph(ctx.get());
current_op_name = op_desc(out);
@@ -7225,8 +7227,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v));
}
for (int hsk : { 40, 64, 80, 96, 128, 192, 256, 576 }) {
for (int hsv : { 40, 64, 80, 96, 128, 192, 256, 512 }) {
for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) {
for (int hsv : { 40, 64, 72, 80, 96, 128, 192, 256, 512 }) {
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
@@ -7569,6 +7571,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
if (mode == MODE_SUPPORT) {
auto test_cases = make_test_cases_eval();
filter_test_cases(test_cases, params_filter);
// Filter out fusion cases
test_cases.erase(
std::remove_if(test_cases.begin(), test_cases.end(), [](const std::unique_ptr<test_case> & tc) {
return tc->run_whole_graph();
}),
test_cases.end()
);
for (auto & test : test_cases) {
test->eval_support(backend, op_names_filter, output_printer);
}
@@ -7619,6 +7630,14 @@ static void show_test_coverage() {
all_ops.insert(ggml_glu_op_name((enum ggml_glu_op)i));
}
auto test_cases = make_test_cases_eval();
// Filter out fusion cases
test_cases.erase(
std::remove_if(test_cases.begin(), test_cases.end(), [](const std::unique_ptr<test_case> & tc) {
return tc->run_whole_graph();
}),
test_cases.end()
);
std::set<std::string> tested_ops;
ggml_init_params params = {
+8 -1
View File
@@ -131,7 +131,14 @@ int main(int argc, char ** argv) {
}
batch = llama_batch_get_one(&token, 1);
if (llama_decode(ctx.get(), batch)) {
int ret = llama_decode(ctx.get(), batch);
if (ret == 1 && i > 0) {
LOG_INF("Context full, stopping generation.\n");
break;
}
if (ret != 0) {
LOG_ERR("Model %d/%d, Context %d/%d: failed to decode\n", m + 1, num_models, c + 1, num_contexts);
failed.store(true);
return;
-2
View File
@@ -221,7 +221,5 @@ int main(int argc, char ** argv) {
llama_backend_free();
LOG("\n\n");
return 0;
}
+3 -1
View File
@@ -154,8 +154,9 @@ enum projector_type {
PROJECTOR_TYPE_LFM2,
PROJECTOR_TYPE_KIMIVL,
PROJECTOR_TYPE_LIGHTONOCR,
PROJECTOR_TYPE_UNKNOWN,
PROJECTOR_TYPE_COGVLM,
PROJECTOR_TYPE_JANUS_PRO,
PROJECTOR_TYPE_UNKNOWN,
};
static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
@@ -180,6 +181,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
};
static projector_type clip_projector_type_from_string(const std::string & str) {
+660 -369
View File
File diff suppressed because it is too large Load Diff
+10
View File
@@ -1,6 +1,7 @@
#pragma once
#include "ggml.h"
#include <stddef.h>
#include <stdint.h>
@@ -22,9 +23,18 @@ enum clip_modality {
CLIP_MODALITY_AUDIO,
};
enum clip_flash_attn_type {
CLIP_FLASH_ATTN_TYPE_AUTO = -1,
CLIP_FLASH_ATTN_TYPE_DISABLED = 0,
CLIP_FLASH_ATTN_TYPE_ENABLED = 1,
};
struct clip_context_params {
bool use_gpu;
enum ggml_log_level verbosity;
enum clip_flash_attn_type flash_attn_type;
int image_min_tokens;
int image_max_tokens;
};
struct clip_init_result {
+7 -4
View File
@@ -132,10 +132,13 @@ struct mtmd_cli_context {
void init_vision_context(common_params & params) {
const char * clip_path = params.mmproj.path.c_str();
mtmd_context_params mparams = mtmd_context_params_default();
mparams.use_gpu = params.mmproj_use_gpu;
mparams.print_timings = true;
mparams.n_threads = params.cpuparams.n_threads;
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
mparams.use_gpu = params.mmproj_use_gpu;
mparams.print_timings = true;
mparams.n_threads = params.cpuparams.n_threads;
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
mparams.flash_attn_type = params.flash_attn_type;
mparams.image_min_tokens = params.image_min_tokens;
mparams.image_max_tokens = params.image_max_tokens;
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
if (!ctx_vision.get()) {
LOG_ERR("Failed to load vision model from %s\n", clip_path);
+20 -6
View File
@@ -19,7 +19,6 @@
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <limits>
#include <vector>
// represents raw image data, layout is RGBRGBRGB...
@@ -92,6 +91,15 @@ const char * mtmd_default_marker() {
return "<__media__>";
}
static clip_flash_attn_type mtmd_get_clip_flash_attn_type(enum llama_flash_attn_type flash_attn_type) {
switch (flash_attn_type) {
case LLAMA_FLASH_ATTN_TYPE_AUTO: return CLIP_FLASH_ATTN_TYPE_AUTO;
case LLAMA_FLASH_ATTN_TYPE_DISABLED: return CLIP_FLASH_ATTN_TYPE_DISABLED;
case LLAMA_FLASH_ATTN_TYPE_ENABLED: return CLIP_FLASH_ATTN_TYPE_ENABLED;
}
return CLIP_FLASH_ATTN_TYPE_AUTO;
}
mtmd_context_params mtmd_context_params_default() {
mtmd_context_params params;
params.use_gpu = true;
@@ -100,6 +108,9 @@ mtmd_context_params mtmd_context_params_default() {
params.verbosity = GGML_LOG_LEVEL_INFO;
params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
params.media_marker = mtmd_default_marker();
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
params.image_min_tokens = -1;
params.image_max_tokens = -1;
return params;
}
@@ -162,8 +173,13 @@ struct mtmd_context {
}
clip_context_params ctx_clip_params;
ctx_clip_params.use_gpu = ctx_params.use_gpu;
ctx_clip_params.verbosity = ctx_params.verbosity;
ctx_clip_params.use_gpu = ctx_params.use_gpu;
ctx_clip_params.verbosity = ctx_params.verbosity;
ctx_clip_params.flash_attn_type = mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type);
// custom image token limits
ctx_clip_params.image_min_tokens = ctx_params.image_min_tokens;
ctx_clip_params.image_max_tokens = ctx_params.image_max_tokens;
auto res = clip_init(mmproj_fname, ctx_clip_params);
ctx_v = res.ctx_v;
ctx_a = res.ctx_a;
@@ -378,9 +394,7 @@ mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
}
void mtmd_free(mtmd_context * ctx) {
if (ctx) {
delete ctx;
}
delete ctx;
}
struct mtmd_tokenizer {
+5
View File
@@ -82,6 +82,11 @@ struct mtmd_context_params {
enum ggml_log_level verbosity;
const char * image_marker; // deprecated, use media_marker instead
const char * media_marker;
enum llama_flash_attn_type flash_attn_type;
// limit number of image tokens, only for vision models with dynamic resolution
int image_min_tokens; // minimum number of tokens for image input (default: read from metadata)
int image_max_tokens; // maximum number of tokens for image input (default: read from metadata)
};
MTMD_API const char * mtmd_default_marker(void);
Binary file not shown.
+80 -18
View File
@@ -2407,7 +2407,7 @@ struct server_context {
params_dft.devices = params_base.speculative.devices;
params_dft.model = params_base.speculative.model;
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx;
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
params_dft.n_parallel = 1;
params_dft.cache_type_k = params_base.speculative.cache_type_k;
@@ -2452,10 +2452,13 @@ struct server_context {
std::string & mmproj_path = params_base.mmproj.path;
if (!mmproj_path.empty()) {
mtmd_context_params mparams = mtmd_context_params_default();
mparams.use_gpu = params_base.mmproj_use_gpu;
mparams.print_timings = false;
mparams.n_threads = params_base.cpuparams.n_threads;
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
mparams.use_gpu = params_base.mmproj_use_gpu;
mparams.print_timings = false;
mparams.n_threads = params_base.cpuparams.n_threads;
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
mparams.flash_attn_type = params_base.flash_attn_type;
mparams.image_min_tokens = params_base.image_min_tokens;
mparams.image_max_tokens = params_base.image_max_tokens;
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
if (mctx == nullptr) {
SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
@@ -2495,10 +2498,16 @@ struct server_context {
}
void init() {
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
const int n_ctx_train = llama_model_n_ctx_train(model);
int n_ctx_slot = llama_n_ctx_seq(ctx);
if (n_ctx_slot > n_ctx_train) {
SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train);
n_ctx_slot = n_ctx_train;
}
for (int i = 0; i < params_base.n_parallel; i++) {
server_slot slot;
@@ -2527,7 +2536,7 @@ struct server_context {
}
}
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
slot.callback_on_release = [this](int) {
queue_tasks.pop_deferred_task();
@@ -2699,6 +2708,39 @@ struct server_context {
return ret;
}
// return true if at least one slot has been purged
// TODO: improve logic
// - smarter decision which slot to purge (LRU or longest prompt?)
// - move slot to level 2 cache instead of removing?
// - instead of purging, try to store and resume later?
bool try_purge_idle_slots() {
bool res = false;
if (!params_base.kv_unified) {
return res;
}
for (auto & slot : slots) {
if (slot.is_processing()) {
continue;
}
if (slot.prompt.n_tokens() > 0) {
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
slot.prompt.tokens.clear();
res = true;
// purge slots one by one
break;
}
}
return res;
}
bool launch_slot_with_task(server_slot & slot, server_task && task) {
slot.reset();
@@ -3635,9 +3677,10 @@ struct server_context {
int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);
// next, batch any pending prompts without exceeding n_batch
float alora_scale = -1.0f;
float alora_scale = -1.0f;
size_t alora_disabled_id = 0;
// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) {
// check if we can batch this slot with the previous one
@@ -3914,8 +3957,11 @@ struct server_context {
// truncate any tokens that are beyond n_past for this slot
const llama_pos p0 = slot.prompt.tokens.pos_next();
SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
SLT_WRN(slot, "failed to truncate tokens with position >= %d\n", p0);
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
// there is no common part left
@@ -3924,8 +3970,6 @@ struct server_context {
slot.prompt.tokens.clear();
}
SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
// check if we should process the image
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
// process the image
@@ -4126,6 +4170,8 @@ struct server_context {
std::string err;
if (n_batch == 1 && ret == 1) {
// TODO: try to terminate only the largest active slot/sequence and continue with the rest
// need to remove the tokens from the current batch too
err = "Context size has been exceeded.";
}
@@ -4141,17 +4187,23 @@ struct server_context {
// TODO: handle ret == 2 (abort) when we start aborting
if (!err.empty()) {
SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
for (auto & slot : slots) {
send_error(slot, err);
slot.release();
if (slot.is_processing()) {
send_error(slot, err);
slot.release();
}
}
break;
}
}
// retry with half the batch size to try to find a free slot in the KV cache
n_batch /= 2;
if (!try_purge_idle_slots()) {
n_batch /= 2;
}
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
@@ -4391,6 +4443,15 @@ int main(int argc, char ** argv) {
return 1;
}
// TODO: should we have a separate n_parallel parameter for the server?
// https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177
if (params.n_parallel == 1 && params.kv_unified == false) {
LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true\n", __func__);
params.n_parallel = 4;
params.kv_unified = true;
}
common_init();
// struct that contains llama context and inference
@@ -4849,6 +4910,7 @@ int main(int argc, char ** argv) {
json data = {
{ "default_generation_settings", default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel },
{ "model_alias", ctx_server.params_base.model_alias },
{ "model_path", ctx_server.params_base.model.path },
{ "modalities", json {
{"vision", ctx_server.oai_parser_opt.allow_image},
@@ -4944,7 +5006,7 @@ int main(int argc, char ** argv) {
// Everything else, including multimodal completions.
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
}
const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel;
const size_t n_ctx_slot = ctx_server.slots.front().n_ctx;
tasks.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
auto n_prompt_tokens = inputs[i].size();
@@ -433,21 +433,21 @@ def test_context_size_exceeded_stream():
@pytest.mark.parametrize(
"n_batch,batch_count,reuse_cache",
[
(64, 15, False),
(64, 3, False),
(64, 1, True),
]
)
def test_return_progresssss(n_batch, batch_count, reuse_cache):
def test_return_progress(n_batch, batch_count, reuse_cache):
global server
server.n_batch = n_batch
server.n_ctx = 2048
server.n_ctx = 256
server.n_slots = 1
server.start()
def make_cmpl_request():
return server.make_stream_request("POST", "/chat/completions", data={
"max_tokens": 10,
"messages": [
{"role": "user", "content": "This is a test" * 100},
{"role": "user", "content": "This is a test" * 10},
],
"stream": True,
"return_progress": True,
@@ -368,6 +368,37 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
# assert match_regex(re_content, res.body["content"])
@pytest.mark.parametrize(
"n_ctx,n_slots,n_predict_vals,expected_success",
[
(256, 4, [80, 40, 80, 80], [True, True, True, True]),
(256, 4, [70, 70, 70, 70], [False, False, False, False]),
(256, 4, [90, 90, 40, 90], [False, False, True, False]),
(256, 4, [90, 90, 40, 75], [True, True, True, True]),
],
)
def test_completion_unified(n_ctx, n_slots, n_predict_vals, expected_success):
global server
server.n_slots = n_slots
server.kv_unified = True
server.n_ctx = n_ctx
server.start()
prompt = "A"
tasks = []
for n_predict in n_predict_vals:
tasks.append((server.make_request, ("POST", "/completion", {"prompt": prompt, "n_predict": n_predict})))
results = parallel_function_calls(tasks)
for res, n_predict, expect_ok in zip(results, n_predict_vals, expected_success):
if expect_ok:
assert res.status_code == 200
assert "content" in res.body
if "timings" in res.body:
assert res.body["timings"]["predicted_n"] == n_predict
else:
assert res.status_code == 500
assert "content" not in res.body
@pytest.mark.parametrize(
"prompt,n_predict,response_fields",
[
+2 -2
View File
@@ -18,7 +18,7 @@ def test_infill_without_input_extra():
"input_suffix": "}\n",
})
assert res.status_code == 200
assert match_regex("(Ann|small|shiny|Daddy)+", res.body["content"])
assert match_regex("(Ann|small|shiny|Daddy|Jimmy)+", res.body["content"])
def test_infill_with_input_extra():
@@ -34,7 +34,7 @@ def test_infill_with_input_extra():
"input_suffix": "}\n",
})
assert res.status_code == 200
assert match_regex("(Dad|excited|park)+", res.body["content"])
assert match_regex("(Dad|excited|park|Jimmy)+", res.body["content"])
@pytest.mark.parametrize("input_extra", [
+3
View File
@@ -78,6 +78,7 @@ class ServerProcess:
server_embeddings: bool | None = False
server_reranking: bool | None = False
server_metrics: bool | None = False
kv_unified: bool | None = False
server_slots: bool | None = False
pooling: str | None = None
draft: int | None = None
@@ -159,6 +160,8 @@ class ServerProcess:
server_args.append("--reranking")
if self.server_metrics:
server_args.append("--metrics")
if self.kv_unified:
server_args.append("--kv-unified")
if self.server_slots:
server_args.append("--slots")
else:
+2 -1
View File
@@ -1212,7 +1212,7 @@ public:
for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) {
auto * chunk = tokens.map_idx_to_media[it->first].get();
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
map_idx_to_media[start_idx+it->first] = std::move(new_chunk);
map_idx_to_media[start_idx + it->first] = std::move(new_chunk);
}
}
}
@@ -1244,6 +1244,7 @@ public:
}
void clear() {
map_idx_to_media.clear();
tokens.clear();
}
+361
View File
@@ -59,6 +59,7 @@
"prettier-plugin-tailwindcss": "^0.6.11",
"rehype-katex": "^7.0.1",
"remark-math": "^6.0.0",
"sass": "^1.93.3",
"storybook": "^9.0.17",
"svelte": "^5.0.0",
"svelte-check": "^4.0.0",
@@ -1176,6 +1177,330 @@
"node": ">= 8"
}
},
"node_modules/@parcel/watcher": {
"version": "2.5.1",
"resolved": "https://registry.npmjs.org/@parcel/watcher/-/watcher-2.5.1.tgz",
"integrity": "sha512-dfUnCxiN9H4ap84DvD2ubjw+3vUNpstxa0TneY/Paat8a3R4uQZDLSvWjmznAY/DoahqTHl9V46HF/Zs3F29pg==",
"dev": true,
"hasInstallScript": true,
"license": "MIT",
"optional": true,
"dependencies": {
"detect-libc": "^1.0.3",
"is-glob": "^4.0.3",
"micromatch": "^4.0.5",
"node-addon-api": "^7.0.0"
},
"engines": {
"node": ">= 10.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/parcel"
},
"optionalDependencies": {
"@parcel/watcher-android-arm64": "2.5.1",
"@parcel/watcher-darwin-arm64": "2.5.1",
"@parcel/watcher-darwin-x64": "2.5.1",
"@parcel/watcher-freebsd-x64": "2.5.1",
"@parcel/watcher-linux-arm-glibc": "2.5.1",
"@parcel/watcher-linux-arm-musl": "2.5.1",
"@parcel/watcher-linux-arm64-glibc": "2.5.1",
"@parcel/watcher-linux-arm64-musl": "2.5.1",
"@parcel/watcher-linux-x64-glibc": "2.5.1",
"@parcel/watcher-linux-x64-musl": "2.5.1",
"@parcel/watcher-win32-arm64": "2.5.1",
"@parcel/watcher-win32-ia32": "2.5.1",
"@parcel/watcher-win32-x64": "2.5.1"
}
},
"node_modules/@parcel/watcher-android-arm64": {
"version": "2.5.1",
"resolved": "https://registry.npmjs.org/@parcel/watcher-android-arm64/-/watcher-android-arm64-2.5.1.tgz",
"integrity": "sha512-KF8+j9nNbUN8vzOFDpRMsaKBHZ/mcjEjMToVMJOhTozkDonQFFrRcfdLWn6yWKCmJKmdVxSgHiYvTCef4/qcBA==",
"cpu": [
"arm64"
],
"dev": true,
"license": "MIT",
"optional": true,
"os": [
"android"
],
"engines": {
"node": ">= 10.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/parcel"
}
},
"node_modules/@parcel/watcher-darwin-arm64": {
"version": "2.5.1",
"resolved": "https://registry.npmjs.org/@parcel/watcher-darwin-arm64/-/watcher-darwin-arm64-2.5.1.tgz",
"integrity": "sha512-eAzPv5osDmZyBhou8PoF4i6RQXAfeKL9tjb3QzYuccXFMQU0ruIc/POh30ePnaOyD1UXdlKguHBmsTs53tVoPw==",
"cpu": [
"arm64"
],
"dev": true,
"license": "MIT",
"optional": true,
"os": [
"darwin"
],
"engines": {
"node": ">= 10.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/parcel"
}
},
"node_modules/@parcel/watcher-darwin-x64": {
"version": "2.5.1",
"resolved": "https://registry.npmjs.org/@parcel/watcher-darwin-x64/-/watcher-darwin-x64-2.5.1.tgz",
"integrity": "sha512-1ZXDthrnNmwv10A0/3AJNZ9JGlzrF82i3gNQcWOzd7nJ8aj+ILyW1MTxVk35Db0u91oD5Nlk9MBiujMlwmeXZg==",
"cpu": [
"x64"
],
"dev": true,
"license": "MIT",
"optional": true,
"os": [
"darwin"
],
"engines": {
"node": ">= 10.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/parcel"
}
},
"node_modules/@parcel/watcher-freebsd-x64": {
"version": "2.5.1",
"resolved": "https://registry.npmjs.org/@parcel/watcher-freebsd-x64/-/watcher-freebsd-x64-2.5.1.tgz",
"integrity": "sha512-SI4eljM7Flp9yPuKi8W0ird8TI/JK6CSxju3NojVI6BjHsTyK7zxA9urjVjEKJ5MBYC+bLmMcbAWlZ+rFkLpJQ==",
"cpu": [
"x64"
],
"dev": true,
"license": "MIT",
"optional": true,
"os": [
"freebsd"
],
"engines": {
"node": ">= 10.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/parcel"
}
},
"node_modules/@parcel/watcher-linux-arm-glibc": {
"version": "2.5.1",
"resolved": "https://registry.npmjs.org/@parcel/watcher-linux-arm-glibc/-/watcher-linux-arm-glibc-2.5.1.tgz",
"integrity": "sha512-RCdZlEyTs8geyBkkcnPWvtXLY44BCeZKmGYRtSgtwwnHR4dxfHRG3gR99XdMEdQ7KeiDdasJwwvNSF5jKtDwdA==",
"cpu": [
"arm"
],
"dev": true,
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/parcel"
}
},
"node_modules/@parcel/watcher-linux-arm-musl": {
"version": "2.5.1",
"resolved": "https://registry.npmjs.org/@parcel/watcher-linux-arm-musl/-/watcher-linux-arm-musl-2.5.1.tgz",
"integrity": "sha512-6E+m/Mm1t1yhB8X412stiKFG3XykmgdIOqhjWj+VL8oHkKABfu/gjFj8DvLrYVHSBNC+/u5PeNrujiSQ1zwd1Q==",
"cpu": [
"arm"
],
"dev": true,
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/parcel"
}
},
"node_modules/@parcel/watcher-linux-arm64-glibc": {
"version": "2.5.1",
"resolved": "https://registry.npmjs.org/@parcel/watcher-linux-arm64-glibc/-/watcher-linux-arm64-glibc-2.5.1.tgz",
"integrity": "sha512-LrGp+f02yU3BN9A+DGuY3v3bmnFUggAITBGriZHUREfNEzZh/GO06FF5u2kx8x+GBEUYfyTGamol4j3m9ANe8w==",
"cpu": [
"arm64"
],
"dev": true,
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/parcel"
}
},
"node_modules/@parcel/watcher-linux-arm64-musl": {
"version": "2.5.1",
"resolved": "https://registry.npmjs.org/@parcel/watcher-linux-arm64-musl/-/watcher-linux-arm64-musl-2.5.1.tgz",
"integrity": "sha512-cFOjABi92pMYRXS7AcQv9/M1YuKRw8SZniCDw0ssQb/noPkRzA+HBDkwmyOJYp5wXcsTrhxO0zq1U11cK9jsFg==",
"cpu": [
"arm64"
],
"dev": true,
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/parcel"
}
},
"node_modules/@parcel/watcher-linux-x64-glibc": {
"version": "2.5.1",
"resolved": "https://registry.npmjs.org/@parcel/watcher-linux-x64-glibc/-/watcher-linux-x64-glibc-2.5.1.tgz",
"integrity": "sha512-GcESn8NZySmfwlTsIur+49yDqSny2IhPeZfXunQi48DMugKeZ7uy1FX83pO0X22sHntJ4Ub+9k34XQCX+oHt2A==",
"cpu": [
"x64"
],
"dev": true,
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/parcel"
}
},
"node_modules/@parcel/watcher-linux-x64-musl": {
"version": "2.5.1",
"resolved": "https://registry.npmjs.org/@parcel/watcher-linux-x64-musl/-/watcher-linux-x64-musl-2.5.1.tgz",
"integrity": "sha512-n0E2EQbatQ3bXhcH2D1XIAANAcTZkQICBPVaxMeaCVBtOpBZpWJuf7LwyWPSBDITb7In8mqQgJ7gH8CILCURXg==",
"cpu": [
"x64"
],
"dev": true,
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/parcel"
}
},
"node_modules/@parcel/watcher-win32-arm64": {
"version": "2.5.1",
"resolved": "https://registry.npmjs.org/@parcel/watcher-win32-arm64/-/watcher-win32-arm64-2.5.1.tgz",
"integrity": "sha512-RFzklRvmc3PkjKjry3hLF9wD7ppR4AKcWNzH7kXR7GUe0Igb3Nz8fyPwtZCSquGrhU5HhUNDr/mKBqj7tqA2Vw==",
"cpu": [
"arm64"
],
"dev": true,
"license": "MIT",
"optional": true,
"os": [
"win32"
],
"engines": {
"node": ">= 10.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/parcel"
}
},
"node_modules/@parcel/watcher-win32-ia32": {
"version": "2.5.1",
"resolved": "https://registry.npmjs.org/@parcel/watcher-win32-ia32/-/watcher-win32-ia32-2.5.1.tgz",
"integrity": "sha512-c2KkcVN+NJmuA7CGlaGD1qJh1cLfDnQsHjE89E60vUEMlqduHGCdCLJCID5geFVM0dOtA3ZiIO8BoEQmzQVfpQ==",
"cpu": [
"ia32"
],
"dev": true,
"license": "MIT",
"optional": true,
"os": [
"win32"
],
"engines": {
"node": ">= 10.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/parcel"
}
},
"node_modules/@parcel/watcher-win32-x64": {
"version": "2.5.1",
"resolved": "https://registry.npmjs.org/@parcel/watcher-win32-x64/-/watcher-win32-x64-2.5.1.tgz",
"integrity": "sha512-9lHBdJITeNR++EvSQVUcaZoWupyHfXe1jZvGZ06O/5MflPcuPLtEphScIBL+AiCWBO46tDSHzWyD0uDmmZqsgA==",
"cpu": [
"x64"
],
"dev": true,
"license": "MIT",
"optional": true,
"os": [
"win32"
],
"engines": {
"node": ">= 10.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/parcel"
}
},
"node_modules/@parcel/watcher/node_modules/detect-libc": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/detect-libc/-/detect-libc-1.0.3.tgz",
"integrity": "sha512-pGjwhsmsp4kL2RTz08wcOlGN83otlqHeD/Z5T8GXZB+/YcpQ/dgo+lbU8ZsGxV0HIvqqxo9l7mqYwyYMD9bKDg==",
"dev": true,
"license": "Apache-2.0",
"optional": true,
"bin": {
"detect-libc": "bin/detect-libc.js"
},
"engines": {
"node": ">=0.10"
}
},
"node_modules/@playwright/test": {
"version": "1.54.1",
"resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.54.1.tgz",
@@ -4697,6 +5022,13 @@
"node": ">= 4"
}
},
"node_modules/immutable": {
"version": "5.1.4",
"resolved": "https://registry.npmjs.org/immutable/-/immutable-5.1.4.tgz",
"integrity": "sha512-p6u1bG3YSnINT5RQmx/yRZBpenIl30kVxkTLDyHLIMk0gict704Q9n+thfDI7lTRm9vXdDYutVzXhzcThxTnXA==",
"dev": true,
"license": "MIT"
},
"node_modules/import-fresh": {
"version": "3.3.1",
"resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.1.tgz",
@@ -6462,6 +6794,14 @@
"tslib": "^2.0.3"
}
},
"node_modules/node-addon-api": {
"version": "7.1.1",
"resolved": "https://registry.npmjs.org/node-addon-api/-/node-addon-api-7.1.1.tgz",
"integrity": "sha512-5m3bsyrjFWE1xf7nz7YXdN4udnVtXK6/Yfgn5qnahL6bCkf2yKt4k3nuTKAtT4r3IG8JNR2ncsIMdZuAzJjHQQ==",
"dev": true,
"license": "MIT",
"optional": true
},
"node_modules/object-inspect": {
"version": "1.13.4",
"resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.4.tgz",
@@ -7484,6 +7824,27 @@
"dev": true,
"license": "MIT"
},
"node_modules/sass": {
"version": "1.93.3",
"resolved": "https://registry.npmjs.org/sass/-/sass-1.93.3.tgz",
"integrity": "sha512-elOcIZRTM76dvxNAjqYrucTSI0teAF/L2Lv0s6f6b7FOwcwIuA357bIE871580AjHJuSvLIRUosgV+lIWx6Rgg==",
"dev": true,
"license": "MIT",
"dependencies": {
"chokidar": "^4.0.0",
"immutable": "^5.0.2",
"source-map-js": ">=0.6.2 <2.0.0"
},
"bin": {
"sass": "sass.js"
},
"engines": {
"node": ">=14.0.0"
},
"optionalDependencies": {
"@parcel/watcher": "^2.4.1"
}
},
"node_modules/scheduler": {
"version": "0.26.0",
"resolved": "https://registry.npmjs.org/scheduler/-/scheduler-0.26.0.tgz",
+1
View File
@@ -61,6 +61,7 @@
"prettier-plugin-tailwindcss": "^0.6.11",
"rehype-katex": "^7.0.1",
"remark-math": "^6.0.0",
"sass": "^1.93.3",
"storybook": "^9.0.17",
"svelte": "^5.0.0",
"svelte-check": "^4.0.0",
@@ -134,6 +134,15 @@
}
}
$effect(() => {
if (open) {
pdfImages = [];
pdfImagesLoading = false;
pdfImagesError = null;
pdfViewMode = 'pages';
}
});
$effect(() => {
if (open && isPdf && pdfViewMode === 'pages') {
loadPdfImages();
@@ -3,7 +3,16 @@
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
import { isLoading } from '$lib/stores/chat.svelte';
import { fade } from 'svelte/transition';
import { Check, Copy, Package, X } from '@lucide/svelte';
import {
Check,
Copy,
Package,
X,
Gauge,
Clock,
WholeWord,
ChartNoAxesColumn
} from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import { Checkbox } from '$lib/components/ui/checkbox';
import { INPUT_CLASSES } from '$lib/constants/input-classes';
@@ -76,8 +85,8 @@
let displayedModel = $derived((): string | null => {
if (!currentConfig.showModelInfo) return null;
if (currentConfig.modelSelectorEnabled) {
return message.model ?? null;
if (message.model) {
return message.model;
}
return serverModel;
@@ -160,22 +169,58 @@
</div>
{/if}
{#if displayedModel()}
<span class="mt-6 mb-4 inline-flex items-center gap-1 text-xs text-muted-foreground">
<Package class="h-3.5 w-3.5" />
<div class="info my-6 grid gap-4">
{#if displayedModel()}
<span class="inline-flex items-center gap-2 text-xs text-muted-foreground">
<span class="inline-flex items-center gap-1">
<Package class="h-3.5 w-3.5" />
<span>Model used:</span>
<span>Model used:</span>
</span>
<button
class="inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
onclick={handleCopyModel}
>
{displayedModel()}
<button
class="inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
onclick={handleCopyModel}
>
{displayedModel()}
<Copy class="ml-1 h-3 w-3 " />
</button>
</span>
{/if}
<Copy class="ml-1 h-3 w-3 " />
</button>
</span>
{/if}
{#if currentConfig.showMessageStats && message.timings && message.timings.predicted_n && message.timings.predicted_ms}
{@const tokensPerSecond = (message.timings.predicted_n / message.timings.predicted_ms) * 1000}
<span class="inline-flex items-center gap-2 text-xs text-muted-foreground">
<span class="inline-flex items-center gap-1">
<ChartNoAxesColumn class="h-3.5 w-3.5" />
<span>Statistics:</span>
</span>
<div class="inline-flex flex-wrap items-center gap-2 text-xs text-muted-foreground">
<span
class="inline-flex items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
>
<Gauge class="h-3 w-3" />
{tokensPerSecond.toFixed(2)} tokens/s
</span>
<span
class="inline-flex items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
>
<WholeWord class="h-3 w-3" />
{message.timings.predicted_n} tokens
</span>
<span
class="inline-flex items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
>
<Clock class="h-3 w-3" />
{(message.timings.predicted_ms / 1000).toFixed(2)}s
</span>
</div>
</span>
{/if}
</div>
{#if message.timestamp && !isEditing}
<ChatMessageActions
@@ -52,6 +52,11 @@
{ value: 'dark', label: 'Dark', icon: Moon }
]
},
{
key: 'showMessageStats',
label: 'Show message generation statistics',
type: 'checkbox'
},
{
key: 'showTokensPerSecond',
label: 'Show tokens per second',
@@ -0,0 +1,93 @@
<script lang="ts">
import { Dialog as DialogPrimitive } from 'bits-ui';
import XIcon from '@lucide/svelte/icons/x';
interface Props {
open: boolean;
code: string;
language: string;
onOpenChange?: (open: boolean) => void;
}
let { open = $bindable(), code, language, onOpenChange }: Props = $props();
let iframeRef = $state<HTMLIFrameElement | null>(null);
$effect(() => {
if (!iframeRef) return;
if (open) {
iframeRef.srcdoc = code;
} else {
iframeRef.srcdoc = '';
}
});
function handleOpenChange(nextOpen: boolean) {
open = nextOpen;
onOpenChange?.(nextOpen);
}
</script>
<DialogPrimitive.Root {open} onOpenChange={handleOpenChange}>
<DialogPrimitive.Portal>
<DialogPrimitive.Overlay class="code-preview-overlay" />
<DialogPrimitive.Content class="code-preview-content">
<iframe
bind:this={iframeRef}
title="Preview {language}"
sandbox="allow-scripts"
class="code-preview-iframe"
></iframe>
<DialogPrimitive.Close
class="code-preview-close absolute top-4 right-4 border-none bg-transparent text-white opacity-70 mix-blend-difference transition-opacity hover:opacity-100 focus-visible:ring-0 focus-visible:ring-offset-0 focus-visible:outline-none disabled:pointer-events-none [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-8"
aria-label="Close preview"
>
<XIcon />
<span class="sr-only">Close preview</span>
</DialogPrimitive.Close>
</DialogPrimitive.Content>
</DialogPrimitive.Portal>
</DialogPrimitive.Root>
<style lang="postcss">
:global(.code-preview-overlay) {
position: fixed;
inset: 0;
background-color: transparent;
z-index: 100000;
}
:global(.code-preview-content) {
position: fixed;
inset: 0;
top: 0 !important;
left: 0 !important;
width: 100dvw;
height: 100dvh;
margin: 0;
padding: 0;
border: none;
border-radius: 0;
background-color: transparent;
box-shadow: none;
display: block;
overflow: hidden;
transform: none !important;
z-index: 100001;
}
:global(.code-preview-iframe) {
display: block;
width: 100dvw;
height: 100dvh;
border: 0;
}
:global(.code-preview-close) {
position: absolute;
z-index: 100002;
}
</style>
@@ -8,13 +8,15 @@
import rehypeKatex from 'rehype-katex';
import rehypeStringify from 'rehype-stringify';
import { copyCodeToClipboard } from '$lib/utils/copy';
import { preprocessLaTeX } from '$lib/utils/latex-protection';
import { browser } from '$app/environment';
import 'katex/dist/katex.min.css';
import '$styles/katex-custom.scss';
import githubDarkCss from 'highlight.js/styles/github-dark.css?inline';
import githubLightCss from 'highlight.js/styles/github.css?inline';
import { mode } from 'mode-watcher';
import { remarkLiteralHtml } from '$lib/markdown/literal-html';
import CodePreviewDialog from './CodePreviewDialog.svelte';
interface Props {
content: string;
@@ -25,6 +27,9 @@
let containerRef = $state<HTMLDivElement>();
let processedHtml = $state('');
let previewDialogOpen = $state(false);
let previewCode = $state('');
let previewLanguage = $state('text');
function loadHighlightTheme(isDark: boolean) {
if (!browser) return;
@@ -117,7 +122,6 @@
const rawCode = codeElement.textContent || '';
const codeId = `code-${Date.now()}-${index}`;
codeElement.setAttribute('data-code-id', codeId);
codeElement.setAttribute('data-raw-code', rawCode);
@@ -138,11 +142,30 @@
copyButton.setAttribute('type', 'button');
copyButton.innerHTML = `
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-copy-icon lucide-copy"><rect width="14" height="14" x="8" y="8" rx="2" ry="2"/><path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/></svg>
`;
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-copy-icon lucide-copy"><rect width="14" height="14" x="8" y="8" rx="2" ry="2"/><path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/></svg>
`;
const actions = document.createElement('div');
actions.className = 'code-block-actions';
actions.appendChild(copyButton);
if (language.toLowerCase() === 'html') {
const previewButton = document.createElement('button');
previewButton.className = 'preview-code-btn';
previewButton.setAttribute('data-code-id', codeId);
previewButton.setAttribute('title', 'Preview code');
previewButton.setAttribute('type', 'button');
previewButton.innerHTML = `
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-eye lucide-eye-icon"><path d="M2.062 12.345a1 1 0 0 1 0-.69C3.5 7.73 7.36 5 12 5s8.5 2.73 9.938 6.655a1 1 0 0 1 0 .69C20.5 16.27 16.64 19 12 19s-8.5-2.73-9.938-6.655"/><circle cx="12" cy="12" r="3"/></svg>
`;
actions.appendChild(previewButton);
}
header.appendChild(languageLabel);
header.appendChild(copyButton);
header.appendChild(actions);
wrapper.appendChild(header);
const clonedPre = pre.cloneNode(true) as HTMLElement;
@@ -154,19 +177,9 @@
return mutated ? tempDiv.innerHTML : html;
}
function normalizeMathDelimiters(text: string): string {
return text
.replace(/(^|[^\\])\\\[((?:\\.|[\s\S])*?)\\\]/g, (_, prefix: string, content: string) => {
return `${prefix}$$${content}$$`;
})
.replace(/(^|[^\\])\\\(((?:\\.|[\s\S])*?)\\\)/g, (_, prefix: string, content: string) => {
return `${prefix}$${content}$`;
});
}
async function processMarkdown(text: string): Promise<string> {
try {
const normalized = normalizeMathDelimiters(text);
let normalized = preprocessLaTeX(text);
const result = await processor().process(normalized);
const html = String(result);
const enhancedLinks = enhanceLinks(html);
@@ -180,49 +193,105 @@
}
}
function setupCopyButtons() {
function getCodeInfoFromTarget(target: HTMLElement) {
const wrapper = target.closest('.code-block-wrapper');
if (!wrapper) {
console.error('No wrapper found');
return null;
}
const codeElement = wrapper.querySelector<HTMLElement>('code[data-code-id]');
if (!codeElement) {
console.error('No code element found in wrapper');
return null;
}
const rawCode = codeElement.getAttribute('data-raw-code');
if (rawCode === null) {
console.error('No raw code found');
return null;
}
const languageLabel = wrapper.querySelector<HTMLElement>('.code-language');
const language = languageLabel?.textContent?.trim() || 'text';
return { rawCode, language };
}
async function handleCopyClick(event: Event) {
event.preventDefault();
event.stopPropagation();
const target = event.currentTarget as HTMLButtonElement | null;
if (!target) {
return;
}
const info = getCodeInfoFromTarget(target);
if (!info) {
return;
}
try {
await copyCodeToClipboard(info.rawCode);
} catch (error) {
console.error('Failed to copy code:', error);
}
}
function handlePreviewClick(event: Event) {
event.preventDefault();
event.stopPropagation();
const target = event.currentTarget as HTMLButtonElement | null;
if (!target) {
return;
}
const info = getCodeInfoFromTarget(target);
if (!info) {
return;
}
previewCode = info.rawCode;
previewLanguage = info.language;
previewDialogOpen = true;
}
function setupCodeBlockActions() {
if (!containerRef) return;
const copyButtons = containerRef.querySelectorAll('.copy-code-btn');
const wrappers = containerRef.querySelectorAll<HTMLElement>('.code-block-wrapper');
for (const button of copyButtons) {
button.addEventListener('click', async (e) => {
e.preventDefault();
e.stopPropagation();
for (const wrapper of wrappers) {
const copyButton = wrapper.querySelector<HTMLButtonElement>('.copy-code-btn');
const previewButton = wrapper.querySelector<HTMLButtonElement>('.preview-code-btn');
const target = e.currentTarget as HTMLButtonElement;
const codeId = target.getAttribute('data-code-id');
if (copyButton && copyButton.dataset.listenerBound !== 'true') {
copyButton.dataset.listenerBound = 'true';
copyButton.addEventListener('click', handleCopyClick);
}
if (!codeId) {
console.error('No code ID found on button');
return;
}
if (previewButton && previewButton.dataset.listenerBound !== 'true') {
previewButton.dataset.listenerBound = 'true';
previewButton.addEventListener('click', handlePreviewClick);
}
}
}
// Find the code element within the same wrapper
const wrapper = target.closest('.code-block-wrapper');
if (!wrapper) {
console.error('No wrapper found');
return;
}
function handlePreviewDialogOpenChange(open: boolean) {
previewDialogOpen = open;
const codeElement = wrapper.querySelector('code[data-code-id]');
if (!codeElement) {
console.error('No code element found in wrapper');
return;
}
const rawCode = codeElement.getAttribute('data-raw-code');
if (!rawCode) {
console.error('No raw code found');
return;
}
try {
await copyCodeToClipboard(rawCode);
} catch (error) {
console.error('Failed to copy code:', error);
}
});
if (!open) {
previewCode = '';
previewLanguage = 'text';
}
}
@@ -243,7 +312,7 @@
$effect(() => {
if (containerRef && processedHtml) {
setupCopyButtons();
setupCodeBlockActions();
}
});
</script>
@@ -253,6 +322,13 @@
{@html processedHtml}
</div>
<CodePreviewDialog
open={previewDialogOpen}
code={previewCode}
language={previewLanguage}
onOpenChange={handlePreviewDialogOpenChange}
/>
<style>
/* Base typography styles */
div :global(p:not(:last-child)) {
@@ -472,7 +548,14 @@
letter-spacing: 0.05em;
}
div :global(.copy-code-btn) {
div :global(.code-block-actions) {
display: flex;
align-items: center;
gap: 0.5rem;
}
div :global(.copy-code-btn),
div :global(.preview-code-btn) {
display: flex;
align-items: center;
justify-content: center;
@@ -483,11 +566,13 @@
transition: all 0.2s ease;
}
div :global(.copy-code-btn:hover) {
div :global(.copy-code-btn:hover),
div :global(.preview-code-btn:hover) {
transform: scale(1.05);
}
div :global(.copy-code-btn:active) {
div :global(.copy-code-btn:active),
div :global(.preview-code-btn:active) {
transform: scale(0.95);
}
@@ -0,0 +1,35 @@
/**
* Matches common Markdown code blocks to exclude them from further processing (e.g. LaTeX).
* - Fenced: ```...```
* - Inline: `...` (does NOT support nested backticks or multi-backtick syntax)
*
* Note: This pattern does not handle advanced cases like:
* `` `code with `backticks` `` or \\``...\\``
*/
export const CODE_BLOCK_REGEXP = /(```[\s\S]*?```|`[^`\n]+`)/g;
/**
* Matches LaTeX math delimiters \(...\) and \[...\] only when not preceded by a backslash (i.e., not escaped),
* while also capturing code blocks (```, `...`) so they can be skipped during processing.
*
* Uses negative lookbehind `(?<!\\)` to avoid matching \\( or \\[.
* Using the lookbehind pattern `(?<!\\)` we skip matches
* that are preceded by a backslash, e.g.
* `Definitions\\(also called macros)` (title of chapter 20 in The TeXbook)
* or `\\[4pt]` (LaTeX line-break).
*
* group 1: code-block
* group 2: square-bracket
* group 3: round-bracket
*/
export const LATEX_MATH_AND_CODE_PATTERN =
/(```[\S\s]*?```|`.*?`)|(?<!\\)\\\[([\S\s]*?[^\\])\\]|(?<!\\)\\\((.*?)\\\)/g;
/** Regex to capture the content of a $$...\\\\...$$ block (display-formula with line-break) */
export const LATEX_LINEBREAK_REGEXP = /\$\$([\s\S]*?\\\\[\s\S]*?)\$\$/;
/** map from mchem-regexp to replacement */
export const MHCHEM_PATTERN_MAP: readonly [RegExp, string][] = [
[/(\s)\$\\ce{/g, '$1$\\\\ce{'],
[/(\s)\$\\pu{/g, '$1$\\\\pu{']
] as const;
@@ -8,6 +8,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
showThoughtInProgress: false,
disableReasoningFormat: false,
keepStatsVisible: false,
showMessageStats: true,
askForTitleConfirmation: false,
pasteLongTextToFileLen: 2500,
pdfAsImage: false,
@@ -82,6 +83,8 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
disableReasoningFormat:
'Show raw LLM output without backend parsing and frontend Markdown rendering to inspect streaming across different models.',
keepStatsVisible: 'Keep processing statistics visible after generation finishes.',
showMessageStats:
'Display generation statistics (tokens/second, token count, duration) below each assistant message.',
askForTitleConfirmation:
'Ask for confirmation before automatically changing conversation title when editing the first message.',
pdfAsImage: 'Parse PDF as image instead of text (requires vision-capable model).',
@@ -69,6 +69,10 @@ export const TEXT_FILE_TYPES = {
extensions: [FileExtensionText.MD],
mimeTypes: [MimeTypeText.MARKDOWN]
},
[FileTypeText.ASCIIDOC]: {
extensions: [FileExtensionText.ADOC],
mimeTypes: [MimeTypeText.ASCIIDOC]
},
[FileTypeText.JAVASCRIPT]: {
extensions: [FileExtensionText.JS],
mimeTypes: [MimeTypeText.JAVASCRIPT, MimeTypeText.JAVASCRIPT_APP]
@@ -33,6 +33,7 @@ export enum FileTypePdf {
export enum FileTypeText {
PLAIN_TEXT = 'plainText',
MARKDOWN = 'markdown',
ASCIIDOC = 'asciidoc',
JAVASCRIPT = 'javascript',
TYPESCRIPT = 'typescript',
JSX = 'jsx',
@@ -86,6 +87,7 @@ export enum FileExtensionPdf {
export enum FileExtensionText {
TXT = '.txt',
MD = '.md',
ADOC = '.adoc',
JS = '.js',
TS = '.ts',
JSX = '.jsx',
@@ -147,6 +149,7 @@ export enum MimeTypeImage {
export enum MimeTypeText {
PLAIN = 'text/plain',
MARKDOWN = 'text/markdown',
ASCIIDOC = 'text/asciidoc',
JAVASCRIPT = 'text/javascript',
JAVASCRIPT_APP = 'application/javascript',
TYPESCRIPT = 'text/typescript',
+16 -4
View File
@@ -54,6 +54,7 @@ export class ChatService {
onError,
onReasoningChunk,
onModel,
onFirstValidChunk,
// Generation parameters
temperature,
max_tokens,
@@ -201,6 +202,7 @@ export class ChatService {
onError,
onReasoningChunk,
onModel,
onFirstValidChunk,
conversationId,
abortController.signal
);
@@ -267,6 +269,7 @@ export class ChatService {
onError?: (error: Error) => void,
onReasoningChunk?: (chunk: string) => void,
onModel?: (model: string) => void,
onFirstValidChunk?: () => void,
conversationId?: string,
abortSignal?: AbortSignal
): Promise<void> {
@@ -283,6 +286,7 @@ export class ChatService {
let lastTimings: ChatMessageTimings | undefined;
let streamFinished = false;
let modelEmitted = false;
let firstValidChunkEmitted = false;
try {
let chunk = '';
@@ -311,10 +315,12 @@ export class ChatService {
try {
const parsed: ApiChatCompletionStreamChunk = JSON.parse(data);
const chunkModel = this.extractModelName(parsed);
if (chunkModel && !modelEmitted) {
modelEmitted = true;
onModel?.(chunkModel);
if (!firstValidChunkEmitted && parsed.object === 'chat.completion.chunk') {
firstValidChunkEmitted = true;
if (!abortSignal?.aborted) {
onFirstValidChunk?.();
}
}
const content = parsed.choices[0]?.delta?.content;
@@ -322,6 +328,12 @@ export class ChatService {
const timings = parsed.timings;
const promptProgress = parsed.prompt_progress;
const chunkModel = this.extractModelName(parsed);
if (chunkModel && !modelEmitted) {
modelEmitted = true;
onModel?.(chunkModel);
}
if (timings || promptProgress) {
this.updateProcessingState(timings, promptProgress, conversationId);
if (timings) {
@@ -1,6 +1,7 @@
import { DatabaseStore } from '$lib/stores/database';
import { chatService, slotsService } from '$lib/services';
import { config } from '$lib/stores/settings.svelte';
import { serverStore } from '$lib/stores/server.svelte';
import { normalizeModelName } from '$lib/utils/model-names';
import { filterByLeafNodeId, findLeafNode, findDescendantMessages } from '$lib/utils/branching';
import { browser } from '$app/environment';
@@ -362,9 +363,41 @@ class ChatStore {
let resolvedModel: string | null = null;
let modelPersisted = false;
const currentConfig = config();
const preferServerPropsModel = !currentConfig.modelSelectorEnabled;
let serverPropsRefreshed = false;
let updateModelFromServerProps: ((persistImmediately?: boolean) => void) | null = null;
const recordModel = (modelName: string, persistImmediately = true): void => {
const normalizedModel = normalizeModelName(modelName);
const refreshServerPropsOnce = () => {
if (serverPropsRefreshed) {
return;
}
serverPropsRefreshed = true;
const hasExistingProps = serverStore.serverProps !== null;
serverStore
.fetchServerProps({ silent: hasExistingProps })
.then(() => {
updateModelFromServerProps?.(true);
})
.catch((error) => {
console.warn('Failed to refresh server props after streaming started:', error);
});
};
const recordModel = (modelName: string | null | undefined, persistImmediately = true): void => {
const serverModelName = serverStore.modelName;
const preferredModelSource = preferServerPropsModel
? (serverModelName ?? modelName ?? null)
: (modelName ?? serverModelName ?? null);
if (!preferredModelSource) {
return;
}
const normalizedModel = normalizeModelName(preferredModelSource);
if (!normalizedModel || normalizedModel === resolvedModel) {
return;
@@ -388,6 +421,20 @@ class ChatStore {
}
};
if (preferServerPropsModel) {
updateModelFromServerProps = (persistImmediately = true) => {
const currentServerModel = serverStore.modelName;
if (!currentServerModel) {
return;
}
recordModel(currentServerModel, persistImmediately);
};
updateModelFromServerProps(false);
}
slotsService.startStreaming();
slotsService.setActiveConversation(assistantMessage.convId);
@@ -396,6 +443,9 @@ class ChatStore {
{
...this.getApiOptions(),
onFirstValidChunk: () => {
refreshServerPropsOnce();
},
onChunk: (chunk: string) => {
streamedContent += chunk;
this.setConversationStreaming(
@@ -52,6 +52,7 @@ class ServerStore {
private _error = $state<string | null>(null);
private _serverWarning = $state<string | null>(null);
private _slotsEndpointAvailable = $state<boolean | null>(null);
private fetchServerPropsPromise: Promise<void> | null = null;
private readCachedServerProps(): ApiLlamaCppServerProps | null {
if (!browser) return null;
@@ -98,6 +99,9 @@ class ServerStore {
}
get modelName(): string | null {
if (this._serverProps?.model_alias) {
return this._serverProps.model_alias;
}
if (!this._serverProps?.model_path) return null;
return this._serverProps.model_path.split(/(\\|\/)/).pop() || null;
}
@@ -171,73 +175,65 @@ class ServerStore {
/**
* Fetches server properties from the server
*/
async fetchServerProps(): Promise<void> {
this._loading = true;
this._error = null;
this._serverWarning = null;
async fetchServerProps(options: { silent?: boolean } = {}): Promise<void> {
const { silent = false } = options;
const isSilent = silent && this._serverProps !== null;
try {
console.log('Fetching server properties...');
const props = await ChatService.getServerProps();
this._serverProps = props;
this.persistServerProps(props);
console.log('Server properties loaded:', props);
if (this.fetchServerPropsPromise) {
return this.fetchServerPropsPromise;
}
// Check slots endpoint availability after server props are loaded
await this.checkSlotsEndpointAvailability();
} catch (error) {
const hadCachedProps = this._serverProps !== null;
let errorMessage = 'Failed to connect to server';
let isOfflineLikeError = false;
let isServerSideError = false;
if (!isSilent) {
this._loading = true;
this._error = null;
this._serverWarning = null;
}
if (error instanceof Error) {
// Handle specific error types with user-friendly messages
if (error.name === 'TypeError' && error.message.includes('fetch')) {
errorMessage = 'Server is not running or unreachable';
isOfflineLikeError = true;
} else if (error.message.includes('ECONNREFUSED')) {
errorMessage = 'Connection refused - server may be offline';
isOfflineLikeError = true;
} else if (error.message.includes('ENOTFOUND')) {
errorMessage = 'Server not found - check server address';
isOfflineLikeError = true;
} else if (error.message.includes('ETIMEDOUT')) {
errorMessage = 'Request timed out - the server took too long to respond';
isOfflineLikeError = true;
} else if (error.message.includes('503')) {
errorMessage = 'Server temporarily unavailable - try again shortly';
isServerSideError = true;
} else if (error.message.includes('500')) {
errorMessage = 'Server error - check server logs';
isServerSideError = true;
} else if (error.message.includes('404')) {
errorMessage = 'Server endpoint not found';
} else if (error.message.includes('403') || error.message.includes('401')) {
errorMessage = 'Access denied';
const hadProps = this._serverProps !== null;
const fetchPromise = (async () => {
try {
const props = await ChatService.getServerProps();
this._serverProps = props;
this.persistServerProps(props);
this._error = null;
this._serverWarning = null;
await this.checkSlotsEndpointAvailability();
} catch (error) {
if (isSilent && hadProps) {
console.warn('Silent server props refresh failed, keeping cached data:', error);
return;
}
this.handleFetchServerPropsError(error, hadProps);
} finally {
if (!isSilent) {
this._loading = false;
}
this.fetchServerPropsPromise = null;
}
})();
let cachedProps: ApiLlamaCppServerProps | null = null;
this.fetchServerPropsPromise = fetchPromise;
if (!hadCachedProps) {
cachedProps = this.readCachedServerProps();
if (cachedProps) {
this._serverProps = cachedProps;
this._error = null;
await fetchPromise;
}
if (isOfflineLikeError || isServerSideError) {
this._serverWarning = errorMessage;
}
/**
* Handles fetch failures by attempting to recover cached server props and
* updating the user-facing error or warning state appropriately.
*/
private handleFetchServerPropsError(error: unknown, hadProps: boolean): void {
const { errorMessage, isOfflineLikeError, isServerSideError } = this.normalizeFetchError(error);
console.warn(
'Failed to refresh server properties, using cached values from localStorage:',
errorMessage
);
} else {
this._error = errorMessage;
}
} else {
let cachedProps: ApiLlamaCppServerProps | null = null;
if (!hadProps) {
cachedProps = this.readCachedServerProps();
if (cachedProps) {
this._serverProps = cachedProps;
this._error = null;
if (isOfflineLikeError || isServerSideError) {
@@ -245,14 +241,66 @@ class ServerStore {
}
console.warn(
'Failed to refresh server properties, continuing with cached values:',
'Failed to refresh server properties, using cached values from localStorage:',
errorMessage
);
} else {
this._error = errorMessage;
}
console.error('Error fetching server properties:', error);
} finally {
this._loading = false;
} else {
this._error = null;
if (isOfflineLikeError || isServerSideError) {
this._serverWarning = errorMessage;
}
console.warn(
'Failed to refresh server properties, continuing with cached values:',
errorMessage
);
}
console.error('Error fetching server properties:', error);
}
private normalizeFetchError(error: unknown): {
errorMessage: string;
isOfflineLikeError: boolean;
isServerSideError: boolean;
} {
let errorMessage = 'Failed to connect to server';
let isOfflineLikeError = false;
let isServerSideError = false;
if (error instanceof Error) {
const message = error.message || '';
if (error.name === 'TypeError' && message.includes('fetch')) {
errorMessage = 'Server is not running or unreachable';
isOfflineLikeError = true;
} else if (message.includes('ECONNREFUSED')) {
errorMessage = 'Connection refused - server may be offline';
isOfflineLikeError = true;
} else if (message.includes('ENOTFOUND')) {
errorMessage = 'Server not found - check server address';
isOfflineLikeError = true;
} else if (message.includes('ETIMEDOUT')) {
errorMessage = 'Request timed out - the server took too long to respond';
isOfflineLikeError = true;
} else if (message.includes('503')) {
errorMessage = 'Server temporarily unavailable - try again shortly';
isServerSideError = true;
} else if (message.includes('500')) {
errorMessage = 'Server error - check server logs';
isServerSideError = true;
} else if (message.includes('404')) {
errorMessage = 'Server endpoint not found';
} else if (message.includes('403') || message.includes('401')) {
errorMessage = 'Access denied';
}
}
return { errorMessage, isOfflineLikeError, isServerSideError };
}
/**
@@ -264,6 +312,7 @@ class ServerStore {
this._serverWarning = null;
this._loading = false;
this._slotsEndpointAvailable = null;
this.fetchServerPropsPromise = null;
this.persistServerProps(null);
}
}
+1
View File
@@ -186,6 +186,7 @@ export interface ApiChatCompletionRequest {
}
export interface ApiChatCompletionStreamChunk {
object?: string;
model?: string;
choices: Array<{
model?: string;
+1
View File
@@ -42,6 +42,7 @@ export interface SettingsChatServiceOptions {
onChunk?: (chunk: string) => void;
onReasoningChunk?: (chunk: string) => void;
onModel?: (model: string) => void;
onFirstValidChunk?: () => void;
onComplete?: (response: string, reasoningContent?: string, timings?: ChatMessageTimings) => void;
onError?: (error: Error) => void;
}
@@ -0,0 +1,355 @@
/* eslint-disable no-irregular-whitespace */
import { describe, it, expect, test } from 'vitest';
import { maskInlineLaTeX, preprocessLaTeX } from './latex-protection';
describe('maskInlineLaTeX', () => {
it('should protect LaTeX $x + y$ but not money $3.99', () => {
const latexExpressions: string[] = [];
const input = 'I have $10, $3.99 and $x + y$ and $100x$. The amount is $2,000.';
const output = maskInlineLaTeX(input, latexExpressions);
expect(output).toBe('I have $10, $3.99 and <<LATEX_0>> and <<LATEX_1>>. The amount is $2,000.');
expect(latexExpressions).toEqual(['$x + y$', '$100x$']);
});
it('should ignore money like $5 and $12.99', () => {
const latexExpressions: string[] = [];
const input = 'Prices are $12.99 and $5. Tax?';
const output = maskInlineLaTeX(input, latexExpressions);
expect(output).toBe('Prices are $12.99 and $5. Tax?');
expect(latexExpressions).toEqual([]);
});
it('should protect inline math $a^2 + b^2$ even after text', () => {
const latexExpressions: string[] = [];
const input = 'Pythagorean: $a^2 + b^2 = c^2$.';
const output = maskInlineLaTeX(input, latexExpressions);
expect(output).toBe('Pythagorean: <<LATEX_0>>.');
expect(latexExpressions).toEqual(['$a^2 + b^2 = c^2$']);
});
it('should not protect math that has letter after closing $ (e.g. units)', () => {
const latexExpressions: string[] = [];
const input = 'The cost is $99 and change.';
const output = maskInlineLaTeX(input, latexExpressions);
expect(output).toBe('The cost is $99 and change.');
expect(latexExpressions).toEqual([]);
});
it('should allow $x$ followed by punctuation', () => {
const latexExpressions: string[] = [];
const input = 'We know $x$, right?';
const output = maskInlineLaTeX(input, latexExpressions);
expect(output).toBe('We know <<LATEX_0>>, right?');
expect(latexExpressions).toEqual(['$x$']);
});
it('should work across multiple lines', () => {
const latexExpressions: string[] = [];
const input = `Emma buys cupcakes for $3 each.\nHow much is $x + y$?`;
const output = maskInlineLaTeX(input, latexExpressions);
expect(output).toBe(`Emma buys cupcakes for $3 each.\nHow much is <<LATEX_0>>?`);
expect(latexExpressions).toEqual(['$x + y$']);
});
it('should not protect $100 but protect $matrix$', () => {
const latexExpressions: string[] = [];
const input = '$100 and $\\mathrm{GL}_2(\\mathbb{F}_7)$ are different.';
const output = maskInlineLaTeX(input, latexExpressions);
expect(output).toBe('$100 and <<LATEX_0>> are different.');
expect(latexExpressions).toEqual(['$\\mathrm{GL}_2(\\mathbb{F}_7)$']);
});
it('should skip if $ is followed by digit and alphanumeric after close (money)', () => {
const latexExpressions: string[] = [];
const input = 'I paid $5 quickly.';
const output = maskInlineLaTeX(input, latexExpressions);
expect(output).toBe('I paid $5 quickly.');
expect(latexExpressions).toEqual([]);
});
it('should protect LaTeX even with special chars inside', () => {
const latexExpressions: string[] = [];
const input = 'Consider $\\alpha_1 + \\beta_2$ now.';
const output = maskInlineLaTeX(input, latexExpressions);
expect(output).toBe('Consider <<LATEX_0>> now.');
expect(latexExpressions).toEqual(['$\\alpha_1 + \\beta_2$']);
});
it('short text', () => {
const latexExpressions: string[] = ['$0$'];
const input = '$a$\n$a$ and $b$';
const output = maskInlineLaTeX(input, latexExpressions);
expect(output).toBe('<<LATEX_1>>\n<<LATEX_2>> and <<LATEX_3>>');
expect(latexExpressions).toEqual(['$0$', '$a$', '$a$', '$b$']);
});
it('empty text', () => {
const latexExpressions: string[] = [];
const input = '$\n$$\n';
const output = maskInlineLaTeX(input, latexExpressions);
expect(output).toBe('$\n$$\n');
expect(latexExpressions).toEqual([]);
});
it('LaTeX-spacer preceded by backslash', () => {
const latexExpressions: string[] = [];
const input = `\\[
\\boxed{
\\begin{aligned}
N_{\\text{att}}^{\\text{(MHA)}} &=
h \\bigl[\\, d_{\\text{model}}\\;d_{k} + d_{\\text{model}}\\;d_{v}\\, \\bigr] && (\\text{Q,K,V })\\\\
&\\quad+ h(d_{k}+d_{k}+d_{v}) && (\\text{ Q,K,V}\\\\[4pt]
&\\quad+ (h d_{v})\\, d_{\\text{model}} && (\\text{ }W^{O})\\\\
&\\quad+ d_{\\text{model}} && (\\text{ }b^{O})
\\end{aligned}}
\\]`;
const output = maskInlineLaTeX(input, latexExpressions);
expect(output).toBe(input);
expect(latexExpressions).toEqual([]);
});
});
describe('preprocessLaTeX', () => {
test('converts inline \\( ... \\) to $...$', () => {
const input =
'\\( \\mathrm{GL}_2(\\mathbb{F}_7) \\): Group of invertible matrices with entries in \\(\\mathbb{F}_7\\).';
const output = preprocessLaTeX(input);
expect(output).toBe(
'$ \\mathrm{GL}_2(\\mathbb{F}_7) $: Group of invertible matrices with entries in $\\mathbb{F}_7$.'
);
});
test("don't inline \\\\( ... \\) to $...$", () => {
const input =
'Chapter 20 of The TeXbook, in source "Definitions\\\\(also called Macros)", containst the formula \\((x_1,\\ldots,x_n)\\).';
const output = preprocessLaTeX(input);
expect(output).toBe(
'Chapter 20 of The TeXbook, in source "Definitions\\\\(also called Macros)", containst the formula $(x_1,\\ldots,x_n)$.'
);
});
test('preserves display math \\[ ... \\] and protects adjacent text', () => {
const input = `Some kernel of \\(\\mathrm{SL}_2(\\mathbb{F}_7)\\):
\\[
\\left\\{ \\begin{pmatrix} 1 & 0 \\\\ 0 & 1 \\end{pmatrix}, \\begin{pmatrix} -1 & 0 \\\\ 0 & -1 \\end{pmatrix} \\right\\} = \\{\\pm I\\}
\\]`;
const output = preprocessLaTeX(input);
expect(output).toBe(`Some kernel of $\\mathrm{SL}_2(\\mathbb{F}_7)$:
$$
\\left\\{ \\begin{pmatrix} 1 & 0 \\\\ 0 & 1 \\end{pmatrix}, \\begin{pmatrix} -1 & 0 \\\\ 0 & -1 \\end{pmatrix} \\right\\} = \\{\\pm I\\}
$$`);
});
test('handles standalone display math equation', () => {
const input = `Algebra:
\\[
x = \\frac{-b \\pm \\sqrt{\\,b^{2}-4ac\\,}}{2a}
\\]`;
const output = preprocessLaTeX(input);
expect(output).toBe(`Algebra:
$$
x = \\frac{-b \\pm \\sqrt{\\,b^{2}-4ac\\,}}{2a}
$$`);
});
test('does not interpret currency values as LaTeX', () => {
const input = 'I have $10, $3.99 and $x + y$ and $100x$. The amount is $2,000.';
const output = preprocessLaTeX(input);
expect(output).toBe('I have \\$10, \\$3.99 and $x + y$ and $100x$. The amount is \\$2,000.');
});
test('ignores dollar signs followed by digits (money), but keeps valid math $x + y$', () => {
const input = 'I have $10, $3.99 and $x + y$ and $100x$. The amount is $2,000.';
const output = preprocessLaTeX(input);
expect(output).toBe('I have \\$10, \\$3.99 and $x + y$ and $100x$. The amount is \\$2,000.');
});
test('handles real-world word problems with amounts and no math delimiters', () => {
const input =
'Emma buys 2 cupcakes for $3 each and 1 cookie for $1.50. How much money does she spend in total?';
const output = preprocessLaTeX(input);
expect(output).toBe(
'Emma buys 2 cupcakes for \\$3 each and 1 cookie for \\$1.50. How much money does she spend in total?'
);
});
test('handles decimal amounts in word problem correctly', () => {
const input =
'Maria has $20. She buys a notebook for $4.75 and a pack of pencils for $3.25. How much change does she receive?';
const output = preprocessLaTeX(input);
expect(output).toBe(
'Maria has \\$20. She buys a notebook for \\$4.75 and a pack of pencils for \\$3.25. How much change does she receive?'
);
});
test('preserves display math with surrounding non-ASCII text', () => {
const input = `1kg の質量は
\\[
E = (1\\ \\text{kg}) \\times (3.0 \\times 10^8\\ \\text{m/s})^2 \\approx 9.0 \\times 10^{16}\\ \\text{J}
\\]
21 TNT `;
const output = preprocessLaTeX(input);
expect(output).toBe(
`1kg の質量は
$$
E = (1\\ \\text{kg}) \\times (3.0 \\times 10^8\\ \\text{m/s})^2 \\approx 9.0 \\times 10^{16}\\ \\text{J}
$$
21 TNT `
);
});
test('LaTeX-spacer preceded by backslash', () => {
const input = `\\[
\\boxed{
\\begin{aligned}
N_{\\text{att}}^{\\text{(MHA)}} &=
h \\bigl[\\, d_{\\text{model}}\\;d_{k} + d_{\\text{model}}\\;d_{v}\\, \\bigr] && (\\text{Q,K,V })\\\\
&\\quad+ h(d_{k}+d_{k}+d_{v}) && (\\text{ Q,K,V}\\\\[4pt]
&\\quad+ (h d_{v})\\, d_{\\text{model}} && (\\text{ }W^{O})\\\\
&\\quad+ d_{\\text{model}} && (\\text{ }b^{O})
\\end{aligned}}
\\]`;
const output = preprocessLaTeX(input);
expect(output).toBe(
`$$
\\boxed{
\\begin{aligned}
N_{\\text{att}}^{\\text{(MHA)}} &=
h \\bigl[\\, d_{\\text{model}}\\;d_{k} + d_{\\text{model}}\\;d_{v}\\, \\bigr] && (\\text{Q,K,V })\\\\
&\\quad+ h(d_{k}+d_{k}+d_{v}) && (\\text{ Q,K,V}\\\\[4pt]
&\\quad+ (h d_{v})\\, d_{\\text{model}} && (\\text{ }W^{O})\\\\
&\\quad+ d_{\\text{model}} && (\\text{ }b^{O})
\\end{aligned}}
$$`
);
});
test('converts \\[ ... \\] even when preceded by text without space', () => {
const input = 'Some line ...\nAlgebra: \\[x = \\frac{-b \\pm \\sqrt{\\,b^{2}-4ac\\,}}{2a}\\]';
const output = preprocessLaTeX(input);
expect(output).toBe(
'Some line ...\nAlgebra: \n$$x = \\frac{-b \\pm \\sqrt{\\,b^{2}-4ac\\,}}{2a}$$\n'
);
});
test('converts \\[ ... \\] in table-cells', () => {
const input = `| ID | Expression |\n| #1 | \\[
x = \\frac{-b \\pm \\sqrt{\\,b^{2}-4ac\\,}}{2a}
\\] |`;
const output = preprocessLaTeX(input);
expect(output).toBe(
'| ID | Expression |\n| #1 | $x = \\frac{-b \\pm \\sqrt{\\,b^{2}-4ac\\,}}{2a}$ |'
);
});
test('escapes isolated $ before digits ($5 → \\$5), but not valid math', () => {
const input = 'This costs $5 and this is math $x^2$. $100 is money.';
const output = preprocessLaTeX(input);
expect(output).toBe('This costs \\$5 and this is math $x^2$. \\$100 is money.');
// Note: Since $x^2$ is detected as valid LaTeX, it's preserved.
// $5 becomes \$5 only *after* real math is masked — but here it's correct because the masking logic avoids treating $5 as math.
});
test('display with LaTeX-line-breaks', () => {
const input = String.raw`- Algebraic topology, Homotopy Groups of $\mathbb{S}^3$:
$$\pi_n(\mathbb{S}^3) = \begin{cases}
\mathbb{Z} & n = 3 \\
0 & n > 3, n \neq 4 \\
\mathbb{Z}_2 & n = 4 \\
\end{cases}$$`;
const output = preprocessLaTeX(input);
// If the formula contains '\\' the $$-delimiters should be in their own line.
expect(output).toBe(`- Algebraic topology, Homotopy Groups of $\\mathbb{S}^3$:
$$\n\\pi_n(\\mathbb{S}^3) = \\begin{cases}
\\mathbb{Z} & n = 3 \\\\
0 & n > 3, n \\neq 4 \\\\
\\mathbb{Z}_2 & n = 4 \\\\
\\end{cases}\n$$`);
});
test('handles mhchem notation safely if present', () => {
const input = 'Chemical reaction: \\( \\ce{H2O} \\) and $\\ce{CO2}$';
const output = preprocessLaTeX(input);
expect(output).toBe('Chemical reaction: $ \\ce{H2O} $ and $\\ce{CO2}$');
});
test('preserves code blocks', () => {
const input = 'Inline code: `sum $total` and block:\n```\ndollar $amount\n```\nEnd.';
const output = preprocessLaTeX(input);
expect(output).toBe(input); // Code blocks prevent misinterpretation
});
test('escape backslash in mchem ce', () => {
const input = 'mchem ce:\n$\\ce{2H2(g) + O2(g) -> 2H2O(l)}$';
const output = preprocessLaTeX(input);
// mhchem-escape would insert a backslash here.
expect(output).toBe('mchem ce:\n$\\ce{2H2(g) + O2(g) -> 2H2O(l)}$');
});
test('escape backslash in mchem pu', () => {
const input = 'mchem pu:\n$\\pu{-572 kJ mol^{-1}}$';
const output = preprocessLaTeX(input);
// mhchem-escape would insert a backslash here.
expect(output).toBe('mchem pu:\n$\\pu{-572 kJ mol^{-1}}$');
});
test('LaTeX in blockquotes with display math', () => {
const input =
'> **Definition (limit):** \n> \\[\n> \\lim_{x\\to a} f(x) = L\n> \\]\n> means that as \\(x\\) gets close to \\(a\\).';
const output = preprocessLaTeX(input);
// Blockquote markers should be preserved, LaTeX should be converted
expect(output).toContain('> **Definition (limit):**');
expect(output).toContain('$$');
expect(output).toContain('$x$');
expect(output).not.toContain('\\[');
expect(output).not.toContain('\\]');
expect(output).not.toContain('\\(');
expect(output).not.toContain('\\)');
});
test('LaTeX in blockquotes with inline math', () => {
const input =
"> The derivative \\(f'(x)\\) at point \\(x=a\\) measures slope.\n> Formula: \\(f'(a)=\\lim_{h\\to 0}\\frac{f(a+h)-f(a)}{h}\\)";
const output = preprocessLaTeX(input);
// Blockquote markers should be preserved, inline LaTeX converted to $...$
expect(output).toContain("> The derivative $f'(x)$ at point $x=a$ measures slope.");
expect(output).toContain("> Formula: $f'(a)=\\lim_{h\\to 0}\\frac{f(a+h)-f(a)}{h}$");
});
test('Mixed content with blockquotes and regular text', () => {
const input =
'Regular text with \\(x^2\\).\n\n> Quote with \\(y^2\\).\n\nMore text with \\(z^2\\).';
const output = preprocessLaTeX(input);
// All LaTeX should be converted, blockquote markers preserved
expect(output).toBe('Regular text with $x^2$.\n\n> Quote with $y^2$.\n\nMore text with $z^2$.');
});
});
@@ -0,0 +1,267 @@
import {
CODE_BLOCK_REGEXP,
LATEX_MATH_AND_CODE_PATTERN,
LATEX_LINEBREAK_REGEXP,
MHCHEM_PATTERN_MAP
} from '$lib/constants/latex-protection';
/**
* Replaces inline LaTeX expressions enclosed in `$...$` with placeholders, avoiding dollar signs
* that appear to be part of monetary values or identifiers.
*
* This function processes the input line by line and skips `$` sequences that are likely
* part of money amounts (e.g., `$5`, `$100.99`) or code-like tokens (e.g., `var$`, `$var`).
* Valid LaTeX inline math is replaced with a placeholder like `<<LATEX_0>>`, and the
* actual LaTeX content is stored in the provided `latexExpressions` array.
*
* @param content - The input text potentially containing LaTeX expressions.
* @param latexExpressions - An array used to collect extracted LaTeX expressions.
* @returns The processed string with LaTeX replaced by placeholders.
*/
export function maskInlineLaTeX(content: string, latexExpressions: string[]): string {
if (!content.includes('$')) {
return content;
}
return content
.split('\n')
.map((line) => {
if (line.indexOf('$') == -1) {
return line;
}
let processedLine = '';
let currentPosition = 0;
while (currentPosition < line.length) {
const openDollarIndex = line.indexOf('$', currentPosition);
if (openDollarIndex == -1) {
processedLine += line.slice(currentPosition);
break;
}
// Is there a next $-sign?
const closeDollarIndex = line.indexOf('$', openDollarIndex + 1);
if (closeDollarIndex == -1) {
processedLine += line.slice(currentPosition);
break;
}
const charBeforeOpen = openDollarIndex > 0 ? line[openDollarIndex - 1] : '';
const charAfterOpen = line[openDollarIndex + 1];
const charBeforeClose =
openDollarIndex + 1 < closeDollarIndex ? line[closeDollarIndex - 1] : '';
const charAfterClose = closeDollarIndex + 1 < line.length ? line[closeDollarIndex + 1] : '';
let shouldSkipAsNonLatex = false;
if (closeDollarIndex == currentPosition + 1) {
// No content
shouldSkipAsNonLatex = true;
}
if (/[A-Za-z0-9_$-]/.test(charBeforeOpen)) {
// Character, digit, $, _ or - before first '$', no TeX.
shouldSkipAsNonLatex = true;
}
if (
/[0-9]/.test(charAfterOpen) &&
(/[A-Za-z0-9_$-]/.test(charAfterClose) || ' ' == charBeforeClose)
) {
// First $ seems to belong to an amount.
shouldSkipAsNonLatex = true;
}
if (shouldSkipAsNonLatex) {
processedLine += line.slice(currentPosition, openDollarIndex + 1);
currentPosition = openDollarIndex + 1;
continue;
}
// Treat as LaTeX
processedLine += line.slice(currentPosition, openDollarIndex);
const latexContent = line.slice(openDollarIndex, closeDollarIndex + 1);
latexExpressions.push(latexContent);
processedLine += `<<LATEX_${latexExpressions.length - 1}>>`;
currentPosition = closeDollarIndex + 1;
}
return processedLine;
})
.join('\n');
}
function escapeBrackets(text: string): string {
return text.replace(
LATEX_MATH_AND_CODE_PATTERN,
(
match: string,
codeBlock: string | undefined,
squareBracket: string | undefined,
roundBracket: string | undefined
): string => {
if (codeBlock != null) {
return codeBlock;
} else if (squareBracket != null) {
return `$$${squareBracket}$$`;
} else if (roundBracket != null) {
return `$${roundBracket}$`;
}
return match;
}
);
}
// Escape $\\ce{...} → $\\ce{...} but with proper handling
function escapeMhchem(text: string): string {
return MHCHEM_PATTERN_MAP.reduce((result, [pattern, replacement]) => {
return result.replace(pattern, replacement);
}, text);
}
const doEscapeMhchem = false;
/**
* Preprocesses markdown content to safely handle LaTeX math expressions while protecting
* against false positives (e.g., dollar amounts like $5.99) and ensuring proper rendering.
*
* This function:
* - Protects code blocks (```) and inline code (`...`)
* - Safeguards block and inline LaTeX: \(...\), \[...\], $$...$$, and selective $...$
* - Escapes standalone dollar signs before numbers (e.g., $5 \$5) to prevent misinterpretation
* - Restores protected LaTeX and code blocks after processing
* - Converts \(...\) $...$ and \[...\] $$...$$ for compatibility with math renderers
* - Applies additional escaping for brackets and mhchem syntax if needed
*
* @param content - The raw text (e.g., markdown) that may contain LaTeX or code blocks.
* @returns The preprocessed string with properly escaped and normalized LaTeX.
*
* @example
* preprocessLaTeX("Price: $10. The equation is \\(x^2\\).")
* // → "Price: $10. The equation is $x^2$."
*/
export function preprocessLaTeX(content: string): string {
// See also:
// https://github.com/danny-avila/LibreChat/blob/main/client/src/utils/latex.ts
// Step 0: Temporarily remove blockquote markers (>) to process LaTeX correctly
// Store the structure so we can restore it later
const blockquoteMarkers: Map<number, string> = new Map();
const lines = content.split('\n');
const processedLines = lines.map((line, index) => {
const match = line.match(/^(>\s*)/);
if (match) {
blockquoteMarkers.set(index, match[1]);
return line.slice(match[1].length);
}
return line;
});
content = processedLines.join('\n');
// Step 1: Protect code blocks
const codeBlocks: string[] = [];
content = content.replace(CODE_BLOCK_REGEXP, (match) => {
codeBlocks.push(match);
return `<<CODE_BLOCK_${codeBlocks.length - 1}>>`;
});
// Step 2: Protect existing LaTeX expressions
const latexExpressions: string[] = [];
// Match \S...\[...\] and protect them and insert a line-break.
content = content.replace(/([\S].*?)\\\[([\s\S]*?)\\\](.*)/g, (match, group1, group2, group3) => {
// Check if there are characters following the formula (display-formula in a table-cell?)
if (group1.endsWith('\\')) {
return match; // Backslash before \[, do nothing.
}
const hasSuffix = /\S/.test(group3);
let optBreak;
if (hasSuffix) {
latexExpressions.push(`\\(${group2.trim()}\\)`); // Convert into inline.
optBreak = '';
} else {
latexExpressions.push(`\\[${group2}\\]`);
optBreak = '\n';
}
return `${group1}${optBreak}<<LATEX_${latexExpressions.length - 1}>>${optBreak}${group3}`;
});
// Match \(...\), \[...\], $$...$$ and protect them
content = content.replace(
/(\$\$[\s\S]*?\$\$|(?<!\\)\\\[[\s\S]*?\\\]|(?<!\\)\\\(.*?\\\))/g,
(match) => {
latexExpressions.push(match);
return `<<LATEX_${latexExpressions.length - 1}>>`;
}
);
// Protect inline $...$ but NOT if it looks like money (e.g., $10, $3.99)
content = maskInlineLaTeX(content, latexExpressions);
// Step 3: Escape standalone $ before digits (currency like $5 → \$5)
// (Now that inline math is protected, this will only escape dollars not already protected)
content = content.replace(/\$(?=\d)/g, '\\$');
// Step 4: Restore protected LaTeX expressions (they are valid)
content = content.replace(/<<LATEX_(\d+)>>/g, (_, index) => {
let expr = latexExpressions[parseInt(index)];
const match = expr.match(LATEX_LINEBREAK_REGEXP);
if (match) {
// Katex: The $$-delimiters should be in their own line
// if there are \\-line-breaks.
const formula = match[1];
const prefix = formula.startsWith('\n') ? '' : '\n';
const suffix = formula.endsWith('\n') ? '' : '\n';
expr = '$$' + prefix + formula + suffix + '$$';
}
return expr;
});
// Step 5: Restore code blocks
content = content.replace(/<<CODE_BLOCK_(\d+)>>/g, (_, index) => {
return codeBlocks[parseInt(index)];
});
// Step 6: Apply additional escaping functions (brackets and mhchem)
content = escapeBrackets(content);
if (doEscapeMhchem && (content.includes('\\ce{') || content.includes('\\pu{'))) {
content = escapeMhchem(content);
}
// Final pass: Convert \(...\) → $...$, \[...\] → $$...$$
content = content
// Using the lookbehind pattern `(?<!\\)` we skip matches
// that are preceded by a backslash, e.g.
// `Definitions\\(also called macros)` (title of chapter 20 in The TeXbook).
.replace(/(?<!\\)\\\((.+?)\\\)/g, '$$$1$') // inline
.replace(
// Using the lookbehind pattern `(?<!\\)` we skip matches
// that are preceded by a backslash, e.g. `\\[4pt]`.
/(?<!\\)\\\[([\s\S]*?)\\\]/g, // display, see also PR #16599
(_, prefix: string, content: string) => {
return `${prefix}$$${content}$$`;
}
);
// Step 7: Restore blockquote markers
if (blockquoteMarkers.size > 0) {
const finalLines = content.split('\n');
const restoredLines = finalLines.map((line, index) => {
const marker = blockquoteMarkers.get(index);
return marker ? marker + line : line;
});
content = restoredLines.join('\n');
}
return content;
}
@@ -1,3 +1,4 @@
/* eslint-disable no-irregular-whitespace */
// Math Formulas Content
export const MATH_FORMULAS_MD = String.raw`
# Mathematical Formulas and Expressions
@@ -150,6 +151,70 @@ $$\lim_{x \to 0} \frac{\sin x}{x} = 1$$
$$\lim_{n \to \infty} \left(1 + \frac{x}{n}\right)^n = e^x$$
## Further Bracket Styles and Amounts
- \( \mathrm{GL}_2(\mathbb{F}_7) \): Group of invertible matrices with entries in \(\mathbb{F}_7\).
- Some kernel of \(\mathrm{SL}_2(\mathbb{F}_7)\):
\[
\left\{ \begin{pmatrix} 1 & 0 \\ 0 & 1 \end{pmatrix}, \begin{pmatrix} -1 & 0 \\ 0 & -1 \end{pmatrix} \right\} = \{\pm I\}
\]
- Algebra:
\[
x = \frac{-b \pm \sqrt{\,b^{2}-4ac\,}}{2a}
\]
- $100 and $12.99 are amounts, not LaTeX.
- I have $10, $3.99 and $x + y$ and $100x$. The amount is $2,000.
- Emma buys 2 cupcakes for $3 each and 1 cookie for $1.50. How much money does she spend in total?
- Maria has $20. She buys a notebook for $4.75 and a pack of pencils for $3.25. How much change does she receive?
- 1kg
\[
E = (1\ \text{kg}) \times (3.0 \times 10^8\ \text{m/s})^2 \approx 9.0 \times 10^{16}\ \text{J}
\]
21 TNT
- Algebra: \[
x = \frac{-b \pm \sqrt{\,b^{2}-4ac\,}}{2a}
\]
- Algebraic topology, Homotopy Groups of $\mathbb{S}^3$:
$$\pi_n(\mathbb{S}^3) = \begin{cases}
\mathbb{Z} & n = 3 \\
0 & n > 3, n \neq 4 \\
\mathbb{Z}_2 & n = 4 \\
\end{cases}$$
- Spacer preceded by backslash:
\[
\boxed{
\begin{aligned}
N_{\text{att}}^{\text{(MHA)}} &=
h \bigl[\, d_{\text{model}}\;d_{k} + d_{\text{model}}\;d_{v}\, \bigr] && (\text{Q,K,V })\\
&\quad+ h(d_{k}+d_{k}+d_{v}) && (\text{ Q,K,V}\\[4pt]
&\quad+ (h d_{v})\, d_{\text{model}} && (\text{ }W^{O})\\
&\quad+ d_{\text{model}} && (\text{ }b^{O})
\end{aligned}}
\]
## Formulas in a Table
| Area | Expression | Comment |
|------|------------|---------|
| **Algebra** | \[
x = \frac{-b \pm \sqrt{\,b^{2}-4ac\,}}{2a}
\] | Quadratic formula |
| | \[
(a+b)^{n} = \sum_{k=0}^{n}\binom{n}{k}\,a^{\,n-k}\,b^{\,k}
\] | Binomial theorem |
| | \(\displaystyle \prod_{k=1}^{n}k = n! \) | Factorial definition |
| **Geometry** | \( \mathbf{a}\cdot \mathbf{b} = \|\mathbf{a}\|\,\|\mathbf{b}\|\,\cos\theta \) | Dot product & angle |
## No math (but chemical)
Balanced chemical reaction with states:
\[
\ce{2H2(g) + O2(g) -> 2H2O(l)}
\]
The standard enthalpy change for the reaction is: $\Delta H^\circ = \pu{-572 kJ mol^{-1}}$.
---
*This document showcases various mathematical notation and formulas that can be rendered in markdown using LaTeX syntax.*
@@ -0,0 +1,13 @@
// Override KaTeX SCSS variables to disable ttf and woff fonts
// Only use woff2 format which is embedded in the bundle
$use-woff2: true;
$use-woff: false;
$use-ttf: false;
// Use Vite alias for font folder
$font-folder: 'katex-fonts';
// Import KaTeX SCSS with overridden variables
// Note: @import is deprecated but required because KaTeX uses @import internally
// The deprecation warnings are from KaTeX's code and cannot be avoided
@import 'katex/src/styles/katex.scss';
+3
View File
@@ -22,6 +22,9 @@ const config = {
}),
output: {
bundleStrategy: 'inline'
},
alias: {
$styles: 'src/styles'
}
},
+29 -4
View File
@@ -18,6 +18,15 @@ const GUIDE_FOR_FRONTEND = `
const MAX_BUNDLE_SIZE = 2 * 1024 * 1024;
/**
* the maximum size of an embedded asset in bytes,
* e.g. maximum size of embedded font (see node_modules/katex/dist/fonts/*.woff2)
*/
const MAX_ASSET_SIZE = 32000;
/** public/index.html.gz minified flag */
const ENABLE_JS_MINIFICATION = true;
function llamaCppBuildPlugin() {
return {
name: 'llamacpp:build',
@@ -75,12 +84,28 @@ function llamaCppBuildPlugin() {
}
export default defineConfig({
build: {
chunkSizeWarningLimit: 3072
resolve: {
alias: {
'katex-fonts': resolve('node_modules/katex/dist/fonts')
}
},
build: {
assetsInlineLimit: MAX_ASSET_SIZE,
chunkSizeWarningLimit: 3072,
minify: ENABLE_JS_MINIFICATION
},
css: {
preprocessorOptions: {
scss: {
additionalData: `
$use-woff2: true;
$use-woff: false;
$use-ttf: false;
`
}
}
},
plugins: [tailwindcss(), sveltekit(), devtoolsJson(), llamaCppBuildPlugin()],
test: {
projects: [
{
+209 -55
View File
@@ -8,8 +8,8 @@
#ifndef CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_VERSION "0.26.0"
#define CPPHTTPLIB_VERSION_NUM "0x001A00"
#define CPPHTTPLIB_VERSION "0.27.0"
#define CPPHTTPLIB_VERSION_NUM "0x001B00"
/*
* Platform compatibility check
@@ -1052,6 +1052,9 @@ private:
ssize_t write_headers(Stream &strm, const Headers &headers);
std::string make_host_and_port_string(const std::string &host, int port,
bool is_ssl);
} // namespace detail
class Server {
@@ -1129,6 +1132,8 @@ public:
Server &
set_header_writer(std::function<ssize_t(Stream &, Headers &)> const &writer);
Server &set_trusted_proxies(const std::vector<std::string> &proxies);
Server &set_keep_alive_max_count(size_t count);
Server &set_keep_alive_timeout(time_t sec);
@@ -1167,6 +1172,9 @@ protected:
const std::function<void(Request &)> &setup_request);
std::atomic<socket_t> svr_sock_{INVALID_SOCKET};
std::vector<std::string> trusted_proxies_;
size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT;
time_t keep_alive_timeout_sec_ = CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND;
time_t read_timeout_sec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND;
@@ -1719,8 +1727,6 @@ private:
const std::string &boundary, const UploadFormDataItems &items,
const FormDataProviderItems &provider_items) const;
std::string adjust_host_string(const std::string &host) const;
virtual bool
process_socket(const Socket &socket,
std::chrono::time_point<std::chrono::steady_clock> start_time,
@@ -1953,14 +1959,17 @@ public:
void update_certs(X509 *cert, EVP_PKEY *private_key,
X509_STORE *client_ca_cert_store = nullptr);
int ssl_last_error() const { return last_ssl_error_; }
private:
bool process_and_close_socket(socket_t sock) override;
STACK_OF(X509_NAME) * extract_ca_names_from_x509_store(X509_STORE *store);
SSL_CTX *ctx_;
std::mutex ctx_mutex_;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
int last_ssl_error_ = 0;
#endif
};
class SSLClient final : public ClientImpl {
@@ -4596,13 +4605,35 @@ inline bool zstd_decompressor::decompress(const char *data, size_t data_length,
}
#endif
inline bool is_prohibited_header_name(const std::string &name) {
using udl::operator""_t;
switch (str2tag(name)) {
case "REMOTE_ADDR"_t:
case "REMOTE_PORT"_t:
case "LOCAL_ADDR"_t:
case "LOCAL_PORT"_t: return true;
default: return false;
}
}
inline bool has_header(const Headers &headers, const std::string &key) {
if (is_prohibited_header_name(key)) { return false; }
return headers.find(key) != headers.end();
}
inline const char *get_header_value(const Headers &headers,
const std::string &key, const char *def,
size_t id) {
if (is_prohibited_header_name(key)) {
#ifndef CPPHTTPLIB_NO_EXCEPTIONS
std::string msg = "Prohibited header name '" + key + "' is specified.";
throw std::invalid_argument(msg);
#else
return "";
#endif
}
auto rng = headers.equal_range(key);
auto it = rng.first;
std::advance(it, static_cast<ssize_t>(id));
@@ -7261,6 +7292,30 @@ inline bool RegexMatcher::match(Request &request) const {
return std::regex_match(request.path, request.matches, regex_);
}
inline std::string make_host_and_port_string(const std::string &host, int port,
bool is_ssl) {
std::string result;
// Enclose IPv6 address in brackets (but not if already enclosed)
if (host.find(':') == std::string::npos ||
(!host.empty() && host[0] == '[')) {
// IPv4, hostname, or already bracketed IPv6
result = host;
} else {
// IPv6 address without brackets
result = "[" + host + "]";
}
// Append port if not default
if ((!is_ssl && port == 80) || (is_ssl && port == 443)) {
; // do nothing
} else {
result += ":" + std::to_string(port);
}
return result;
}
} // namespace detail
// HTTP server implementation
@@ -7473,6 +7528,12 @@ inline Server &Server::set_header_writer(
return *this;
}
inline Server &
Server::set_trusted_proxies(const std::vector<std::string> &proxies) {
trusted_proxies_ = proxies;
return *this;
}
inline Server &Server::set_keep_alive_max_count(size_t count) {
keep_alive_max_count_ = count;
return *this;
@@ -8261,6 +8322,40 @@ inline bool Server::dispatch_request_for_content_reader(
return false;
}
inline std::string
get_client_ip(const std::string &x_forwarded_for,
const std::vector<std::string> &trusted_proxies) {
// X-Forwarded-For is a comma-separated list per RFC 7239
std::vector<std::string> ip_list;
detail::split(x_forwarded_for.data(),
x_forwarded_for.data() + x_forwarded_for.size(), ',',
[&](const char *b, const char *e) {
auto r = detail::trim(b, e, 0, static_cast<size_t>(e - b));
ip_list.emplace_back(std::string(b + r.first, b + r.second));
});
for (size_t i = 0; i < ip_list.size(); ++i) {
auto ip = ip_list[i];
auto is_trusted_proxy =
std::any_of(trusted_proxies.begin(), trusted_proxies.end(),
[&](const std::string &proxy) { return ip == proxy; });
if (is_trusted_proxy) {
if (i == 0) {
// If the trusted proxy is the first IP, there's no preceding client IP
return ip;
} else {
// Return the IP immediately before the trusted proxy
return ip_list[i - 1];
}
}
}
// If no trusted proxy is found, return the first IP in the list
return ip_list.front();
}
inline bool
Server::process_request(Stream &strm, const std::string &remote_addr,
int remote_port, const std::string &local_addr,
@@ -8324,15 +8419,16 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
connection_closed = true;
}
req.remote_addr = remote_addr;
if (!trusted_proxies_.empty() && req.has_header("X-Forwarded-For")) {
auto x_forwarded_for = req.get_header_value("X-Forwarded-For");
req.remote_addr = get_client_ip(x_forwarded_for, trusted_proxies_);
} else {
req.remote_addr = remote_addr;
}
req.remote_port = remote_port;
req.set_header("REMOTE_ADDR", req.remote_addr);
req.set_header("REMOTE_PORT", std::to_string(req.remote_port));
req.local_addr = local_addr;
req.local_port = local_port;
req.set_header("LOCAL_ADDR", req.local_addr);
req.set_header("LOCAL_PORT", std::to_string(req.local_port));
if (req.has_header("Accept")) {
const auto &accept_header = req.get_header_value("Accept");
@@ -8522,7 +8618,7 @@ inline ClientImpl::ClientImpl(const std::string &host, int port,
const std::string &client_cert_path,
const std::string &client_key_path)
: host_(detail::escape_abstract_namespace_unix_domain(host)), port_(port),
host_and_port_(adjust_host_string(host_) + ":" + std::to_string(port)),
host_and_port_(detail::make_host_and_port_string(host_, port, is_ssl())),
client_cert_path_(client_cert_path), client_key_path_(client_key_path) {}
inline ClientImpl::~ClientImpl() {
@@ -8703,8 +8799,9 @@ inline bool ClientImpl::send_(Request &req, Response &res, Error &error) {
{
std::lock_guard<std::mutex> guard(socket_mutex_);
// Set this to false immediately - if it ever gets set to true by the end of
// the request, we know another thread instructed us to close the socket.
// Set this to false immediately - if it ever gets set to true by the end
// of the request, we know another thread instructed us to close the
// socket.
socket_should_be_closed_when_request_is_done_ = false;
auto is_alive = false;
@@ -8720,10 +8817,10 @@ inline bool ClientImpl::send_(Request &req, Response &res, Error &error) {
#endif
if (!is_alive) {
// Attempt to avoid sigpipe by shutting down non-gracefully if it seems
// like the other side has already closed the connection Also, there
// cannot be any requests in flight from other threads since we locked
// request_mutex_, so safe to close everything immediately
// Attempt to avoid sigpipe by shutting down non-gracefully if it
// seems like the other side has already closed the connection Also,
// there cannot be any requests in flight from other threads since we
// locked request_mutex_, so safe to close everything immediately
const bool shutdown_gracefully = false;
shutdown_ssl(socket_, shutdown_gracefully);
shutdown_socket(socket_);
@@ -9027,7 +9124,8 @@ inline bool ClientImpl::create_redirect_client(
}
}
// New method for robust client setup (based on basic_manual_redirect.cpp logic)
// New method for robust client setup (based on basic_manual_redirect.cpp
// logic)
template <typename ClientType>
inline void ClientImpl::setup_redirect_client(ClientType &client) {
// Copy basic settings first
@@ -9131,18 +9229,8 @@ inline bool ClientImpl::write_request(Stream &strm, Request &req,
// curl behavior)
if (address_family_ == AF_UNIX) {
req.set_header("Host", "localhost");
} else if (is_ssl()) {
if (port_ == 443) {
req.set_header("Host", host_);
} else {
req.set_header("Host", host_and_port_);
}
} else {
if (port_ == 80) {
req.set_header("Host", host_);
} else {
req.set_header("Host", host_and_port_);
}
req.set_header("Host", host_and_port_);
}
}
@@ -9409,12 +9497,6 @@ inline Result ClientImpl::send_with_content_provider(
#endif
}
inline std::string
ClientImpl::adjust_host_string(const std::string &host) const {
if (host.find(':') != std::string::npos) { return "[" + host + "]"; }
return host;
}
inline void ClientImpl::output_log(const Request &req,
const Response &res) const {
if (logger_) {
@@ -9538,8 +9620,8 @@ inline ContentProviderWithoutLength ClientImpl::get_multipart_content_provider(
const FormDataProviderItems &provider_items) const {
size_t cur_item = 0;
size_t cur_start = 0;
// cur_item and cur_start are copied to within the std::function and maintain
// state between successive calls
// cur_item and cur_start are copied to within the std::function and
// maintain state between successive calls
return [&, cur_item, cur_start](size_t offset,
DataSink &sink) mutable -> bool {
if (!offset && !items.empty()) {
@@ -10251,8 +10333,8 @@ inline void ClientImpl::stop() {
// If there is anything ongoing right now, the ONLY thread-safe thing we can
// do is to shutdown_socket, so that threads using this socket suddenly
// discover they can't read/write any more and error out. Everything else
// (closing the socket, shutting ssl down) is unsafe because these actions are
// not thread-safe.
// (closing the socket, shutting ssl down) is unsafe because these actions
// are not thread-safe.
if (socket_requests_in_flight_ > 0) {
shutdown_socket(socket_);
@@ -10705,6 +10787,19 @@ inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path,
SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path,
client_ca_cert_dir_path);
// Set client CA list to be sent to clients during TLS handshake
if (client_ca_cert_file_path) {
auto ca_list = SSL_load_client_CA_file(client_ca_cert_file_path);
if (ca_list != nullptr) {
SSL_CTX_set_client_CA_list(ctx_, ca_list);
} else {
// Failed to load client CA list, but we continue since
// SSL_CTX_load_verify_locations already succeeded and
// certificate verification will still work
last_ssl_error_ = static_cast<int>(ERR_get_error());
}
}
SSL_CTX_set_verify(
ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
}
@@ -10729,6 +10824,15 @@ inline SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key,
} else if (client_ca_cert_store) {
SSL_CTX_set_cert_store(ctx_, client_ca_cert_store);
// Extract CA names from the store and set them as the client CA list
auto ca_list = extract_ca_names_from_x509_store(client_ca_cert_store);
if (ca_list) {
SSL_CTX_set_client_CA_list(ctx_, ca_list);
} else {
// Failed to extract CA names, record the error
last_ssl_error_ = static_cast<int>(ERR_get_error());
}
SSL_CTX_set_verify(
ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
}
@@ -10809,6 +10913,44 @@ inline bool SSLServer::process_and_close_socket(socket_t sock) {
return ret;
}
inline STACK_OF(X509_NAME) * SSLServer::extract_ca_names_from_x509_store(
X509_STORE *store) {
if (!store) { return nullptr; }
auto ca_list = sk_X509_NAME_new_null();
if (!ca_list) { return nullptr; }
// Get all objects from the store
auto objs = X509_STORE_get0_objects(store);
if (!objs) {
sk_X509_NAME_free(ca_list);
return nullptr;
}
// Iterate through objects and extract certificate subject names
for (int i = 0; i < sk_X509_OBJECT_num(objs); i++) {
auto obj = sk_X509_OBJECT_value(objs, i);
if (X509_OBJECT_get_type(obj) == X509_LU_X509) {
auto cert = X509_OBJECT_get0_X509(obj);
if (cert) {
auto subject = X509_get_subject_name(cert);
if (subject) {
auto name_dup = X509_NAME_dup(subject);
if (name_dup) { sk_X509_NAME_push(ca_list, name_dup); }
}
}
}
}
// If no names were extracted, free the list and return nullptr
if (sk_X509_NAME_num(ca_list) == 0) {
sk_X509_NAME_free(ca_list);
return nullptr;
}
return ca_list;
}
// SSL HTTP client implementation
inline SSLClient::SSLClient(const std::string &host)
: SSLClient(host, 443, std::string(), std::string()) {}
@@ -10889,7 +11031,8 @@ inline void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) {
if (ca_cert_store) {
if (ctx_) {
if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store) {
// Free memory allocated for old cert and use new store `ca_cert_store`
// Free memory allocated for old cert and use new store
// `ca_cert_store`
SSL_CTX_set_cert_store(ctx_, ca_cert_store);
ca_cert_store_ = ca_cert_store;
}
@@ -10911,10 +11054,15 @@ inline long SSLClient::get_openssl_verify_result() const {
inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; }
inline bool SSLClient::create_and_connect_socket(Socket &socket, Error &error) {
return is_valid() && ClientImpl::create_and_connect_socket(socket, error);
if (!is_valid()) {
error = Error::SSLConnection;
return false;
}
return ClientImpl::create_and_connect_socket(socket, error);
}
// Assumes that socket_mutex_ is locked and that there are no requests in flight
// Assumes that socket_mutex_ is locked and that there are no requests in
// flight
inline bool SSLClient::connect_with_proxy(
Socket &socket,
std::chrono::time_point<std::chrono::steady_clock> start_time,
@@ -11128,6 +11276,11 @@ inline bool SSLClient::initialize_ssl(Socket &socket, Error &error) {
return true;
}
if (ctx_ == nullptr) {
error = Error::SSLConnection;
last_openssl_error_ = ERR_get_error();
}
shutdown_socket(socket);
close_socket(socket);
return false;
@@ -11221,21 +11374,22 @@ SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const {
for (decltype(count) i = 0; i < count && !dsn_matched; i++) {
auto val = sk_GENERAL_NAME_value(alt_names, i);
if (val->type == type) {
auto name =
reinterpret_cast<const char *>(ASN1_STRING_get0_data(val->d.ia5));
auto name_len = static_cast<size_t>(ASN1_STRING_length(val->d.ia5));
if (!val || val->type != type) { continue; }
switch (type) {
case GEN_DNS: dsn_matched = check_host_name(name, name_len); break;
auto name =
reinterpret_cast<const char *>(ASN1_STRING_get0_data(val->d.ia5));
if (name == nullptr) { continue; }
case GEN_IPADD:
if (!memcmp(&addr6, name, addr_len) ||
!memcmp(&addr, name, addr_len)) {
ip_matched = true;
}
break;
auto name_len = static_cast<size_t>(ASN1_STRING_length(val->d.ia5));
switch (type) {
case GEN_DNS: dsn_matched = check_host_name(name, name_len); break;
case GEN_IPADD:
if (!memcmp(&addr6, name, addr_len) || !memcmp(&addr, name, addr_len)) {
ip_matched = true;
}
break;
}
}
+9 -2
View File
@@ -192,18 +192,25 @@ class chat_template {
};
};
const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}};
const auto contains_arg_needle = [&](const std::string & out_str) {
return contains(out_str, "<parameter=argument_needle>")
|| contains(out_str, "\"argument_needle\":")
|| contains(out_str, "'argument_needle':")
|| contains(out_str, ">argument_needle<")
|| contains(out_str, "<parameter name=\"argument_needle\">");
};
// Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want.
out = try_raw_render(json::array({
dummy_user_msg,
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})),
}), {}, false);
auto tool_call_renders_str_arguments = contains(out, "<parameter=argument_needle>") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
auto tool_call_renders_str_arguments = contains_arg_needle(out);
out = try_raw_render(json::array({
dummy_user_msg,
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})),
}), {}, false);
auto tool_call_renders_obj_arguments = contains(out, "<parameter=argument_needle>") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
auto tool_call_renders_obj_arguments = contains_arg_needle(out);
caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
+5 -7
View File
@@ -2205,7 +2205,7 @@ private:
auto value = parseValue();
while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
while (it != end && consumeSpaces() && peekSymbols({ "[", ".", "(" })) {
if (!consumeToken("[").empty()) {
std::shared_ptr<Expression> index;
auto slice_loc = get_location();
@@ -2250,15 +2250,13 @@ private:
auto key = std::make_shared<LiteralExpr>(identifier->location, Value(identifier->get_name()));
value = std::make_shared<SubscriptExpr>(identifier->location, std::move(value), std::move(key));
}
} else if (peekSymbols({ "(" })) {
auto callParams = parseCallArgs();
value = std::make_shared<CallExpr>(get_location(), std::move(value), std::move(callParams));
}
consumeSpaces();
}
if (peekSymbols({ "(" })) {
auto location = get_location();
auto callParams = parseCallArgs();
value = std::make_shared<CallExpr>(location, std::move(value), std::move(callParams));
}
return value;
}
@@ -2738,7 +2736,7 @@ inline std::shared_ptr<Context> Context::builtins() {
globals.set("raise_exception", simple_function("raise_exception", { "message" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
throw std::runtime_error(args.at("message").get<std::string>());
}));
globals.set("tojson", simple_function("tojson", { "value", "indent" }, [](const std::shared_ptr<Context> &, Value & args) {
globals.set("tojson", simple_function("tojson", { "value", "indent", "ensure_ascii" }, [](const std::shared_ptr<Context> &, Value & args) {
return Value(args.at("value").dump(args.get<int64_t>("indent", -1), /* to_json= */ true));
}));
globals.set("items", simple_function("items", { "object" }, [](const std::shared_ptr<Context> &, Value & args) {