mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-30 17:47:40 +02:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2759ccdb4a | |||
| c5023daf60 | |||
| e7da30b584 | |||
| ed8aa63320 | |||
| 48bd26501b | |||
| 622cd010ff | |||
| 070ff4d535 | |||
| bf7b0c9725 | |||
| fcfce040e8 | |||
| ee3a5a10ad | |||
| 7e994168b1 | |||
| bcfa87622a | |||
| a2054e3a8f | |||
| dd52868050 | |||
| 6b9a52422b | |||
| 2f966b8ed8 | |||
| cd5e3b5754 | |||
| 87c9efc3b2 | |||
| 76af40aaaa |
@@ -4,49 +4,49 @@ on:
|
||||
workflow_call:
|
||||
|
||||
jobs:
|
||||
ubuntu-24-riscv64-cpu-cross:
|
||||
runs-on: ubuntu-24.04
|
||||
# ubuntu-24-riscv64-cpu-cross:
|
||||
# runs-on: ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Riscv
|
||||
run: |
|
||||
sudo dpkg --add-architecture riscv64
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - name: Setup Riscv
|
||||
# run: |
|
||||
# sudo dpkg --add-architecture riscv64
|
||||
|
||||
# Add arch-specific repositories for non-amd64 architectures
|
||||
cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list
|
||||
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
|
||||
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
|
||||
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
|
||||
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
|
||||
EOF
|
||||
# # Add arch-specific repositories for non-amd64 architectures
|
||||
# cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list
|
||||
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
|
||||
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
|
||||
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
|
||||
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
|
||||
# EOF
|
||||
|
||||
sudo apt-get update || true ;# Prevent failure due to missing URLs.
|
||||
# sudo apt-get update || true ;# Prevent failure due to missing URLs.
|
||||
|
||||
sudo apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
gcc-14-riscv64-linux-gnu \
|
||||
g++-14-riscv64-linux-gnu
|
||||
# sudo apt-get install -y --no-install-recommends \
|
||||
# build-essential \
|
||||
# gcc-14-riscv64-linux-gnu \
|
||||
# g++-14-riscv64-linux-gnu
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cmake -B build -DLLAMA_CURL=OFF \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_OPENMP=OFF \
|
||||
-DLLAMA_BUILD_EXAMPLES=ON \
|
||||
-DLLAMA_BUILD_TOOLS=ON \
|
||||
-DLLAMA_BUILD_TESTS=OFF \
|
||||
-DCMAKE_SYSTEM_NAME=Linux \
|
||||
-DCMAKE_SYSTEM_PROCESSOR=riscv64 \
|
||||
-DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
|
||||
-DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
|
||||
-DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||
# - name: Build
|
||||
# run: |
|
||||
# cmake -B build -DLLAMA_CURL=OFF \
|
||||
# -DCMAKE_BUILD_TYPE=Release \
|
||||
# -DGGML_OPENMP=OFF \
|
||||
# -DLLAMA_BUILD_EXAMPLES=ON \
|
||||
# -DLLAMA_BUILD_TOOLS=ON \
|
||||
# -DLLAMA_BUILD_TESTS=OFF \
|
||||
# -DCMAKE_SYSTEM_NAME=Linux \
|
||||
# -DCMAKE_SYSTEM_PROCESSOR=riscv64 \
|
||||
# -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
|
||||
# -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \
|
||||
# -DCMAKE_POSITION_INDEPENDENT_CODE=ON \
|
||||
# -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \
|
||||
# -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
|
||||
# -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
|
||||
# -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
# cmake --build build --config Release -j $(nproc)
|
||||
|
||||
# ubuntu-24-riscv64-vulkan-cross:
|
||||
# runs-on: ubuntu-24.04
|
||||
|
||||
@@ -2768,6 +2768,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.image.emplace_back(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MTMD}));
|
||||
add_opt(common_arg(
|
||||
{"--image-min-tokens"}, "N",
|
||||
"minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)",
|
||||
[](common_params & params, int value) {
|
||||
params.image_min_tokens = value;
|
||||
}
|
||||
).set_examples(mmproj_examples).set_env("LLAMA_ARG_IMAGE_MIN_TOKENS"));
|
||||
add_opt(common_arg(
|
||||
{"--image-max-tokens"}, "N",
|
||||
"maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)",
|
||||
[](common_params & params, int value) {
|
||||
params.image_max_tokens = value;
|
||||
}
|
||||
).set_examples(mmproj_examples).set_env("LLAMA_ARG_IMAGE_MAX_TOKENS"));
|
||||
if (llama_supports_rpc()) {
|
||||
add_opt(common_arg(
|
||||
{"--rpc"}, "SERVERS",
|
||||
|
||||
+17
-2
@@ -313,7 +313,6 @@ json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msg
|
||||
}
|
||||
if (!msg.reasoning_content.empty()) {
|
||||
jmsg["reasoning_content"] = msg.reasoning_content;
|
||||
jmsg["thinking"] = msg.reasoning_content; // gpt-oss
|
||||
}
|
||||
if (!msg.tool_name.empty()) {
|
||||
jmsg["name"] = msg.tool_name;
|
||||
@@ -1810,7 +1809,23 @@ static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
|
||||
|
||||
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
auto prompt = apply(tmpl, inputs);
|
||||
|
||||
// Copy reasoning to the "thinking" field as expected by the gpt-oss template
|
||||
auto adjusted_messages = json::array();
|
||||
for (const auto & msg : inputs.messages) {
|
||||
auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string();
|
||||
auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array();
|
||||
|
||||
if (has_reasoning_content && has_tool_calls) {
|
||||
auto adjusted_message = msg;
|
||||
adjusted_message["thinking"] = msg.at("reasoning_content");
|
||||
adjusted_messages.push_back(adjusted_message);
|
||||
} else {
|
||||
adjusted_messages.push_back(msg);
|
||||
}
|
||||
}
|
||||
|
||||
auto prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
|
||||
// Check if we need to replace the return token with end token during
|
||||
// inference and without generation prompt. For more details see:
|
||||
|
||||
@@ -406,6 +406,8 @@ struct common_params {
|
||||
bool mmproj_use_gpu = true; // use GPU for multimodal model
|
||||
bool no_mmproj = false; // explicitly disable multimodal model
|
||||
std::vector<std::string> image; // path to image file(s)
|
||||
int image_min_tokens = -1;
|
||||
int image_max_tokens = -1;
|
||||
|
||||
// finetune
|
||||
struct lr_opt lr;
|
||||
|
||||
@@ -9802,6 +9802,113 @@ class CogVLMModel(LlamaModel):
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register("JanusForConditionalGeneration")
|
||||
class JanusProModel(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA # reuse Llama arch
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Skip vision, aligner, and generation tensors
|
||||
skip_prefixes = (
|
||||
'model.vision_model.',
|
||||
'model.aligner.',
|
||||
'model.vqmodel.',
|
||||
'model.generation_embeddings.',
|
||||
'model.generation_aligner.',
|
||||
'model.generation_head.',
|
||||
)
|
||||
if name.startswith(skip_prefixes):
|
||||
return []
|
||||
|
||||
if name.startswith('model.language_model.'):
|
||||
name = name.replace('model.language_model.', 'model.')
|
||||
elif name.startswith('language_model.'):
|
||||
name = name.replace('language_model.', '')
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("JanusForConditionalGeneration")
|
||||
class JanusProVisionModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.hparams_vision is not None
|
||||
if "intermediate_size" not in self.hparams_vision:
|
||||
mlp_ratio = self.hparams_vision.get("mlp_ratio")
|
||||
hidden_size = self.hparams_vision.get("hidden_size")
|
||||
if mlp_ratio is not None and hidden_size is not None:
|
||||
self.hparams_vision["intermediate_size"] = int(round(hidden_size * mlp_ratio))
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
assert self.hparams_vision is not None
|
||||
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.JANUS_PRO)
|
||||
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-6))
|
||||
|
||||
hidden_act = str(self.hparams_vision.get("hidden_act", "")).lower()
|
||||
if hidden_act == "gelu":
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
elif hidden_act == "silu":
|
||||
self.gguf_writer.add_vision_use_silu(True)
|
||||
|
||||
def _map_aligner_tensor(self, data_torch: Tensor, name: str) -> Iterable[tuple[str, Tensor]]:
|
||||
"""Map aligner tensors to projector format"""
|
||||
suffix = ".bias" if name.endswith(".bias") else ".weight"
|
||||
|
||||
if name.startswith("model.aligner."):
|
||||
local_name = name[len("model.aligner."):]
|
||||
elif name.startswith("aligner."):
|
||||
local_name = name[len("aligner."):]
|
||||
else:
|
||||
raise ValueError(f"Unsupported Janus aligner prefix: {name}")
|
||||
|
||||
if local_name.startswith("fc1."):
|
||||
mm_index = 0
|
||||
elif local_name.startswith("hidden_layers."):
|
||||
parts = local_name.split(".", 2)
|
||||
if len(parts) < 3:
|
||||
raise ValueError(f"Unexpected Janus aligner tensor name: {name}")
|
||||
mm_index = int(parts[1]) + 1
|
||||
else:
|
||||
raise ValueError(f"Unsupported Janus aligner tensor: {name}")
|
||||
|
||||
tensor_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_MMPROJ, mm_index, suffix=suffix)
|
||||
return [(tensor_name, data_torch)]
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
# Skip language model tensors as they will be handled by `JanusProModel`
|
||||
if name.startswith(('model.language_model.', 'language_model.')):
|
||||
return []
|
||||
|
||||
# Skip generation-related components
|
||||
skip_generation_prefixes = (
|
||||
'model.vqmodel.',
|
||||
'vqmodel.',
|
||||
'model.generation_embeddings.',
|
||||
'generation_embeddings.',
|
||||
'model.generation_aligner.',
|
||||
'generation_aligner.',
|
||||
'model.generation_head.',
|
||||
'generation_head.',
|
||||
)
|
||||
if name.startswith(skip_generation_prefixes):
|
||||
return []
|
||||
|
||||
# Handle aligner tensors
|
||||
if name.startswith(('model.aligner.', 'aligner.')):
|
||||
return list(self._map_aligner_tensor(data_torch, name))
|
||||
|
||||
# Handle vision tensors
|
||||
if name.startswith(('model.vision_model.', 'vision_model.')):
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
||||
@@ -138,6 +138,9 @@ if model_path is None:
|
||||
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
|
||||
)
|
||||
|
||||
|
||||
print("Loading model and tokenizer using AutoTokenizer:", model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
print("Model type: ", config.model_type)
|
||||
@@ -147,10 +150,6 @@ print("Number of layers: ", config.num_hidden_layers)
|
||||
print("BOS token id: ", config.bos_token_id)
|
||||
print("EOS token id: ", config.eos_token_id)
|
||||
|
||||
print("Loading model and tokenizer using AutoTokenizer:", model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
if unreleased_model_name:
|
||||
model_name_lower = unreleased_model_name.lower()
|
||||
unreleased_module_path = (
|
||||
@@ -171,7 +170,7 @@ if unreleased_model_name:
|
||||
exit(1)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True
|
||||
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
|
||||
)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
|
||||
@@ -700,7 +700,8 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
for (; ib + 1 < nb; ib += 2) {
|
||||
|
||||
// Compute combined scale for the block 0 and 1
|
||||
const __m128 d_0_1 = (__m128)__lsx_vreplgr2vr_w( GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d) );
|
||||
const float ft0 = GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d);
|
||||
const __m128 d_0_1 = (__m128)(v4f32){ft0, ft0, ft0, ft0};
|
||||
|
||||
const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);
|
||||
|
||||
@@ -714,11 +715,9 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
bx_1 = __lsx_vsub_b(bx_1, off);
|
||||
const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
|
||||
|
||||
//_mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
|
||||
//_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
|
||||
|
||||
// Compute combined scale for the block 2 and 3
|
||||
const __m128 d_2_3 = (__m128)__lsx_vreplgr2vr_w( GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) );
|
||||
const float ft1 = GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d);
|
||||
const __m128 d_2_3 = (__m128)(v4f32){ft1, ft1, ft1, ft1};
|
||||
|
||||
const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);
|
||||
|
||||
|
||||
@@ -500,13 +500,15 @@ inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
||||
|
||||
#endif
|
||||
|
||||
#if defined(__loongarch_asx)
|
||||
#if defined(__loongarch_sx)
|
||||
/* float type data load instructions */
|
||||
static __m128 __lsx_vreplfr2vr_s(const float val) {
|
||||
v4f32 res = {val, val, val, val};
|
||||
return (__m128)res;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__loongarch_asx)
|
||||
static __m256 __lasx_xvreplfr2vr_s(const float val) {
|
||||
v8f32 res = {val, val, val, val, val, val, val, val};
|
||||
return (__m256)res;
|
||||
|
||||
@@ -956,7 +956,7 @@ do { \
|
||||
|
||||
#define GGML_F32Cx8 __m256
|
||||
#define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
|
||||
#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
|
||||
#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
|
||||
|
||||
static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
|
||||
__m256i a;
|
||||
@@ -999,34 +999,34 @@ static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
|
||||
|
||||
#define GGML_F32x4 __m128
|
||||
#define GGML_F32x4_ZERO (__m128)__lsx_vldi(0)
|
||||
#define GGML_F32x4_SET1(x) (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
|
||||
#define GGML_F32x4_SET1(x) (__m128)__lsx_vreplfr2vr_s((x))
|
||||
#define GGML_F32x4_LOAD(x) (__m128)__lsx_vld((x), 0)
|
||||
#define GGML_F32x4_STORE(x, y) __lsx_vst(y, x, 0)
|
||||
#define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
|
||||
#define GGML_F32x4_ADD __lsx_vfadd_s
|
||||
#define GGML_F32x4_MUL __lsx_vfmul_s
|
||||
#define GGML_F32x4_REDUCE(res, x) \
|
||||
{ \
|
||||
int offset = GGML_F32_ARR >> 1; \
|
||||
for (int i = 0; i < offset; ++i) { \
|
||||
x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \
|
||||
} \
|
||||
offset >>= 1; \
|
||||
for (int i = 0; i < offset; ++i) { \
|
||||
x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \
|
||||
} \
|
||||
offset >>= 1; \
|
||||
for (int i = 0; i < offset; ++i) { \
|
||||
x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \
|
||||
} \
|
||||
__m128i tmp = __lsx_vsrli_d((__m128i) x[0], 32); \
|
||||
tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, x[0]); \
|
||||
tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
|
||||
const __m128 t0 = (__m128)__lsx_vshuf4i_w(tmp, 0x88); \
|
||||
tmp = __lsx_vsrli_d((__m128i) t0, 32); \
|
||||
tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, t0); \
|
||||
tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
|
||||
res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
|
||||
|
||||
#define GGML_F32x4_REDUCE(res, x) \
|
||||
{ \
|
||||
int offset = GGML_F32_ARR >> 1; \
|
||||
for (int i = 0; i < offset; ++i) { \
|
||||
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
|
||||
} \
|
||||
offset >>= 1; \
|
||||
for (int i = 0; i < offset; ++i) { \
|
||||
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
|
||||
} \
|
||||
offset >>= 1; \
|
||||
for (int i = 0; i < offset; ++i) { \
|
||||
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
|
||||
} \
|
||||
__m128i t0 = __lsx_vpickev_w((__m128i)x[0], (__m128i)x[0]); \
|
||||
__m128i t1 = __lsx_vpickod_w((__m128i)x[0], (__m128i)x[0]); \
|
||||
__m128 t2 = __lsx_vfadd_s((__m128)t0, (__m128)t1); \
|
||||
__m128i t3 = __lsx_vpickev_w((__m128i)t2, (__m128i)t2); \
|
||||
__m128i t4 = __lsx_vpickod_w((__m128i)t2, (__m128i)t2); \
|
||||
__m128 t5 = __lsx_vfadd_s((__m128)t3, (__m128)t4); \
|
||||
res = (ggml_float) ((v4f32)t5)[0]; \
|
||||
}
|
||||
|
||||
#define GGML_F32_VEC GGML_F32x4
|
||||
@@ -1068,7 +1068,7 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
|
||||
|
||||
#define GGML_F32Cx4 __m128
|
||||
#define GGML_F32Cx4_ZERO (__m128)__lsx_vldi(0)
|
||||
#define GGML_F32Cx4_SET1(x) (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
|
||||
#define GGML_F32Cx4_SET1(x) (__m128)__lsx_vreplfr2vr_s((x))
|
||||
#define GGML_F32Cx4_LOAD(x) (__m128)__lsx_f16x4_load(x)
|
||||
#define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
|
||||
#define GGML_F32Cx4_FMA GGML_F32x4_FMA
|
||||
|
||||
@@ -14,6 +14,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 64, 64>(ctx, dst);
|
||||
} break;
|
||||
case 72: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 72, 72>(ctx, dst);
|
||||
} break;
|
||||
case 80: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst);
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
// nbatch_K == number of K columns to load in parallel for KQ calculation
|
||||
|
||||
// TODO optimize kernel parameters for FP16 NVIDIA (P100)
|
||||
// TODO optimize kernel parameters for head sizes 40, 80, 96, 112
|
||||
// TODO optimize kernel parameters for head sizes 40, 72, 80, 96, 112
|
||||
|
||||
// The ROCm compiler cannot handle templating in __launch_bounds__.
|
||||
// As a workaround, define a macro to package the kernel parameters as uint32_t:
|
||||
@@ -32,6 +32,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 64, 72)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 64, 72)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 64, 72)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 64, 72)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 64, 72)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 64, 40)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 64, 40)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40)
|
||||
@@ -80,6 +86,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 3, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
|
||||
@@ -130,6 +142,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
|
||||
@@ -185,6 +204,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
|
||||
@@ -723,7 +749,7 @@ static __global__ void flash_attn_tile(
|
||||
|
||||
if (
|
||||
#ifdef GGML_USE_WMMA_FATTN
|
||||
(ncols2 != 1 && DV != 40 && DV != 512) ||
|
||||
(ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) ||
|
||||
#endif // GGML_USE_WMMA_FATTN
|
||||
(use_logit_softcap && !(DV == 128 || DV == 256))
|
||||
) {
|
||||
@@ -1198,6 +1224,7 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
|
||||
extern DECL_FATTN_TILE_CASE( 40, 40);
|
||||
extern DECL_FATTN_TILE_CASE( 64, 64);
|
||||
extern DECL_FATTN_TILE_CASE( 72, 72);
|
||||
extern DECL_FATTN_TILE_CASE( 80, 80);
|
||||
extern DECL_FATTN_TILE_CASE( 96, 96);
|
||||
extern DECL_FATTN_TILE_CASE(112, 112);
|
||||
|
||||
@@ -223,6 +223,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
switch (K->ne[0]) {
|
||||
case 40:
|
||||
case 64:
|
||||
case 72:
|
||||
case 80:
|
||||
case 96:
|
||||
case 128:
|
||||
@@ -275,7 +276,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
// If Turing tensor cores available, use them:
|
||||
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
|
||||
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72) {
|
||||
if (can_use_vector_kernel) {
|
||||
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
|
||||
@@ -301,7 +302,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
}
|
||||
|
||||
// Use the WMMA kernel if possible:
|
||||
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) {
|
||||
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {
|
||||
if (can_use_vector_kernel && Q->ne[1] <= 2) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
|
||||
@@ -2115,6 +2115,14 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2] : src1->ne[1]);
|
||||
|
||||
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) ||
|
||||
ggml_backend_buft_is_cuda_split(src1->buffer->buft);
|
||||
|
||||
//TODO: add support for fusion for split buffers
|
||||
if (split) {
|
||||
return false;
|
||||
}
|
||||
|
||||
//we only support fusion for ncols_dst = 1
|
||||
if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
|
||||
return false;
|
||||
@@ -2154,6 +2162,15 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) ||
|
||||
ggml_backend_buft_is_cuda_split(src1->buffer->buft);
|
||||
|
||||
//TODO: add support for fusion for split buffers
|
||||
if (split) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return use_mul_mat_vec_q;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(72, 72);
|
||||
@@ -3,7 +3,7 @@
|
||||
from glob import glob
|
||||
import os
|
||||
|
||||
HEAD_SIZES_KQ = [40, 64, 80, 96, 112, 128, 256, 576]
|
||||
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576]
|
||||
|
||||
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"]
|
||||
|
||||
@@ -81,6 +81,8 @@ for ncols in [8, 16, 32, 64]:
|
||||
for head_size_kq in HEAD_SIZES_KQ:
|
||||
if head_size_kq == 40:
|
||||
continue
|
||||
if head_size_kq == 72:
|
||||
continue
|
||||
if head_size_kq != 576 and ncols2 == 16:
|
||||
continue
|
||||
if head_size_kq == 576 and ncols2 != 16:
|
||||
|
||||
@@ -707,6 +707,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
if (op->src[0]->ne[0] != 32 &&
|
||||
op->src[0]->ne[0] != 40 &&
|
||||
op->src[0]->ne[0] != 64 &&
|
||||
op->src[0]->ne[0] != 72 &&
|
||||
op->src[0]->ne[0] != 80 &&
|
||||
op->src[0]->ne[0] != 96 &&
|
||||
op->src[0]->ne[0] != 112 &&
|
||||
|
||||
@@ -5362,6 +5362,7 @@ typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, hal
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
|
||||
@@ -5374,6 +5375,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
|
||||
@@ -5387,6 +5389,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
|
||||
@@ -5400,6 +5403,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
|
||||
@@ -5412,6 +5416,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
|
||||
@@ -5424,6 +5429,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
|
||||
@@ -5436,6 +5442,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
|
||||
@@ -5448,6 +5455,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
|
||||
|
||||
@@ -8399,6 +8399,7 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
const int is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
||||
|
||||
if (is_mrope) {
|
||||
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
||||
@@ -8489,9 +8490,14 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float), &attn_factor));
|
||||
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float), &beta_fast));
|
||||
CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &beta_slow));
|
||||
// both mrope and vision kernels have sections
|
||||
if (is_mrope || is_vision) {
|
||||
CL_CHECK(clSetKernelArg(kernel, 33, sizeof(int32_t)*4, §ions));
|
||||
}
|
||||
// only mrope has is_imrope
|
||||
if (is_mrope && !is_vision) {
|
||||
CL_CHECK(clSetKernelArg(kernel, 34, sizeof(int), &is_imrope));
|
||||
}
|
||||
|
||||
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
||||
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
||||
|
||||
@@ -392,7 +392,8 @@ kernel void kernel_rope_multi_f32(
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow,
|
||||
int4 sections
|
||||
int4 sections,
|
||||
int is_imrope
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global int*)((global char*)src1 + offset1);
|
||||
@@ -419,17 +420,29 @@ kernel void kernel_rope_multi_f32(
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
float theta_base = 0.0f;
|
||||
|
||||
if (sector < sections.s0) {
|
||||
theta_base = pos[i2];
|
||||
}
|
||||
else if (sector >= sections.s0 && sector < sec_w) {
|
||||
theta_base = pos[i2 + ne2 * 1];
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections.s2) {
|
||||
theta_base = pos[i2 + ne2 * 2];
|
||||
}
|
||||
else if (sector >= sec_w + sections.s2) {
|
||||
theta_base = pos[i2 + ne2 * 3];
|
||||
if (is_imrope) {
|
||||
if (sector % 3 == 1 && sector < 3 * sections.s1) { // h
|
||||
theta_base = (float) pos[i2 + ne02 * 1];
|
||||
} else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w
|
||||
theta_base = (float) pos[i2 + ne02 * 2];
|
||||
} else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t
|
||||
theta_base = (float) pos[i2 + ne02 * 0];
|
||||
} else { // e
|
||||
theta_base = (float) pos[i2 + ne02 * 3];
|
||||
}
|
||||
} else {
|
||||
if (sector < sections.s0) {
|
||||
theta_base = pos[i2];
|
||||
}
|
||||
else if (sector >= sections.s0 && sector < sec_w) {
|
||||
theta_base = pos[i2 + ne2 * 1];
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections.s2) {
|
||||
theta_base = pos[i2 + ne2 * 2];
|
||||
}
|
||||
else if (sector >= sec_w + sections.s2) {
|
||||
theta_base = pos[i2 + ne2 * 3];
|
||||
}
|
||||
}
|
||||
|
||||
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
||||
@@ -490,7 +503,8 @@ kernel void kernel_rope_multi_f16(
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow,
|
||||
int4 sections
|
||||
int4 sections,
|
||||
int is_imrope
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global int*)((global char*)src1 + offset1);
|
||||
@@ -517,17 +531,29 @@ kernel void kernel_rope_multi_f16(
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
float theta_base = 0.0f;
|
||||
|
||||
if (sector < sections.s0) {
|
||||
theta_base = pos[i2];
|
||||
}
|
||||
else if (sector >= sections.s0 && sector < sec_w) {
|
||||
theta_base = pos[i2 + ne2 * 1];
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections.s2) {
|
||||
theta_base = pos[i2 + ne2 * 2];
|
||||
}
|
||||
else if (sector >= sec_w + sections.s2) {
|
||||
theta_base = pos[i2 + ne2 * 3];
|
||||
if (is_imrope) {
|
||||
if (sector % 3 == 1 && sector < 3 * sections.s1) { // h
|
||||
theta_base = (float) pos[i2 + ne02 * 1];
|
||||
} else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w
|
||||
theta_base = (float) pos[i2 + ne02 * 2];
|
||||
} else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t
|
||||
theta_base = (float) pos[i2 + ne02 * 0];
|
||||
} else { // e
|
||||
theta_base = (float) pos[i2 + ne02 * 3];
|
||||
}
|
||||
} else {
|
||||
if (sector < sections.s0) {
|
||||
theta_base = pos[i2];
|
||||
}
|
||||
else if (sector >= sections.s0 && sector < sec_w) {
|
||||
theta_base = pos[i2 + ne2 * 1];
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections.s2) {
|
||||
theta_base = pos[i2 + ne2 * 2];
|
||||
}
|
||||
else if (sector >= sec_w + sections.s2) {
|
||||
theta_base = pos[i2 + ne2 * 3];
|
||||
}
|
||||
}
|
||||
|
||||
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
||||
|
||||
@@ -2,26 +2,43 @@
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
#define GGML_ASSERT_TENSOR_FITS_INT(t) \
|
||||
GGML_ASSERT((t)->ne[0] < INT_MAX && (t)->ne[1] < INT_MAX && (t)->ne[2] < INT_MAX && (t)->ne[3] < INT_MAX)
|
||||
|
||||
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const float * src0_dd = (const float *) dst->src[0]->data;
|
||||
float * dst_dd = (float *) dst->data;
|
||||
|
||||
const int64_t ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3];
|
||||
const int64_t ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2],
|
||||
ne03 = dst->src[0]->ne[3];
|
||||
GGML_ASSERT_TENSOR_FITS_INT(dst);
|
||||
GGML_ASSERT_TENSOR_FITS_INT(dst->src[0]);
|
||||
|
||||
const int nr0 = (int) (ne00 / ne0);
|
||||
const int nr1 = (int) (ne01 / ne1);
|
||||
const int nr2 = (int) (ne02 / ne2);
|
||||
const int nr3 = (int) (ne03 / ne3);
|
||||
const int ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3];
|
||||
const int ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2],
|
||||
ne03 = dst->src[0]->ne[3];
|
||||
|
||||
const size_t total = ne0 * ne1 * ne2 * ne3;
|
||||
const int BLOCK_SIZE = 256;
|
||||
const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const int nr0 = ne00 / ne0;
|
||||
const int nr1 = ne01 / ne1;
|
||||
const int nr2 = ne02 / ne2;
|
||||
const int nr3 = ne03 / ne3;
|
||||
|
||||
const int nb0 = dst->src[0]->nb[0];
|
||||
const int nb1 = dst->src[0]->nb[1];
|
||||
const int nb2 = dst->src[0]->nb[2];
|
||||
const int nb3 = dst->src[0]->nb[3];
|
||||
|
||||
const char * base = (const char *) src0_dd;
|
||||
|
||||
const size_t total = (size_t) ne0 * ne1 * ne2 * ne3;
|
||||
constexpr int BLOCK_SIZE = 256;
|
||||
const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
const float inv_ne0 = 1.0f / ne0;
|
||||
const float inv_ne_01 = 1.0f / (ne0 * ne1);
|
||||
const float inv_ne_012 = 1.0f / (ne0 * ne1 * ne2);
|
||||
const int repeat_count = nr0 * nr1 * nr2 * nr3;
|
||||
|
||||
queue_ptr stream = ctx.stream();
|
||||
|
||||
@@ -33,24 +50,27 @@ void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
||||
return;
|
||||
}
|
||||
|
||||
const int i0 = i % ne0;
|
||||
const int i1 = (i / ne0) % ne1;
|
||||
const int i2 = (i / (ne0 * ne1)) % ne2;
|
||||
const int i3 = i / (ne0 * ne1 * ne2);
|
||||
const int i3 = (int) (i * inv_ne_012);
|
||||
const int i2 = (int) (i * inv_ne_01) - i3 * ne2;
|
||||
const int i1 = (int) (i * inv_ne0) - (int) (i * inv_ne_01) * ne1;
|
||||
const int i0 = i - (int) (i * inv_ne0) * ne0;
|
||||
|
||||
int j0 = 0, j1 = 0, j2 = 0, j3 = 0;
|
||||
float acc = 0.0f;
|
||||
|
||||
for (int j3 = 0; j3 < nr3; ++j3) {
|
||||
for (int j2 = 0; j2 < nr2; ++j2) {
|
||||
for (int j1 = 0; j1 < nr1; ++j1) {
|
||||
for (int j0 = 0; j0 < nr0; ++j0) {
|
||||
acc += src0_dd[(i0 + j0 * ne0) + (i1 + j1 * ne1) * ne00 + (i2 + j2 * ne2) * ne00 * ne01 +
|
||||
(i3 + j3 * ne3) * ne00 * ne01 * ne02];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < repeat_count; ++j) {
|
||||
const float * ptr = (const float *) (base + (i0 + j0 * ne0) * nb0 + (i1 + j1 * ne1) * nb1 +
|
||||
(i2 + j2 * ne2) * nb2 + (i3 + j3 * ne3) * nb3);
|
||||
acc += *ptr;
|
||||
|
||||
int carry = (++j0 >= nr0);
|
||||
j0 -= carry * nr0;
|
||||
carry = (carry && (++j1 >= nr1));
|
||||
j1 -= carry * nr1;
|
||||
carry = (carry && (++j2 >= nr2));
|
||||
j2 -= carry * nr2;
|
||||
j3 += carry;
|
||||
}
|
||||
dst_dd[i] = acc;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -3186,6 +3186,7 @@ class VisionProjectorType:
|
||||
KIMIVL = "kimivl"
|
||||
LIGHTONOCR = "lightonocr"
|
||||
COGVLM = "cogvlm"
|
||||
JANUS_PRO = "janus_pro"
|
||||
|
||||
|
||||
# Items here are (block size, type size)
|
||||
|
||||
@@ -1183,6 +1183,7 @@ class TensorNameMap:
|
||||
"model.mm_projector.mlp.mlp.{bid}",
|
||||
"vision_model.vision_adapter.mlp.fc{bid}", # llama 4
|
||||
"mlp1.{bid}", # InternVL
|
||||
"model.aligner.fc1.hidden_layers.{bid}", # Janus Pro
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MMPROJ_PEG: (
|
||||
@@ -1291,6 +1292,7 @@ class TensorNameMap:
|
||||
"model.vision_tower.encoder.layer.{bid}.attention.projection_layer", # Intern-S1
|
||||
"vpm.encoder.layers.{bid}.self_attn.out_proj",
|
||||
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
|
||||
"model.vision_model.encoder.layers.{bid}.self_attn.projection_layer", # Janus Pro
|
||||
"vision_model.model.layers.{bid}.self_attn.o_proj", # llama4
|
||||
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral-hf
|
||||
"vision_encoder.transformer.layers.{bid}.attention.wo", # pixtral
|
||||
|
||||
+4
-3
@@ -461,7 +461,10 @@ extern "C" {
|
||||
LLAMA_API bool llama_supports_gpu_offload(void);
|
||||
LLAMA_API bool llama_supports_rpc (void);
|
||||
|
||||
// NOTE: After creating a llama_context, it is recommended to query the actual values using these functions
|
||||
// In some cases the requested values via llama_context_params may differ from the actual values used by the context
|
||||
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
||||
@@ -585,7 +588,7 @@ extern "C" {
|
||||
LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
|
||||
|
||||
// Manually free a LoRA adapter
|
||||
// Note: loaded adapters will be free when the associated model is deleted
|
||||
// NOTE: loaded adapters will be free when the associated model is deleted
|
||||
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
|
||||
|
||||
// Get the invocation tokens if the current lora is an alora
|
||||
@@ -1111,8 +1114,6 @@ extern "C" {
|
||||
// // sample from the logits of the last token in the batch
|
||||
// const llama_token id = llama_sampler_sample(smpl, ctx, -1);
|
||||
//
|
||||
// // accepting the token updates the internal state of certain samplers (e.g. grammar, repetition, etc.)
|
||||
// llama_sampler_accept(smpl, id);
|
||||
// ...
|
||||
// }
|
||||
//
|
||||
|
||||
+27
-10
@@ -112,11 +112,24 @@ llama_context::llama_context(
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
||||
if (cparams.kv_unified) {
|
||||
cparams.n_ctx_seq = cparams.n_ctx;
|
||||
} else {
|
||||
cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
|
||||
|
||||
if (cparams.n_ctx_seq == 0) {
|
||||
throw std::runtime_error("n_ctx_seq == 0");
|
||||
}
|
||||
|
||||
if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) {
|
||||
cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max;
|
||||
LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx);
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
|
||||
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
|
||||
LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
|
||||
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
||||
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
||||
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
|
||||
@@ -125,14 +138,14 @@ llama_context::llama_context(
|
||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
||||
|
||||
if (n_ctx_per_seq < hparams.n_ctx_train) {
|
||||
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
|
||||
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
||||
if (cparams.n_ctx_seq < hparams.n_ctx_train) {
|
||||
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
|
||||
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
|
||||
}
|
||||
|
||||
if (n_ctx_per_seq > hparams.n_ctx_train) {
|
||||
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
|
||||
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
||||
if (cparams.n_ctx_seq > hparams.n_ctx_train) {
|
||||
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
|
||||
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
|
||||
}
|
||||
|
||||
if (!hparams.vocab_only) {
|
||||
@@ -453,8 +466,8 @@ uint32_t llama_context::n_ctx() const {
|
||||
return cparams.n_ctx;
|
||||
}
|
||||
|
||||
uint32_t llama_context::n_ctx_per_seq() const {
|
||||
return cparams.n_ctx / cparams.n_seq_max;
|
||||
uint32_t llama_context::n_ctx_seq() const {
|
||||
return cparams.n_ctx_seq;
|
||||
}
|
||||
|
||||
uint32_t llama_context::n_batch() const {
|
||||
@@ -2383,6 +2396,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
|
||||
return ctx->n_ctx();
|
||||
}
|
||||
|
||||
uint32_t llama_n_ctx_seq(const llama_context * ctx) {
|
||||
return ctx->n_ctx_seq();
|
||||
}
|
||||
|
||||
uint32_t llama_n_batch(const llama_context * ctx) {
|
||||
return ctx->n_batch();
|
||||
}
|
||||
|
||||
+5
-5
@@ -43,11 +43,11 @@ struct llama_context {
|
||||
|
||||
ggml_backend_sched_t get_sched() const;
|
||||
|
||||
uint32_t n_ctx() const;
|
||||
uint32_t n_ctx_per_seq() const;
|
||||
uint32_t n_batch() const;
|
||||
uint32_t n_ubatch() const;
|
||||
uint32_t n_seq_max() const;
|
||||
uint32_t n_ctx() const;
|
||||
uint32_t n_ctx_seq() const;
|
||||
uint32_t n_batch() const;
|
||||
uint32_t n_ubatch() const;
|
||||
uint32_t n_seq_max() const;
|
||||
|
||||
uint32_t n_threads() const;
|
||||
uint32_t n_threads_batch() const;
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
struct llama_cparams {
|
||||
uint32_t n_ctx; // context size used during inference
|
||||
uint32_t n_ctx_seq; // context for a single sequence
|
||||
uint32_t n_batch;
|
||||
uint32_t n_ubatch;
|
||||
uint32_t n_seq_max;
|
||||
|
||||
+4
-10
@@ -6712,14 +6712,14 @@ float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) co
|
||||
}
|
||||
|
||||
ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const {
|
||||
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
||||
const uint32_t n_ctx_seq = cparams.n_ctx_seq;
|
||||
|
||||
// choose long/short freq factors based on the context size
|
||||
if (layers[il].rope_freqs != nullptr) {
|
||||
return layers[il].rope_freqs;
|
||||
}
|
||||
|
||||
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
|
||||
if (n_ctx_seq > hparams.n_ctx_orig_yarn) {
|
||||
return layers[il].rope_long;
|
||||
}
|
||||
|
||||
@@ -6795,12 +6795,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
/* filter_attn */ std::move(filter_attn),
|
||||
/* filter_recr */ std::move(filter_recr));
|
||||
} else {
|
||||
uint32_t n_ctx_per_stream = cparams.n_ctx;
|
||||
|
||||
if (!cparams.kv_unified) {
|
||||
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
|
||||
}
|
||||
|
||||
llama_memory_i::layer_reuse_cb reuse = nullptr;
|
||||
|
||||
if (arch == LLM_ARCH_GEMMA3N) {
|
||||
@@ -6824,7 +6818,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
cparams.offload_kqv,
|
||||
params.swa_full,
|
||||
cparams.kv_unified,
|
||||
n_ctx_per_stream,
|
||||
cparams.n_ctx_seq,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_ubatch,
|
||||
1,
|
||||
@@ -6840,7 +6834,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
!cparams.flash_attn,
|
||||
cparams.offload_kqv,
|
||||
cparams.kv_unified,
|
||||
n_ctx_per_stream,
|
||||
cparams.n_ctx_seq,
|
||||
cparams.n_seq_max,
|
||||
1,
|
||||
hparams.n_swa,
|
||||
|
||||
@@ -1454,6 +1454,8 @@ struct test_case {
|
||||
ggml_context_ptr ctx(ggml_init(params)); // smart ptr
|
||||
GGML_ASSERT(ctx);
|
||||
|
||||
gf = ggml_new_graph_custom(ctx.get(), graph_nodes, false);
|
||||
|
||||
ggml_tensor * out = build_graph(ctx.get());
|
||||
current_op_name = op_desc(out);
|
||||
|
||||
@@ -7225,8 +7227,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v));
|
||||
}
|
||||
|
||||
for (int hsk : { 40, 64, 80, 96, 128, 192, 256, 576 }) {
|
||||
for (int hsv : { 40, 64, 80, 96, 128, 192, 256, 512 }) {
|
||||
for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) {
|
||||
for (int hsv : { 40, 64, 72, 80, 96, 128, 192, 256, 512 }) {
|
||||
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
|
||||
if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
|
||||
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
|
||||
@@ -7569,6 +7571,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
if (mode == MODE_SUPPORT) {
|
||||
auto test_cases = make_test_cases_eval();
|
||||
filter_test_cases(test_cases, params_filter);
|
||||
|
||||
// Filter out fusion cases
|
||||
test_cases.erase(
|
||||
std::remove_if(test_cases.begin(), test_cases.end(), [](const std::unique_ptr<test_case> & tc) {
|
||||
return tc->run_whole_graph();
|
||||
}),
|
||||
test_cases.end()
|
||||
);
|
||||
|
||||
for (auto & test : test_cases) {
|
||||
test->eval_support(backend, op_names_filter, output_printer);
|
||||
}
|
||||
@@ -7619,6 +7630,14 @@ static void show_test_coverage() {
|
||||
all_ops.insert(ggml_glu_op_name((enum ggml_glu_op)i));
|
||||
}
|
||||
auto test_cases = make_test_cases_eval();
|
||||
// Filter out fusion cases
|
||||
test_cases.erase(
|
||||
std::remove_if(test_cases.begin(), test_cases.end(), [](const std::unique_ptr<test_case> & tc) {
|
||||
return tc->run_whole_graph();
|
||||
}),
|
||||
test_cases.end()
|
||||
);
|
||||
|
||||
std::set<std::string> tested_ops;
|
||||
|
||||
ggml_init_params params = {
|
||||
|
||||
@@ -131,7 +131,14 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
batch = llama_batch_get_one(&token, 1);
|
||||
if (llama_decode(ctx.get(), batch)) {
|
||||
|
||||
int ret = llama_decode(ctx.get(), batch);
|
||||
if (ret == 1 && i > 0) {
|
||||
LOG_INF("Context full, stopping generation.\n");
|
||||
break;
|
||||
}
|
||||
|
||||
if (ret != 0) {
|
||||
LOG_ERR("Model %d/%d, Context %d/%d: failed to decode\n", m + 1, num_models, c + 1, num_contexts);
|
||||
failed.store(true);
|
||||
return;
|
||||
|
||||
@@ -155,6 +155,7 @@ enum projector_type {
|
||||
PROJECTOR_TYPE_KIMIVL,
|
||||
PROJECTOR_TYPE_LIGHTONOCR,
|
||||
PROJECTOR_TYPE_COGVLM,
|
||||
PROJECTOR_TYPE_JANUS_PRO,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -180,6 +181,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
|
||||
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
|
||||
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
|
||||
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
|
||||
};
|
||||
|
||||
static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||
|
||||
+238
-47
@@ -6,7 +6,6 @@
|
||||
#include "clip-impl.h"
|
||||
#include "ggml.h"
|
||||
#include "ggml-cpp.h"
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-alloc.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "gguf.h"
|
||||
@@ -17,17 +16,15 @@
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <regex>
|
||||
#include <stdexcept>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <cinttypes>
|
||||
#include <limits>
|
||||
#include <array>
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
|
||||
// TODO: allow to pass callback from user code
|
||||
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
|
||||
|
||||
enum ffn_op_type {
|
||||
@@ -172,8 +169,8 @@ struct clip_hparams {
|
||||
int32_t n_layer;
|
||||
// idefics3
|
||||
int32_t image_longest_edge = 0;
|
||||
int32_t image_min_pixels = 0;
|
||||
int32_t image_max_pixels = 0;
|
||||
int32_t image_min_pixels = -1;
|
||||
int32_t image_max_pixels = -1;
|
||||
int32_t n_merge = 0; // number of patch merges **per-side**
|
||||
|
||||
float image_mean[3];
|
||||
@@ -206,11 +203,15 @@ struct clip_hparams {
|
||||
int minicpmv_version = 0;
|
||||
int32_t minicpmv_query_num = 0; // MiniCPM-V query number
|
||||
|
||||
// custom value provided by user, can be undefined if not set
|
||||
int32_t custom_image_min_tokens = -1;
|
||||
int32_t custom_image_max_tokens = -1;
|
||||
|
||||
void set_limit_image_tokens(int n_tokens_min, int n_tokens_max) {
|
||||
const int cur_merge = n_merge == 0 ? 1 : n_merge;
|
||||
const int patch_area = patch_size * patch_size * cur_merge * cur_merge;
|
||||
image_min_pixels = n_tokens_min * patch_area;
|
||||
image_max_pixels = n_tokens_max * patch_area;
|
||||
image_min_pixels = (custom_image_min_tokens > 0 ? custom_image_min_tokens : n_tokens_min) * patch_area;
|
||||
image_max_pixels = (custom_image_max_tokens > 0 ? custom_image_max_tokens : n_tokens_max) * patch_area;
|
||||
warmup_image_size = static_cast<int>(std::sqrt(image_max_pixels));
|
||||
}
|
||||
|
||||
@@ -219,6 +220,7 @@ struct clip_hparams {
|
||||
GGML_ASSERT(n_tok_per_side * n_tok_per_side == n_tokens && "n_tokens must be n*n");
|
||||
const int cur_merge = n_merge == 0 ? 1 : n_merge;
|
||||
warmup_image_size = n_tok_per_side * patch_size * cur_merge;
|
||||
// TODO: support warmup size for custom token numbers
|
||||
}
|
||||
};
|
||||
|
||||
@@ -426,12 +428,14 @@ struct clip_ctx {
|
||||
|
||||
int max_nodes = 8192;
|
||||
ggml_backend_sched_ptr sched;
|
||||
clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO;
|
||||
|
||||
// for debugging
|
||||
bool debug_graph = false;
|
||||
std::vector<ggml_tensor *> debug_print_tensors;
|
||||
|
||||
clip_ctx(clip_context_params & ctx_params) {
|
||||
flash_attn_type = ctx_params.flash_attn_type;
|
||||
debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
|
||||
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
||||
if (!backend_cpu) {
|
||||
@@ -460,6 +464,13 @@ struct clip_ctx {
|
||||
LOG_INF("%s: CLIP using CPU backend\n", __func__);
|
||||
}
|
||||
|
||||
if (ctx_params.image_min_tokens > 0) {
|
||||
model.hparams.custom_image_min_tokens = ctx_params.image_min_tokens;
|
||||
}
|
||||
if (ctx_params.image_max_tokens > 0) {
|
||||
model.hparams.custom_image_max_tokens = ctx_params.image_max_tokens;
|
||||
}
|
||||
|
||||
backend_ptrs.push_back(backend_cpu);
|
||||
backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu));
|
||||
|
||||
@@ -589,6 +600,15 @@ struct clip_graph {
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.mm_2_b);
|
||||
|
||||
} else if (ctx->proj_type() == PROJECTOR_TYPE_JANUS_PRO) {
|
||||
cur = build_ffn(cur,
|
||||
model.mm_0_w, model.mm_0_b,
|
||||
nullptr, nullptr,
|
||||
model.mm_1_w, model.mm_1_b,
|
||||
hparams.ffn_op,
|
||||
-1);
|
||||
|
||||
} else {
|
||||
GGML_ABORT("SigLIP: Unsupported projector type");
|
||||
}
|
||||
@@ -753,6 +773,15 @@ struct clip_graph {
|
||||
ggml_set_name(window_mask, "window_mask");
|
||||
ggml_set_input(window_mask);
|
||||
|
||||
// if flash attn is used, we need to pad the mask and cast to f16
|
||||
if (ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
|
||||
int n_pad = GGML_PAD(window_mask->ne[1], GGML_KQ_MASK_PAD) - window_mask->ne[1];
|
||||
if (n_pad > 0) {
|
||||
window_mask = ggml_pad(ctx0, window_mask, 0, n_pad, 0, 0);
|
||||
}
|
||||
window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16);
|
||||
}
|
||||
|
||||
// inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size]
|
||||
GGML_ASSERT(batch_size == 1);
|
||||
inpL = ggml_reshape_2d(ctx0, inpL, n_embd * 4, n_patches_x * n_patches_y * batch_size / 4);
|
||||
@@ -1730,7 +1759,6 @@ struct clip_graph {
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
// whisper encoder with custom projector
|
||||
ggml_cgraph * build_whisper_enc() {
|
||||
const int n_frames = img.nx;
|
||||
@@ -2260,17 +2288,25 @@ private:
|
||||
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
|
||||
//cb(k, "k", il);
|
||||
|
||||
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
|
||||
v = ggml_cont(ctx0, v);
|
||||
//cb(k, "v", il);
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
||||
// TODO @ngxson : support flash attention
|
||||
{
|
||||
if (ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
|
||||
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
||||
|
||||
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
|
||||
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
|
||||
|
||||
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f);
|
||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
||||
|
||||
} else {
|
||||
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
|
||||
v = ggml_cont(ctx0, v);
|
||||
|
||||
const auto n_tokens = q->ne[1];
|
||||
const auto n_head = q->ne[2];
|
||||
// const auto n_kv = k->ne[1]; // for flash attention
|
||||
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
// F32 may not needed for vision encoders?
|
||||
@@ -2450,6 +2486,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
{
|
||||
res = graph.build_kimivl();
|
||||
} break;
|
||||
case PROJECTOR_TYPE_JANUS_PRO:
|
||||
{
|
||||
res = graph.build_siglip();
|
||||
} break;
|
||||
case PROJECTOR_TYPE_COGVLM:
|
||||
{
|
||||
res = graph.build_cogvlm();
|
||||
@@ -2758,6 +2798,12 @@ struct clip_model_loader {
|
||||
// see: https://github.com/ggml-org/llama.cpp/issues/16842#issuecomment-3475144858
|
||||
hparams.set_limit_image_tokens(8, 2048);
|
||||
hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
|
||||
const int warn_min_pixels = 1024 * hparams.n_merge * hparams.n_merge * hparams.patch_size * hparams.patch_size;
|
||||
if (hparams.image_min_pixels < warn_min_pixels) {
|
||||
LOG_WRN("%s: Qwen-VL models require at minimum 1024 image tokens to function correctly on grounding tasks\n", __func__);
|
||||
LOG_WRN("%s: if you encounter problems with accuracy, try adding --image-min-tokens 1024\n", __func__);
|
||||
LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__);
|
||||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LLAMA4:
|
||||
{
|
||||
@@ -2782,6 +2828,13 @@ struct clip_model_loader {
|
||||
break;
|
||||
}
|
||||
|
||||
// sanity check
|
||||
{
|
||||
if (hparams.image_max_pixels < hparams.image_min_pixels) {
|
||||
throw std::runtime_error(string_format("%s: image_max_pixels (%d) is less than image_min_pixels (%d)\n", __func__, hparams.image_max_pixels, hparams.image_min_pixels));
|
||||
}
|
||||
}
|
||||
|
||||
LOG_INF("%s: projector: %s\n", __func__, proj_type.c_str());
|
||||
LOG_INF("%s: n_embd: %d\n", __func__, hparams.n_embd);
|
||||
LOG_INF("%s: n_head: %d\n", __func__, hparams.n_head);
|
||||
@@ -2798,10 +2851,10 @@ struct clip_model_loader {
|
||||
LOG_INF("%s: n_merge: %d\n", __func__, hparams.n_merge);
|
||||
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
|
||||
if (hparams.image_min_pixels > 0) {
|
||||
LOG_INF("%s: image_min_pixels: %d\n", __func__, hparams.image_min_pixels);
|
||||
LOG_INF("%s: image_min_pixels: %d%s\n", __func__, hparams.image_min_pixels, hparams.custom_image_min_tokens > 0 ? " (custom value)" : "");
|
||||
}
|
||||
if (hparams.image_max_pixels > 0) {
|
||||
LOG_INF("%s: image_max_pixels: %d\n", __func__, hparams.image_max_pixels);
|
||||
LOG_INF("%s: image_max_pixels: %d%s\n", __func__, hparams.image_max_pixels, hparams.custom_image_max_tokens > 0 ? " (custom value)" : "");
|
||||
}
|
||||
} else if (is_audio) {
|
||||
LOG_INF("\n--- audio hparams ---\n");
|
||||
@@ -3151,6 +3204,13 @@ struct clip_model_loader {
|
||||
model.mm_boi = get_tensor(TN_TOK_BOI);
|
||||
model.mm_eoi = get_tensor(TN_TOK_EOI);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_JANUS_PRO:
|
||||
{
|
||||
model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
|
||||
model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
|
||||
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
|
||||
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
|
||||
} break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown projector type");
|
||||
}
|
||||
@@ -3192,7 +3252,87 @@ struct clip_model_loader {
|
||||
}
|
||||
}
|
||||
|
||||
void alloc_compute_meta(clip_ctx & ctx_clip) {
|
||||
struct support_info_op {
|
||||
ggml_tensor * op;
|
||||
|
||||
// true if the op runs on the accelerated ctx_clip.backend
|
||||
bool is_accel = true;
|
||||
};
|
||||
|
||||
struct support_info_graph {
|
||||
// whether the clip_ctx.backend supports flash attention
|
||||
bool fattn = true;
|
||||
ggml_tensor * fattn_op = nullptr; // for debugging
|
||||
|
||||
std::vector<support_info_op> ops;
|
||||
};
|
||||
|
||||
static void warmup(clip_ctx & ctx_clip) {
|
||||
support_info_graph info;
|
||||
|
||||
if (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_AUTO) {
|
||||
// try to enable flash attention to see if it's supported
|
||||
ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED;
|
||||
info = alloc_compute_meta(ctx_clip);
|
||||
if (!info.fattn && info.fattn_op) {
|
||||
auto op = info.fattn_op;
|
||||
LOG_WRN("%s: *****************************************************************\n", __func__);
|
||||
LOG_WRN("%s: WARNING: flash attention not supported by %s, memory usage will increase\n", __func__, ggml_backend_name(ctx_clip.backend));
|
||||
LOG_WRN("%s: op params: \n", __func__);
|
||||
static auto print_shape = [](const char * fn, const char * name, ggml_tensor * t) {
|
||||
LOG_WRN("%s: %s: type = %s, ne = [%d %d %d %d], nb = [%d %d %d %d]\n", fn,
|
||||
name, ggml_type_name(t->type),
|
||||
t->ne[0], t->ne[1], t->ne[2], t->ne[3],
|
||||
t->nb[0], t->nb[1], t->nb[2], t->nb[3]);
|
||||
};
|
||||
print_shape(__func__, " dst", op);
|
||||
print_shape(__func__, "src0", op->src[0]);
|
||||
print_shape(__func__, "src1", op->src[1]);
|
||||
print_shape(__func__, "src2", op->src[2]);
|
||||
LOG_WRN("%s: please report this on github as an issue\n", __func__);
|
||||
LOG_WRN("%s: *****************************************************************\n", __func__);
|
||||
ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED;
|
||||
alloc_compute_meta(ctx_clip);
|
||||
}
|
||||
} else {
|
||||
info = alloc_compute_meta(ctx_clip);
|
||||
if (!info.fattn && ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
|
||||
LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
LOG_INF("%s: flash attention is %s\n", __func__,
|
||||
(ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
|
||||
|
||||
// print ops that are not supported by the GPU backend (if there is one)
|
||||
if (ctx_clip.backend && ctx_clip.backend != ctx_clip.backend_cpu) {
|
||||
std::vector<support_info_op> unsupported_ops;
|
||||
for (const auto & op : info.ops) {
|
||||
if (!op.is_accel) {
|
||||
unsupported_ops.push_back(op);
|
||||
}
|
||||
}
|
||||
if (!unsupported_ops.empty()) {
|
||||
LOG_WRN("%s: *****************************************************************\n", __func__);
|
||||
LOG_WRN("%s: WARNING: the CLIP graph uses unsupported operators by the backend\n", __func__);
|
||||
LOG_WRN("%s: the performance will be suboptimal \n", __func__);
|
||||
LOG_WRN("%s: list of unsupported ops (backend=%s):\n", __func__, ggml_backend_name(ctx_clip.backend));
|
||||
for (const auto & op : unsupported_ops) {
|
||||
LOG_WRN("%s: %16s: type = %s, ne = [%d %d %d %d]\n", __func__,
|
||||
ggml_op_name(op.op->op),
|
||||
ggml_type_name(op.op->type),
|
||||
op.op->ne[0], op.op->ne[1], op.op->ne[2], op.op->ne[3]);
|
||||
}
|
||||
LOG_WRN("%s: flash attention is %s\n", __func__,
|
||||
(ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
|
||||
LOG_WRN("%s: please report this on github as an issue\n", __func__);
|
||||
LOG_WRN("%s: ref: https://github.com/ggml-org/llama.cpp/pull/16837#issuecomment-3461676118\n", __func__);
|
||||
LOG_WRN("%s: *****************************************************************\n", __func__);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip) {
|
||||
const auto & hparams = ctx_clip.model.hparams;
|
||||
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
|
||||
|
||||
@@ -3223,57 +3363,95 @@ struct clip_model_loader {
|
||||
size / 1024.0 / 1024.0);
|
||||
}
|
||||
}
|
||||
|
||||
const int n_splits = ggml_backend_sched_get_n_splits(ctx_clip.sched.get());
|
||||
const int n_nodes = ggml_graph_n_nodes(gf);
|
||||
|
||||
LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes);
|
||||
|
||||
support_info_graph res {
|
||||
/*.fattn = */ true,
|
||||
/*.fattn_op = */ nullptr,
|
||||
/*.ops = */ {},
|
||||
};
|
||||
|
||||
// check op support
|
||||
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
||||
ggml_tensor * node = ggml_graph_node(gf, i);
|
||||
res.ops.push_back({node, true});
|
||||
if (!ggml_backend_supports_op(ctx_clip.backend, node)) {
|
||||
res.ops.back().is_accel = false;
|
||||
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||
res.fattn = false;
|
||||
res.fattn_op = node;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void get_bool(const std::string & key, bool & output, bool required = true) {
|
||||
void get_bool(const std::string & key, bool & output, bool required = true) const {
|
||||
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
||||
if (i < 0) {
|
||||
if (required) throw std::runtime_error("Key not found: " + key);
|
||||
if (required) {
|
||||
throw std::runtime_error("Key not found: " + key);
|
||||
}
|
||||
return;
|
||||
}
|
||||
output = gguf_get_val_bool(ctx_gguf.get(), i);
|
||||
}
|
||||
|
||||
void get_i32(const std::string & key, int & output, bool required = true) {
|
||||
void get_i32(const std::string & key, int & output, bool required = true) const {
|
||||
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
||||
if (i < 0) {
|
||||
if (required) throw std::runtime_error("Key not found: " + key);
|
||||
if (required) {
|
||||
throw std::runtime_error("Key not found: " + key);
|
||||
}
|
||||
return;
|
||||
}
|
||||
output = gguf_get_val_i32(ctx_gguf.get(), i);
|
||||
}
|
||||
|
||||
void get_u32(const std::string & key, int & output, bool required = true) {
|
||||
void get_u32(const std::string & key, int & output, bool required = true) const {
|
||||
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
||||
if (i < 0) {
|
||||
if (required) throw std::runtime_error("Key not found: " + key);
|
||||
if (required) {
|
||||
throw std::runtime_error("Key not found: " + key);
|
||||
}
|
||||
return;
|
||||
}
|
||||
output = gguf_get_val_u32(ctx_gguf.get(), i);
|
||||
}
|
||||
|
||||
void get_f32(const std::string & key, float & output, bool required = true) {
|
||||
void get_f32(const std::string & key, float & output, bool required = true) const {
|
||||
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
||||
if (i < 0) {
|
||||
if (required) throw std::runtime_error("Key not found: " + key);
|
||||
if (required) {
|
||||
throw std::runtime_error("Key not found: " + key);
|
||||
}
|
||||
return;
|
||||
}
|
||||
output = gguf_get_val_f32(ctx_gguf.get(), i);
|
||||
}
|
||||
|
||||
void get_string(const std::string & key, std::string & output, bool required = true) {
|
||||
void get_string(const std::string & key, std::string & output, bool required = true) const {
|
||||
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
||||
if (i < 0) {
|
||||
if (required) throw std::runtime_error("Key not found: " + key);
|
||||
if (required) {
|
||||
throw std::runtime_error("Key not found: " + key);
|
||||
}
|
||||
return;
|
||||
}
|
||||
output = std::string(gguf_get_val_str(ctx_gguf.get(), i));
|
||||
}
|
||||
|
||||
void get_arr_int(const std::string & key, std::vector<int> & output, bool required = true) {
|
||||
void get_arr_int(const std::string & key, std::vector<int> & output, bool required = true) const {
|
||||
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
||||
if (i < 0) {
|
||||
if (required) throw std::runtime_error("Key not found: " + key);
|
||||
if (required) {
|
||||
throw std::runtime_error("Key not found: " + key);
|
||||
}
|
||||
return;
|
||||
}
|
||||
int n = gguf_get_arr_n(ctx_gguf.get(), i);
|
||||
@@ -3284,7 +3462,7 @@ struct clip_model_loader {
|
||||
}
|
||||
}
|
||||
|
||||
void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) {
|
||||
static void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) {
|
||||
auto & hparams = model.hparams;
|
||||
for (int x = 1; x <= max_patches_per_side; x++) {
|
||||
for (int y = 1; y <= max_patches_per_side; y++) {
|
||||
@@ -3312,24 +3490,22 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
|
||||
ctx_vision = new clip_ctx(ctx_params);
|
||||
loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION);
|
||||
loader.load_tensors(*ctx_vision);
|
||||
loader.alloc_compute_meta(*ctx_vision);
|
||||
loader.warmup(*ctx_vision);
|
||||
}
|
||||
|
||||
if (loader.has_audio) {
|
||||
ctx_audio = new clip_ctx(ctx_params);
|
||||
loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO);
|
||||
loader.load_tensors(*ctx_audio);
|
||||
loader.alloc_compute_meta(*ctx_audio);
|
||||
loader.warmup(*ctx_audio);
|
||||
}
|
||||
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: failed to load model '%s': %s\n", __func__, fname, e.what());
|
||||
if (ctx_vision) {
|
||||
delete ctx_vision;
|
||||
}
|
||||
if (ctx_audio) {
|
||||
delete ctx_audio;
|
||||
}
|
||||
|
||||
delete ctx_vision;
|
||||
delete ctx_audio;
|
||||
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
|
||||
@@ -3367,10 +3543,10 @@ void clip_image_size_free(struct clip_image_size * load_image_size) {
|
||||
}
|
||||
delete load_image_size;
|
||||
}
|
||||
void clip_image_u8_free(struct clip_image_u8 * img) { if (img) delete img; }
|
||||
void clip_image_f32_free(struct clip_image_f32 * img) { if (img) delete img; }
|
||||
void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { if (batch) delete batch; }
|
||||
void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { if (batch) delete batch; }
|
||||
void clip_image_u8_free(struct clip_image_u8 * img) { delete img; }
|
||||
void clip_image_f32_free(struct clip_image_f32 * img) { delete img; }
|
||||
void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { delete batch; }
|
||||
void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { delete batch; }
|
||||
|
||||
size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch) {
|
||||
return batch->entries.size();
|
||||
@@ -4018,7 +4194,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||
case PROJECTOR_TYPE_QWEN25VL:
|
||||
case PROJECTOR_TYPE_QWEN3VL:
|
||||
{
|
||||
// step 1: make a blank canvas which aligns to the grid
|
||||
GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0);
|
||||
clip_image_u8 resized;
|
||||
const clip_image_size new_size = img_tool::calc_size_preserved_ratio(
|
||||
original_size,
|
||||
@@ -4096,10 +4272,22 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||
res_imgs->entries.push_back(std::move(img_f32));
|
||||
} break;
|
||||
|
||||
case PROJECTOR_TYPE_JANUS_PRO:
|
||||
{
|
||||
// Janus Pro preprocessing: pad to square with gray(127), resize to 384x384
|
||||
const std::array<uint8_t, 3> pad_color = {127, 127, 127};
|
||||
clip_image_u8 resized_image;
|
||||
int sz = params.image_size;
|
||||
img_tool::resize(*img, resized_image, {sz, sz}, img_tool::RESIZE_ALGO_BILINEAR, true, pad_color);
|
||||
clip_image_f32_ptr img_f32(clip_image_f32_init());
|
||||
normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std);
|
||||
res_imgs->entries.push_back(std::move(img_f32));
|
||||
} break;
|
||||
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||
{
|
||||
GGML_ASSERT(params.image_min_pixels && params.image_max_pixels);
|
||||
GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0);
|
||||
clip_image_u8 resized_image;
|
||||
// the original pixtral model doesn't have n_merge
|
||||
const int cur_merge = params.n_merge == 0 ? 1 : params.n_merge;
|
||||
@@ -4133,7 +4321,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
{
|
||||
GGML_ASSERT(params.image_min_pixels && params.image_max_pixels);
|
||||
GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0);
|
||||
const clip_image_size target_size = img_tool::calc_size_preserved_ratio(
|
||||
original_size,
|
||||
params.patch_size * params.n_merge,
|
||||
@@ -4272,6 +4460,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||
switch (proj) {
|
||||
case PROJECTOR_TYPE_MLP:
|
||||
case PROJECTOR_TYPE_MLP_NORM:
|
||||
case PROJECTOR_TYPE_JANUS_PRO:
|
||||
{
|
||||
// do nothing
|
||||
} break;
|
||||
@@ -4782,6 +4971,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_JANUS_PRO:
|
||||
case PROJECTOR_TYPE_COGVLM:
|
||||
{
|
||||
// do nothing
|
||||
@@ -4870,6 +5060,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
return ctx->model.mm_model_mlp_3_w->ne[1];
|
||||
case PROJECTOR_TYPE_QWEN2VL:
|
||||
case PROJECTOR_TYPE_QWEN25VL:
|
||||
case PROJECTOR_TYPE_JANUS_PRO:
|
||||
return ctx->model.mm_1_b->ne[0];
|
||||
case PROJECTOR_TYPE_QWEN3VL:
|
||||
// main path + deepstack paths
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
@@ -22,9 +23,18 @@ enum clip_modality {
|
||||
CLIP_MODALITY_AUDIO,
|
||||
};
|
||||
|
||||
enum clip_flash_attn_type {
|
||||
CLIP_FLASH_ATTN_TYPE_AUTO = -1,
|
||||
CLIP_FLASH_ATTN_TYPE_DISABLED = 0,
|
||||
CLIP_FLASH_ATTN_TYPE_ENABLED = 1,
|
||||
};
|
||||
|
||||
struct clip_context_params {
|
||||
bool use_gpu;
|
||||
enum ggml_log_level verbosity;
|
||||
enum clip_flash_attn_type flash_attn_type;
|
||||
int image_min_tokens;
|
||||
int image_max_tokens;
|
||||
};
|
||||
|
||||
struct clip_init_result {
|
||||
|
||||
@@ -132,10 +132,13 @@ struct mtmd_cli_context {
|
||||
void init_vision_context(common_params & params) {
|
||||
const char * clip_path = params.mmproj.path.c_str();
|
||||
mtmd_context_params mparams = mtmd_context_params_default();
|
||||
mparams.use_gpu = params.mmproj_use_gpu;
|
||||
mparams.print_timings = true;
|
||||
mparams.n_threads = params.cpuparams.n_threads;
|
||||
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||
mparams.use_gpu = params.mmproj_use_gpu;
|
||||
mparams.print_timings = true;
|
||||
mparams.n_threads = params.cpuparams.n_threads;
|
||||
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||
mparams.flash_attn_type = params.flash_attn_type;
|
||||
mparams.image_min_tokens = params.image_min_tokens;
|
||||
mparams.image_max_tokens = params.image_max_tokens;
|
||||
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
|
||||
if (!ctx_vision.get()) {
|
||||
LOG_ERR("Failed to load vision model from %s\n", clip_path);
|
||||
|
||||
+20
-6
@@ -19,7 +19,6 @@
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
// represents raw image data, layout is RGBRGBRGB...
|
||||
@@ -92,6 +91,15 @@ const char * mtmd_default_marker() {
|
||||
return "<__media__>";
|
||||
}
|
||||
|
||||
static clip_flash_attn_type mtmd_get_clip_flash_attn_type(enum llama_flash_attn_type flash_attn_type) {
|
||||
switch (flash_attn_type) {
|
||||
case LLAMA_FLASH_ATTN_TYPE_AUTO: return CLIP_FLASH_ATTN_TYPE_AUTO;
|
||||
case LLAMA_FLASH_ATTN_TYPE_DISABLED: return CLIP_FLASH_ATTN_TYPE_DISABLED;
|
||||
case LLAMA_FLASH_ATTN_TYPE_ENABLED: return CLIP_FLASH_ATTN_TYPE_ENABLED;
|
||||
}
|
||||
return CLIP_FLASH_ATTN_TYPE_AUTO;
|
||||
}
|
||||
|
||||
mtmd_context_params mtmd_context_params_default() {
|
||||
mtmd_context_params params;
|
||||
params.use_gpu = true;
|
||||
@@ -100,6 +108,9 @@ mtmd_context_params mtmd_context_params_default() {
|
||||
params.verbosity = GGML_LOG_LEVEL_INFO;
|
||||
params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
|
||||
params.media_marker = mtmd_default_marker();
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
|
||||
params.image_min_tokens = -1;
|
||||
params.image_max_tokens = -1;
|
||||
return params;
|
||||
}
|
||||
|
||||
@@ -162,8 +173,13 @@ struct mtmd_context {
|
||||
}
|
||||
|
||||
clip_context_params ctx_clip_params;
|
||||
ctx_clip_params.use_gpu = ctx_params.use_gpu;
|
||||
ctx_clip_params.verbosity = ctx_params.verbosity;
|
||||
ctx_clip_params.use_gpu = ctx_params.use_gpu;
|
||||
ctx_clip_params.verbosity = ctx_params.verbosity;
|
||||
ctx_clip_params.flash_attn_type = mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type);
|
||||
// custom image token limits
|
||||
ctx_clip_params.image_min_tokens = ctx_params.image_min_tokens;
|
||||
ctx_clip_params.image_max_tokens = ctx_params.image_max_tokens;
|
||||
|
||||
auto res = clip_init(mmproj_fname, ctx_clip_params);
|
||||
ctx_v = res.ctx_v;
|
||||
ctx_a = res.ctx_a;
|
||||
@@ -378,9 +394,7 @@ mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
|
||||
}
|
||||
|
||||
void mtmd_free(mtmd_context * ctx) {
|
||||
if (ctx) {
|
||||
delete ctx;
|
||||
}
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
struct mtmd_tokenizer {
|
||||
|
||||
@@ -82,6 +82,11 @@ struct mtmd_context_params {
|
||||
enum ggml_log_level verbosity;
|
||||
const char * image_marker; // deprecated, use media_marker instead
|
||||
const char * media_marker;
|
||||
enum llama_flash_attn_type flash_attn_type;
|
||||
|
||||
// limit number of image tokens, only for vision models with dynamic resolution
|
||||
int image_min_tokens; // minimum number of tokens for image input (default: read from metadata)
|
||||
int image_max_tokens; // maximum number of tokens for image input (default: read from metadata)
|
||||
};
|
||||
|
||||
MTMD_API const char * mtmd_default_marker(void);
|
||||
|
||||
Binary file not shown.
+80
-18
@@ -2407,7 +2407,7 @@ struct server_context {
|
||||
|
||||
params_dft.devices = params_base.speculative.devices;
|
||||
params_dft.model = params_base.speculative.model;
|
||||
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
|
||||
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx;
|
||||
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
||||
params_dft.n_parallel = 1;
|
||||
params_dft.cache_type_k = params_base.speculative.cache_type_k;
|
||||
@@ -2452,10 +2452,13 @@ struct server_context {
|
||||
std::string & mmproj_path = params_base.mmproj.path;
|
||||
if (!mmproj_path.empty()) {
|
||||
mtmd_context_params mparams = mtmd_context_params_default();
|
||||
mparams.use_gpu = params_base.mmproj_use_gpu;
|
||||
mparams.print_timings = false;
|
||||
mparams.n_threads = params_base.cpuparams.n_threads;
|
||||
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||
mparams.use_gpu = params_base.mmproj_use_gpu;
|
||||
mparams.print_timings = false;
|
||||
mparams.n_threads = params_base.cpuparams.n_threads;
|
||||
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||
mparams.flash_attn_type = params_base.flash_attn_type;
|
||||
mparams.image_min_tokens = params_base.image_min_tokens;
|
||||
mparams.image_max_tokens = params_base.image_max_tokens;
|
||||
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
|
||||
if (mctx == nullptr) {
|
||||
SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
|
||||
@@ -2495,10 +2498,16 @@ struct server_context {
|
||||
}
|
||||
|
||||
void init() {
|
||||
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
|
||||
|
||||
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
|
||||
|
||||
const int n_ctx_train = llama_model_n_ctx_train(model);
|
||||
|
||||
int n_ctx_slot = llama_n_ctx_seq(ctx);
|
||||
if (n_ctx_slot > n_ctx_train) {
|
||||
SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train);
|
||||
n_ctx_slot = n_ctx_train;
|
||||
}
|
||||
|
||||
for (int i = 0; i < params_base.n_parallel; i++) {
|
||||
server_slot slot;
|
||||
|
||||
@@ -2527,7 +2536,7 @@ struct server_context {
|
||||
}
|
||||
}
|
||||
|
||||
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
|
||||
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
|
||||
|
||||
slot.callback_on_release = [this](int) {
|
||||
queue_tasks.pop_deferred_task();
|
||||
@@ -2699,6 +2708,39 @@ struct server_context {
|
||||
return ret;
|
||||
}
|
||||
|
||||
// return true if at least one slot has been purged
|
||||
// TODO: improve logic
|
||||
// - smarter decision which slot to purge (LRU or longest prompt?)
|
||||
// - move slot to level 2 cache instead of removing?
|
||||
// - instead of purging, try to store and resume later?
|
||||
bool try_purge_idle_slots() {
|
||||
bool res = false;
|
||||
|
||||
if (!params_base.kv_unified) {
|
||||
return res;
|
||||
}
|
||||
|
||||
for (auto & slot : slots) {
|
||||
if (slot.is_processing()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (slot.prompt.n_tokens() > 0) {
|
||||
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||
slot.prompt.tokens.clear();
|
||||
|
||||
res = true;
|
||||
|
||||
// purge slots one by one
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
bool launch_slot_with_task(server_slot & slot, server_task && task) {
|
||||
slot.reset();
|
||||
|
||||
@@ -3635,9 +3677,10 @@ struct server_context {
|
||||
int32_t n_batch = llama_n_batch(ctx);
|
||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||
|
||||
// next, batch any pending prompts without exceeding n_batch
|
||||
float alora_scale = -1.0f;
|
||||
float alora_scale = -1.0f;
|
||||
size_t alora_disabled_id = 0;
|
||||
|
||||
// next, batch any pending prompts without exceeding n_batch
|
||||
if (params_base.cont_batching || batch.n_tokens == 0) {
|
||||
for (auto & slot : slots) {
|
||||
// check if we can batch this slot with the previous one
|
||||
@@ -3914,8 +3957,11 @@ struct server_context {
|
||||
|
||||
// truncate any tokens that are beyond n_past for this slot
|
||||
const llama_pos p0 = slot.prompt.tokens.pos_next();
|
||||
|
||||
SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
|
||||
|
||||
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
|
||||
SLT_WRN(slot, "failed to truncate tokens with position >= %d\n", p0);
|
||||
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||
|
||||
// there is no common part left
|
||||
@@ -3924,8 +3970,6 @@ struct server_context {
|
||||
slot.prompt.tokens.clear();
|
||||
}
|
||||
|
||||
SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
|
||||
|
||||
// check if we should process the image
|
||||
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
|
||||
// process the image
|
||||
@@ -4126,6 +4170,8 @@ struct server_context {
|
||||
std::string err;
|
||||
|
||||
if (n_batch == 1 && ret == 1) {
|
||||
// TODO: try to terminate only the largest active slot/sequence and continue with the rest
|
||||
// need to remove the tokens from the current batch too
|
||||
err = "Context size has been exceeded.";
|
||||
}
|
||||
|
||||
@@ -4141,17 +4187,23 @@ struct server_context {
|
||||
// TODO: handle ret == 2 (abort) when we start aborting
|
||||
|
||||
if (!err.empty()) {
|
||||
SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
|
||||
SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
|
||||
|
||||
for (auto & slot : slots) {
|
||||
send_error(slot, err);
|
||||
slot.release();
|
||||
if (slot.is_processing()) {
|
||||
send_error(slot, err);
|
||||
slot.release();
|
||||
}
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// retry with half the batch size to try to find a free slot in the KV cache
|
||||
n_batch /= 2;
|
||||
if (!try_purge_idle_slots()) {
|
||||
n_batch /= 2;
|
||||
}
|
||||
|
||||
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
|
||||
|
||||
@@ -4391,6 +4443,15 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// TODO: should we have a separate n_parallel parameter for the server?
|
||||
// https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177
|
||||
if (params.n_parallel == 1 && params.kv_unified == false) {
|
||||
LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true\n", __func__);
|
||||
|
||||
params.n_parallel = 4;
|
||||
params.kv_unified = true;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
// struct that contains llama context and inference
|
||||
@@ -4849,6 +4910,7 @@ int main(int argc, char ** argv) {
|
||||
json data = {
|
||||
{ "default_generation_settings", default_generation_settings_for_props },
|
||||
{ "total_slots", ctx_server.params_base.n_parallel },
|
||||
{ "model_alias", ctx_server.params_base.model_alias },
|
||||
{ "model_path", ctx_server.params_base.model.path },
|
||||
{ "modalities", json {
|
||||
{"vision", ctx_server.oai_parser_opt.allow_image},
|
||||
@@ -4944,7 +5006,7 @@ int main(int argc, char ** argv) {
|
||||
// Everything else, including multimodal completions.
|
||||
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
||||
}
|
||||
const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel;
|
||||
const size_t n_ctx_slot = ctx_server.slots.front().n_ctx;
|
||||
tasks.reserve(inputs.size());
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
auto n_prompt_tokens = inputs[i].size();
|
||||
|
||||
@@ -433,21 +433,21 @@ def test_context_size_exceeded_stream():
|
||||
@pytest.mark.parametrize(
|
||||
"n_batch,batch_count,reuse_cache",
|
||||
[
|
||||
(64, 15, False),
|
||||
(64, 3, False),
|
||||
(64, 1, True),
|
||||
]
|
||||
)
|
||||
def test_return_progresssss(n_batch, batch_count, reuse_cache):
|
||||
def test_return_progress(n_batch, batch_count, reuse_cache):
|
||||
global server
|
||||
server.n_batch = n_batch
|
||||
server.n_ctx = 2048
|
||||
server.n_ctx = 256
|
||||
server.n_slots = 1
|
||||
server.start()
|
||||
def make_cmpl_request():
|
||||
return server.make_stream_request("POST", "/chat/completions", data={
|
||||
"max_tokens": 10,
|
||||
"messages": [
|
||||
{"role": "user", "content": "This is a test" * 100},
|
||||
{"role": "user", "content": "This is a test" * 10},
|
||||
],
|
||||
"stream": True,
|
||||
"return_progress": True,
|
||||
|
||||
@@ -368,6 +368,37 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
|
||||
# assert match_regex(re_content, res.body["content"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"n_ctx,n_slots,n_predict_vals,expected_success",
|
||||
[
|
||||
(256, 4, [80, 40, 80, 80], [True, True, True, True]),
|
||||
(256, 4, [70, 70, 70, 70], [False, False, False, False]),
|
||||
(256, 4, [90, 90, 40, 90], [False, False, True, False]),
|
||||
(256, 4, [90, 90, 40, 75], [True, True, True, True]),
|
||||
],
|
||||
)
|
||||
def test_completion_unified(n_ctx, n_slots, n_predict_vals, expected_success):
|
||||
global server
|
||||
server.n_slots = n_slots
|
||||
server.kv_unified = True
|
||||
server.n_ctx = n_ctx
|
||||
server.start()
|
||||
prompt = "A"
|
||||
tasks = []
|
||||
for n_predict in n_predict_vals:
|
||||
tasks.append((server.make_request, ("POST", "/completion", {"prompt": prompt, "n_predict": n_predict})))
|
||||
results = parallel_function_calls(tasks)
|
||||
for res, n_predict, expect_ok in zip(results, n_predict_vals, expected_success):
|
||||
if expect_ok:
|
||||
assert res.status_code == 200
|
||||
assert "content" in res.body
|
||||
if "timings" in res.body:
|
||||
assert res.body["timings"]["predicted_n"] == n_predict
|
||||
else:
|
||||
assert res.status_code == 500
|
||||
assert "content" not in res.body
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt,n_predict,response_fields",
|
||||
[
|
||||
|
||||
@@ -18,7 +18,7 @@ def test_infill_without_input_extra():
|
||||
"input_suffix": "}\n",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Ann|small|shiny|Daddy)+", res.body["content"])
|
||||
assert match_regex("(Ann|small|shiny|Daddy|Jimmy)+", res.body["content"])
|
||||
|
||||
|
||||
def test_infill_with_input_extra():
|
||||
@@ -34,7 +34,7 @@ def test_infill_with_input_extra():
|
||||
"input_suffix": "}\n",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Dad|excited|park)+", res.body["content"])
|
||||
assert match_regex("(Dad|excited|park|Jimmy)+", res.body["content"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input_extra", [
|
||||
|
||||
@@ -78,6 +78,7 @@ class ServerProcess:
|
||||
server_embeddings: bool | None = False
|
||||
server_reranking: bool | None = False
|
||||
server_metrics: bool | None = False
|
||||
kv_unified: bool | None = False
|
||||
server_slots: bool | None = False
|
||||
pooling: str | None = None
|
||||
draft: int | None = None
|
||||
@@ -159,6 +160,8 @@ class ServerProcess:
|
||||
server_args.append("--reranking")
|
||||
if self.server_metrics:
|
||||
server_args.append("--metrics")
|
||||
if self.kv_unified:
|
||||
server_args.append("--kv-unified")
|
||||
if self.server_slots:
|
||||
server_args.append("--slots")
|
||||
else:
|
||||
|
||||
@@ -1212,7 +1212,7 @@ public:
|
||||
for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) {
|
||||
auto * chunk = tokens.map_idx_to_media[it->first].get();
|
||||
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
|
||||
map_idx_to_media[start_idx+it->first] = std::move(new_chunk);
|
||||
map_idx_to_media[start_idx + it->first] = std::move(new_chunk);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1244,6 +1244,7 @@ public:
|
||||
}
|
||||
|
||||
void clear() {
|
||||
map_idx_to_media.clear();
|
||||
tokens.clear();
|
||||
}
|
||||
|
||||
|
||||
Generated
+361
@@ -59,6 +59,7 @@
|
||||
"prettier-plugin-tailwindcss": "^0.6.11",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"remark-math": "^6.0.0",
|
||||
"sass": "^1.93.3",
|
||||
"storybook": "^9.0.17",
|
||||
"svelte": "^5.0.0",
|
||||
"svelte-check": "^4.0.0",
|
||||
@@ -1176,6 +1177,330 @@
|
||||
"node": ">= 8"
|
||||
}
|
||||
},
|
||||
"node_modules/@parcel/watcher": {
|
||||
"version": "2.5.1",
|
||||
"resolved": "https://registry.npmjs.org/@parcel/watcher/-/watcher-2.5.1.tgz",
|
||||
"integrity": "sha512-dfUnCxiN9H4ap84DvD2ubjw+3vUNpstxa0TneY/Paat8a3R4uQZDLSvWjmznAY/DoahqTHl9V46HF/Zs3F29pg==",
|
||||
"dev": true,
|
||||
"hasInstallScript": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"detect-libc": "^1.0.3",
|
||||
"is-glob": "^4.0.3",
|
||||
"micromatch": "^4.0.5",
|
||||
"node-addon-api": "^7.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 10.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/parcel"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@parcel/watcher-android-arm64": "2.5.1",
|
||||
"@parcel/watcher-darwin-arm64": "2.5.1",
|
||||
"@parcel/watcher-darwin-x64": "2.5.1",
|
||||
"@parcel/watcher-freebsd-x64": "2.5.1",
|
||||
"@parcel/watcher-linux-arm-glibc": "2.5.1",
|
||||
"@parcel/watcher-linux-arm-musl": "2.5.1",
|
||||
"@parcel/watcher-linux-arm64-glibc": "2.5.1",
|
||||
"@parcel/watcher-linux-arm64-musl": "2.5.1",
|
||||
"@parcel/watcher-linux-x64-glibc": "2.5.1",
|
||||
"@parcel/watcher-linux-x64-musl": "2.5.1",
|
||||
"@parcel/watcher-win32-arm64": "2.5.1",
|
||||
"@parcel/watcher-win32-ia32": "2.5.1",
|
||||
"@parcel/watcher-win32-x64": "2.5.1"
|
||||
}
|
||||
},
|
||||
"node_modules/@parcel/watcher-android-arm64": {
|
||||
"version": "2.5.1",
|
||||
"resolved": "https://registry.npmjs.org/@parcel/watcher-android-arm64/-/watcher-android-arm64-2.5.1.tgz",
|
||||
"integrity": "sha512-KF8+j9nNbUN8vzOFDpRMsaKBHZ/mcjEjMToVMJOhTozkDonQFFrRcfdLWn6yWKCmJKmdVxSgHiYvTCef4/qcBA==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"android"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/parcel"
|
||||
}
|
||||
},
|
||||
"node_modules/@parcel/watcher-darwin-arm64": {
|
||||
"version": "2.5.1",
|
||||
"resolved": "https://registry.npmjs.org/@parcel/watcher-darwin-arm64/-/watcher-darwin-arm64-2.5.1.tgz",
|
||||
"integrity": "sha512-eAzPv5osDmZyBhou8PoF4i6RQXAfeKL9tjb3QzYuccXFMQU0ruIc/POh30ePnaOyD1UXdlKguHBmsTs53tVoPw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/parcel"
|
||||
}
|
||||
},
|
||||
"node_modules/@parcel/watcher-darwin-x64": {
|
||||
"version": "2.5.1",
|
||||
"resolved": "https://registry.npmjs.org/@parcel/watcher-darwin-x64/-/watcher-darwin-x64-2.5.1.tgz",
|
||||
"integrity": "sha512-1ZXDthrnNmwv10A0/3AJNZ9JGlzrF82i3gNQcWOzd7nJ8aj+ILyW1MTxVk35Db0u91oD5Nlk9MBiujMlwmeXZg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/parcel"
|
||||
}
|
||||
},
|
||||
"node_modules/@parcel/watcher-freebsd-x64": {
|
||||
"version": "2.5.1",
|
||||
"resolved": "https://registry.npmjs.org/@parcel/watcher-freebsd-x64/-/watcher-freebsd-x64-2.5.1.tgz",
|
||||
"integrity": "sha512-SI4eljM7Flp9yPuKi8W0ird8TI/JK6CSxju3NojVI6BjHsTyK7zxA9urjVjEKJ5MBYC+bLmMcbAWlZ+rFkLpJQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"freebsd"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/parcel"
|
||||
}
|
||||
},
|
||||
"node_modules/@parcel/watcher-linux-arm-glibc": {
|
||||
"version": "2.5.1",
|
||||
"resolved": "https://registry.npmjs.org/@parcel/watcher-linux-arm-glibc/-/watcher-linux-arm-glibc-2.5.1.tgz",
|
||||
"integrity": "sha512-RCdZlEyTs8geyBkkcnPWvtXLY44BCeZKmGYRtSgtwwnHR4dxfHRG3gR99XdMEdQ7KeiDdasJwwvNSF5jKtDwdA==",
|
||||
"cpu": [
|
||||
"arm"
|
||||
],
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/parcel"
|
||||
}
|
||||
},
|
||||
"node_modules/@parcel/watcher-linux-arm-musl": {
|
||||
"version": "2.5.1",
|
||||
"resolved": "https://registry.npmjs.org/@parcel/watcher-linux-arm-musl/-/watcher-linux-arm-musl-2.5.1.tgz",
|
||||
"integrity": "sha512-6E+m/Mm1t1yhB8X412stiKFG3XykmgdIOqhjWj+VL8oHkKABfu/gjFj8DvLrYVHSBNC+/u5PeNrujiSQ1zwd1Q==",
|
||||
"cpu": [
|
||||
"arm"
|
||||
],
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/parcel"
|
||||
}
|
||||
},
|
||||
"node_modules/@parcel/watcher-linux-arm64-glibc": {
|
||||
"version": "2.5.1",
|
||||
"resolved": "https://registry.npmjs.org/@parcel/watcher-linux-arm64-glibc/-/watcher-linux-arm64-glibc-2.5.1.tgz",
|
||||
"integrity": "sha512-LrGp+f02yU3BN9A+DGuY3v3bmnFUggAITBGriZHUREfNEzZh/GO06FF5u2kx8x+GBEUYfyTGamol4j3m9ANe8w==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/parcel"
|
||||
}
|
||||
},
|
||||
"node_modules/@parcel/watcher-linux-arm64-musl": {
|
||||
"version": "2.5.1",
|
||||
"resolved": "https://registry.npmjs.org/@parcel/watcher-linux-arm64-musl/-/watcher-linux-arm64-musl-2.5.1.tgz",
|
||||
"integrity": "sha512-cFOjABi92pMYRXS7AcQv9/M1YuKRw8SZniCDw0ssQb/noPkRzA+HBDkwmyOJYp5wXcsTrhxO0zq1U11cK9jsFg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/parcel"
|
||||
}
|
||||
},
|
||||
"node_modules/@parcel/watcher-linux-x64-glibc": {
|
||||
"version": "2.5.1",
|
||||
"resolved": "https://registry.npmjs.org/@parcel/watcher-linux-x64-glibc/-/watcher-linux-x64-glibc-2.5.1.tgz",
|
||||
"integrity": "sha512-GcESn8NZySmfwlTsIur+49yDqSny2IhPeZfXunQi48DMugKeZ7uy1FX83pO0X22sHntJ4Ub+9k34XQCX+oHt2A==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/parcel"
|
||||
}
|
||||
},
|
||||
"node_modules/@parcel/watcher-linux-x64-musl": {
|
||||
"version": "2.5.1",
|
||||
"resolved": "https://registry.npmjs.org/@parcel/watcher-linux-x64-musl/-/watcher-linux-x64-musl-2.5.1.tgz",
|
||||
"integrity": "sha512-n0E2EQbatQ3bXhcH2D1XIAANAcTZkQICBPVaxMeaCVBtOpBZpWJuf7LwyWPSBDITb7In8mqQgJ7gH8CILCURXg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/parcel"
|
||||
}
|
||||
},
|
||||
"node_modules/@parcel/watcher-win32-arm64": {
|
||||
"version": "2.5.1",
|
||||
"resolved": "https://registry.npmjs.org/@parcel/watcher-win32-arm64/-/watcher-win32-arm64-2.5.1.tgz",
|
||||
"integrity": "sha512-RFzklRvmc3PkjKjry3hLF9wD7ppR4AKcWNzH7kXR7GUe0Igb3Nz8fyPwtZCSquGrhU5HhUNDr/mKBqj7tqA2Vw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/parcel"
|
||||
}
|
||||
},
|
||||
"node_modules/@parcel/watcher-win32-ia32": {
|
||||
"version": "2.5.1",
|
||||
"resolved": "https://registry.npmjs.org/@parcel/watcher-win32-ia32/-/watcher-win32-ia32-2.5.1.tgz",
|
||||
"integrity": "sha512-c2KkcVN+NJmuA7CGlaGD1qJh1cLfDnQsHjE89E60vUEMlqduHGCdCLJCID5geFVM0dOtA3ZiIO8BoEQmzQVfpQ==",
|
||||
"cpu": [
|
||||
"ia32"
|
||||
],
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/parcel"
|
||||
}
|
||||
},
|
||||
"node_modules/@parcel/watcher-win32-x64": {
|
||||
"version": "2.5.1",
|
||||
"resolved": "https://registry.npmjs.org/@parcel/watcher-win32-x64/-/watcher-win32-x64-2.5.1.tgz",
|
||||
"integrity": "sha512-9lHBdJITeNR++EvSQVUcaZoWupyHfXe1jZvGZ06O/5MflPcuPLtEphScIBL+AiCWBO46tDSHzWyD0uDmmZqsgA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/parcel"
|
||||
}
|
||||
},
|
||||
"node_modules/@parcel/watcher/node_modules/detect-libc": {
|
||||
"version": "1.0.3",
|
||||
"resolved": "https://registry.npmjs.org/detect-libc/-/detect-libc-1.0.3.tgz",
|
||||
"integrity": "sha512-pGjwhsmsp4kL2RTz08wcOlGN83otlqHeD/Z5T8GXZB+/YcpQ/dgo+lbU8ZsGxV0HIvqqxo9l7mqYwyYMD9bKDg==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"bin": {
|
||||
"detect-libc": "bin/detect-libc.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=0.10"
|
||||
}
|
||||
},
|
||||
"node_modules/@playwright/test": {
|
||||
"version": "1.54.1",
|
||||
"resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.54.1.tgz",
|
||||
@@ -4697,6 +5022,13 @@
|
||||
"node": ">= 4"
|
||||
}
|
||||
},
|
||||
"node_modules/immutable": {
|
||||
"version": "5.1.4",
|
||||
"resolved": "https://registry.npmjs.org/immutable/-/immutable-5.1.4.tgz",
|
||||
"integrity": "sha512-p6u1bG3YSnINT5RQmx/yRZBpenIl30kVxkTLDyHLIMk0gict704Q9n+thfDI7lTRm9vXdDYutVzXhzcThxTnXA==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/import-fresh": {
|
||||
"version": "3.3.1",
|
||||
"resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.1.tgz",
|
||||
@@ -6462,6 +6794,14 @@
|
||||
"tslib": "^2.0.3"
|
||||
}
|
||||
},
|
||||
"node_modules/node-addon-api": {
|
||||
"version": "7.1.1",
|
||||
"resolved": "https://registry.npmjs.org/node-addon-api/-/node-addon-api-7.1.1.tgz",
|
||||
"integrity": "sha512-5m3bsyrjFWE1xf7nz7YXdN4udnVtXK6/Yfgn5qnahL6bCkf2yKt4k3nuTKAtT4r3IG8JNR2ncsIMdZuAzJjHQQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true
|
||||
},
|
||||
"node_modules/object-inspect": {
|
||||
"version": "1.13.4",
|
||||
"resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.4.tgz",
|
||||
@@ -7484,6 +7824,27 @@
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/sass": {
|
||||
"version": "1.93.3",
|
||||
"resolved": "https://registry.npmjs.org/sass/-/sass-1.93.3.tgz",
|
||||
"integrity": "sha512-elOcIZRTM76dvxNAjqYrucTSI0teAF/L2Lv0s6f6b7FOwcwIuA357bIE871580AjHJuSvLIRUosgV+lIWx6Rgg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"chokidar": "^4.0.0",
|
||||
"immutable": "^5.0.2",
|
||||
"source-map-js": ">=0.6.2 <2.0.0"
|
||||
},
|
||||
"bin": {
|
||||
"sass": "sass.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@parcel/watcher": "^2.4.1"
|
||||
}
|
||||
},
|
||||
"node_modules/scheduler": {
|
||||
"version": "0.26.0",
|
||||
"resolved": "https://registry.npmjs.org/scheduler/-/scheduler-0.26.0.tgz",
|
||||
|
||||
@@ -61,6 +61,7 @@
|
||||
"prettier-plugin-tailwindcss": "^0.6.11",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"remark-math": "^6.0.0",
|
||||
"sass": "^1.93.3",
|
||||
"storybook": "^9.0.17",
|
||||
"svelte": "^5.0.0",
|
||||
"svelte-check": "^4.0.0",
|
||||
|
||||
+9
@@ -134,6 +134,15 @@
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (open) {
|
||||
pdfImages = [];
|
||||
pdfImagesLoading = false;
|
||||
pdfImagesError = null;
|
||||
pdfViewMode = 'pages';
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (open && isPdf && pdfViewMode === 'pages') {
|
||||
loadPdfImages();
|
||||
|
||||
@@ -8,8 +8,9 @@
|
||||
import rehypeKatex from 'rehype-katex';
|
||||
import rehypeStringify from 'rehype-stringify';
|
||||
import { copyCodeToClipboard } from '$lib/utils/copy';
|
||||
import { preprocessLaTeX } from '$lib/utils/latex-protection';
|
||||
import { browser } from '$app/environment';
|
||||
import 'katex/dist/katex.min.css';
|
||||
import '$styles/katex-custom.scss';
|
||||
|
||||
import githubDarkCss from 'highlight.js/styles/github-dark.css?inline';
|
||||
import githubLightCss from 'highlight.js/styles/github.css?inline';
|
||||
@@ -176,19 +177,9 @@
|
||||
return mutated ? tempDiv.innerHTML : html;
|
||||
}
|
||||
|
||||
function normalizeMathDelimiters(text: string): string {
|
||||
return text
|
||||
.replace(/(^|[^\\])\\\[((?:\\.|[\s\S])*?)\\\]/g, (_, prefix: string, content: string) => {
|
||||
return `${prefix}$$${content}$$`;
|
||||
})
|
||||
.replace(/(^|[^\\])\\\(((?:\\.|[\s\S])*?)\\\)/g, (_, prefix: string, content: string) => {
|
||||
return `${prefix}$${content}$`;
|
||||
});
|
||||
}
|
||||
|
||||
async function processMarkdown(text: string): Promise<string> {
|
||||
try {
|
||||
const normalized = normalizeMathDelimiters(text);
|
||||
let normalized = preprocessLaTeX(text);
|
||||
const result = await processor().process(normalized);
|
||||
const html = String(result);
|
||||
const enhancedLinks = enhanceLinks(html);
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
/**
|
||||
* Matches common Markdown code blocks to exclude them from further processing (e.g. LaTeX).
|
||||
* - Fenced: ```...```
|
||||
* - Inline: `...` (does NOT support nested backticks or multi-backtick syntax)
|
||||
*
|
||||
* Note: This pattern does not handle advanced cases like:
|
||||
* `` `code with `backticks` `` or \\``...\\``
|
||||
*/
|
||||
export const CODE_BLOCK_REGEXP = /(```[\s\S]*?```|`[^`\n]+`)/g;
|
||||
|
||||
/**
|
||||
* Matches LaTeX math delimiters \(...\) and \[...\] only when not preceded by a backslash (i.e., not escaped),
|
||||
* while also capturing code blocks (```, `...`) so they can be skipped during processing.
|
||||
*
|
||||
* Uses negative lookbehind `(?<!\\)` to avoid matching \\( or \\[.
|
||||
* Using the look‑behind pattern `(?<!\\)` we skip matches
|
||||
* that are preceded by a backslash, e.g.
|
||||
* `Definitions\\(also called macros)` (title of chapter 20 in The TeXbook)
|
||||
* or `\\[4pt]` (LaTeX line-break).
|
||||
*
|
||||
* group 1: code-block
|
||||
* group 2: square-bracket
|
||||
* group 3: round-bracket
|
||||
*/
|
||||
export const LATEX_MATH_AND_CODE_PATTERN =
|
||||
/(```[\S\s]*?```|`.*?`)|(?<!\\)\\\[([\S\s]*?[^\\])\\]|(?<!\\)\\\((.*?)\\\)/g;
|
||||
|
||||
/** Regex to capture the content of a $$...\\\\...$$ block (display-formula with line-break) */
|
||||
export const LATEX_LINEBREAK_REGEXP = /\$\$([\s\S]*?\\\\[\s\S]*?)\$\$/;
|
||||
|
||||
/** map from mchem-regexp to replacement */
|
||||
export const MHCHEM_PATTERN_MAP: readonly [RegExp, string][] = [
|
||||
[/(\s)\$\\ce{/g, '$1$\\\\ce{'],
|
||||
[/(\s)\$\\pu{/g, '$1$\\\\pu{']
|
||||
] as const;
|
||||
@@ -99,6 +99,9 @@ class ServerStore {
|
||||
}
|
||||
|
||||
get modelName(): string | null {
|
||||
if (this._serverProps?.model_alias) {
|
||||
return this._serverProps.model_alias;
|
||||
}
|
||||
if (!this._serverProps?.model_path) return null;
|
||||
return this._serverProps.model_path.split(/(\\|\/)/).pop() || null;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,355 @@
|
||||
/* eslint-disable no-irregular-whitespace */
|
||||
import { describe, it, expect, test } from 'vitest';
|
||||
import { maskInlineLaTeX, preprocessLaTeX } from './latex-protection';
|
||||
|
||||
describe('maskInlineLaTeX', () => {
|
||||
it('should protect LaTeX $x + y$ but not money $3.99', () => {
|
||||
const latexExpressions: string[] = [];
|
||||
const input = 'I have $10, $3.99 and $x + y$ and $100x$. The amount is $2,000.';
|
||||
const output = maskInlineLaTeX(input, latexExpressions);
|
||||
|
||||
expect(output).toBe('I have $10, $3.99 and <<LATEX_0>> and <<LATEX_1>>. The amount is $2,000.');
|
||||
expect(latexExpressions).toEqual(['$x + y$', '$100x$']);
|
||||
});
|
||||
|
||||
it('should ignore money like $5 and $12.99', () => {
|
||||
const latexExpressions: string[] = [];
|
||||
const input = 'Prices are $12.99 and $5. Tax?';
|
||||
const output = maskInlineLaTeX(input, latexExpressions);
|
||||
|
||||
expect(output).toBe('Prices are $12.99 and $5. Tax?');
|
||||
expect(latexExpressions).toEqual([]);
|
||||
});
|
||||
|
||||
it('should protect inline math $a^2 + b^2$ even after text', () => {
|
||||
const latexExpressions: string[] = [];
|
||||
const input = 'Pythagorean: $a^2 + b^2 = c^2$.';
|
||||
const output = maskInlineLaTeX(input, latexExpressions);
|
||||
|
||||
expect(output).toBe('Pythagorean: <<LATEX_0>>.');
|
||||
expect(latexExpressions).toEqual(['$a^2 + b^2 = c^2$']);
|
||||
});
|
||||
|
||||
it('should not protect math that has letter after closing $ (e.g. units)', () => {
|
||||
const latexExpressions: string[] = [];
|
||||
const input = 'The cost is $99 and change.';
|
||||
const output = maskInlineLaTeX(input, latexExpressions);
|
||||
|
||||
expect(output).toBe('The cost is $99 and change.');
|
||||
expect(latexExpressions).toEqual([]);
|
||||
});
|
||||
|
||||
it('should allow $x$ followed by punctuation', () => {
|
||||
const latexExpressions: string[] = [];
|
||||
const input = 'We know $x$, right?';
|
||||
const output = maskInlineLaTeX(input, latexExpressions);
|
||||
|
||||
expect(output).toBe('We know <<LATEX_0>>, right?');
|
||||
expect(latexExpressions).toEqual(['$x$']);
|
||||
});
|
||||
|
||||
it('should work across multiple lines', () => {
|
||||
const latexExpressions: string[] = [];
|
||||
const input = `Emma buys cupcakes for $3 each.\nHow much is $x + y$?`;
|
||||
const output = maskInlineLaTeX(input, latexExpressions);
|
||||
|
||||
expect(output).toBe(`Emma buys cupcakes for $3 each.\nHow much is <<LATEX_0>>?`);
|
||||
expect(latexExpressions).toEqual(['$x + y$']);
|
||||
});
|
||||
|
||||
it('should not protect $100 but protect $matrix$', () => {
|
||||
const latexExpressions: string[] = [];
|
||||
const input = '$100 and $\\mathrm{GL}_2(\\mathbb{F}_7)$ are different.';
|
||||
const output = maskInlineLaTeX(input, latexExpressions);
|
||||
|
||||
expect(output).toBe('$100 and <<LATEX_0>> are different.');
|
||||
expect(latexExpressions).toEqual(['$\\mathrm{GL}_2(\\mathbb{F}_7)$']);
|
||||
});
|
||||
|
||||
it('should skip if $ is followed by digit and alphanumeric after close (money)', () => {
|
||||
const latexExpressions: string[] = [];
|
||||
const input = 'I paid $5 quickly.';
|
||||
const output = maskInlineLaTeX(input, latexExpressions);
|
||||
|
||||
expect(output).toBe('I paid $5 quickly.');
|
||||
expect(latexExpressions).toEqual([]);
|
||||
});
|
||||
|
||||
it('should protect LaTeX even with special chars inside', () => {
|
||||
const latexExpressions: string[] = [];
|
||||
const input = 'Consider $\\alpha_1 + \\beta_2$ now.';
|
||||
const output = maskInlineLaTeX(input, latexExpressions);
|
||||
|
||||
expect(output).toBe('Consider <<LATEX_0>> now.');
|
||||
expect(latexExpressions).toEqual(['$\\alpha_1 + \\beta_2$']);
|
||||
});
|
||||
|
||||
it('short text', () => {
|
||||
const latexExpressions: string[] = ['$0$'];
|
||||
const input = '$a$\n$a$ and $b$';
|
||||
const output = maskInlineLaTeX(input, latexExpressions);
|
||||
|
||||
expect(output).toBe('<<LATEX_1>>\n<<LATEX_2>> and <<LATEX_3>>');
|
||||
expect(latexExpressions).toEqual(['$0$', '$a$', '$a$', '$b$']);
|
||||
});
|
||||
|
||||
it('empty text', () => {
|
||||
const latexExpressions: string[] = [];
|
||||
const input = '$\n$$\n';
|
||||
const output = maskInlineLaTeX(input, latexExpressions);
|
||||
|
||||
expect(output).toBe('$\n$$\n');
|
||||
expect(latexExpressions).toEqual([]);
|
||||
});
|
||||
|
||||
it('LaTeX-spacer preceded by backslash', () => {
|
||||
const latexExpressions: string[] = [];
|
||||
const input = `\\[
|
||||
\\boxed{
|
||||
\\begin{aligned}
|
||||
N_{\\text{att}}^{\\text{(MHA)}} &=
|
||||
h \\bigl[\\, d_{\\text{model}}\\;d_{k} + d_{\\text{model}}\\;d_{v}\\, \\bigr] && (\\text{Q,K,V の重み})\\\\
|
||||
&\\quad+ h(d_{k}+d_{k}+d_{v}) && (\\text{バイアス Q,K,V)}\\\\[4pt]
|
||||
&\\quad+ (h d_{v})\\, d_{\\text{model}} && (\\text{出力射影 }W^{O})\\\\
|
||||
&\\quad+ d_{\\text{model}} && (\\text{バイアス }b^{O})
|
||||
\\end{aligned}}
|
||||
\\]`;
|
||||
const output = maskInlineLaTeX(input, latexExpressions);
|
||||
|
||||
expect(output).toBe(input);
|
||||
expect(latexExpressions).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('preprocessLaTeX', () => {
|
||||
test('converts inline \\( ... \\) to $...$', () => {
|
||||
const input =
|
||||
'\\( \\mathrm{GL}_2(\\mathbb{F}_7) \\): Group of invertible matrices with entries in \\(\\mathbb{F}_7\\).';
|
||||
const output = preprocessLaTeX(input);
|
||||
expect(output).toBe(
|
||||
'$ \\mathrm{GL}_2(\\mathbb{F}_7) $: Group of invertible matrices with entries in $\\mathbb{F}_7$.'
|
||||
);
|
||||
});
|
||||
|
||||
test("don't inline \\\\( ... \\) to $...$", () => {
|
||||
const input =
|
||||
'Chapter 20 of The TeXbook, in source "Definitions\\\\(also called Macros)", containst the formula \\((x_1,\\ldots,x_n)\\).';
|
||||
const output = preprocessLaTeX(input);
|
||||
expect(output).toBe(
|
||||
'Chapter 20 of The TeXbook, in source "Definitions\\\\(also called Macros)", containst the formula $(x_1,\\ldots,x_n)$.'
|
||||
);
|
||||
});
|
||||
|
||||
test('preserves display math \\[ ... \\] and protects adjacent text', () => {
|
||||
const input = `Some kernel of \\(\\mathrm{SL}_2(\\mathbb{F}_7)\\):
|
||||
\\[
|
||||
\\left\\{ \\begin{pmatrix} 1 & 0 \\\\ 0 & 1 \\end{pmatrix}, \\begin{pmatrix} -1 & 0 \\\\ 0 & -1 \\end{pmatrix} \\right\\} = \\{\\pm I\\}
|
||||
\\]`;
|
||||
const output = preprocessLaTeX(input);
|
||||
|
||||
expect(output).toBe(`Some kernel of $\\mathrm{SL}_2(\\mathbb{F}_7)$:
|
||||
$$
|
||||
\\left\\{ \\begin{pmatrix} 1 & 0 \\\\ 0 & 1 \\end{pmatrix}, \\begin{pmatrix} -1 & 0 \\\\ 0 & -1 \\end{pmatrix} \\right\\} = \\{\\pm I\\}
|
||||
$$`);
|
||||
});
|
||||
|
||||
test('handles standalone display math equation', () => {
|
||||
const input = `Algebra:
|
||||
\\[
|
||||
x = \\frac{-b \\pm \\sqrt{\\,b^{2}-4ac\\,}}{2a}
|
||||
\\]`;
|
||||
const output = preprocessLaTeX(input);
|
||||
|
||||
expect(output).toBe(`Algebra:
|
||||
$$
|
||||
x = \\frac{-b \\pm \\sqrt{\\,b^{2}-4ac\\,}}{2a}
|
||||
$$`);
|
||||
});
|
||||
|
||||
test('does not interpret currency values as LaTeX', () => {
|
||||
const input = 'I have $10, $3.99 and $x + y$ and $100x$. The amount is $2,000.';
|
||||
const output = preprocessLaTeX(input);
|
||||
|
||||
expect(output).toBe('I have \\$10, \\$3.99 and $x + y$ and $100x$. The amount is \\$2,000.');
|
||||
});
|
||||
|
||||
test('ignores dollar signs followed by digits (money), but keeps valid math $x + y$', () => {
|
||||
const input = 'I have $10, $3.99 and $x + y$ and $100x$. The amount is $2,000.';
|
||||
const output = preprocessLaTeX(input);
|
||||
|
||||
expect(output).toBe('I have \\$10, \\$3.99 and $x + y$ and $100x$. The amount is \\$2,000.');
|
||||
});
|
||||
|
||||
test('handles real-world word problems with amounts and no math delimiters', () => {
|
||||
const input =
|
||||
'Emma buys 2 cupcakes for $3 each and 1 cookie for $1.50. How much money does she spend in total?';
|
||||
const output = preprocessLaTeX(input);
|
||||
|
||||
expect(output).toBe(
|
||||
'Emma buys 2 cupcakes for \\$3 each and 1 cookie for \\$1.50. How much money does she spend in total?'
|
||||
);
|
||||
});
|
||||
|
||||
test('handles decimal amounts in word problem correctly', () => {
|
||||
const input =
|
||||
'Maria has $20. She buys a notebook for $4.75 and a pack of pencils for $3.25. How much change does she receive?';
|
||||
const output = preprocessLaTeX(input);
|
||||
|
||||
expect(output).toBe(
|
||||
'Maria has \\$20. She buys a notebook for \\$4.75 and a pack of pencils for \\$3.25. How much change does she receive?'
|
||||
);
|
||||
});
|
||||
|
||||
test('preserves display math with surrounding non-ASCII text', () => {
|
||||
const input = `1 kg の質量は
|
||||
\\[
|
||||
E = (1\\ \\text{kg}) \\times (3.0 \\times 10^8\\ \\text{m/s})^2 \\approx 9.0 \\times 10^{16}\\ \\text{J}
|
||||
\\]
|
||||
というエネルギーに相当します。これは約 21 百万トンの TNT が爆発したときのエネルギーに匹敵します。`;
|
||||
const output = preprocessLaTeX(input);
|
||||
|
||||
expect(output).toBe(
|
||||
`1 kg の質量は
|
||||
$$
|
||||
E = (1\\ \\text{kg}) \\times (3.0 \\times 10^8\\ \\text{m/s})^2 \\approx 9.0 \\times 10^{16}\\ \\text{J}
|
||||
$$
|
||||
というエネルギーに相当します。これは約 21 百万トンの TNT が爆発したときのエネルギーに匹敵します。`
|
||||
);
|
||||
});
|
||||
|
||||
test('LaTeX-spacer preceded by backslash', () => {
|
||||
const input = `\\[
|
||||
\\boxed{
|
||||
\\begin{aligned}
|
||||
N_{\\text{att}}^{\\text{(MHA)}} &=
|
||||
h \\bigl[\\, d_{\\text{model}}\\;d_{k} + d_{\\text{model}}\\;d_{v}\\, \\bigr] && (\\text{Q,K,V の重み})\\\\
|
||||
&\\quad+ h(d_{k}+d_{k}+d_{v}) && (\\text{バイアス Q,K,V)}\\\\[4pt]
|
||||
&\\quad+ (h d_{v})\\, d_{\\text{model}} && (\\text{出力射影 }W^{O})\\\\
|
||||
&\\quad+ d_{\\text{model}} && (\\text{バイアス }b^{O})
|
||||
\\end{aligned}}
|
||||
\\]`;
|
||||
const output = preprocessLaTeX(input);
|
||||
expect(output).toBe(
|
||||
`$$
|
||||
\\boxed{
|
||||
\\begin{aligned}
|
||||
N_{\\text{att}}^{\\text{(MHA)}} &=
|
||||
h \\bigl[\\, d_{\\text{model}}\\;d_{k} + d_{\\text{model}}\\;d_{v}\\, \\bigr] && (\\text{Q,K,V の重み})\\\\
|
||||
&\\quad+ h(d_{k}+d_{k}+d_{v}) && (\\text{バイアス Q,K,V)}\\\\[4pt]
|
||||
&\\quad+ (h d_{v})\\, d_{\\text{model}} && (\\text{出力射影 }W^{O})\\\\
|
||||
&\\quad+ d_{\\text{model}} && (\\text{バイアス }b^{O})
|
||||
\\end{aligned}}
|
||||
$$`
|
||||
);
|
||||
});
|
||||
|
||||
test('converts \\[ ... \\] even when preceded by text without space', () => {
|
||||
const input = 'Some line ...\nAlgebra: \\[x = \\frac{-b \\pm \\sqrt{\\,b^{2}-4ac\\,}}{2a}\\]';
|
||||
const output = preprocessLaTeX(input);
|
||||
|
||||
expect(output).toBe(
|
||||
'Some line ...\nAlgebra: \n$$x = \\frac{-b \\pm \\sqrt{\\,b^{2}-4ac\\,}}{2a}$$\n'
|
||||
);
|
||||
});
|
||||
|
||||
test('converts \\[ ... \\] in table-cells', () => {
|
||||
const input = `| ID | Expression |\n| #1 | \\[
|
||||
x = \\frac{-b \\pm \\sqrt{\\,b^{2}-4ac\\,}}{2a}
|
||||
\\] |`;
|
||||
const output = preprocessLaTeX(input);
|
||||
|
||||
expect(output).toBe(
|
||||
'| ID | Expression |\n| #1 | $x = \\frac{-b \\pm \\sqrt{\\,b^{2}-4ac\\,}}{2a}$ |'
|
||||
);
|
||||
});
|
||||
|
||||
test('escapes isolated $ before digits ($5 → \\$5), but not valid math', () => {
|
||||
const input = 'This costs $5 and this is math $x^2$. $100 is money.';
|
||||
const output = preprocessLaTeX(input);
|
||||
|
||||
expect(output).toBe('This costs \\$5 and this is math $x^2$. \\$100 is money.');
|
||||
// Note: Since $x^2$ is detected as valid LaTeX, it's preserved.
|
||||
// $5 becomes \$5 only *after* real math is masked — but here it's correct because the masking logic avoids treating $5 as math.
|
||||
});
|
||||
|
||||
test('display with LaTeX-line-breaks', () => {
|
||||
const input = String.raw`- Algebraic topology, Homotopy Groups of $\mathbb{S}^3$:
|
||||
$$\pi_n(\mathbb{S}^3) = \begin{cases}
|
||||
\mathbb{Z} & n = 3 \\
|
||||
0 & n > 3, n \neq 4 \\
|
||||
\mathbb{Z}_2 & n = 4 \\
|
||||
\end{cases}$$`;
|
||||
const output = preprocessLaTeX(input);
|
||||
// If the formula contains '\\' the $$-delimiters should be in their own line.
|
||||
expect(output).toBe(`- Algebraic topology, Homotopy Groups of $\\mathbb{S}^3$:
|
||||
$$\n\\pi_n(\\mathbb{S}^3) = \\begin{cases}
|
||||
\\mathbb{Z} & n = 3 \\\\
|
||||
0 & n > 3, n \\neq 4 \\\\
|
||||
\\mathbb{Z}_2 & n = 4 \\\\
|
||||
\\end{cases}\n$$`);
|
||||
});
|
||||
|
||||
test('handles mhchem notation safely if present', () => {
|
||||
const input = 'Chemical reaction: \\( \\ce{H2O} \\) and $\\ce{CO2}$';
|
||||
const output = preprocessLaTeX(input);
|
||||
|
||||
expect(output).toBe('Chemical reaction: $ \\ce{H2O} $ and $\\ce{CO2}$');
|
||||
});
|
||||
|
||||
test('preserves code blocks', () => {
|
||||
const input = 'Inline code: `sum $total` and block:\n```\ndollar $amount\n```\nEnd.';
|
||||
const output = preprocessLaTeX(input);
|
||||
|
||||
expect(output).toBe(input); // Code blocks prevent misinterpretation
|
||||
});
|
||||
|
||||
test('escape backslash in mchem ce', () => {
|
||||
const input = 'mchem ce:\n$\\ce{2H2(g) + O2(g) -> 2H2O(l)}$';
|
||||
const output = preprocessLaTeX(input);
|
||||
|
||||
// mhchem-escape would insert a backslash here.
|
||||
expect(output).toBe('mchem ce:\n$\\ce{2H2(g) + O2(g) -> 2H2O(l)}$');
|
||||
});
|
||||
|
||||
test('escape backslash in mchem pu', () => {
|
||||
const input = 'mchem pu:\n$\\pu{-572 kJ mol^{-1}}$';
|
||||
const output = preprocessLaTeX(input);
|
||||
|
||||
// mhchem-escape would insert a backslash here.
|
||||
expect(output).toBe('mchem pu:\n$\\pu{-572 kJ mol^{-1}}$');
|
||||
});
|
||||
|
||||
test('LaTeX in blockquotes with display math', () => {
|
||||
const input =
|
||||
'> **Definition (limit):** \n> \\[\n> \\lim_{x\\to a} f(x) = L\n> \\]\n> means that as \\(x\\) gets close to \\(a\\).';
|
||||
const output = preprocessLaTeX(input);
|
||||
|
||||
// Blockquote markers should be preserved, LaTeX should be converted
|
||||
expect(output).toContain('> **Definition (limit):**');
|
||||
expect(output).toContain('$$');
|
||||
expect(output).toContain('$x$');
|
||||
expect(output).not.toContain('\\[');
|
||||
expect(output).not.toContain('\\]');
|
||||
expect(output).not.toContain('\\(');
|
||||
expect(output).not.toContain('\\)');
|
||||
});
|
||||
|
||||
test('LaTeX in blockquotes with inline math', () => {
|
||||
const input =
|
||||
"> The derivative \\(f'(x)\\) at point \\(x=a\\) measures slope.\n> Formula: \\(f'(a)=\\lim_{h\\to 0}\\frac{f(a+h)-f(a)}{h}\\)";
|
||||
const output = preprocessLaTeX(input);
|
||||
|
||||
// Blockquote markers should be preserved, inline LaTeX converted to $...$
|
||||
expect(output).toContain("> The derivative $f'(x)$ at point $x=a$ measures slope.");
|
||||
expect(output).toContain("> Formula: $f'(a)=\\lim_{h\\to 0}\\frac{f(a+h)-f(a)}{h}$");
|
||||
});
|
||||
|
||||
test('Mixed content with blockquotes and regular text', () => {
|
||||
const input =
|
||||
'Regular text with \\(x^2\\).\n\n> Quote with \\(y^2\\).\n\nMore text with \\(z^2\\).';
|
||||
const output = preprocessLaTeX(input);
|
||||
|
||||
// All LaTeX should be converted, blockquote markers preserved
|
||||
expect(output).toBe('Regular text with $x^2$.\n\n> Quote with $y^2$.\n\nMore text with $z^2$.');
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,267 @@
|
||||
import {
|
||||
CODE_BLOCK_REGEXP,
|
||||
LATEX_MATH_AND_CODE_PATTERN,
|
||||
LATEX_LINEBREAK_REGEXP,
|
||||
MHCHEM_PATTERN_MAP
|
||||
} from '$lib/constants/latex-protection';
|
||||
|
||||
/**
|
||||
* Replaces inline LaTeX expressions enclosed in `$...$` with placeholders, avoiding dollar signs
|
||||
* that appear to be part of monetary values or identifiers.
|
||||
*
|
||||
* This function processes the input line by line and skips `$` sequences that are likely
|
||||
* part of money amounts (e.g., `$5`, `$100.99`) or code-like tokens (e.g., `var$`, `$var`).
|
||||
* Valid LaTeX inline math is replaced with a placeholder like `<<LATEX_0>>`, and the
|
||||
* actual LaTeX content is stored in the provided `latexExpressions` array.
|
||||
*
|
||||
* @param content - The input text potentially containing LaTeX expressions.
|
||||
* @param latexExpressions - An array used to collect extracted LaTeX expressions.
|
||||
* @returns The processed string with LaTeX replaced by placeholders.
|
||||
*/
|
||||
export function maskInlineLaTeX(content: string, latexExpressions: string[]): string {
|
||||
if (!content.includes('$')) {
|
||||
return content;
|
||||
}
|
||||
return content
|
||||
.split('\n')
|
||||
.map((line) => {
|
||||
if (line.indexOf('$') == -1) {
|
||||
return line;
|
||||
}
|
||||
|
||||
let processedLine = '';
|
||||
let currentPosition = 0;
|
||||
|
||||
while (currentPosition < line.length) {
|
||||
const openDollarIndex = line.indexOf('$', currentPosition);
|
||||
|
||||
if (openDollarIndex == -1) {
|
||||
processedLine += line.slice(currentPosition);
|
||||
break;
|
||||
}
|
||||
|
||||
// Is there a next $-sign?
|
||||
const closeDollarIndex = line.indexOf('$', openDollarIndex + 1);
|
||||
|
||||
if (closeDollarIndex == -1) {
|
||||
processedLine += line.slice(currentPosition);
|
||||
break;
|
||||
}
|
||||
|
||||
const charBeforeOpen = openDollarIndex > 0 ? line[openDollarIndex - 1] : '';
|
||||
const charAfterOpen = line[openDollarIndex + 1];
|
||||
const charBeforeClose =
|
||||
openDollarIndex + 1 < closeDollarIndex ? line[closeDollarIndex - 1] : '';
|
||||
const charAfterClose = closeDollarIndex + 1 < line.length ? line[closeDollarIndex + 1] : '';
|
||||
|
||||
let shouldSkipAsNonLatex = false;
|
||||
|
||||
if (closeDollarIndex == currentPosition + 1) {
|
||||
// No content
|
||||
shouldSkipAsNonLatex = true;
|
||||
}
|
||||
|
||||
if (/[A-Za-z0-9_$-]/.test(charBeforeOpen)) {
|
||||
// Character, digit, $, _ or - before first '$', no TeX.
|
||||
shouldSkipAsNonLatex = true;
|
||||
}
|
||||
|
||||
if (
|
||||
/[0-9]/.test(charAfterOpen) &&
|
||||
(/[A-Za-z0-9_$-]/.test(charAfterClose) || ' ' == charBeforeClose)
|
||||
) {
|
||||
// First $ seems to belong to an amount.
|
||||
shouldSkipAsNonLatex = true;
|
||||
}
|
||||
|
||||
if (shouldSkipAsNonLatex) {
|
||||
processedLine += line.slice(currentPosition, openDollarIndex + 1);
|
||||
currentPosition = openDollarIndex + 1;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// Treat as LaTeX
|
||||
processedLine += line.slice(currentPosition, openDollarIndex);
|
||||
const latexContent = line.slice(openDollarIndex, closeDollarIndex + 1);
|
||||
latexExpressions.push(latexContent);
|
||||
processedLine += `<<LATEX_${latexExpressions.length - 1}>>`;
|
||||
currentPosition = closeDollarIndex + 1;
|
||||
}
|
||||
|
||||
return processedLine;
|
||||
})
|
||||
.join('\n');
|
||||
}
|
||||
|
||||
function escapeBrackets(text: string): string {
|
||||
return text.replace(
|
||||
LATEX_MATH_AND_CODE_PATTERN,
|
||||
(
|
||||
match: string,
|
||||
codeBlock: string | undefined,
|
||||
squareBracket: string | undefined,
|
||||
roundBracket: string | undefined
|
||||
): string => {
|
||||
if (codeBlock != null) {
|
||||
return codeBlock;
|
||||
} else if (squareBracket != null) {
|
||||
return `$$${squareBracket}$$`;
|
||||
} else if (roundBracket != null) {
|
||||
return `$${roundBracket}$`;
|
||||
}
|
||||
|
||||
return match;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// Escape $\\ce{...} → $\\ce{...} but with proper handling
|
||||
function escapeMhchem(text: string): string {
|
||||
return MHCHEM_PATTERN_MAP.reduce((result, [pattern, replacement]) => {
|
||||
return result.replace(pattern, replacement);
|
||||
}, text);
|
||||
}
|
||||
|
||||
const doEscapeMhchem = false;
|
||||
|
||||
/**
|
||||
* Preprocesses markdown content to safely handle LaTeX math expressions while protecting
|
||||
* against false positives (e.g., dollar amounts like $5.99) and ensuring proper rendering.
|
||||
*
|
||||
* This function:
|
||||
* - Protects code blocks (```) and inline code (`...`)
|
||||
* - Safeguards block and inline LaTeX: \(...\), \[...\], $$...$$, and selective $...$
|
||||
* - Escapes standalone dollar signs before numbers (e.g., $5 → \$5) to prevent misinterpretation
|
||||
* - Restores protected LaTeX and code blocks after processing
|
||||
* - Converts \(...\) → $...$ and \[...\] → $$...$$ for compatibility with math renderers
|
||||
* - Applies additional escaping for brackets and mhchem syntax if needed
|
||||
*
|
||||
* @param content - The raw text (e.g., markdown) that may contain LaTeX or code blocks.
|
||||
* @returns The preprocessed string with properly escaped and normalized LaTeX.
|
||||
*
|
||||
* @example
|
||||
* preprocessLaTeX("Price: $10. The equation is \\(x^2\\).")
|
||||
* // → "Price: $10. The equation is $x^2$."
|
||||
*/
|
||||
export function preprocessLaTeX(content: string): string {
|
||||
// See also:
|
||||
// https://github.com/danny-avila/LibreChat/blob/main/client/src/utils/latex.ts
|
||||
|
||||
// Step 0: Temporarily remove blockquote markers (>) to process LaTeX correctly
|
||||
// Store the structure so we can restore it later
|
||||
const blockquoteMarkers: Map<number, string> = new Map();
|
||||
const lines = content.split('\n');
|
||||
const processedLines = lines.map((line, index) => {
|
||||
const match = line.match(/^(>\s*)/);
|
||||
if (match) {
|
||||
blockquoteMarkers.set(index, match[1]);
|
||||
return line.slice(match[1].length);
|
||||
}
|
||||
return line;
|
||||
});
|
||||
content = processedLines.join('\n');
|
||||
|
||||
// Step 1: Protect code blocks
|
||||
const codeBlocks: string[] = [];
|
||||
|
||||
content = content.replace(CODE_BLOCK_REGEXP, (match) => {
|
||||
codeBlocks.push(match);
|
||||
|
||||
return `<<CODE_BLOCK_${codeBlocks.length - 1}>>`;
|
||||
});
|
||||
|
||||
// Step 2: Protect existing LaTeX expressions
|
||||
const latexExpressions: string[] = [];
|
||||
|
||||
// Match \S...\[...\] and protect them and insert a line-break.
|
||||
content = content.replace(/([\S].*?)\\\[([\s\S]*?)\\\](.*)/g, (match, group1, group2, group3) => {
|
||||
// Check if there are characters following the formula (display-formula in a table-cell?)
|
||||
if (group1.endsWith('\\')) {
|
||||
return match; // Backslash before \[, do nothing.
|
||||
}
|
||||
const hasSuffix = /\S/.test(group3);
|
||||
let optBreak;
|
||||
|
||||
if (hasSuffix) {
|
||||
latexExpressions.push(`\\(${group2.trim()}\\)`); // Convert into inline.
|
||||
optBreak = '';
|
||||
} else {
|
||||
latexExpressions.push(`\\[${group2}\\]`);
|
||||
optBreak = '\n';
|
||||
}
|
||||
|
||||
return `${group1}${optBreak}<<LATEX_${latexExpressions.length - 1}>>${optBreak}${group3}`;
|
||||
});
|
||||
|
||||
// Match \(...\), \[...\], $$...$$ and protect them
|
||||
content = content.replace(
|
||||
/(\$\$[\s\S]*?\$\$|(?<!\\)\\\[[\s\S]*?\\\]|(?<!\\)\\\(.*?\\\))/g,
|
||||
(match) => {
|
||||
latexExpressions.push(match);
|
||||
|
||||
return `<<LATEX_${latexExpressions.length - 1}>>`;
|
||||
}
|
||||
);
|
||||
|
||||
// Protect inline $...$ but NOT if it looks like money (e.g., $10, $3.99)
|
||||
content = maskInlineLaTeX(content, latexExpressions);
|
||||
|
||||
// Step 3: Escape standalone $ before digits (currency like $5 → \$5)
|
||||
// (Now that inline math is protected, this will only escape dollars not already protected)
|
||||
content = content.replace(/\$(?=\d)/g, '\\$');
|
||||
|
||||
// Step 4: Restore protected LaTeX expressions (they are valid)
|
||||
content = content.replace(/<<LATEX_(\d+)>>/g, (_, index) => {
|
||||
let expr = latexExpressions[parseInt(index)];
|
||||
const match = expr.match(LATEX_LINEBREAK_REGEXP);
|
||||
if (match) {
|
||||
// Katex: The $$-delimiters should be in their own line
|
||||
// if there are \\-line-breaks.
|
||||
const formula = match[1];
|
||||
const prefix = formula.startsWith('\n') ? '' : '\n';
|
||||
const suffix = formula.endsWith('\n') ? '' : '\n';
|
||||
expr = '$$' + prefix + formula + suffix + '$$';
|
||||
}
|
||||
return expr;
|
||||
});
|
||||
|
||||
// Step 5: Restore code blocks
|
||||
content = content.replace(/<<CODE_BLOCK_(\d+)>>/g, (_, index) => {
|
||||
return codeBlocks[parseInt(index)];
|
||||
});
|
||||
|
||||
// Step 6: Apply additional escaping functions (brackets and mhchem)
|
||||
content = escapeBrackets(content);
|
||||
|
||||
if (doEscapeMhchem && (content.includes('\\ce{') || content.includes('\\pu{'))) {
|
||||
content = escapeMhchem(content);
|
||||
}
|
||||
|
||||
// Final pass: Convert \(...\) → $...$, \[...\] → $$...$$
|
||||
content = content
|
||||
// Using the look‑behind pattern `(?<!\\)` we skip matches
|
||||
// that are preceded by a backslash, e.g.
|
||||
// `Definitions\\(also called macros)` (title of chapter 20 in The TeXbook).
|
||||
.replace(/(?<!\\)\\\((.+?)\\\)/g, '$$$1$') // inline
|
||||
.replace(
|
||||
// Using the look‑behind pattern `(?<!\\)` we skip matches
|
||||
// that are preceded by a backslash, e.g. `\\[4pt]`.
|
||||
/(?<!\\)\\\[([\s\S]*?)\\\]/g, // display, see also PR #16599
|
||||
(_, prefix: string, content: string) => {
|
||||
return `${prefix}$$${content}$$`;
|
||||
}
|
||||
);
|
||||
|
||||
// Step 7: Restore blockquote markers
|
||||
if (blockquoteMarkers.size > 0) {
|
||||
const finalLines = content.split('\n');
|
||||
const restoredLines = finalLines.map((line, index) => {
|
||||
const marker = blockquoteMarkers.get(index);
|
||||
return marker ? marker + line : line;
|
||||
});
|
||||
content = restoredLines.join('\n');
|
||||
}
|
||||
|
||||
return content;
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
/* eslint-disable no-irregular-whitespace */
|
||||
// Math Formulas Content
|
||||
export const MATH_FORMULAS_MD = String.raw`
|
||||
# Mathematical Formulas and Expressions
|
||||
@@ -150,6 +151,70 @@ $$\lim_{x \to 0} \frac{\sin x}{x} = 1$$
|
||||
|
||||
$$\lim_{n \to \infty} \left(1 + \frac{x}{n}\right)^n = e^x$$
|
||||
|
||||
## Further Bracket Styles and Amounts
|
||||
|
||||
- \( \mathrm{GL}_2(\mathbb{F}_7) \): Group of invertible matrices with entries in \(\mathbb{F}_7\).
|
||||
- Some kernel of \(\mathrm{SL}_2(\mathbb{F}_7)\):
|
||||
\[
|
||||
\left\{ \begin{pmatrix} 1 & 0 \\ 0 & 1 \end{pmatrix}, \begin{pmatrix} -1 & 0 \\ 0 & -1 \end{pmatrix} \right\} = \{\pm I\}
|
||||
\]
|
||||
- Algebra:
|
||||
\[
|
||||
x = \frac{-b \pm \sqrt{\,b^{2}-4ac\,}}{2a}
|
||||
\]
|
||||
- $100 and $12.99 are amounts, not LaTeX.
|
||||
- I have $10, $3.99 and $x + y$ and $100x$. The amount is $2,000.
|
||||
- Emma buys 2 cupcakes for $3 each and 1 cookie for $1.50. How much money does she spend in total?
|
||||
- Maria has $20. She buys a notebook for $4.75 and a pack of pencils for $3.25. How much change does she receive?
|
||||
- 1 kg の質量は
|
||||
\[
|
||||
E = (1\ \text{kg}) \times (3.0 \times 10^8\ \text{m/s})^2 \approx 9.0 \times 10^{16}\ \text{J}
|
||||
\]
|
||||
というエネルギーに相当します。これは約 21 百万トンの TNT が爆発したときのエネルギーに匹敵します。
|
||||
- Algebra: \[
|
||||
x = \frac{-b \pm \sqrt{\,b^{2}-4ac\,}}{2a}
|
||||
\]
|
||||
- Algebraic topology, Homotopy Groups of $\mathbb{S}^3$:
|
||||
$$\pi_n(\mathbb{S}^3) = \begin{cases}
|
||||
\mathbb{Z} & n = 3 \\
|
||||
0 & n > 3, n \neq 4 \\
|
||||
\mathbb{Z}_2 & n = 4 \\
|
||||
\end{cases}$$
|
||||
- Spacer preceded by backslash:
|
||||
\[
|
||||
\boxed{
|
||||
\begin{aligned}
|
||||
N_{\text{att}}^{\text{(MHA)}} &=
|
||||
h \bigl[\, d_{\text{model}}\;d_{k} + d_{\text{model}}\;d_{v}\, \bigr] && (\text{Q,K,V の重み})\\
|
||||
&\quad+ h(d_{k}+d_{k}+d_{v}) && (\text{バイアス Q,K,V)}\\[4pt]
|
||||
&\quad+ (h d_{v})\, d_{\text{model}} && (\text{出力射影 }W^{O})\\
|
||||
&\quad+ d_{\text{model}} && (\text{バイアス }b^{O})
|
||||
\end{aligned}}
|
||||
\]
|
||||
|
||||
## Formulas in a Table
|
||||
|
||||
| Area | Expression | Comment |
|
||||
|------|------------|---------|
|
||||
| **Algebra** | \[
|
||||
x = \frac{-b \pm \sqrt{\,b^{2}-4ac\,}}{2a}
|
||||
\] | Quadratic formula |
|
||||
| | \[
|
||||
(a+b)^{n} = \sum_{k=0}^{n}\binom{n}{k}\,a^{\,n-k}\,b^{\,k}
|
||||
\] | Binomial theorem |
|
||||
| | \(\displaystyle \prod_{k=1}^{n}k = n! \) | Factorial definition |
|
||||
| **Geometry** | \( \mathbf{a}\cdot \mathbf{b} = \|\mathbf{a}\|\,\|\mathbf{b}\|\,\cos\theta \) | Dot product & angle |
|
||||
|
||||
## No math (but chemical)
|
||||
|
||||
Balanced chemical reaction with states:
|
||||
|
||||
\[
|
||||
\ce{2H2(g) + O2(g) -> 2H2O(l)}
|
||||
\]
|
||||
|
||||
The standard enthalpy change for the reaction is: $\Delta H^\circ = \pu{-572 kJ mol^{-1}}$.
|
||||
|
||||
---
|
||||
|
||||
*This document showcases various mathematical notation and formulas that can be rendered in markdown using LaTeX syntax.*
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
// Override KaTeX SCSS variables to disable ttf and woff fonts
|
||||
// Only use woff2 format which is embedded in the bundle
|
||||
$use-woff2: true;
|
||||
$use-woff: false;
|
||||
$use-ttf: false;
|
||||
|
||||
// Use Vite alias for font folder
|
||||
$font-folder: 'katex-fonts';
|
||||
|
||||
// Import KaTeX SCSS with overridden variables
|
||||
// Note: @import is deprecated but required because KaTeX uses @import internally
|
||||
// The deprecation warnings are from KaTeX's code and cannot be avoided
|
||||
@import 'katex/src/styles/katex.scss';
|
||||
@@ -22,6 +22,9 @@ const config = {
|
||||
}),
|
||||
output: {
|
||||
bundleStrategy: 'inline'
|
||||
},
|
||||
alias: {
|
||||
$styles: 'src/styles'
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
@@ -18,6 +18,15 @@ const GUIDE_FOR_FRONTEND = `
|
||||
|
||||
const MAX_BUNDLE_SIZE = 2 * 1024 * 1024;
|
||||
|
||||
/**
|
||||
* the maximum size of an embedded asset in bytes,
|
||||
* e.g. maximum size of embedded font (see node_modules/katex/dist/fonts/*.woff2)
|
||||
*/
|
||||
const MAX_ASSET_SIZE = 32000;
|
||||
|
||||
/** public/index.html.gz minified flag */
|
||||
const ENABLE_JS_MINIFICATION = true;
|
||||
|
||||
function llamaCppBuildPlugin() {
|
||||
return {
|
||||
name: 'llamacpp:build',
|
||||
@@ -75,12 +84,28 @@ function llamaCppBuildPlugin() {
|
||||
}
|
||||
|
||||
export default defineConfig({
|
||||
build: {
|
||||
chunkSizeWarningLimit: 3072
|
||||
resolve: {
|
||||
alias: {
|
||||
'katex-fonts': resolve('node_modules/katex/dist/fonts')
|
||||
}
|
||||
},
|
||||
build: {
|
||||
assetsInlineLimit: MAX_ASSET_SIZE,
|
||||
chunkSizeWarningLimit: 3072,
|
||||
minify: ENABLE_JS_MINIFICATION
|
||||
},
|
||||
css: {
|
||||
preprocessorOptions: {
|
||||
scss: {
|
||||
additionalData: `
|
||||
$use-woff2: true;
|
||||
$use-woff: false;
|
||||
$use-ttf: false;
|
||||
`
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
plugins: [tailwindcss(), sveltekit(), devtoolsJson(), llamaCppBuildPlugin()],
|
||||
|
||||
test: {
|
||||
projects: [
|
||||
{
|
||||
|
||||
Vendored
+9
-2
@@ -192,18 +192,25 @@ class chat_template {
|
||||
};
|
||||
};
|
||||
const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}};
|
||||
const auto contains_arg_needle = [&](const std::string & out_str) {
|
||||
return contains(out_str, "<parameter=argument_needle>")
|
||||
|| contains(out_str, "\"argument_needle\":")
|
||||
|| contains(out_str, "'argument_needle':")
|
||||
|| contains(out_str, ">argument_needle<")
|
||||
|| contains(out_str, "<parameter name=\"argument_needle\">");
|
||||
};
|
||||
|
||||
// Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want.
|
||||
out = try_raw_render(json::array({
|
||||
dummy_user_msg,
|
||||
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})),
|
||||
}), {}, false);
|
||||
auto tool_call_renders_str_arguments = contains(out, "<parameter=argument_needle>") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
|
||||
auto tool_call_renders_str_arguments = contains_arg_needle(out);
|
||||
out = try_raw_render(json::array({
|
||||
dummy_user_msg,
|
||||
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})),
|
||||
}), {}, false);
|
||||
auto tool_call_renders_obj_arguments = contains(out, "<parameter=argument_needle>") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
|
||||
auto tool_call_renders_obj_arguments = contains_arg_needle(out);
|
||||
|
||||
caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
|
||||
caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
|
||||
|
||||
Vendored
+5
-7
@@ -2205,7 +2205,7 @@ private:
|
||||
|
||||
auto value = parseValue();
|
||||
|
||||
while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
|
||||
while (it != end && consumeSpaces() && peekSymbols({ "[", ".", "(" })) {
|
||||
if (!consumeToken("[").empty()) {
|
||||
std::shared_ptr<Expression> index;
|
||||
auto slice_loc = get_location();
|
||||
@@ -2250,15 +2250,13 @@ private:
|
||||
auto key = std::make_shared<LiteralExpr>(identifier->location, Value(identifier->get_name()));
|
||||
value = std::make_shared<SubscriptExpr>(identifier->location, std::move(value), std::move(key));
|
||||
}
|
||||
} else if (peekSymbols({ "(" })) {
|
||||
auto callParams = parseCallArgs();
|
||||
value = std::make_shared<CallExpr>(get_location(), std::move(value), std::move(callParams));
|
||||
}
|
||||
consumeSpaces();
|
||||
}
|
||||
|
||||
if (peekSymbols({ "(" })) {
|
||||
auto location = get_location();
|
||||
auto callParams = parseCallArgs();
|
||||
value = std::make_shared<CallExpr>(location, std::move(value), std::move(callParams));
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
@@ -2738,7 +2736,7 @@ inline std::shared_ptr<Context> Context::builtins() {
|
||||
globals.set("raise_exception", simple_function("raise_exception", { "message" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
|
||||
throw std::runtime_error(args.at("message").get<std::string>());
|
||||
}));
|
||||
globals.set("tojson", simple_function("tojson", { "value", "indent" }, [](const std::shared_ptr<Context> &, Value & args) {
|
||||
globals.set("tojson", simple_function("tojson", { "value", "indent", "ensure_ascii" }, [](const std::shared_ptr<Context> &, Value & args) {
|
||||
return Value(args.at("value").dump(args.get<int64_t>("indent", -1), /* to_json= */ true));
|
||||
}));
|
||||
globals.set("items", simple_function("items", { "object" }, [](const std::shared_ptr<Context> &, Value & args) {
|
||||
|
||||
Reference in New Issue
Block a user