Compare commits

..

12 Commits

Author SHA1 Message Date
liminfei-amd 1191758c5d vulkan: fail the build when a shader fails to compile (#24450)
* vulkan-shaders-gen: fail the build when a shader fails to compile

vulkan-shaders-gen did not detect shader-compile subprocess failures, so a
broken libggml-vulkan could be produced while the build reported success and
the breakage only surfaced at run time. execute_command() discarded the child
exit code (POSIX waitpid passed nullptr for status; the Windows branch never
called GetExitCodeProcess) and string_to_spv decided success only from whether
stderr was empty, so a non-zero exit with empty stderr, or a subprocess that
failed to launch, was treated as success.

Return the child exit code from execute_command() (WEXITSTATUS on POSIX,
GetExitCodeProcess on Windows), treat a non-zero exit or non-empty stderr or a
launch exception as a failure, and record it in an atomic flag. main() checks
the flag after process_shaders() and returns EXIT_FAILURE before writing the
output files, so the build stops instead of emitting a broken backend.

Fixes #24393

Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com>

* vulkan-shaders-gen: simplify compile_failed access and drop unreachable return

Address review feedback on #24450:
- Access the std::atomic<bool> compile_failed directly (= / implicit bool)
  instead of .store()/.load(); the flag stays atomic because the worker
  threads in process_shaders() set it concurrently.
- Remove the unreachable trailing return -1 in execute_command(): on POSIX the
  child _exit()s after execvp and the parent returns (fork()<0 throws); on
  Windows the block returns the exit code.

Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com>

---------

Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com>
2026-06-24 11:42:03 +02:00
Pascal 00139b660b ui: loading bar below the model picker (#24931)
* ui: show model load progress on the selector trigger

Mirror the in-dropdown stage progress as a thin bar on the selector
trigger, so the active model's load percent stays visible when the menu
is closed. Same status gating and composite fraction as the dropdown
row, so both bars track the selected model in sync.

Suggested-by: Julien Chaumond <@julien-c>

* ui: show model load progress bar on the in-conversation model selector

* ui: tune model load indicator to a pulsing highlight (suggested by @ngxson)

Also wire the indicator onto the mobile sheet trigger, which was missing
it since mobile uses the sheet instead of the dropdown.

* ui: thin (@allozaur) pulsating (@ngxson) model load bar
2026-06-24 10:50:44 +02:00
Aleksander Grygier ef9c13d4c2 ui: New Logo + Navigation cleanup & Mobile UI/UX improvements (#24897)
* chore: `npm audit fix --force`

* feat: Update sidebar toggle to use Logo

* refactor: Clean up favicon SVG

* feat: Refactor logo component and implement theme-aware favicon generation

* feat: Add configurable padding to generated PWA assets

* test: Add unit tests for writeThemeFavicons

* refactor: Componentization

* feat: WIP

* feat: WIP

* feat: WIP

* feat: Mobile UI

* feat: add SEARCH route constant

* feat: create SidebarNavigationSearchResults component

* refactor: use SidebarNavigationSearchResults in conversation list

* feat: enable mobile search navigation in sidebar actions

* feat: add mobile search route and page

* fix: prevent sidebar overflow on mobile viewports

* fix: Mobile sidebar

* feat: Mobile Search WIP

* feat: Mobile WIP

* feat: Add PWA standalone detection and refine mobile UI

* feat: Improve mobile layout, sidebar handling, and chat scrolling

* feat: Improve mobile sidebar visibility and iOS Safari chat spacing

* fix: Disable auto-scroll on mobile

* chore: Linting

* fix: Wrong condition

* feat: Mobile chat scroll

* refactor: WIP

* fix: Desktop initial scroll always working again

* fix: Partial fix for mobile auto-scroll / initial scroll

* fix: Desktop auto-scroll on initial load and during streaming

* fix: Mobile scrolling logic

* refactor: Clean up

* feat: Improve start UI

* feat: Add `delay` to `fadeInView`

* feat: Auto-scroll button

* refactor: Cleanup

* refactor: Extract chat dialogs and alerts into dedicated component

* refactor: Reorganize ChatScreen component structure and initialization

* feat: Improve auto-scroll after sending message

* feat: UI improvements

* fix: Settings link

* feat: UI improvements

* fix: better UI spacing

* fix: Remove unneeded logic

* fix: Chat Processing Info UI rendering

* feat: Improve mobile UI

* feat: UI improvement

* fix: Conditional transition delay for Chat Messages based on route from

* fix: Delay mobile sidebar collapse for smoother transitions

* fix: Mobile scroll down button + sidebar pointer events

* fix: Mobile UI

* fix: Auto scrolling

* fix: Implement dynamic height calculations for chat auto-scroll positioning and UI elements

* fix: Retrieve `autofocus` for Chat Form textarea

* fix: Use proper class

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* refactor: extract scroll-to-bottom logic and fix message send flow

* fix: update viewport store usage and remove conflicting autofocus

* feat: add accessibility labels to scroll down button

* fix: correct HTML structure in sidebar empty states

* fix: dynamically toggle processing info visibility

* chore: remove commented exports and fix formatting

* fix

* fix: Mobile Chat Form Add Action Sheet interactions

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-06-24 10:21:33 +02:00
Tarek Dakhran 88636e178f model : Add LFM2.5-ColBERT-350M and LFM2.5-Embedding-350M (#24913)
* model : Add LFM2.5-ColBERT-350M and LFM2.5-Embedding-350M

* Restore LFM2 models in README.md
2026-06-24 09:49:46 +03:00
Jeff Bolz ac4105d68b vulkan: Apply bias before softmax in FA, to avoid overflow (#24909) 2026-06-23 22:34:00 -05:00
kononnable be4a6a63eb server : check draft context creation error (#24922) 2026-06-23 16:56:50 +02:00
Jeff Bolz 72a9269172 vulkan: support all backend tests for SQR/SQRT/SIN/COS/CLAMP/LEAKY_RELU/NORM (#24582)
* vulkan: make SQR/SQRT/SIN/COS/CLAMP/LEAKY_RELU use unary.comp

* vulkan: make NORM support noncontig

* add noncontiguous row test cases for norm/l2_norm, handle this in the CPU backend and l2_norm.comp

* fix supports_op for cuda and webgpu
2026-06-23 09:48:24 -05:00
Jeff Bolz 92e854ab83 vulkan: Support GET_ROWS_BACK (#24883) 2026-06-23 15:39:37 +02:00
Jeff Bolz c5606364b2 vulkan: support CONV_3D (#24612)
* vulkan: support CONV_3D

This is a pretty direct port of conv2d_mm.comp to CONV_3D, done by codex
and cleaned up by me.

* disable slower perf tests
2026-06-23 15:39:20 +02:00
Jeff Bolz 0eb874d374 vulkan: make mul_mm ALIGNED a spec constant (#24689)
This trims down some of the shader variant explosion and reduces binary size.
2026-06-23 14:26:17 +02:00
Xuan-Son Nguyen 75ad0b23ed server: fix remote preset handling, add test (#24938)
* server: add test for remote preset

* fix remote preset handling

* fix

* fix test
2026-06-23 13:28:34 +02:00
Wyatt Caldwell c926ad0985 vulkan: link ggml-cpu when GGML_VULKAN_CHECK_RESULTS / RUN_TESTS are enabled (#24444)
The result-checking and test debug paths in ggml-vulkan.cpp call ggml_graph_compute_with_ctx() to compute a CPU reference graph, but that symbol is defined in ggml-cpu, which ggml-vulkan does not link. Enabling -DGGML_VULKAN_CHECK_RESULTS=ON (or -DGGML_VULKAN_RUN_TESTS=ON) therefore fails to link with an unresolved external (e.g. LNK2019 on MSVC, undefined reference on GCC/Clang). This regressed after ggml-cpu was split into its own library. Link ggml-cpu under those two options so the debug builds link again.

Signed-off-by: Wyatt Caldwell <218154709+Detensable@users.noreply.github.com>
2026-06-23 12:55:46 +02:00
127 changed files with 3432 additions and 2529 deletions
+3 -1
View File
@@ -142,7 +142,9 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
- [x] [Liquid LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2)
- [x] [Liquid LFM2.5 models](https://huggingface.co/collections/LiquidAI/lfm25)
- [x] [Liquid Nanos](https://huggingface.co/collections/LiquidAI/liquid-nanos)
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
- [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86)
- [x] [Mellum models](https://huggingface.co/JetBrains/models?search=mellum)
+7 -4
View File
@@ -301,6 +301,8 @@ static handle_model_result common_params_handle_model(struct common_params_model
const common_download_opts & opts) {
handle_model_result result;
// TODO @ngxson : refactor this into a new common_model_download_context
if (!model.docker_repo.empty()) {
model.path = common_docker_resolve_model(model.docker_repo);
} else if (!model.hf_repo.empty()) {
@@ -396,7 +398,7 @@ static bool parse_bool_value(const std::string & value) {
// CLI argument parsing functions
//
bool common_params_handle_models(common_params & params, llama_example curr_ex, common_download_callback * callback) {
bool common_params_handle_models(common_params & params, llama_example curr_ex, const common_params_handle_models_params & handle_params) {
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
params.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
@@ -407,9 +409,10 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex,
opts.skip_download = params.skip_download;
opts.download_mtp = spec_type_draft_mtp;
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
opts.preset_only = handle_params.preset_only;
if (callback) {
opts.callback = callback;
if (handle_params.callback) {
opts.callback = handle_params.callback;
}
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
@@ -596,7 +599,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
if (!skip_model_download) {
// handle model and download
common_params_handle_models(params, ctx_arg.ex);
common_params_handle_models(params, ctx_arg.ex, {});
// model is required (except for server)
// TODO @ngxson : maybe show a list of available models in CLI in this case
+6 -1
View File
@@ -130,6 +130,11 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
// see: https://github.com/ggml-org/llama.cpp/issues/18163
void common_params_add_preset_options(std::vector<common_arg> & args);
struct common_params_handle_models_params {
common_download_callback * callback = nullptr;
bool preset_only = false; // if true, only check & download remote preset (for router mode)
};
// populate model paths (main model, mmproj, etc) from -hf if necessary
// return true if the model is ready to use
// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc)
@@ -137,7 +142,7 @@ void common_params_add_preset_options(std::vector<common_arg> & args);
bool common_params_handle_models(
common_params & params,
llama_example curr_ex,
common_download_callback * callback = nullptr);
const common_params_handle_models_params & handle_params);
// initialize argument parser context - used by test-arg-parser and preset
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
+3 -1
View File
@@ -799,6 +799,7 @@ common_download_model_result common_download_model(const common_params_model &
bool download_mmproj = opts.download_mmproj;
bool download_mtp = opts.download_mtp;
bool preset_only = opts.preset_only;
bool is_hf = !model.hf_repo.empty();
if (is_hf) {
@@ -806,7 +807,8 @@ common_download_model_result common_download_model(const common_params_model &
if (!hf.preset.path.empty()) {
// if preset.ini exists, only download that file alone
tasks.push_back({hf.preset.url, hf.preset.local_path});
} else {
} else if (!preset_only) {
// only add other files if we're NOT in preset-only mode (normal run, non-router)
for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path});
}
+1
View File
@@ -55,6 +55,7 @@ struct common_download_opts {
bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid
bool download_mmproj = false;
bool download_mtp = false;
bool preset_only = false; // if true, only check & download remote preset (for router mode)
common_download_callback * callback = nullptr;
};
+1
View File
@@ -124,6 +124,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
"LLaDAModelLM": "llada",
"LLaMAForCausalLM": "llama",
"Lfm25AudioTokenizer": "lfm2",
"Lfm2BidirectionalModel": "lfm2",
"Lfm2ForCausalLM": "lfm2",
"Lfm2Model": "lfm2",
"Lfm2MoeForCausalLM": "lfm2",
+10 -3
View File
@@ -64,11 +64,17 @@ class LFM2Model(TextModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Lfm2Model")
@ModelBase.register("Lfm2Model", "Lfm2BidirectionalModel")
class LFM2ColBertModel(LFM2Model):
model_arch = gguf.MODEL_ARCH.LFM2
dense_tensor_name = "dense_2"
def set_gguf_parameters(self):
super().set_gguf_parameters()
if self.hf_arch == "Lfm2BidirectionalModel":
self.gguf_writer.add_causal_attention(False)
self._try_set_pooling_type()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if not name.startswith(self.dense_tensor_name):
name = "model." + name
@@ -76,10 +82,11 @@ class LFM2ColBertModel(LFM2Model):
yield from super().modify_tensors(data_torch, name, bid)
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
# dense tensor is stored in a separate safetensors file
# optional dense tensor is stored in a separate safetensors file
from safetensors.torch import load_file
tensors_file = self.dir_model / "1_Dense" / "model.safetensors"
assert tensors_file.is_file()
if not tensors_file.is_file():
return
tensor = load_file(tensors_file)["linear.weight"]
self.gguf_writer.add_embedding_length_out(tensor.shape[0])
yield f"{self.dense_tensor_name}.weight", tensor.clone()
+50 -23
View File
@@ -3688,8 +3688,6 @@ static void ggml_compute_forward_norm_f32(
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
@@ -3703,25 +3701,49 @@ static void ggml_compute_forward_norm_f32(
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
float sum = 0.0;
ggml_vec_sum_f32(ne00, &sum, x);
float mean = sum/ne00;
if (nb00 == sizeof(float) && nb0 == sizeof(float)) {
const float * xf = (const float *) x;
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
float variance = 0;
float sum = 0.0;
ggml_vec_sum_f32(ne00, &sum, xf);
float mean = sum/ne00;
float * yf = (float *) y;
float variance = 0;
#ifdef GGML_USE_ACCELERATE
mean = -mean;
vDSP_vsadd(x, 1, &mean, y, 1, ne00);
vDSP_measqv(y, 1, &variance, ne00);
mean = -mean;
vDSP_vsadd(xf, 1, &mean, yf, 1, ne00);
vDSP_measqv(yf, 1, &variance, ne00);
#else
variance = ggml_vec_cvar_f32(ne00, y, x, mean);
variance = ggml_vec_cvar_f32(ne00, yf, xf, mean);
#endif //GGML_USE_ACCELERATE
const float scale = 1.0f/sqrtf(variance + eps);
ggml_vec_scale_f32(ne00, y, scale);
const float scale = 1.0f/sqrtf(variance + eps);
ggml_vec_scale_f32(ne00, yf, scale);
} else {
float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += *(const float *) (x + i00*nb00);
}
const float mean = sum/ne00;
float variance = 0.0f;
for (int64_t i00 = 0; i00 < ne00; i00++) {
const float v = *(const float *) (x + i00*nb00) - mean;
*(float *) (y + i00*nb0) = v;
variance += v * v;
}
variance /= ne00;
const float scale = 1.0f/sqrtf(variance + eps);
for (int64_t i00 = 0; i00 < ne00; i00++) {
*(float *) (y + i00*nb0) *= scale;
}
}
}
}
}
@@ -4142,8 +4164,6 @@ static void ggml_compute_forward_l2_norm_f32(
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
@@ -4158,20 +4178,27 @@ static void ggml_compute_forward_l2_norm_f32(
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
ggml_float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += (ggml_float)(x[i00] * x[i00]);
const float xi = *(const float *) (x + i00*nb00);
sum += (ggml_float)(xi * xi);
}
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
memcpy(y, x, ne00 * sizeof(float));
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
ggml_vec_scale_f32(ne00, y, scale);
char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
if (nb00 == sizeof(float) && nb0 == sizeof(float)) {
memcpy(y, x, ne00 * sizeof(float));
ggml_vec_scale_f32(ne00, (float *) y, scale);
} else {
for (int64_t i00 = 0; i00 < ne00; i00++) {
const float xi = *(const float *) (x + i00*nb00);
*(float *) (y + i00*nb0) = xi * scale;
}
}
}
}
}
+1 -1
View File
@@ -5334,7 +5334,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
return true;
return ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]);
break;
+5
View File
@@ -108,6 +108,9 @@ if (Vulkan_FOUND)
if (GGML_VULKAN_CHECK_RESULTS)
add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)
# the result-checking path computes a CPU reference graph via
# ggml_graph_compute_with_ctx(), which is defined in ggml-cpu
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
endif()
if (GGML_VULKAN_DEBUG)
@@ -129,6 +132,8 @@ if (Vulkan_FOUND)
if (GGML_VULKAN_RUN_TESTS)
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
# the test path also calls ggml_graph_compute_with_ctx() (ggml-cpu)
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
endif()
# Set up toolchain for host compilation whether cross-compiling or not
+362 -66
View File
@@ -493,6 +493,20 @@ struct vk_conv2d_pipeline_state {
}
};
struct vk_conv3d_pipeline_state {
vk_conv3d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t s2, uint32_t p0, uint32_t p1, uint32_t p2,
uint32_t d0, uint32_t d1, uint32_t d2, uint32_t KW, uint32_t KH, uint32_t KD, uint32_t aligned)
: s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), KW(KW), KH(KH), KD(KD), aligned(aligned) {}
uint32_t s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD;
uint32_t aligned;
bool operator<(const vk_conv3d_pipeline_state &b) const {
return std::tie(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned) <
std::tie(b.s0, b.s1, b.s2, b.p0, b.p1, b.p2, b.d0, b.d1, b.d2, b.KW, b.KH, b.KD, b.aligned);
}
};
struct vk_solve_tri_pipeline_state {
vk_solve_tri_pipeline_state(uint32_t N, uint32_t K)
: N(N), K(K) {}
@@ -777,6 +791,7 @@ struct vk_device_struct {
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_back_f32;
vk_pipeline pipeline_acc_f32;
vk_pipeline pipeline_set_f32;
@@ -801,14 +816,10 @@ struct vk_device_struct {
vk_pipeline pipeline_concat_i8, pipeline_concat_i16, pipeline_concat_i32, pipeline_concat_i64;
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32;
vk_pipeline pipeline_scale_f32;
vk_pipeline pipeline_sqr_f32;
vk_pipeline pipeline_sqrt_f32;
vk_pipeline pipeline_sin_f32;
vk_pipeline pipeline_cos_f32;
vk_pipeline pipeline_log[2];
vk_pipeline pipeline_tri[2];
vk_pipeline pipeline_diag[2];
vk_pipeline pipeline_clamp_f32;
vk_pipeline pipeline_clamp[2];
vk_pipeline pipeline_pad_f32;
vk_pipeline pipeline_roll_f32;
vk_pipeline pipeline_repeat_i32, pipeline_repeat_back_f32;
@@ -840,6 +851,10 @@ struct vk_device_struct {
vk_pipeline pipeline_gelu_quick[2];
vk_pipeline pipeline_silu[2];
vk_pipeline pipeline_relu[2];
vk_pipeline pipeline_sqr[2];
vk_pipeline pipeline_sqrt[2];
vk_pipeline pipeline_sin[2];
vk_pipeline pipeline_cos[2];
vk_pipeline pipeline_xielu[2];
vk_pipeline pipeline_neg[2];
vk_pipeline pipeline_tanh[2];
@@ -871,7 +886,7 @@ struct vk_device_struct {
vk_pipeline pipeline_geglu_erf[2];
vk_pipeline pipeline_geglu_quick[2];
vk_pipeline pipeline_leaky_relu_f32;
vk_pipeline pipeline_leaky_relu[2];
vk_pipeline pipeline_silu_back_f32;
vk_pipeline pipeline_diag_mask_inf_f32;
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
@@ -924,6 +939,8 @@ struct vk_device_struct {
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f32[CONV_SHAPE_COUNT];
std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f16_f32[CONV_SHAPE_COUNT];
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
@@ -1669,6 +1686,41 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
}
struct vk_op_conv3d_push_constants {
uint32_t OC;
uint32_t IC;
uint32_t N;
uint32_t IW;
uint32_t IH;
uint32_t ID;
uint32_t OW;
uint32_t OH;
uint32_t OD;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
uint32_t OWOHODmp; uint32_t OWOHODL;
};
template <> void init_pushconst_fastdiv(vk_op_conv3d_push_constants &p) {
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
init_fastdiv_values(p.OW*p.OH*p.OD, p.OWOHODmp, p.OWOHODL);
}
struct vk_op_conv2d_dw_push_constants {
uint32_t ne;
uint32_t batches;
@@ -4074,19 +4126,35 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
}
#endif
auto const &ggml_vk_mul_mm_spec = [](std::vector<uint32_t> spec, bool aligned) {
spec.push_back(aligned ? 1u : 0u);
return spec;
};
const int mul_mat_id_param_count = 5;
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
auto const &ggml_vk_mul_mm_cm2_spec = [](std::vector<uint32_t> spec, bool aligned, bool mul_mat_id) {
if (mul_mat_id && spec.size() > 5) {
spec.insert(spec.begin() + 5, aligned ? 1u : 0u);
} else {
spec.push_back(aligned ? 1u : 0u);
}
if (mul_mat_id && spec.size() == 6) {
spec.push_back(32);
}
return spec;
};
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), l_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), m_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), s_align, true); \
// Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
@@ -4161,17 +4229,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, true); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, true); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, true); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, true); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, true); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, true); \
// Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@@ -4284,32 +4352,32 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
// Selects dot2 SPIR-V variant at runtime when device->dot2_f16 is true
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
// bf16 scalar path promotes to f32, no dot2 variant
#define CREATE_MM_NODOT2(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l_int[TYPE]) { \
@@ -4474,17 +4542,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l_int[TYPE]) \
@@ -4879,6 +4947,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_back_f32, "get_rows_back_f32", get_rows_back_f32_len, get_rows_back_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {256, 1, 1}, {}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
@@ -4903,7 +4972,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
}
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_nc_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
@@ -5023,11 +5092,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -5037,8 +5101,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -5058,6 +5120,12 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
CREATE_UNARY(gelu_quick)
CREATE_UNARY(silu)
CREATE_UNARY(relu)
CREATE_UNARY(sqr)
CREATE_UNARY(sqrt)
CREATE_UNARY(sin)
CREATE_UNARY(cos)
CREATE_UNARY(clamp)
CREATE_UNARY(leaky_relu)
CREATE_UNARY(xielu)
CREATE_UNARY(neg)
CREATE_UNARY(tanh)
@@ -5097,7 +5165,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
CREATE_GLU(geglu_quick)
#undef CREATE_GLU
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
@@ -5314,7 +5381,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
// conv2d, conv_transpose_2d
// conv2d, conv_transpose_2d, conv3d
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
// smaller WG for the small-tile fallback gives more concurrent WGs per SM
uint32_t conv2d_WG_SIZE = (s == CONV_SHAPE_64x32) ? 128 : 256;
@@ -5377,8 +5444,8 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
return (conv2d_BS.K * (conv2d_BS.CRS + pad) + conv2d_BS.CRS * (conv2d_BS.NPQ + pad) + csh_elems) * elem_size;
};
// coopmat1 needs to store the output through shared memory, so check up front
// whether it'll fit and disable it before applying coopmat1 parameters.
// 2D, transpose-2D, and 3D conv use the same KxCRS @ CRSxNPQ shmem
// layout. cm1 needs Csh for output, so check before applying cm1 params.
if (conv2d_use_cm1 && device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_cm1_shmem_pad, true, true)) {
conv2d_use_cm1 = false;
}
@@ -5470,6 +5537,53 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
}
#undef CREATE_CONV
#undef CREATE_CONVS
std::vector<uint32_t> conv3d_spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, conv2d_SHMEM_PAD };
#define CREATE_CONV3D(type_suffix, spv_suffix) \
for (auto &c : device->pipeline_conv3d##type_suffix[s]) { \
const vk_conv3d_pipeline_state &state = c.first; \
std::vector<uint32_t> spec_constants_cpy = conv3d_spec_constants; \
spec_constants_cpy.push_back(state.s0); \
spec_constants_cpy.push_back(state.s1); \
spec_constants_cpy.push_back(state.s2); \
spec_constants_cpy.push_back(state.p0); \
spec_constants_cpy.push_back(state.p1); \
spec_constants_cpy.push_back(state.p2); \
spec_constants_cpy.push_back(state.d0); \
spec_constants_cpy.push_back(state.d1); \
spec_constants_cpy.push_back(state.d2); \
spec_constants_cpy.push_back(state.KW); \
spec_constants_cpy.push_back(state.KH); \
spec_constants_cpy.push_back(state.KD); \
spec_constants_cpy.push_back(state.aligned); \
spec_constants_cpy.push_back(conv2d_csh_store); \
spec_constants_cpy.push_back(conv2d_WM); \
spec_constants_cpy.push_back(conv2d_WN); \
ggml_vk_create_pipeline( \
device, c.second, "conv3d" #type_suffix, \
conv3d##type_suffix##spv_suffix##_len, conv3d##type_suffix##spv_suffix##_data, "main", 3, \
sizeof(vk_op_conv3d_push_constants), wg_denoms, spec_constants_cpy, 1, true, conv2d_required_subgroup_size != 0, conv2d_required_subgroup_size); \
}
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
CREATE_CONV3D(_f32, _cm2)
CREATE_CONV3D(_f16_f32, _cm2)
} else
#endif
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (conv2d_use_cm1) {
CREATE_CONV3D(_f32, _cm1)
CREATE_CONV3D(_f16_f32, _cm1)
} else
#endif
if (conv2d_UNROLL) {
CREATE_CONV3D(_f32, _unroll)
CREATE_CONV3D(_f16_f32, _unroll)
} else {
CREATE_CONV3D(_f32, )
CREATE_CONV3D(_f16_f32, )
}
#undef CREATE_CONV3D
}
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@@ -10294,6 +10408,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_get_rows_f32[src0->type];
}
return nullptr;
case GGML_OP_GET_ROWS_BACK:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_get_rows_back_f32;
}
return nullptr;
case GGML_OP_ACC:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_acc_f32;
@@ -10400,23 +10519,27 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
return nullptr;
case GGML_OP_SQR:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_sqr_f32;
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_sqr[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_SQRT:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_sqrt_f32;
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_sqrt[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_SIN:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_sin_f32;
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_sin[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_COS:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_cos_f32;
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_cos[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_LOG:
@@ -10438,8 +10561,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
return nullptr;
case GGML_OP_CLAMP:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_clamp_f32;
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_clamp[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_PAD:
@@ -10807,8 +10931,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
return nullptr;
case GGML_OP_LEAKY_RELU:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_leaky_relu_f32;
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_leaky_relu[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_CONV_2D:
@@ -10885,6 +11010,61 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
}
return nullptr;
case GGML_OP_CONV_3D:
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
const uint32_t OC = (uint32_t)ggml_get_op_params_i32(dst, 11);
const uint32_t IC = (uint32_t)ggml_get_op_params_i32(dst, 9);
const uint32_t N = (uint32_t)ggml_get_op_params_i32(dst, 10);
const uint32_t NPQ = N * dst->ne[2] * dst->ne[1] * dst->ne[0];
const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, OC, NPQ);
const uint32_t KW = (uint32_t)src0->ne[0];
const uint32_t KH = (uint32_t)src0->ne[1];
const uint32_t KD = (uint32_t)src0->ne[2];
const uint32_t s0 = (uint32_t)ggml_get_op_params_i32(dst, 0);
const uint32_t s1 = (uint32_t)ggml_get_op_params_i32(dst, 1);
const uint32_t s2 = (uint32_t)ggml_get_op_params_i32(dst, 2);
const uint32_t p0 = (uint32_t)ggml_get_op_params_i32(dst, 3);
const uint32_t p1 = (uint32_t)ggml_get_op_params_i32(dst, 4);
const uint32_t p2 = (uint32_t)ggml_get_op_params_i32(dst, 5);
const uint32_t d0 = (uint32_t)ggml_get_op_params_i32(dst, 6);
const uint32_t d1 = (uint32_t)ggml_get_op_params_i32(dst, 7);
const uint32_t d2 = (uint32_t)ggml_get_op_params_i32(dst, 8);
const uint32_t CRS = IC * KW * KH * KD;
const uint32_t BS_K = vk_conv_block_sizes[shape].K;
const uint32_t BS_CRS = vk_conv_block_sizes[shape].CRS;
const uint32_t BS_NPQ = vk_conv_block_sizes[shape].NPQ;
const uint32_t aligned = ((OC % BS_K == 0) &&
(CRS % BS_CRS == 0) &&
(NPQ % BS_NPQ == 0)) ? 1u : 0u;
vk_conv3d_pipeline_state conv3d_pipeline_state(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned);
std::map<vk_conv3d_pipeline_state, vk_pipeline> *pipelines = nullptr;
if (src0->type == GGML_TYPE_F32) {
pipelines = &ctx->device->pipeline_conv3d_f32[shape];
} else if (src0->type == GGML_TYPE_F16) {
pipelines = &ctx->device->pipeline_conv3d_f16_f32[shape];
} else {
return nullptr;
}
vk_pipeline pipeline = nullptr;
{
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
auto it = pipelines->find(conv3d_pipeline_state);
if (it != pipelines->end()) {
pipeline = it->second;
} else {
(*pipelines)[conv3d_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
}
}
return pipeline;
}
return nullptr;
case GGML_OP_ADD1:
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_add1_f16_f16;
@@ -11135,6 +11315,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
break;
case GGML_OP_GET_ROWS_BACK:
elements = { (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], 1 };
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
break;
case GGML_OP_ARGSORT:
GGML_ASSERT(0);
break;
@@ -11220,6 +11404,21 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
GGML_ABORT("invalid push constant type for CONV_2D");
}
break;
case GGML_OP_CONV_3D:
if constexpr (std::is_same_v<PC, vk_op_conv3d_push_constants>) {
const uint32_t NPQ = pc.N * pc.OD * pc.OH * pc.OW;
const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, pc.OC, NPQ);
const uint32_t NPQ_blocks = CEIL_DIV(NPQ, vk_conv_block_sizes[shape].NPQ);
elements = { pc.OC, NPQ_blocks, 1 };
if (elements[1] > 512) {
elements[2] = CEIL_DIV(elements[1], 512);
elements[1] = 512;
}
} else {
GGML_ABORT("invalid push constant type for CONV_3D");
}
break;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
@@ -11236,6 +11435,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_TRI:
case GGML_OP_DIAG:
case GGML_OP_CLAMP:
case GGML_OP_LEAKY_RELU:
case GGML_OP_PAD:
case GGML_OP_ROLL:
case GGML_OP_REPEAT:
@@ -11380,6 +11580,21 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
});
}
static void ggml_vk_get_rows_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GET_ROWS_BACK, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2], (uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
0.0f, 0.0f, 0,
});
}
static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
@@ -12087,8 +12302,10 @@ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx,
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = op_params[0];
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, std::move(p));
}
static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -13118,6 +13335,51 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, std::move(p));
}
static void ggml_vk_conv_3d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
GGML_ASSERT(nb10 == sizeof(float));
GGML_ASSERT(nb0 == sizeof(float));
vk_op_conv3d_push_constants p{};
p.IC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 9));
p.N = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 10));
p.OC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 11));
GGML_ASSERT(src0->ne[3] == (int64_t)p.IC * p.OC);
GGML_ASSERT(src1->ne[3] == (int64_t)p.IC * p.N);
GGML_ASSERT(dst->ne[3] == (int64_t)p.OC * p.N);
p.IW = static_cast<uint32_t>(ne10);
p.IH = static_cast<uint32_t>(ne11);
p.ID = static_cast<uint32_t>(ne12);
p.OW = static_cast<uint32_t>(ne0);
p.OH = static_cast<uint32_t>(ne1);
p.OD = static_cast<uint32_t>(ne2);
// the shader clamps src addresses to p.IC * p.N * p.IW * p.IH * p.ID - 1 in uint32, so the
// total input element count must fit in a uint32.
GGML_ASSERT((uint64_t)p.IC * p.N * p.IW * p.IH * p.ID <= 0xFFFFFFFFull);
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
p.nb03 = static_cast<uint32_t>(nb03 / nb00);
p.nb11 = static_cast<uint32_t>(nb11 / nb10);
p.nb12 = static_cast<uint32_t>(nb12 / nb10);
p.nb13 = static_cast<uint32_t>(nb13 / nb10);
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
p.nb3 = static_cast<uint32_t>(nb3 / nb0);
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_3D, std::move(p));
}
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
vk_op_conv2d_dw_push_constants p{};
p.ne = ggml_nelements(dst);
@@ -13144,7 +13406,10 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
const float * op_params = (const float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f });
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = op_params[0];
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, std::move(p));
}
#ifdef GGML_VULKAN_RUN_TESTS
@@ -14247,6 +14512,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_GET_ROWS:
ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_GET_ROWS_BACK:
ggml_vk_get_rows_back(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_ADD:
if (ctx->num_additional_fused_ops) {
@@ -14515,6 +14784,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_CONV_TRANSPOSE_2D:
ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_CONV_3D:
ggml_vk_conv_3d(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_CONV_2D_DW:
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node);
@@ -16964,6 +17237,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return false;
}
}
case GGML_OP_GET_ROWS_BACK:
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SET_ROWS:
{
switch (op->type) {
@@ -17060,12 +17335,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_TRANSPOSE:
case GGML_OP_RMS_NORM:
return true;
case GGML_OP_NORM:
case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_NORM:
case GGML_OP_L2_NORM:
return ggml_is_contiguous_rows(op->src[0]) &&
op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
@@ -17084,8 +17358,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_LEAKY_RELU:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
op->type == op->src[0]->type;
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
@@ -17285,6 +17560,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
ggml_is_contiguous(op->src[1]) &&
ggml_is_contiguous(op));
}
case GGML_OP_CONV_3D:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32 &&
ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op->src[1]) &&
ggml_is_contiguous(op);
default:
return false;
}
@@ -18128,6 +18410,20 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
const int32_t d0 = tensor->op_params[4];
const int32_t d1 = tensor->op_params[5];
tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
} else if (tensor->op == GGML_OP_CONV_3D) {
const int32_t s0 = tensor->op_params[0];
const int32_t s1 = tensor->op_params[1];
const int32_t s2 = tensor->op_params[2];
const int32_t p0 = tensor->op_params[3];
const int32_t p1 = tensor->op_params[4];
const int32_t p2 = tensor->op_params[5];
const int32_t d0 = tensor->op_params[6];
const int32_t d1 = tensor->op_params[7];
const int32_t d2 = tensor->op_params[8];
const int32_t IC = tensor->op_params[9];
const int32_t N = tensor->op_params[10];
const int32_t OC = tensor->op_params[11];
tensor_clone = ggml_conv_3d_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, s2, p0, p1, p2, d0, d1, d2, IC, N, OC);
} else if (tensor->op == GGML_OP_CONV_2D_DW) {
const int32_t s0 = tensor->op_params[0];
const int32_t s1 = tensor->op_params[1];
@@ -1,17 +0,0 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
}
@@ -0,0 +1,431 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#ifdef COOPMAT2
#extension GL_NV_cooperative_matrix2 : enable
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_KHR_memory_scope_semantics : enable
#endif
#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_KHR_memory_scope_semantics : enable
#endif
#include "types.glsl"
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
layout(binding = 0) readonly buffer A {
A_TYPE knl_data[];
}; // src0 - kernel: [KW, KH, KD, IC*OC]
layout(binding = 1) readonly buffer B {
B_TYPE src_data[];
}; // src1 - input: [IW, IH, ID, IC*N] -- channel_first format
layout(binding = 2) writeonly buffer D {
D_TYPE dst_data[];
}; // dst - result: [OW, OH, OD, OC*N]
layout(push_constant) uniform parameter {
// I/O channels, batch size
uint32_t OC;
uint32_t IC;
uint32_t N;
// Tensor spatial sizes: input, output
uint32_t IW;
uint32_t IH;
uint32_t ID;
uint32_t OW;
uint32_t OH;
uint32_t OD;
// Strides in elements
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
// fastdiv helper values
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
uint32_t OWOHODmp; uint32_t OWOHODL;
}
p;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
// Blocktile sizes
layout(constant_id = 1) const uint BS_K = 128;
layout(constant_id = 2) const uint BS_CRS = 16;
layout(constant_id = 3) const uint BS_NPQ = 128;
// Thread-tile sizes
layout(constant_id = 4) const uint TS_K = 8;
layout(constant_id = 5) const uint SHMEM_PAD = 4;
// Stride, padding, dilation
layout(constant_id = 6) const uint s0 = 1;
layout(constant_id = 7) const uint s1 = 1;
layout(constant_id = 8) const uint s2 = 1;
layout(constant_id = 9) const uint p0 = 0;
layout(constant_id = 10) const uint p1 = 0;
layout(constant_id = 11) const uint p2 = 0;
layout(constant_id = 12) const uint d0 = 1;
layout(constant_id = 13) const uint d1 = 1;
layout(constant_id = 14) const uint d2 = 1;
// Kernel spatial sizes
layout(constant_id = 15) const uint KW = 1;
layout(constant_id = 16) const uint KH = 1;
layout(constant_id = 17) const uint KD = 1;
// when set, skip bounds checks and address clamps (K/CRS/NPQ are tile-aligned)
layout(constant_id = 18) const uint aligned = 0;
// stage cm2 result through shmem (Csh) for coalesced stores. cm1 always does this.
layout(constant_id = 19) const uint csh_store = 0;
#ifdef COOPMAT
// cm1 subgroup tile: each subgroup computes a WM x WN region as a grid of
// TM x TN x TK fragments. Requires WM%TM == WN%TN == BS_K%WM == BS_NPQ%WN ==
// BS_CRS%TK == 0, and WG_SIZE == (BS_K/WM) * (BS_NPQ/WN) * subgroup_size.
layout(constant_id = 20) const uint WM = 32;
layout(constant_id = 21) const uint WN = 32;
const uint TM = 16;
const uint TN = 16;
const uint TK = 16;
const uint cms_per_row = WM / TM;
const uint cms_per_col = WN / TN;
const uint warps_M = BS_K / WM;
const uint warps_N = BS_NPQ / WN;
#endif
// without padding, ID_idx/IH_idx/IW_idx are in bounds by construction
const bool dhw_in_bounds = (p0 == 0) && (p1 == 0) && (p2 == 0);
uint32_t tid = gl_LocalInvocationID.x;
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
uint splitWork(uint work_size, uint block_size) {
return (block_size + work_size - 1) / block_size;
}
uint32_t K = p.OC;
uint32_t CRS = p.IC * KD * KH * KW;
uint32_t NPQ = p.N * p.OD * p.OH * p.OW;
// Number of blocktiles per input
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
#if defined(COOPMAT2) || defined(COOPMAT)
#define SHMEM_TYPE float16_t
#else
#define SHMEM_TYPE float
#endif
const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
const uint32_t Ash_len = BS_K * Ash_stride;
const uint32_t Bsh_len = BS_CRS * Bsh_stride;
shared SHMEM_TYPE Ash[Ash_len]; // K x CRS
shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ
#if defined(COOPMAT2) || defined(COOPMAT)
// stage matC through shmem so global stores are row-major (NPQ-contiguous)
const uint32_t Csh_stride = BS_NPQ;
#ifdef COOPMAT
const uint32_t Csh_len = BS_K * Csh_stride;
#else
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1;
#endif
shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ
#endif
// Threadtile sizes
const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
// Number of threadtiles per blocktile
const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
/*
Compute
KxCRS @ CRSxNPQ = K x NPQ
K=OC
C=IC
D,R,S=KD,KH,KW
Z,P,Q=OD,OH,OW
*/
uint32_t B_idx_K = gl_WorkGroupID.x;
uint32_t B_idx_NPQ = gl_WorkGroupID.y + gl_WorkGroupID.z * 512;
uint32_t T_y = tid / NT_NPQ;
uint32_t T_x = tid % NT_NPQ;
uint32_t Ar = tid / BS_CRS;
uint32_t Ac = tid % BS_CRS;
const uint32_t ArpWg = WG_SIZE / BS_CRS;
uint32_t Br = tid / BS_NPQ;
uint32_t Bc = tid % BS_NPQ;
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}
void split_crs(uint32_t crs_idx, out uint32_t ic, out uint32_t kd, out uint32_t kh, out uint32_t kw) {
const uint32_t KHKW = KH * KW;
const uint32_t KDKHKW = KD * KHKW;
ic = crs_idx / KDKHKW;
uint32_t rem = crs_idx - ic * KDKHKW;
kd = rem / KHKW;
rem = rem - kd * KHKW;
kh = rem / KW;
kw = rem - kh * KW;
}
void split_npq(uint32_t npq_idx, out uint32_t n, out uint32_t od, out uint32_t oh, out uint32_t ow) {
const uint32_t OWOH = p.OW * p.OH;
n = fastdiv(npq_idx, p.OWOHODmp, p.OWOHODL);
uint32_t rem = npq_idx - n * p.OD * OWOH;
od = fastdiv(rem, p.OWOHmp, p.OWOHL);
rem = rem - od * OWOH;
oh = fastdiv(rem, p.OWmp, p.OWL);
ow = rem - oh * p.OW;
}
#ifdef COOPMAT2
#define ACC_TYPE float16_t
ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem)
{
uint32_t K_idx = B_idx_K * BS_K + r;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c;
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
dst_data[dst_idx] = D_TYPE(elem);
}
return elem;
}
#endif
void main() {
if (B_idx_NPQ * BS_NPQ >= NPQ) {
return;
}
#ifdef COOPMAT2
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);
#elif defined(COOPMAT)
coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
sums[i] = coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0);
}
const uint warp_r = gl_SubgroupID / warps_N;
const uint warp_c = gl_SubgroupID % warps_N;
#else
float regC[TS_K][TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = 0.0;
}
}
#endif
/* Advance block in CRS dim */
[[dont_unroll]] for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
uint32_t CRS_idx_a = B_idx_CRS * BS_CRS + Ac;
uint32_t IC_idx_a;
uint32_t KD_idx_a;
uint32_t KH_idx_a;
uint32_t KW_idx_a;
split_crs(CRS_idx_a, IC_idx_a, KD_idx_a, KH_idx_a, KW_idx_a);
/* Load kernel to A_block: (BS_K x BS_CRS)*/
UNROLL for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
uint32_t B_ly = r_offset + Ar;
uint32_t B_lx = Ac;
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
uint32_t knl_idx = KW_idx_a + KH_idx_a * p.nb01 + KD_idx_a * p.nb02 + (K_idx * p.IC + IC_idx_a) * p.nb03;
if (aligned == 0) {
knl_idx = min(knl_idx, K * CRS - 1);
}
float val = knl_data[knl_idx];
if (aligned == 0 && (K_idx >= K || CRS_idx_a >= CRS)) {
val = 0.0;
}
Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val);
}
/* Load input to B_block: (BS_CRS x BS_NPQ) */
UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
uint32_t B_ly = r_offset + Br; /* Row index of B block */
uint32_t B_lx = Bc;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t CRS_idx_b = B_idx_CRS * BS_CRS + B_ly;
uint32_t IC_idx_b;
uint32_t KD_idx_b;
uint32_t KH_idx_b;
uint32_t KW_idx_b;
split_crs(CRS_idx_b, IC_idx_b, KD_idx_b, KH_idx_b, KW_idx_b);
uint32_t ID_idx = OD_idx * s2 + KD_idx_b * d2 - p2;
uint32_t IH_idx = OH_idx * s1 + KH_idx_b * d1 - p1;
uint32_t IW_idx = OW_idx * s0 + KW_idx_b * d0 - p0;
uint32_t src_idx = IW_idx + IH_idx * p.nb11 + ID_idx * p.nb12 + (N_idx * p.IC + IC_idx_b) * p.nb13;
// skip clamp when address can't go OOB
if (aligned == 0 || !dhw_in_bounds) {
src_idx = min(src_idx, p.IC * p.N * p.IW * p.IH * p.ID - 1);
}
float val = src_data[src_idx];
bool oob = false;
if (aligned == 0 && (CRS_idx_b >= CRS || NPQ_idx >= NPQ)) {
oob = true;
}
// also catches lower-bound underflow (idx wraps to 0x80000000+)
if (!dhw_in_bounds && (ID_idx >= p.ID || IH_idx >= p.IH || IW_idx >= p.IW)) {
oob = true;
}
if (oob) {
val = 0.0;
}
Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
}
barrier();
#ifdef COOPMAT2
coopmat<float16_t, gl_ScopeWorkgroup, BS_K, BS_CRS, gl_MatrixUseA> matA;
coopmat<float16_t, gl_ScopeWorkgroup, BS_CRS, BS_NPQ, gl_MatrixUseB> matB;
coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
matC = coopMatMulAdd(matA, matB, matC);
#elif defined(COOPMAT)
// each subgroup multiplies its grid of fragments per TK-sized CRS chunk
[[unroll]] for (uint k_step = 0; k_step < BS_CRS / TK; k_step++) {
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a[cms_per_row];
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
const uint a_off = (warp_r * WM + cm_row * TM) * Ash_stride + k_step * TK;
coopMatLoad(cache_a[cm_row], Ash, a_off, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
}
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
const uint b_off = k_step * TK * Bsh_stride + warp_c * WN + cm_col * TN;
coopMatLoad(cache_b, Bsh, b_off, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a[cm_row], cache_b, sums[cm_col * cms_per_row + cm_row]);
}
}
}
#else
if (T_y * TS_K < K) {
UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
float regA[TS_K];
float regB[TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
}
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
}
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
}
}
}
}
#endif
barrier();
}
/* Save C* */
#if defined(COOPMAT2) || defined(COOPMAT)
// stage matC into Csh, then write to dst with coalesced NPQ-contiguous stores
#ifdef COOPMAT
const bool use_staged_store = true;
#else
const bool use_staged_store = (csh_store != 0);
#endif
if (use_staged_store) {
#ifdef COOPMAT
// cm1: each subgroup stores its fragment grid into its Csh slot
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const uint csh_off = (warp_r * WM + cm_row * TM) * Csh_stride + warp_c * WN + cm_col * TN;
coopMatStore(sums[cm_col * cms_per_row + cm_row], Csh, csh_off, Csh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
}
#else
coopMatStore(matC, Csh, 0, Csh_stride, gl_CooperativeMatrixLayoutRowMajor);
#endif
barrier();
// cooperative shmem->global: WG threads spread across BS_NPQ (the
// contiguous direction of dst), each iter covers store_rows_per_iter K-rows
const uint32_t store_rows_per_iter = WG_SIZE / BS_NPQ;
const uint32_t store_iters = BS_K / store_rows_per_iter;
const uint32_t k_thread_offset = tid / BS_NPQ;
const uint32_t npq_thread = tid % BS_NPQ;
[[unroll]] for (uint32_t i = 0; i < store_iters; i++) {
uint32_t k_local = i * store_rows_per_iter + k_thread_offset;
uint32_t K_idx = B_idx_K * BS_K + k_local;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + npq_thread;
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
dst_data[dst_idx] = D_TYPE(Csh[k_local * Csh_stride + npq_thread]);
}
}
}
#ifdef COOPMAT2
else {
coopMatPerElementNV(matC, matC, perElemOpStore);
}
#endif
#else
if (T_y * TS_K < K) {
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
dst_data[dst_idx] = D_TYPE(regC[T_ly][T_lx]);
}
}
}
}
#endif
}
@@ -1,17 +0,0 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val));
}
@@ -463,6 +463,7 @@ void main() {
}
rowmaxf = max(rowmaxf, float(Sf[r][c]));
}
rowmaxf += FATTN_KQ_MAX_OFFSET;
float Moldf = Mf[r];
// M = max(rowmax, Mold)
@@ -352,6 +352,7 @@ void main() {
}
rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp]));
}
rowmaxf += FATTN_KQ_MAX_OFFSET;
float Moldf = Mf[r];
// Compute max across the row
@@ -0,0 +1,25 @@
#version 450
#include "types.glsl"
#include "generic_binary_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint col = gl_GlobalInvocationID.x;
if (col >= p.ne20) {
return;
}
for (uint row = gl_GlobalInvocationID.y; row < p.ne21; row += gl_WorkGroupSize.y * gl_NumWorkGroups.y) {
float sum = 0.0f;
for (uint i = 0; i < p.ne10; ++i) {
if (data_b[get_boffset() + i*p.nb10] == int(row)) {
sum += data_a[get_aoffset() + i*p.nb01 + col*p.nb00];
}
}
data_d[get_doffset() + row*p.nb21 + col*p.nb20] = sum;
}
}
@@ -14,16 +14,13 @@ void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint i3 = row / (p.ne11 * p.ne12);
const uint i3_offset = i3 * p.ne12 * p.ne11;
const uint i2 = (row - i3_offset) / p.ne11;
const uint i2_offset = i2 * p.ne11;
const uint i1 = row - i3_offset - i2_offset;
const uint a_base = get_aoffset() + src0_idx(row * p.ne00);
const uint d_base = get_doffset() + dst_idx(row * p.ne10);
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]);
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_base + i0*p.nb00]);
sum[tid] += xi * xi;
}
@@ -39,6 +36,6 @@ void main() {
const FLOAT_TYPE scale = 1.0f / max(sqrt(sum[0]), FLOAT_TYPE(p.param1));
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]));
data_d[d_base + i0*p.nb10] = D_TYPE(scale * FLOAT_TYPE(data_a[a_base + i0*p.nb00]));
}
}
@@ -1,22 +0,0 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float val = float(data_a[i]);
data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1);
}
+31 -23
View File
@@ -38,17 +38,7 @@
#define LOAD_VEC_B 1
#endif
// Load 2 values at once without affecting index calculations through LOAD_VEC
#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED)
#define LOAD_VEC_BATCH_A 2
#else
#define LOAD_VEC_BATCH_A 1
#endif
#if !defined(ALIGNED)
#define LOAD_VEC_BATCH_B 2
#else
#define LOAD_VEC_BATCH_B 1
#endif
layout (constant_id = 11) const uint ALIGNED = 0;
#if !defined(TO_FLOAT_TYPE)
#define TO_FLOAT_TYPE FLOAT_TYPE
@@ -57,6 +47,13 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(DATA_A_F32)
layout (binding = 0) readonly buffer A_SCALAR {float data_a_scalar[];};
#elif defined(DATA_A_F16)
layout (binding = 0) readonly buffer A_SCALAR {float16_t data_a_scalar[];};
#elif defined(DATA_A_BF16)
layout (binding = 0) readonly buffer A_SCALAR {uint16_t data_a_scalar[];};
#endif
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
@@ -65,6 +62,7 @@ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 1) readonly buffer B_SCALAR {B_TYPE_SCALAR data_b_scalar[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
@@ -194,13 +192,23 @@ void main() {
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
const uint LOAD_VEC_A_EFF = (ALIGNED != 0) ? LOAD_VEC_A : 1;
const uint LOAD_VEC_BATCH_A = (ALIGNED != 0) ? 1 : 2;
#else
const uint LOAD_VEC_A_EFF = LOAD_VEC_A;
const uint LOAD_VEC_BATCH_A = 1;
#endif
const uint LOAD_VEC_B_EFF = (ALIGNED != 0) ? LOAD_VEC_B : 1;
const uint LOAD_VEC_BATCH_B = (ALIGNED != 0) ? 1 : 2;
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK;
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK;
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A);
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A);
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B);
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B);
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A_EFF * LOAD_VEC_BATCH_A / BK;
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B_EFF * LOAD_VEC_BATCH_B / BK;
#ifdef MUL_MAT_ID
#ifdef MUL_MAT_ID_USE_SUBGROUPS
@@ -239,15 +247,15 @@ void main() {
uint pos_a =
#ifdef MUL_MAT_ID
expert_idx * (p.batch_stride_a / LOAD_VEC_A) +
expert_idx * (p.batch_stride_a / LOAD_VEC_A_EFF) +
#else
batch_idx_a * (p.batch_stride_a / LOAD_VEC_A) +
batch_idx_a * (p.batch_stride_a / LOAD_VEC_A_EFF) +
#endif
(ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
(ir * BM * p.stride_a + start_k) / LOAD_VEC_A_EFF;
#ifdef MUL_MAT_ID
uint pos_b = 0;
#else
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B_EFF;
#endif
#ifdef COOPMAT
@@ -287,8 +295,8 @@ void main() {
barrier();
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
pos_a += BK / LOAD_VEC_A_EFF;
pos_b += BK / LOAD_VEC_B_EFF;
#ifdef COOPMAT
[[unroll]] for (uint i = 0; i < BK; i += TK) {
@@ -36,6 +36,7 @@ layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working wit
layout (constant_id = 4) const bool enable_smaller_matrices = false;
const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
layout (constant_id = 5) const uint ALIGNED = 0;
layout (push_constant) uniform parameter
{
@@ -111,7 +112,7 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
};
uint _ne1;
layout (constant_id = 5) const uint subgroup_size = 32;
layout (constant_id = 6) const uint subgroup_size = 32;
shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
@@ -297,12 +298,12 @@ void main() {
// Hint to the compiler that values are aligned (want 16B alignment).
// Quants are always block-aligned, no alignment needed.
#if ALIGNED
if (ALIGNED != 0) {
#if QUANT_K == 1
stride_a &= ~7;
#endif
stride_b &= ~7;
stride_a &= ~7;
#endif
stride_b &= ~7;
}
// Create layouts for both clamped and unclamped accesses
tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2);
@@ -1,50 +1,57 @@
void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) {
#if defined(DATA_A_F32) || defined(DATA_A_F16)
#if LOAD_VEC_A == 8
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]);
buf_a[buf_idx ] = aa[0].xy;
buf_a[buf_idx + 1] = aa[0].zw;
buf_a[buf_idx + 2] = aa[1].xy;
buf_a[buf_idx + 3] = aa[1].zw;
if (ALIGNED != 0) {
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]);
buf_a[buf_idx ] = aa[0].xy;
buf_a[buf_idx + 1] = aa[0].zw;
buf_a[buf_idx + 2] = aa[1].xy;
buf_a[buf_idx + 3] = aa[1].zw;
return;
}
#elif LOAD_VEC_A == 4
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]);
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
#else // LOAD_VEC_BATCH_A == 2
if (ALIGNED != 0) {
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]);
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
return;
}
#endif
const uint idx = pos_a + col * p.stride_a + row * 2;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx],
data_a[idx + 1]);
buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx],
data_a_scalar[idx + 1]);
} else if (idx_m < p.M && block + row * 2 < end_k) {
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], 0.0f);
buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx], 0.0f);
} else {
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
#elif defined(DATA_A_BF16)
#if LOAD_VEC_A == 4
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx]));
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
#else // LOAD_VEC_BATCH_A == 2
if (ALIGNED != 0) {
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx]));
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
return;
}
#endif
const uint idx = pos_a + col * p.stride_a + row * 2;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]),
TO_FLOAT_TYPE(data_a[idx + 1]));
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]),
TO_FLOAT_TYPE(data_a_scalar[idx + 1]));
} else if (idx_m < p.M && block + row * 2 < end_k) {
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), 0.0f);
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]), 0.0f);
} else {
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
#elif defined(DATA_A_Q4_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
@@ -526,75 +533,85 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
#if !defined(MUL_MAT_ID)
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) {
#if LOAD_VEC_B == 8
// Not supported for b_type bf16 because bf16mat2x4 does not exist
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
buf_b[buf_idx + 3] = bb[1].zw;
if (ALIGNED != 0) {
// Not supported for b_type bf16 because bf16mat2x4 does not exist
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
buf_b[buf_idx + 3] = bb[1].zw;
return;
}
#elif LOAD_VEC_B == 4
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
if (ALIGNED != 0) {
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
#if defined(DATA_B_BF16)
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
#else
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
return;
}
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
#else // LOAD_VEC_BATCH_B == 2
const uint idx = pos_b + col * p.stride_b + row * 2;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (idx_n < p.N && block + row * 2 + 1 < end_k) {
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
TO_FLOAT_TYPE(data_b[idx + 1]));
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]),
TO_FLOAT_TYPE(data_b_scalar[idx + 1]));
} else if (idx_n < p.N && block + row * 2 < end_k) {
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f);
} else {
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
}
#else
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) {
#if LOAD_VEC_B == 8
// Not supported for b_type bf16 because bf16mat2x4 does not exist
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
buf_b[buf_idx + 3] = bb[1].zw;
if (ALIGNED != 0) {
// Not supported for b_type bf16 because bf16mat2x4 does not exist
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
buf_b[buf_idx + 3] = bb[1].zw;
return;
}
#elif LOAD_VEC_B == 4
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
if (ALIGNED != 0) {
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
#if defined(DATA_B_BF16)
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
#else
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
return;
}
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
#else // LOAD_VEC_BATCH_B == 2
const uint row_i = ic * BN + col;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
TO_FLOAT_TYPE(data_b[idx + 1]));
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]),
TO_FLOAT_TYPE(data_b_scalar[idx + 1]));
} else if (row_i < _ne1 && block + row * 2 < end_k) {
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f);
} else {
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
}
#endif
+10 -10
View File
@@ -1,26 +1,26 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#include "generic_unary_head.glsl"
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
shared vec2 sum[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint a_base = get_aoffset() + src0_idx(row * p.ne00);
const uint d_base = get_doffset() + dst_idx(row * p.ne10);
sum[tid] = vec2(0.0f, 0.0f);
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
const float xi = float(data_a[row*p.KX + col]);
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
const float xi = float(data_a[a_base + i0*p.nb00]);
sum[tid].x += xi;
sum[tid].y += xi * xi;
}
@@ -34,11 +34,11 @@ void main() {
barrier();
}
const float mean = sum[0].x / p.KX;
const float var = sum[0].y / p.KX - mean * mean;
const float mean = sum[0].x / p.ne00;
const float var = sum[0].y / p.ne00 - mean * mean;
const float inv_std = inversesqrt(var + p.param1);
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std);
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
data_d[d_base + i0*p.nb10] = D_TYPE((float(data_a[a_base + i0*p.nb00]) - mean) * inv_std);
}
}
@@ -1,17 +0,0 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val));
}
@@ -1,17 +0,0 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sqrt(val));
}
@@ -1,17 +0,0 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val);
}
@@ -17,6 +17,30 @@ float op_neg(float x) {
return -x;
}
float op_sqr(float x) {
return x * x;
}
float op_sqrt(float x) {
return sqrt(x);
}
float op_sin(float x) {
return sin(x);
}
float op_cos(float x) {
return cos(x);
}
float op_clamp(float x) {
return clamp(x, p.param1, p.param2);
}
float op_leaky_relu(float x) {
return max(x, 0.0f) + min(x, 0.0f) * p.param1;
}
float op_step(float x) {
return x >= 0.0f ? 1.0f : 0.0f;
}
@@ -11,6 +11,7 @@
#include <future>
#include <queue>
#include <condition_variable>
#include <atomic>
#include <cstdio>
#include <cstring>
#include <cstdlib>
@@ -34,6 +35,9 @@
std::mutex lock;
std::vector<std::pair<std::string, std::string>> shader_fnames;
// Set when any shader subprocess fails (non-zero exit / stderr / launch failure) so the
// build is stopped instead of silently producing a broken libggml-vulkan. (issue #24393)
static std::atomic<bool> compile_failed{false};
std::locale c_locale("C");
std::string GLSLC = "glslc";
@@ -78,7 +82,7 @@ enum MatMulIdType {
namespace {
void execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
int execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
#ifdef _WIN32
HANDLE stdout_read, stdout_write;
HANDLE stderr_read, stderr_write;
@@ -127,8 +131,11 @@ void execute_command(std::vector<std::string>& command, std::string& stdout_str,
CloseHandle(stdout_read);
CloseHandle(stderr_read);
WaitForSingleObject(pi.hProcess, INFINITE);
DWORD exit_code = 1;
GetExitCodeProcess(pi.hProcess, &exit_code);
CloseHandle(pi.hProcess);
CloseHandle(pi.hThread);
return (int)exit_code;
#else
int stdout_pipe[2];
int stderr_pipe[2];
@@ -175,7 +182,9 @@ void execute_command(std::vector<std::string>& command, std::string& stdout_str,
close(stdout_pipe[0]);
close(stderr_pipe[0]);
waitpid(pid, nullptr, 0);
int status = 0;
waitpid(pid, &status, 0);
return WIFEXITED(status) ? WEXITSTATUS(status) : -1;
}
#endif
}
@@ -372,13 +381,14 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
// }
// std::cout << std::endl;
execute_command(cmd, stdout_str, stderr_str);
if (!stderr_str.empty()) {
std::cerr << "cannot compile " << name << "\n\n";
int exit_code = execute_command(cmd, stdout_str, stderr_str);
if (exit_code != 0 || !stderr_str.empty()) {
std::cerr << "cannot compile " << name << " (exit code " << exit_code << ")\n\n";
for (const auto& part : cmd) {
std::cerr << part << " ";
}
std::cerr << "\n\n" << stderr_str << std::endl;
compile_failed = true;
return;
}
@@ -398,6 +408,7 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
shader_fnames.push_back(std::make_pair(name, out_path));
} catch (const std::exception& e) {
std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
compile_failed = true;
}
}
@@ -539,11 +550,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
};
// Shaders with f16 B_TYPE
string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
// bf16
{
@@ -565,8 +574,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
#endif
{
if (!dot2) {
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPE_SCALAR", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
}
}
@@ -583,8 +591,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
// For unaligned, load one at a time for f32/f16, or two at a time for quants
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
// For aligned matmul loads
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
@@ -597,13 +603,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
// don't generate f32 variants for coopmat2
if (!coopmat2) {
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPE_SCALAR", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
}
if (tname != "f16" && tname != "f32") {
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
@@ -850,21 +854,12 @@ void process_shaders() {
string_to_spv("repeat_i32", "repeat.comp", {{"A_TYPE", "int32_t"}, {"D_TYPE", "int32_t"}});
string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("get_rows_back_f32", "get_rows_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
string_to_spv("repeat_i16", "repeat.comp", {{"A_TYPE", "int16_t"}, {"D_TYPE", "int16_t"}});
string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("sqrt_f32", "sqrt.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("concat_i8", "concat.comp", {{"A_TYPE", "uint8_t"}, {"B_TYPE", "uint8_t"}, {"D_TYPE", "uint8_t"}});
@@ -891,6 +886,18 @@ void process_shaders() {
string_to_spv("silu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_silu"}});
string_to_spv("relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_relu"}});
string_to_spv("relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_relu"}});
string_to_spv("sqr_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqr"}});
string_to_spv("sqr_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqr"}});
string_to_spv("sqrt_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqrt"}});
string_to_spv("sqrt_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqrt"}});
string_to_spv("sin_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sin"}});
string_to_spv("sin_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sin"}});
string_to_spv("cos_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_cos"}});
string_to_spv("cos_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_cos"}});
string_to_spv("clamp_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_clamp"}});
string_to_spv("clamp_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_clamp"}});
string_to_spv("leaky_relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_leaky_relu"}});
string_to_spv("leaky_relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_leaky_relu"}});
string_to_spv("neg_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_neg"}});
string_to_spv("neg_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_neg"}});
string_to_spv("tanh_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_tanh"}});
@@ -948,7 +955,6 @@ void process_shaders() {
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
@@ -1060,6 +1066,31 @@ void process_shaders() {
}
}
for (auto unroll : {false, true}) {
for (auto a_f16 : {false, true}) {
std::map<std::string, std::string> defines = {
{"A_TYPE", a_f16 ? "float16_t" : "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"},
{"UNROLL", unroll ? "[[unroll]]" : ""},
};
std::string name = std::string("conv3d") + (a_f16 ? "_f16" : "") + "_f32";
string_to_spv(name + (unroll ? "_unroll" : ""), "conv3d_mm.comp", defines);
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (unroll) {
auto cm2_defines = defines;
cm2_defines["COOPMAT2"] = "1";
string_to_spv(name, "conv3d_mm.comp", cm2_defines, true, false, true);
}
#endif
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (unroll) {
auto cm1_defines = defines;
cm1_defines["COOPMAT"] = "1";
string_to_spv(name, "conv3d_mm.comp", cm1_defines, true, true, false);
}
#endif
}
}
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
@@ -1251,6 +1282,11 @@ int main(int argc, char** argv) {
process_shaders();
if (compile_failed) {
std::cerr << "vulkan-shaders-gen: one or more shaders failed to compile" << std::endl;
return EXIT_FAILURE;
}
write_output_files();
return EXIT_SUCCESS;
+1 -1
View File
@@ -4270,7 +4270,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_OP_RMS_NORM:
case GGML_OP_NORM:
case GGML_OP_L2_NORM:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
supports_op = (op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32) && ggml_is_contiguous_rows(src0);
break;
case GGML_OP_ROPE:
supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
+14 -4
View File
@@ -190,7 +190,15 @@ llama_model_lfm2::graph<iswa>::graph(const llama_model & model, const llm_graph_
auto * conv_rs = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs);
auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);
bx = ggml_concat(ctx0, conv, bx, 0);
// causal prepends the state, non-causal pads symmetrically for a centered window
if (hparams.causal_attn) {
bx = ggml_concat(ctx0, conv, bx, 0);
} else {
const int64_t pad = (hparams.n_shortconv_l_cache - 1) / 2;
auto * left = ggml_cont(ctx0,
ggml_view_3d(ctx0, conv, pad, hparams.n_embd, n_seqs, conv->nb[1], conv->nb[2], (d_conv - pad) * conv->nb[0]));
bx = ggml_pad_ext(ctx0, ggml_concat(ctx0, left, bx, 0), 0, pad, 0, 0, 0, 0, 0, 0);
}
GGML_ASSERT(bx->ne[0] > conv->ne[0]);
// last d_conv columns is a new conv state
@@ -266,10 +274,12 @@ llama_model_lfm2::graph<iswa>::graph(const llama_model & model, const llm_graph_
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
if (!cparams.embeddings) {
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;
res->t_logits = cur;
}
ggml_build_forward_expand(gf, cur);
}
+54 -8
View File
@@ -3298,21 +3298,29 @@ struct test_norm : public test_case {
const std::array<int64_t, 4> ne;
const bool v; // whether a is a non-contiguous view
const float eps;
const bool noncontig_rows;
std::string vars() override {
return VARS_TO_STR4(type, ne, v, eps);
return VARS_TO_STR5(type, ne, v, eps, noncontig_rows);
}
test_norm(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 5, 4, 3},
bool v = false,
float eps = 1e-6f)
: type(type), ne(ne), v(v), eps(eps) {}
float eps = 1e-6f,
bool noncontig_rows = false)
: type(type), ne(ne), v(v), eps(eps), noncontig_rows(noncontig_rows) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
const std::array<int64_t, 4> ne_a = noncontig_rows ?
std::array<int64_t, 4>{ ne[1], ne[0], ne[2], ne[3] } : ne;
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_name(a, "a");
if (noncontig_rows) {
a = ggml_permute(ctx, a, 1, 0, 2, 3);
ggml_set_name(a, "permuted a");
}
if (v) {
a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(a, "view of a");
@@ -6193,21 +6201,29 @@ struct test_l2_norm : public test_case {
const std::array<int64_t, 4> ne;
const float eps;
bool v;
bool noncontig_rows;
std::string vars() override {
return VARS_TO_STR4(type, ne, eps, v);
return VARS_TO_STR5(type, ne, eps, v, noncontig_rows);
}
test_l2_norm(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 64, 320, 1},
float eps = 1e-12f,
bool v = false)
: type(type), ne(ne), eps(eps), v(v) {}
bool v = false,
bool noncontig_rows = false)
: type(type), ne(ne), eps(eps), v(v), noncontig_rows(noncontig_rows) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
const std::array<int64_t, 4> ne_a = noncontig_rows ?
std::array<int64_t, 4>{ ne[1], ne[0], ne[2], ne[3] } : ne;
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_name(a, "a");
if (noncontig_rows) {
a = ggml_permute(ctx, a, 1, 0, 2, 3);
ggml_set_name(a, "permuted a");
}
if (v) {
a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(a, "view of a");
@@ -8282,9 +8298,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
}
test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, false, eps, true));
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false, true));
}
}
@@ -9272,6 +9290,34 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
}
}
struct conv3d_perf_case {
int N, IC, ID, IH, IW, OC, KD, KH, KW, s0, s1, s2, p0, p1, p2, d0, d1, d2;
};
const std::vector<conv3d_perf_case> conv3d_cases = {
{1, 320, 8, 38, 26, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 1280, 8, 38, 26, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 320, 8, 76, 52, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 1280, 8, 76, 52, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 320, 8, 152, 104, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
#if 0
// too slow on some devices
{1, 1280, 8, 152, 104, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 320, 4, 304, 208, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
{1, 640, 4, 304, 208, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
#endif
};
for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
for (const conv3d_perf_case & c : conv3d_cases) {
test_cases.emplace_back(new test_conv_3d(
c.N, c.IC, c.ID, c.IH, c.IW,
c.OC, c.KD, c.KH, c.KW,
c.s0, c.s1, c.s2, c.p0, c.p1, c.p2, c.d0, c.d1, c.d2,
kernel_type));
}
}
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
+7 -1
View File
@@ -89,7 +89,9 @@ struct server_batch {
}
~server_batch() {
llama_batch_free(batch);
if (batch.token != nullptr) {
llama_batch_free(batch);
}
}
void init(int32_t n_tokens_alloc) {
@@ -1215,6 +1217,10 @@ private:
cparams.ctx_other = ctx_tgt;
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
if (ctx_dft == nullptr) {
SRV_ERR("%s", "failed to create draft context\n");
return false;
}
params_base.speculative.draft.ctx_tgt = ctx_tgt;
params_base.speculative.draft.ctx_dft = ctx_dft.get();
+4 -2
View File
@@ -224,7 +224,7 @@ void server_model_meta::update_caps() {
});
params.offline = true;
// params.skip_download = true; // TODO: ideally, we should validate the model here, but it takes too much time
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {});
if (params.mmproj.path.empty()) {
multimodal = { false, false };
} else {
@@ -1393,7 +1393,9 @@ struct server_download_state : public common_download_callback {
bool run(common_params & params) {
try {
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, this);
common_params_handle_models_params p;
p.callback = this;
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, p);
is_ok = true;
} catch (const std::exception & e) {
auto model_name = params.model.get_name();
+12 -1
View File
@@ -89,6 +89,17 @@ int llama_server(int argc, char ** argv) {
llama_backend_init();
llama_numa_init(params.numa);
// note: router mode also accepts -hf remote-preset, so we need to check that first
if (!params.model.hf_repo.empty()) {
try {
common_params_handle_models_params handle_params;
handle_params.preset_only = true;
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, handle_params);
} catch (const std::exception & e) {
// ignored for now
}
}
// router server never loads a model and must not touch the GPU
const bool is_router_server = params.model.path.empty()
&& params.model.hf_repo.empty();
@@ -263,7 +274,7 @@ int llama_server(int argc, char ** argv) {
return child.run_download(params);
} else if (!is_router_server) {
// single-model mode (NOT spawned by router)
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {});
}
//
+19
View File
@@ -256,6 +256,25 @@ def test_router_reload_models():
os.remove(preset_path)
def test_router_remote_preset():
global server
server.model_hf_repo = "ggml-org/test-preset-ci"
server.model_hf_file = None
server.offline = False
server.start()
# Should see preset models in GET /models
res = server.make_request("GET", "/models")
assert res.status_code == 200
ids = {item["id"] for item in res.body.get("data", [])}
assert "tinygemma3-preset" in ids
assert "stories260K-test" in ids
# Should be able to load a preset model
model_id = "tinygemma3-preset"
_load_model_and_wait(model_id)
MODEL_DOWNLOAD_ID = "ggml-org/test-model-router-download:F16"
MODEL_DOWNLOAD_TIMEOUT = 30
+1 -2
View File
@@ -28,10 +28,9 @@ vite.config.ts.timestamp-*
# PWA Artifacts
apple-splash-*.png
apple-touch-icon-*.png
favicon.ico
favicon-dark.ico
maskable-icon-*.png
pwa-*.png
static/favicon*
# Storybook
*storybook.log
+7 -7
View File
@@ -35,7 +35,7 @@
"bits-ui": "2.18.1",
"clsx": "2.1.1",
"dexie": "4.4.3",
"dompurify": "3.4.5",
"dompurify": "3.4.11",
"eslint": "9.39.4",
"eslint-config-prettier": "10.1.8",
"eslint-plugin-storybook": "10.4.2",
@@ -8653,9 +8653,9 @@
"peer": true
},
"node_modules/dompurify": {
"version": "3.4.5",
"resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.4.5.tgz",
"integrity": "sha512-OrwIBKsdNSVEeubdJ1HBv/wNENRM9ytAVCv7YXt//A3vPdVMNuACRqK9mXCGCBW2ln7BT/A4X0jXHo2Gu89miA==",
"version": "3.4.11",
"resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.4.11.tgz",
"integrity": "sha512-zhlUV12GsaRzMsf9q5M254YhA4+VuF0fG+QFqu6aYpoGlKtz+w8//jBcGVYBgQkR5GHjUomejY84AV+/uPbWdw==",
"dev": true,
"license": "(MPL-2.0 OR Apache-2.0)",
"optionalDependencies": {
@@ -10226,9 +10226,9 @@
}
},
"node_modules/hono": {
"version": "4.12.23",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.23.tgz",
"integrity": "sha512-eIaZ9qDgu7XV0pxOCrg7/WhnQ6Ivm22UcxhXx/A3dcbqbbYgBEkc6e/J/s7j2tS96zoB0S9VBdLwQNCWwUo4LA==",
"version": "4.12.26",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.26.tgz",
"integrity": "sha512-uyZtpnYxM9CmQ7QsQknM4zN8EftNqhON1qYeIKM0Se67CCEe2c44xyGURwB0axX2fBDu1dqHrHAc1hmNT8ITkw==",
"dev": true,
"license": "MIT",
"engines": {
+1 -1
View File
@@ -54,7 +54,7 @@
"bits-ui": "2.18.1",
"clsx": "2.1.1",
"dexie": "4.4.3",
"dompurify": "3.4.5",
"dompurify": "3.4.11",
"eslint": "9.39.4",
"eslint-config-prettier": "10.1.8",
"eslint-plugin-storybook": "10.4.2",
+8 -1
View File
@@ -1,4 +1,10 @@
import { defineConfig } from '@vite-pwa/assets-generator/config';
import { FAVICON_COLORS, PWA_ASSET_GENERATOR } from './src/lib/constants/pwa';
import { writeThemeFavicons } from './scripts/favicon-colorize';
writeThemeFavicons(FAVICON_COLORS.LIGHT, FAVICON_COLORS.DARK, {
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
});
export default defineConfig({
headLinkOptions: {
@@ -7,7 +13,8 @@ export default defineConfig({
preset: {
transparent: {
sizes: [],
favicons: [[48, 'favicon-dark.ico']]
favicons: [[48, 'favicon-dark.ico']],
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
},
maskable: {
sizes: []
+19 -2
View File
@@ -5,15 +5,32 @@ import {
} from '@vite-pwa/assets-generator/config';
import { readFileSync } from 'node:fs';
import { resolve } from 'node:path';
import { THEME_COLORS, PWA_GENERATOR_DEVICES, PWA_ASSET_GENERATOR } from './src/lib/constants/pwa';
import {
THEME_COLORS,
PWA_GENERATOR_DEVICES,
PWA_ASSET_GENERATOR,
FAVICON_COLORS
} from './src/lib/constants/pwa';
import { SplashOrientation } from './src/lib/enums/splash.enums';
import { writeThemeFavicons } from './scripts/favicon-colorize';
writeThemeFavicons(FAVICON_COLORS.LIGHT, FAVICON_COLORS.DARK, {
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
});
export default defineConfig({
headLinkOptions: {
preset: PWA_ASSET_GENERATOR.LINK_PRESET
},
preset: combinePresetAndAppleSplashScreens(
minimal2023Preset,
{
...minimal2023Preset,
// tiny margin so favicon.ico / pwa-*.png breathe inside the canvas
transparent: {
...minimal2023Preset.transparent,
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
}
},
{
padding: PWA_ASSET_GENERATOR.SPLASH_PADDING,
resizeOptions: {
+107
View File
@@ -0,0 +1,107 @@
import { mkdirSync, readFileSync, writeFileSync } from 'node:fs';
import { dirname, resolve } from 'node:path';
import { fileURLToPath } from 'node:url';
const HERE = dirname(fileURLToPath(import.meta.url));
const PROJECT_ROOT = resolve(HERE, '..');
const DEFAULT_LOGO = resolve(PROJECT_ROOT, 'src/lib/assets/logo.svg');
const DEFAULT_OUT_DIR = resolve(PROJECT_ROOT, 'static');
const DEFAULT_OUT_LIGHT = resolve(DEFAULT_OUT_DIR, 'favicon.svg');
const DEFAULT_OUT_DARK = resolve(DEFAULT_OUT_DIR, 'favicon-dark.svg');
const CURRENT_COLOR = 'currentColor';
export interface ColorizedFavicon {
light: string;
dark: string;
}
export interface WriteThemeFaviconsOptions {
sourcePath?: string;
lightOutPath?: string;
darkOutPath?: string;
/**
* Fraction of the icon (0..1) to leave as an even margin on each side.
* Applied by wrapping the inner content in a `<g transform="...">` so the
* source `src/lib/assets/logo.svg` is not modified. Pass 0 to disable.
*/
padding?: number;
}
/**
* Replace every `currentColor` occurrence in the SVG with the given color.
* Pure: no filesystem access, so it is straightforward to unit-test.
*/
export function colorizeFaviconSvg(
svg: string,
lightColor: string,
darkColor: string
): ColorizedFavicon {
return {
light: svg.replaceAll(CURRENT_COLOR, lightColor),
dark: svg.replaceAll(CURRENT_COLOR, darkColor)
};
}
/**
* Shrink the inner SVG content uniformly and re-center it so `padding` (a
* 0..1 fraction) is reserved as equal margin on each side. Returns the input
* unchanged for non-positive padding, missing/invalid `viewBox`, or unexpected
* markup so the caller always gets a renderable SVG.
*/
export function padFaviconSvg(svg: string, padding: number): string {
if (!(padding > 0) || padding >= 1) return svg;
const viewBoxMatch = svg.match(/viewBox\s*=\s*["']([^"']+)["']/i);
if (!viewBoxMatch) return svg;
const parts = viewBoxMatch[1]
.trim()
.split(/[\s,]+/)
.map(Number);
if (parts.length !== 4 || parts.some((n) => !Number.isFinite(n))) return svg;
const [, , width, height] = parts;
if (width <= 0 || height <= 0) return svg;
const scale = 1 - padding;
const translateX = (padding * width) / 2;
const translateY = (padding * height) / 2;
const openTagStart = svg.search(/<svg\b/i);
if (openTagStart === -1) return svg;
const openTagEnd = svg.indexOf('>', openTagStart);
if (openTagEnd === -1) return svg;
const closeStart = svg.lastIndexOf('</svg');
if (closeStart === -1 || closeStart <= openTagEnd) return svg;
const openTag = svg.slice(0, openTagEnd + 1);
const inner = svg.slice(openTagEnd + 1, closeStart);
const closeTag = svg.slice(closeStart);
const group = `<g transform="translate(${translateX} ${translateY}) scale(${scale})">`;
return `${openTag}${group}${inner}</g>${closeTag}`;
}
/**
* Read `src/lib/assets/logo.svg`, colorize it for both themes, and write
* the results to the static directory so the PWA asset generator can consume
* them. Paths can be overridden for tests.
*/
export function writeThemeFavicons(
lightColor: string,
darkColor: string,
{
sourcePath = DEFAULT_LOGO,
lightOutPath = DEFAULT_OUT_LIGHT,
darkOutPath = DEFAULT_OUT_DARK,
padding = 0
}: WriteThemeFaviconsOptions = {}
): void {
const source = readFileSync(sourcePath, 'utf-8');
const { light, dark } = colorizeFaviconSvg(source, lightColor, darkColor);
mkdirSync(dirname(lightOutPath), { recursive: true });
writeFileSync(lightOutPath, padFaviconSvg(light, padding));
writeFileSync(darkOutPath, padFaviconSvg(dark, padding));
}
+6 -1
View File
@@ -48,6 +48,7 @@
--chat-form-area-height: 8rem;
--chat-form-area-offset: 2rem;
--chat-form-padding-top: 6rem;
--max-message-height: max(24rem, min(80dvh, calc(100dvh - var(--chat-form-area-height) - 12rem)));
}
@@ -55,6 +56,7 @@
:root {
--chat-form-area-height: 24rem;
--chat-form-area-offset: 12rem;
--chat-form-padding-top: 6rem;
}
}
@@ -141,7 +143,6 @@
@apply bg-background text-foreground;
scrollbar-width: thin;
scrollbar-gutter: stable;
overflow: hidden; /* Added due to Mermaid rendering somehow causing the double scrollbar */
}
/* Global scrollbar styling - visible only on hover */
@@ -193,3 +194,7 @@
scrollbar-width: none;
}
}
.mermaidTooltip {
display: none !important;
}
@@ -10,9 +10,9 @@ import { isElementInViewport } from '$lib/utils/viewport';
*/
export function fadeInView(
node: HTMLElement,
options: { duration?: number; y?: number; skipIfVisible?: boolean } = {}
options: { duration?: number; y?: number; delay?: number; skipIfVisible?: boolean } = {}
) {
const { duration = 300, y = 0, skipIfVisible = false } = options;
const { duration = 300, y = 0, delay = 0, skipIfVisible = false } = options;
if (skipIfVisible && isElementInViewport(node)) {
return;
@@ -27,10 +27,12 @@ export function fadeInView(
(entries) => {
for (const entry of entries) {
if (entry.isIntersecting) {
requestAnimationFrame(() => {
node.style.opacity = '1';
node.style.transform = 'translateY(0)';
});
setTimeout(() => {
requestAnimationFrame(() => {
node.style.opacity = '1';
node.style.transform = 'translateY(0)';
});
}, delay);
observer.disconnect();
}
}
+7
View File
@@ -0,0 +1,7 @@
<svg width="512" height="512" viewBox="0 0 512 512" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M244.95 8C215.233 8 187.774 23.8591 172.923 49.5999L95.6009 183.625C60.2162 244.959 104.481 321.6 175.29 321.6H208L316.977 132.708C348.959 77.2719 308.95 8 244.95 8ZM208 321.6H351.947C415.982 321.6 456.013 390.91 424.013 446.377C409.155 472.132 381.681 488 351.947 488H271.29C200.481 488 156.216 411.359 191.601 350.026L208 321.6Z" fill="currentColor"/>
<path d="M208 321.6H16L106.462 164.8L208 321.6Z" fill="currentColor"/>
<path d="M388.923 8L208 321.6L253.6 8H388.923Z" fill="currentColor"/>
<path d="M304 488H112L202.462 331.2L304 488Z" fill="currentColor"/>
<path d="M496 321.6H208L419.399 454.4L496 321.6Z" fill="currentColor"/>
</svg>

After

Width:  |  Height:  |  Size: 771 B

@@ -8,12 +8,13 @@
ariaLabel?: string;
class?: string;
disabled?: boolean;
href?: string;
icon: Component;
iconSize?: string;
onclick: (e?: MouseEvent) => void;
onclick?: (e?: MouseEvent) => void;
size?: ButtonSize;
stopPropagationOnClick?: boolean;
tooltip: string;
tooltip?: string;
variant?: ButtonVariant;
tooltipSide?: TooltipSide;
}
@@ -22,6 +23,7 @@
icon,
tooltip,
variant = 'ghost',
href = '',
size = 'sm',
class: className = '',
disabled = false,
@@ -31,34 +33,49 @@
onclick,
ariaLabel
}: Props = $props();
let innerWidth = $state(0);
const showTooltip = $derived(!!tooltip && innerWidth > 768);
</script>
<Tooltip.Root>
<Tooltip.Trigger>
<!-- prevent another nested button element -->
{#snippet child({ props })}
<Button
{...props}
{variant}
{size}
{disabled}
onclick={(e: MouseEvent) => {
if (stopPropagationOnClick) e.stopPropagation();
{#snippet button(props = {})}
<Button
{...props}
{href}
{variant}
{size}
{disabled}
onclick={(e: MouseEvent) => {
if (stopPropagationOnClick) e.stopPropagation();
onclick?.(e);
}}
class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!"
aria-label={ariaLabel || tooltip}
>
{#if icon}
{@const IconComponent = icon}
<IconComponent class={iconSize} />
{/if}
</Button>
{/snippet}
</Tooltip.Trigger>
onclick?.(e);
}}
class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!"
aria-label={ariaLabel || tooltip}
>
{#if icon}
{@const IconComponent = icon}
<Tooltip.Content side={tooltipSide}>
<p>{tooltip}</p>
</Tooltip.Content>
</Tooltip.Root>
<IconComponent class={iconSize} />
{/if}
</Button>
{/snippet}
{#if showTooltip}
<Tooltip.Root>
<Tooltip.Trigger>
<!-- prevent another nested button element -->
{#snippet child({ props })}
{@render button(props)}
{/snippet}
</Tooltip.Trigger>
<Tooltip.Content side={tooltipSide}>
<p>{tooltip}</p>
</Tooltip.Content>
</Tooltip.Root>
{:else}
{@render button({ href })}
{/if}
<svelte:window bind:innerWidth />
@@ -494,7 +494,7 @@
/>
<div
class="{INPUT_CLASSES} overflow-hidden rounded-3xl backdrop-blur-md {disabled
class="{INPUT_CLASSES} overflow-hidden rounded-4xl md:rounded-3xl backdrop-blur-md {disabled
? 'cursor-not-allowed opacity-60'
: ''}"
data-slot="input-area"
@@ -510,7 +510,7 @@
/>
<div
class="flex-column relative min-h-[48px] items-center rounded-3xl py-2 pb-2.25 shadow-sm transition-all focus-within:shadow-md md:!py-3"
class="flex-column relative min-h-12 items-center rounded-4xl md:rounded-3xl py-2 pb-2.25 shadow-sm transition-all focus-within:shadow-md md:py-3!"
onpaste={handlePaste}
>
<ChatFormTextarea
@@ -15,7 +15,7 @@
<Tooltip.Root>
<Tooltip.Trigger class="w-full">
<Button
class="file-upload-button h-8 w-8 rounded-full p-0"
class="file-upload-button md:h-8 md:w-8 h-9 w-9 rounded-full p-0"
{disabled}
{onclick}
variant="secondary"
@@ -15,6 +15,7 @@
import { McpLogo } from '$lib/components/app';
import { PencilRuler, ChevronDown, ChevronRight } from '@lucide/svelte';
import { HealthCheckStatus } from '$lib/enums';
import { AttachmentAction } from '$lib/enums/attachment.enums';
interface Props {
class?: string;
@@ -270,14 +271,22 @@
</Collapsible.Root>
{/if}
<button type="button" class={sheetItemClass} onclick={onSystemPromptClick}>
<button
type="button"
class={sheetItemClass}
onclick={() => attachmentMenu.callbacks[AttachmentAction.SYSTEM_PROMPT_CLICK]()}
>
<MessageSquare class="h-4 w-4 shrink-0" />
<span>System Message</span>
</button>
{#if hasMcpPromptsSupport}
<button type="button" class={sheetItemClass} onclick={onMcpPromptClick}>
<button
type="button"
class={sheetItemClass}
onclick={() => attachmentMenu.callbacks[AttachmentAction.MCP_PROMPT_CLICK]()}
>
<Zap class="h-4 w-4 shrink-0" />
<span>MCP Prompt</span>
@@ -285,7 +294,11 @@
{/if}
{#if hasMcpResourcesSupport}
<button type="button" class={sheetItemClass} onclick={onMcpResourcesClick}>
<button
type="button"
class={sheetItemClass}
onclick={() => attachmentMenu.callbacks[AttachmentAction.MCP_RESOURCES_CLICK]()}
>
<FolderOpen class="h-4 w-4 shrink-0" />
<span>MCP Resources</span>
@@ -42,6 +42,7 @@
{hasMcpPromptsSupport}
{hasMcpResourcesSupport}
{onFileUpload}
{onSystemPromptClick}
{onMcpPromptClick}
{onMcpResourcesClick}
>
@@ -20,7 +20,7 @@
type="submit"
disabled={isDisabled}
class={[
'h-8 w-8 rounded-full p-0',
'md:h-8 md:w-8 h-9 w-9 rounded-full p-0',
showErrorState &&
'bg-red-400/10 text-red-400 hover:bg-red-400/20 hover:text-red-400 disabled:opacity-100'
]}
@@ -1,4 +1,5 @@
<script lang="ts">
import { isMobile } from '$lib/stores/viewport.svelte';
import { autoResizeTextarea } from '$lib/utils';
import { onMount } from 'svelte';
@@ -37,7 +38,9 @@
}
export function focus() {
textareaElement?.focus();
if (isMobile.current) return;
textareaElement?.focus({ preventScroll: true });
}
export function resetHeight() {
@@ -231,7 +231,7 @@
editedContent = message.content;
}
textareaElement?.focus();
textareaElement?.focus({ preventScroll: true });
editedExtras = message.extra ? [...message.extra] : [];
editedUploadedFiles = [];
@@ -324,7 +324,7 @@
}
</script>
<div use:fadeInView>
<div use:fadeInView class="chat-message">
{#if message.role === MessageRole.SYSTEM}
<ChatMessageSystem
bind:textareaElement
@@ -180,6 +180,9 @@
let displayedModel = $derived(message.model ?? null);
// model being switched to while it loads, so the selector bar tracks it
let pendingModel = $state<string | null>(null);
let isCurrentlyLoading = $derived(isLoading());
let isStreaming = $derived(isChatStreaming());
let hasNoContent = $derived(!message?.content?.trim());
@@ -207,6 +210,42 @@
isLastAssistantMessage
);
let assistantEl: HTMLDivElement | undefined = $state();
let lastUserMessageHeight = $state(0);
let assistantMarginTop = $state(0);
$effect(() => {
if (!assistantEl) return;
assistantMarginTop = Math.round(parseFloat(getComputedStyle(assistantEl).marginTop));
const chatMessageEl = assistantEl.closest('.chat-message');
const previousChatMessage = chatMessageEl?.previousElementSibling;
const userMessageEl = previousChatMessage?.querySelector(
'.chat-message-user'
) as HTMLElement | null;
if (!userMessageEl) {
lastUserMessageHeight = 0;
return;
}
const updateHeight = () => {
const rect = userMessageEl.getBoundingClientRect();
const marginTop = Math.round(parseFloat(getComputedStyle(userMessageEl).marginTop));
lastUserMessageHeight = Math.round(rect.height + marginTop);
};
updateHeight();
const resizeObserver = new ResizeObserver(updateHeight);
resizeObserver.observe(userMessageEl);
return () => {
resizeObserver.disconnect();
};
});
function handleCopyModel() {
void copyToClipboard(displayedModel ?? '');
}
@@ -219,12 +258,17 @@
</script>
<div
class="text-md group w-full leading-7.5 {className}"
bind:this={assistantEl}
class="chat-message-assistant text-md group w-full leading-7.5 {className}"
style:--last-user-message-height={lastUserMessageHeight > 0
? `${lastUserMessageHeight}px`
: undefined}
style:--assistant-margin-top={assistantMarginTop > 0 ? `${assistantMarginTop}px` : undefined}
role="group"
aria-label="Assistant message with actions"
>
{#if showProcessingInfoTop}
<div class="mt-6 w-full max-w-[48rem]" in:fade>
<div class="mt-6 w-full max-w-3xl" in:fade>
<div class="processing-container">
<span class="processing-text">
{modelLoadingText ??
@@ -257,7 +301,7 @@
{/if}
{#if showProcessingInfoBottom}
<div class="mt-4 w-full max-w-[48rem]" in:fade>
<div class="mt-4 w-full max-w-3xl" in:fade>
<div class="processing-container">
<span class="processing-text">
{modelLoadingText ??
@@ -277,13 +321,19 @@
>
{#if isRouter}
<ModelsSelectorDropdown
currentModel={displayedModel}
currentModel={pendingModel ?? displayedModel}
disabled={isLoading()}
onModelChange={async (modelId: string, modelName: string) => {
const status = modelsStore.getModelStatus(modelId);
if (status !== ServerModelStatus.LOADED) {
await modelsStore.loadModel(modelId);
pendingModel = modelId;
try {
await modelsStore.loadModel(modelId);
} finally {
pendingModel = null;
}
}
onRegenerate(modelName);
@@ -351,6 +401,23 @@
</div>
<style>
:global(.chat-message):last-child .chat-message-assistant {
--assistant-min-height-offset: calc(
var(--last-user-message-height, 19rem) + var(--chat-form-height, 6rem) +
var(--chat-form-bottom-position, 0.5rem) + var(--chat-form-padding-top, 6rem) +
var(--assistant-margin-top, 3rem)
);
min-height: calc(100dvh - var(--assistant-min-height-offset));
@media (width > 768px) {
--assistant-min-height-offset: calc(
var(--last-user-message-height, 18rem) + var(--chat-form-height, 6rem) +
var(--chat-form-bottom-position, 1rem) + var(--chat-form-padding-top, 6rem) +
var(--assistant-margin-top, 3rem)
);
}
}
.processing-container {
display: flex;
flex-direction: column;
@@ -48,7 +48,7 @@
<div
aria-label="User message with actions"
class="group flex flex-col items-end gap-3 md:gap-2 {className}"
class="chat-message-user group flex flex-col items-end gap-3 md:gap-2 {className}"
role="group"
>
{#if editCtx.isEditing}
@@ -19,7 +19,7 @@
renderMarkdown = false,
textColorClass = 'text-foreground',
cardBgClass = 'dark:bg-primary/15',
maxHeightStyle = 'max-height: var(--max-message-height);'
maxHeightStyle = ''
}: Props = $props();
let isMultiline = $state(false);
@@ -59,7 +59,7 @@
{#if content.trim()}
<Card
class="max-w-[80%] overflow-y-auto rounded-[1.125rem] border-none bg-primary/5 px-3.75 py-1.5 {textColorClass} backdrop-blur-md data-[multiline]:py-2.5 {cardBgClass}"
class="chat-message-user-bubble max-w-[80%] overflow-y-auto rounded-[1.125rem] border-none bg-primary/5 px-3.75 py-1.5 {textColorClass} backdrop-blur-md data-multiline:py-2.5 {cardBgClass}"
data-multiline={isMultiline ? '' : undefined}
style="{maxHeightStyle} overflow-wrap: anywhere; word-break: break-word;"
>
@@ -37,6 +37,7 @@
let allConversationMessages = $state<DatabaseMessage[]>([]);
let isVisible = $state(false);
let previousConversationId = $state<string | null>(null);
let previousRouteId = $state<string | null>(null);
const currentConfig = config();
@@ -157,8 +158,9 @@
});
});
beforeNavigate(() => {
beforeNavigate((navigation) => {
isVisible = false;
previousRouteId = navigation.from?.route.id ?? null;
});
afterNavigate(() => {
@@ -249,12 +251,13 @@
</script>
<div
class="transition-opacity delay-300 duration-500 ease-out
{isVisible ? 'opacity-100' : 'opacity-0'}"
class="transition-opacity duration-500 ease-out
{isVisible ? 'opacity-100' : 'opacity-0'}
{previousRouteId === '/(chat)/chat/[id]' ? '' : 'delay-300'}"
>
{#each displayMessages as { message, toolMessages, isLastAssistantMessage, siblingInfo } (message.id)}
<ChatMessage
class="mx-auto mt-12 w-full max-w-[48rem]"
class="mx-auto mt-12 w-full max-w-3xl"
{message}
{toolMessages}
{isLastAssistantMessage}
@@ -1,31 +1,28 @@
<script lang="ts">
import { Trash2 } from '@lucide/svelte';
import { afterNavigate } from '$app/navigation';
import { page } from '$app/state';
import {
ChatScreenForm,
ChatMessages,
ChatScreenDragOverlay,
ChatScreenProcessingInfo,
ChatScreenActionScrollDown,
DialogEmptyFileAlert,
DialogFileUploadError,
DialogChatError,
ServerLoadingSplash,
DialogConfirmation,
ChatScreenServerError
} from '$lib/components/app';
import { setProcessingInfoContext } from '$lib/contexts';
import { ErrorDialogType } from '$lib/enums';
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
import { useChatScreenActiveModel } from '$lib/hooks/use-chat-screen-active-model.svelte';
import { useChatScreenDragAndDrop } from '$lib/hooks/use-chat-screen-drag-and-drop.svelte';
import { useChatScreenFileUpload } from '$lib/hooks/use-chat-screen-file-upload.svelte';
import { useChatScreenScroll } from '$lib/hooks/use-chat-screen-scroll.svelte';
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
import { device } from '$lib/stores/device.svelte';
import { isMobile } from '$lib/stores/viewport.svelte';
import {
chatStore,
errorDialog,
isLoading,
isChatStreaming,
isEditing,
getAddFilesHandler,
activeProcessingState
} from '$lib/stores/chat.svelte';
import {
@@ -34,138 +31,81 @@
activeConversation
} from '$lib/stores/conversations.svelte';
import { config } from '$lib/stores/settings.svelte';
import { serverLoading, serverError, isRouterMode } from '$lib/stores/server.svelte';
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
import { isFileTypeSupported, filterFilesByModalities } from '$lib/utils';
import { parseFilesToMessageExtras, processFilesToChatUploaded } from '$lib/utils/browser-only';
import { onMount } from 'svelte';
import { serverLoading, serverError } from '$lib/stores/server.svelte';
import { parseFilesToMessageExtras } from '$lib/utils/browser-only';
import { onDestroy, onMount } from 'svelte';
import ChatScreenGreeting from './ChatScreenGreeting.svelte';
import ChatScreenActionScrollDown from './ChatScreenActionScrollDown.svelte';
import ChatScreenDialogsAndAlerts from './ChatScreenDialogsAndAlerts.svelte';
import { ROUTES } from '$lib/constants';
let { showCenteredEmpty = false } = $props();
const autoScroll = createAutoScrollController();
let disableAutoScroll = $derived(Boolean(config().disableAutoScroll));
let chatScrollContainer: HTMLDivElement | undefined = $state();
let dragCounter = $state(0);
let isDragOver = $state(false);
let showFileErrorDialog = $state(false);
let uploadedFiles = $state<ChatUploadedFile[]>([]);
let fileErrorData = $state<{
generallyUnsupported: File[];
modalityUnsupported: File[];
modalityReasons: Record<string, string>;
supportedTypes: string[];
}>({
generallyUnsupported: [],
modalityUnsupported: [],
modalityReasons: {},
supportedTypes: []
});
let showDeleteDialog = $state(false);
let showEmptyFileDialog = $state(false);
let processingInfoVisible = $state(false);
let emptyFileNames = $state<string[]>([]);
let initialMessage = $state('');
let isEmpty = $derived(
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
);
let activeErrorDialog = $derived(errorDialog());
let isServerLoading = $derived(serverLoading());
let hasPropsError = $derived(!!serverError());
let isCurrentConversationLoading = $derived(isLoading() || isChatStreaming());
let showProcessingInfo = $derived(
isCurrentConversationLoading ||
(config().keepStatsVisible && !!page.params.id) ||
activeProcessingState() !== null
);
let isRouter = $derived(isRouterMode());
let conversationModel = $derived(
chatStore.getConversationModel(activeMessages() as DatabaseMessage[])
);
let activeModelId = $derived.by(() => {
const options = modelOptions();
if (!isRouter) {
return options.length > 0 ? options[0].model : null;
}
const selectedId = selectedModelId();
if (selectedId) {
const model = options.find((m) => m.id === selectedId);
if (model) return model.model;
}
if (conversationModel) {
const model = options.find((m) => m.model === conversationModel);
if (model) return model.model;
}
return null;
});
let modelPropsVersion = $state(0);
setProcessingInfoContext({
get showProcessingInfo() {
return showProcessingInfo;
}
});
$effect(() => {
if (activeModelId) {
const cached = modelsStore.getModelProps(activeModelId);
let disableAutoScroll = $derived(Boolean(config().disableAutoScroll) || isMobile.current);
let isMobileUserScrolledUp = $state(false);
let mobileScrollDownHint = $state(false);
let mobileScrollDownHintLockedUntil = $state(0);
let emptyFileNames = $state<string[]>([]);
let initialMessage = $state('');
let showDeleteDialog = $state(false);
let showEmptyFileDialog = $state(false);
let isEmpty = $derived(
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
);
let activeErrorDialog = $derived(errorDialog());
let isServerLoading = $derived(serverLoading());
let hasPropsError = $derived(!!serverError());
let isCurrentConversationLoading = $derived(isLoading() || isChatStreaming());
let showProcessingInfo = $derived(
isCurrentConversationLoading ||
(config().keepStatsVisible && !!page.params.id) ||
activeProcessingState() !== null
);
let chatFormBottomPosition = $derived.by(() => {
if (!isMobile.current) return '1rem';
if (device.isStandalone) return '1.5rem';
if (device.isIOSSafari) return '0.25rem';
return '0.5rem';
});
if (!cached) {
modelsStore.fetchModelProps(activeModelId).then(() => {
modelPropsVersion++;
});
const autoScroll = createAutoScrollController();
const scroll = useChatScreenScroll(autoScroll);
const activeModel = useChatScreenActiveModel();
const fileUpload = useChatScreenFileUpload({
capabilities: () => ({
hasVision: activeModel.hasVisionModality,
hasAudio: activeModel.hasAudioModality,
hasVideo: activeModel.hasVideoModality
}),
activeModelId: () => activeModel.activeModelId
});
const dragAndDrop = useChatScreenDragAndDrop({
onDrop: fileUpload.handleFileUpload
});
const { handleKeydown } = useKeyboardShortcuts({
deleteActiveConversation: () => {
if (activeConversation()) {
showDeleteDialog = true;
}
}
});
let hasAudioModality = $derived.by(() => {
if (activeModelId) {
void modelPropsVersion;
function handleMobileScroll() {
if (!isMobile.current) return;
return modelsStore.modelSupportsAudio(activeModelId);
}
const container = scroll.chatScrollContainer;
if (!container) return;
return false;
});
let hasVideoModality = $derived.by(() => {
if (activeModelId) {
void modelPropsVersion;
return modelsStore.modelSupportsVideo(activeModelId);
}
return false;
});
let hasVisionModality = $derived.by(() => {
if (activeModelId) {
void modelPropsVersion;
return modelsStore.modelSupportsVision(activeModelId);
}
return false;
});
const distanceFromBottom =
container.scrollHeight - container.clientHeight - container.scrollTop;
isMobileUserScrolledUp = distanceFromBottom > 300;
}
async function handleDeleteConfirm() {
const conversation = activeConversation();
@@ -177,27 +117,69 @@
showDeleteDialog = false;
}
function handleProcessingInfoVisibility(visible: boolean) {
processingInfoVisible = visible;
}
async function handleSendMessage(message: string, files?: ChatUploadedFile[]): Promise<boolean> {
const plainFiles = files ? $state.snapshot(files) : undefined;
const result = plainFiles
? await parseFilesToMessageExtras(plainFiles, activeModel.activeModelId ?? undefined)
: undefined;
function handleDragEnter(event: DragEvent) {
event.preventDefault();
dragCounter++;
if (event.dataTransfer?.types.includes('Files')) {
isDragOver = true;
if (result?.emptyFiles && result.emptyFiles.length > 0) {
emptyFileNames = result.emptyFiles;
showEmptyFileDialog = true;
if (files) {
const emptyFileNamesSet = new Set(result.emptyFiles);
fileUpload.uploadedFiles = fileUpload.uploadedFiles.filter(
(file) => !emptyFileNamesSet.has(file.name)
);
}
return false;
}
handleSendLikeScroll();
await chatStore.sendMessage(message, result?.extras);
return true;
}
function handleDragLeave(event: DragEvent) {
event.preventDefault();
function handleSendLikeScroll() {
if (!isMobile.current) {
autoScroll.enable();
}
dragCounter--;
setTimeout(() => {
const container = scroll.chatScrollContainer;
if (!container) return;
if (dragCounter === 0) {
isDragOver = false;
const lastUserBubble = container.querySelector(
'.chat-message:nth-last-child(2) .chat-message-user .chat-message-user-bubble'
) as HTMLElement | null;
if (isMobile.current) {
// Keep the last user message bubble just above the input on mobile
const bubbleHeight = lastUserBubble?.scrollHeight ?? 0;
const baseHeight = container.scrollHeight - innerHeight;
container.scrollTo({
top: bubbleHeight > 0 ? baseHeight - bubbleHeight : baseHeight,
behavior: 'smooth'
});
} else if (lastUserBubble) {
// On desktop, place the last user message near the top of the viewport
const topPadding = 24;
const bubbleRect = lastUserBubble.getBoundingClientRect();
container.scrollTo({
top: Math.max(0, container.scrollTop + bubbleRect.top - topPadding),
behavior: 'smooth'
});
} else {
autoScroll.scrollToBottom();
}
}, 100);
if (isMobile.current) {
autoScroll.setDisabled(disableAutoScroll);
mobileScrollDownHint = true;
mobileScrollDownHintLockedUntil = Date.now() + 500;
}
}
@@ -207,273 +189,138 @@
}
}
function handleDragOver(event: DragEvent) {
event.preventDefault();
}
function handleDrop(event: DragEvent) {
event.preventDefault();
isDragOver = false;
dragCounter = 0;
if (event.dataTransfer?.files) {
const files = Array.from(event.dataTransfer.files);
if (isEditing()) {
const handler = getAddFilesHandler();
if (handler) {
handler(files);
return;
}
}
processFiles(files);
}
}
function handleFileRemove(fileId: string) {
uploadedFiles = uploadedFiles.filter((f) => f.id !== fileId);
}
function handleFileUpload(files: File[]) {
processFiles(files);
}
const { handleKeydown } = useKeyboardShortcuts({
deleteActiveConversation: () => {
if (activeConversation()) {
showDeleteDialog = true;
}
}
});
async function handleSystemPromptAdd(draft: { message: string; files: ChatUploadedFile[] }) {
if (draft.message || draft.files.length > 0) {
chatStore.savePendingDraft(draft.message, draft.files);
}
await chatStore.addSystemPrompt();
}
function handleScroll() {
autoScroll.handleScroll();
}
async function handleSendMessage(message: string, files?: ChatUploadedFile[]): Promise<boolean> {
const plainFiles = files ? $state.snapshot(files) : undefined;
const result = plainFiles
? await parseFilesToMessageExtras(plainFiles, activeModelId ?? undefined)
: undefined;
if (result?.emptyFiles && result.emptyFiles.length > 0) {
emptyFileNames = result.emptyFiles;
showEmptyFileDialog = true;
if (files) {
const emptyFileNamesSet = new Set(result.emptyFiles);
uploadedFiles = uploadedFiles.filter((file) => !emptyFileNamesSet.has(file.name));
}
return false;
}
const extras = result?.extras;
// Enable autoscroll for user-initiated message sending
autoScroll.enable();
await chatStore.sendMessage(message, extras);
autoScroll.scrollToBottom();
return true;
}
async function processFiles(files: File[]) {
const generallySupported: File[] = [];
const generallyUnsupported: File[] = [];
for (const file of files) {
if (isFileTypeSupported(file.name, file.type)) {
generallySupported.push(file);
} else {
generallyUnsupported.push(file);
}
}
// Use model-specific capabilities for file validation
const capabilities = {
hasVision: hasVisionModality,
hasAudio: hasAudioModality,
hasVideo: hasVideoModality
};
const { supportedFiles, unsupportedFiles, modalityReasons } = filterFilesByModalities(
generallySupported,
capabilities
);
const allUnsupportedFiles = [...generallyUnsupported, ...unsupportedFiles];
if (allUnsupportedFiles.length > 0) {
const supportedTypes: string[] = ['text files', 'PDFs'];
if (hasVisionModality) supportedTypes.push('images');
if (hasAudioModality) supportedTypes.push('audio files');
if (hasVideoModality) supportedTypes.push('video files');
fileErrorData = {
generallyUnsupported,
modalityUnsupported: unsupportedFiles,
modalityReasons,
supportedTypes
};
showFileErrorDialog = true;
}
if (supportedFiles.length > 0) {
const processed = await processFilesToChatUploaded(
supportedFiles,
activeModelId ?? undefined
);
uploadedFiles = [...uploadedFiles, ...processed];
}
}
afterNavigate(() => {
if (!disableAutoScroll) {
$effect(() => {
const shouldDisableAutoScroll =
config().disableAutoScroll || (isMobile.current && isCurrentConversationLoading);
autoScroll.setDisabled(shouldDisableAutoScroll);
if (!shouldDisableAutoScroll) {
autoScroll.enable();
}
});
function handleMessagesReady() {
if (disableAutoScroll) return;
if (!autoScroll.userScrolledUp) {
requestAnimationFrame(() => {
autoScroll.scrollToBottom('instant');
});
}
}
onMount(() => {
const pendingDraft = chatStore.consumePendingDraft();
if (pendingDraft) {
initialMessage = pendingDraft.message;
fileUpload.uploadedFiles = pendingDraft.files;
}
autoScroll.startObserving();
if (!disableAutoScroll) {
autoScroll.enable();
}
const pendingDraft = chatStore.consumePendingDraft();
if (pendingDraft) {
initialMessage = pendingDraft.message;
uploadedFiles = pendingDraft.files;
if (isMobile.current && isCurrentConversationLoading) {
mobileScrollDownHint = true;
mobileScrollDownHintLockedUntil = Date.now() + 500;
}
handleMobileScroll();
});
$effect(() => {
autoScroll.setContainer(chatScrollContainer);
});
$effect(() => {
autoScroll.setDisabled(disableAutoScroll);
});
onDestroy(() => autoScroll.destroy());
</script>
{#if isDragOver}
{#if dragAndDrop.isDragOver}
<ChatScreenDragOverlay />
{/if}
<svelte:window onkeydown={handleKeydown} />
<svelte:window
onkeydown={handleKeydown}
onscroll={(e) => {
scroll.handleScroll(e);
handleMobileScroll();
if (e.isTrusted && Date.now() > mobileScrollDownHintLockedUntil) {
mobileScrollDownHint = false;
}
}}
/>
{#if isServerLoading}
<ServerLoadingSplash />
{:else}
<div
bind:this={chatScrollContainer}
aria-label="Chat interface with file drop zone"
class="flex h-full flex-col overflow-y-auto px-4 md:px-6"
ondragenter={handleDragEnter}
ondragleave={handleDragLeave}
ondragover={handleDragOver}
ondrop={handleDrop}
onscroll={handleScroll}
class="chat-screen flex grow flex-col min-h-[calc(100dvh-1rem)] md:min-h-full px-4 md:py-0 pt-12 pb-48 md:pb-4"
style:--chat-form-bottom-position={chatFormBottomPosition}
ondragenter={dragAndDrop.dragHandlers.dragenter}
ondragleave={dragAndDrop.dragHandlers.dragleave}
ondragover={dragAndDrop.dragHandlers.dragover}
ondrop={dragAndDrop.dragHandlers.drop}
role="main"
>
<div class="flex grow flex-col pt-14">
{#if !isEmpty}
<ChatMessages
messages={activeMessages()}
onMessagesReady={handleMessagesReady}
onUserAction={() => {
autoScroll.enable();
if (!autoScroll.userScrolledUp) {
autoScroll.scrollToBottom();
}
}}
/>
{/if}
{#if !isEmpty}
<ChatMessages
messages={activeMessages()}
onUserAction={() => {
handleSendLikeScroll();
}}
/>
{/if}
<div
class={[
'pointer-events-none sticky right-4 left-4 mt-auto transition-all duration-200',
isEmpty ? 'bottom-[calc(50dvh-7rem)]' : 'bottom-4 pt-24 md:pt-32'
]}
>
<ChatScreenGreeting {isEmpty} />
<div
class={[
'pointer-events-none md:sticky fixed mt-auto transition-all duration-200',
device.isStandalone
? 'bottom-6 right-4 left-4'
: device.isIOSSafari
? 'bottom-1 left-2 right-2'
: 'bottom-2 right-2 left-2',
isEmpty ? 'md:bottom-[calc(50dvh-7rem)] 2xl:bottom-[calc(50dvh-4rem)]' : 'md:bottom-4'
]}
style:padding-top={!isEmpty ? 'var(--chat-form-padding-top)' : undefined}
>
<ChatScreenGreeting {isEmpty} />
<ChatScreenActionScrollDown
container={chatScrollContainer}
hasProcessingInfoVisible={processingInfoVisible}
/>
<ChatScreenServerError />
<ChatScreenProcessingInfo onVisibilityChange={handleProcessingInfoVisibility} />
<ChatScreenServerError />
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl">
<ChatScreenForm
disabled={hasPropsError || isEditing()}
{initialMessage}
isLoading={isCurrentConversationLoading}
onFileRemove={handleFileRemove}
onFileUpload={handleFileUpload}
onSend={handleSendMessage}
onStop={() => chatStore.stopGeneration()}
onSystemPromptAdd={handleSystemPromptAdd}
bind:uploadedFiles
<div class="pointer-events-none flex flex-col gap-6 items-center w-full">
{#if (isMobile.current ? mobileScrollDownHint || isMobileUserScrolledUp : autoScroll.userScrolledUp) && page.url.hash.includes(ROUTES.CHAT) && page.params.id}
<ChatScreenActionScrollDown
onclick={() => {
mobileScrollDownHint = false;
scroll.chatScrollContainer?.scrollTo({
top: scroll.chatScrollContainer.scrollHeight,
behavior: 'smooth'
});
}}
/>
</div>
{/if}
{#if showProcessingInfo}
<ChatScreenProcessingInfo />
{/if}
</div>
<ChatScreenForm
class="pointer-events-auto conversation-chat-form"
disabled={hasPropsError || isEditing()}
{initialMessage}
isLoading={isCurrentConversationLoading}
onFileRemove={fileUpload.handleFileRemove}
onFileUpload={fileUpload.handleFileUpload}
onSend={handleSendMessage}
onStop={() => chatStore.stopGeneration()}
onSystemPromptAdd={handleSystemPromptAdd}
bind:uploadedFiles={fileUpload.uploadedFiles}
/>
</div>
</div>
{/if}
<DialogFileUploadError bind:open={showFileErrorDialog} {fileErrorData} />
<DialogConfirmation
bind:open={showDeleteDialog}
title="Delete Conversation"
description="Are you sure you want to delete this conversation? This action cannot be undone and will permanently remove all messages in this conversation."
confirmText="Delete"
cancelText="Cancel"
variant="destructive"
icon={Trash2}
onConfirm={handleDeleteConfirm}
onCancel={() => (showDeleteDialog = false)}
/>
<DialogEmptyFileAlert
bind:open={showEmptyFileDialog}
emptyFiles={emptyFileNames}
onOpenChange={(open) => {
if (!open) {
emptyFileNames = [];
}
}}
/>
<DialogChatError
message={activeErrorDialog?.message ?? ''}
contextInfo={activeErrorDialog?.contextInfo}
onOpenChange={handleErrorDialogOpenChange}
open={Boolean(activeErrorDialog)}
type={activeErrorDialog?.type ?? ErrorDialogType.SERVER}
<ChatScreenDialogsAndAlerts
{showDeleteDialog}
{handleDeleteConfirm}
{showEmptyFileDialog}
{emptyFileNames}
{activeErrorDialog}
{handleErrorDialogOpenChange}
{fileUpload}
/>
@@ -1,58 +1,18 @@
<script lang="ts">
import { ArrowDown } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import ActionIcon from '$lib/components/app/actions/ActionIcon.svelte';
interface Props {
container: HTMLDivElement | undefined;
hasProcessingInfoVisible: boolean;
}
let { container, hasProcessingInfoVisible }: Props = $props();
let show = $state(false);
let buttonBottom = $derived(hasProcessingInfoVisible ? '2rem' : '0');
function checkVisibility() {
if (!container) return;
const { scrollTop, scrollHeight, clientHeight } = container;
const distanceFromBottom = scrollHeight - clientHeight - scrollTop;
show = distanceFromBottom > clientHeight * 0.5;
}
function scrollToBottom() {
if (container) {
container.scrollTo({
top: container.scrollHeight,
behavior: 'smooth'
});
}
}
$effect(() => {
const c = container;
if (c) {
c.addEventListener('scroll', checkVisibility);
checkVisibility();
return () => {
c.removeEventListener('scroll', checkVisibility);
};
}
});
let { onclick }: { onclick: (e?: MouseEvent) => void } = $props();
</script>
<div class="relative z-50 mx-auto mb-4 flex max-w-[48rem] justify-center">
<Button
onclick={scrollToBottom}
variant="secondary"
size="icon"
disabled={!show}
class="pointer-events-auto absolute h-10 w-10 rounded-full bg-background/80 shadow-lg backdrop-blur-sm transition-all duration-200 hover:bg-muted/80"
style="bottom: {buttonBottom}; transform: translateY({show ? '0' : '2rem'}); opacity: {show
? 1
: 0};"
aria-label="Scroll to bottom"
>
<ArrowDown class="h-4 w-4" />
</Button>
<div class="pointer-events-auto flex justify-center relative h-0">
<ActionIcon
icon={ArrowDown}
{onclick}
ariaLabel="Scroll to bottom"
tooltip="Scroll to bottom"
size="lg"
iconSize="h-4 w-4"
class="h-9 w-9 rounded-full bg-accent text-accent-foreground absolute bottom-4 shadow-md"
/>
</div>
@@ -0,0 +1,55 @@
<script lang="ts">
import { Trash2 } from '@lucide/svelte';
import { ErrorDialogType } from '$lib/enums';
import {
DialogChatError,
DialogConfirmation,
DialogEmptyFileAlert,
DialogFileUploadError
} from '$lib/components/app';
let {
showDeleteDialog,
handleDeleteConfirm,
showEmptyFileDialog,
emptyFileNames,
activeErrorDialog,
handleErrorDialogOpenChange,
fileUpload
} = $props();
</script>
<DialogFileUploadError
bind:open={fileUpload.showFileErrorDialog}
fileErrorData={fileUpload.fileErrorData}
/>
<DialogConfirmation
bind:open={showDeleteDialog}
title="Delete Conversation"
description="Are you sure you want to delete this conversation? This action cannot be undone and will permanently remove all messages in this conversation."
confirmText="Delete"
cancelText="Cancel"
variant="destructive"
icon={Trash2}
onConfirm={handleDeleteConfirm}
onCancel={() => (showDeleteDialog = false)}
/>
<DialogEmptyFileAlert
bind:open={showEmptyFileDialog}
emptyFiles={emptyFileNames}
onOpenChange={(open) => {
if (!open) {
emptyFileNames = [];
}
}}
/>
<DialogChatError
message={activeErrorDialog?.message ?? ''}
contextInfo={activeErrorDialog?.contextInfo}
onOpenChange={handleErrorDialogOpenChange}
open={Boolean(activeErrorDialog)}
type={activeErrorDialog?.type ?? ErrorDialogType.SERVER}
/>
@@ -2,6 +2,7 @@
import { afterNavigate } from '$app/navigation';
import { page } from '$app/state';
import { ChatForm } from '$lib/components/app';
import { isMobile } from '$lib/stores/viewport.svelte';
import { onMount } from 'svelte';
import { useDraftMessages } from '$lib/hooks/use-draft-messages.svelte';
@@ -32,7 +33,30 @@
}: Props = $props();
let chatFormRef: ChatForm | undefined = $state(undefined);
let formWrapperEl: HTMLDivElement | undefined = $state();
let chatId = $derived(page.params.id as string | undefined);
$effect(() => {
if (!formWrapperEl) return;
const formEl = formWrapperEl.querySelector('form') as HTMLElement | null;
if (!formEl) return;
const updateHeight = () => {
const height = Math.round(formEl.getBoundingClientRect().height);
document.documentElement.style.setProperty('--chat-form-height', `${height}px`);
};
updateHeight();
const resizeObserver = new ResizeObserver(updateHeight);
resizeObserver.observe(formEl);
return () => {
resizeObserver.disconnect();
document.documentElement.style.removeProperty('--chat-form-height');
};
});
let hasLoadingAttachments = $derived(uploadedFiles.some((f) => f.isLoading));
let message = $derived(initialMessage);
let previousIsLoading = $derived(isLoading);
@@ -83,12 +107,14 @@
}
onMount(() => {
setTimeout(() => chatFormRef?.focus(), 10);
if (!isMobile.current) {
setTimeout(() => chatFormRef?.focus(), 100);
}
});
afterNavigate((navigation) => {
if (navigation?.from != null) {
setTimeout(() => chatFormRef?.focus(), 10);
if (navigation?.from != null && !isMobile.current) {
setTimeout(() => chatFormRef?.focus(), 100);
}
});
@@ -108,12 +134,12 @@
});
</script>
<div class="relative mx-auto max-w-[48rem]">
<div class="chat-screen-form-wrapper" bind:this={formWrapperEl}>
<ChatForm
class="mx-auto max-w-3xl {className}"
bind:this={chatFormRef}
bind:value={message}
bind:uploadedFiles
class={className}
{disabled}
{isLoading}
showMcpPromptButton
@@ -1,5 +1,4 @@
<script lang="ts">
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
import { serverStore } from '$lib/stores/server.svelte';
interface Props {
@@ -11,10 +10,9 @@
<div
class={[
'pointer-events-none mb-4 hidden px-4 text-center',
isEmpty && 'pointer-events-auto block!'
'pointer-events-none mb-4 hidden px-4 text-center text-balance',
isEmpty && 'mb-[calc(50dvh-8rem)] md:mb-6 pointer-events-auto block!'
]}
use:fadeInView={{ duration: 300 }}
>
<h1 class="mb-2 text-2xl font-semibold tracking-tight md:text-3xl">Hello there</h1>
@@ -5,13 +5,8 @@
import { chatStore, isLoading, isChatStreaming } from '$lib/stores/chat.svelte';
import { activeMessages, activeConversation } from '$lib/stores/conversations.svelte';
import { config } from '$lib/stores/settings.svelte';
import { getProcessingInfoContext } from '$lib/contexts';
import { page } from '$app/state';
const processingState = useProcessingState();
const processingInfoCtx = getProcessingInfoContext();
let showProcessingInfo = $derived(processingInfoCtx.showProcessingInfo);
let isCurrentConversationLoading = $derived(isLoading());
let isStreaming = $derived(isChatStreaming());
@@ -70,8 +65,8 @@
<div
class={[
'chat-processing-info-container pointer-events-none relative',
page.params.id && showProcessingInfo && 'visible'
'chat-processing-info-container pointer-events-none relative w-full hidden md:block',
processingVisible && 'visible'
]}
>
<div class="chat-processing-info-content absolute bottom-4 left-1/2 -translate-x-1/2">
@@ -677,13 +677,6 @@ export { default as ChatScreenForm } from './ChatScreen/ChatScreenForm.svelte';
*/
export { default as ChatScreenProcessingInfo } from './ChatScreen/ChatScreenProcessingInfo.svelte';
/**
* Scroll-to-bottom action button. Displays a floating button when the user
* has scrolled up more than half a viewport height from the bottom.
* Takes the chat container element as a prop to manage scroll state internally.
*/
export { default as ChatScreenActionScrollDown } from './ChatScreen/ChatScreenActionScrollDown.svelte';
/**
* Server error alert displayed when the server is unreachable.
* Shows the error message with a retry button.
@@ -3,6 +3,7 @@
import { Search, X } from '@lucide/svelte';
interface Props {
autofocus?: boolean;
value?: string;
placeholder?: string;
onInput?: (value: string) => void;
@@ -15,6 +16,7 @@
}
let {
autofocus,
value = $bindable(''),
placeholder = 'Search...',
onInput,
@@ -39,7 +41,7 @@
if (value) {
value = '';
onInput?.('');
ref?.focus();
ref?.focus({ preventScroll: true });
} else {
onClose?.();
}
@@ -52,6 +54,7 @@
/>
<Input
{autofocus}
{id}
bind:value
bind:ref
@@ -0,0 +1,15 @@
<script>
import logoMark from '$lib/assets/logo.svg?raw';
let { class: className = '', style = '' } = $props();
</script>
<div class={className} {style}>
{@html logoMark}
</div>
<style>
div :global(svg) {
width: var(--size, 1rem);
height: var(--size, 1rem);
}
</style>
@@ -51,3 +51,11 @@ export { default as KeyboardShortcutInfo } from './KeyboardShortcutInfo.svelte';
* Preview button is shown only for HTML code blocks.
*/
export { default as CodeBlockActions } from './CodeBlockActions.svelte';
/**
* **Logo** - Application brand mark
*
* Inline SVG of the application logo. Accepts styling via the standard
* `class` and `style` props and inherits color via `currentColor`.
*/
export { default as Logo } from './Logo.svelte';
@@ -0,0 +1,11 @@
<script lang="ts">
let { percent }: { percent: number } = $props();
</script>
<!-- thin determinate load bar pinned to the bottom edge, pulsing while it fills -->
<div class="pointer-events-none absolute inset-x-0 bottom-0 h-0.5 overflow-hidden rounded-b-sm">
<div
class="h-full animate-pulse bg-primary transition-[width] duration-200 ease-out"
style="width: {percent}%"
></div>
</div>
@@ -2,8 +2,10 @@
import { ChevronDown, Loader2, Package } from '@lucide/svelte';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import * as Tooltip from '$lib/components/ui/tooltip';
import { KeyboardKey } from '$lib/enums';
import { KeyboardKey, ServerModelStatus } from '$lib/enums';
import { useModelsSelector } from '$lib/hooks/use-models-selector.svelte';
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
import { modelLoadFraction } from '$lib/utils';
import {
DialogModelInformation,
DropdownMenuSearchable,
@@ -11,6 +13,7 @@
ModelsSelectorList,
ModelsSelectorOption
} from '$lib/components/app';
import ModelLoadHighlight from './ModelLoadHighlight.svelte';
import type { ModelItem } from './utils';
interface Props {
@@ -113,6 +116,17 @@
{/if}
{:else}
{@const selectedOption = ms.getDisplayOption()}
{@const triggerModel = selectedOption?.model}
{@const triggerStatus = triggerModel
? routerModels().find((m) => m.id === triggerModel)?.status?.value
: undefined}
{@const triggerLoading =
!!triggerModel &&
(triggerStatus === ServerModelStatus.LOADING ||
modelsStore.isModelOperationInProgress(triggerModel))}
{@const triggerLoadPercent = triggerLoading
? Math.round(modelLoadFraction(modelsStore.getLoadProgress(triggerModel)) * 100)
: 0}
{#if ms.isRouter}
<DropdownMenu.Root bind:open={isOpen} onOpenChange={ms.handleOpenChange}>
@@ -123,7 +137,7 @@
<DropdownMenu.Trigger
{...props}
class={[
`inline-grid cursor-pointer grid-cols-[1fr_auto_1fr] items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
`relative inline-grid cursor-pointer grid-cols-[1fr_auto_1fr] items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
!ms.isCurrentModelInCache
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
: forceForegroundText
@@ -154,6 +168,10 @@
{:else}
<ChevronDown class="h-3 w-3.5 shrink-0" />
{/if}
{#if triggerLoading}
<ModelLoadHighlight percent={triggerLoadPercent} />
{/if}
</DropdownMenu.Trigger>
{/snippet}
</Tooltip.Trigger>
@@ -10,6 +10,7 @@
RotateCw
} from '@lucide/svelte';
import { ActionIcon, ModelId } from '$lib/components/app';
import ModelLoadHighlight from './ModelLoadHighlight.svelte';
import type { ModelOption } from '$lib/types/models';
import { ServerModelStatus } from '$lib/enums';
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
@@ -119,11 +120,11 @@
</div>
{#if isLoading}
<div class="flex w-4 [@media(pointer:coarse)]:w-5 items-center justify-center">
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-5">
<Loader2 class="h-4 w-4 animate-spin text-muted-foreground" />
</div>
{:else if isFailed}
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
<CircleAlert
class="h-3.5 w-3.5 text-red-500 group-hover:hidden [@media(pointer:coarse)]:hidden"
/>
@@ -140,7 +141,7 @@
</div>
</div>
{:else if isSleeping}
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
<span
class="h-2 w-2 rounded-full bg-orange-400 group-hover:hidden [@media(pointer:coarse)]:hidden"
></span>
@@ -159,7 +160,7 @@
</div>
</div>
{:else if isLoaded}
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
<span
class="h-2 w-2 rounded-full bg-green-500 group-hover:hidden [@media(pointer:coarse)]:hidden"
></span>
@@ -176,7 +177,7 @@
</div>
</div>
{:else}
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
<span
class="h-2 w-2 rounded-full bg-muted-foreground/50 group-hover:hidden [@media(pointer:coarse)]:hidden"
></span>
@@ -196,13 +197,6 @@
</div>
{#if isLoading}
<div
class="pointer-events-none absolute inset-x-0 bottom-0 h-0.5 overflow-hidden rounded-b-sm bg-muted"
>
<div
class="h-full bg-primary transition-[width] duration-200 ease-out"
style="width: {loadPercent}%"
></div>
</div>
<ModelLoadHighlight percent={loadPercent} />
{/if}
</div>
@@ -8,6 +8,10 @@
ModelsSelectorList,
SearchInput
} from '$lib/components/app';
import ModelLoadHighlight from './ModelLoadHighlight.svelte';
import { ServerModelStatus } from '$lib/enums';
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
import { modelLoadFraction } from '$lib/utils';
interface Props {
class?: string;
@@ -61,12 +65,23 @@
<p class="text-xs text-muted-foreground">No models available.</p>
{:else}
{@const selectedOption = ms.getDisplayOption()}
{@const triggerModel = selectedOption?.model}
{@const triggerStatus = triggerModel
? routerModels().find((m) => m.id === triggerModel)?.status?.value
: undefined}
{@const triggerLoading =
!!triggerModel &&
(triggerStatus === ServerModelStatus.LOADING ||
modelsStore.isModelOperationInProgress(triggerModel))}
{@const triggerLoadPercent = triggerLoading
? Math.round(modelLoadFraction(modelsStore.getLoadProgress(triggerModel)) * 100)
: 0}
{#if ms.isRouter}
<button
type="button"
class={[
`inline-flex cursor-pointer items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 max-sm:px-3 max-sm:py-2 text-xs max-sm:text-sm shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
`relative inline-flex cursor-pointer items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 max-sm:px-3 max-sm:py-2 max-sm:text-sm dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
!ms.isCurrentModelInCache
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
: forceForegroundText
@@ -99,6 +114,10 @@
{:else}
<ChevronDown class="h-3 w-3.5 shrink-0" />
{/if}
{#if triggerLoading}
<ModelLoadHighlight percent={triggerLoadPercent} />
{/if}
</button>
<Sheet.Root bind:open={sheetOpen} onOpenChange={handleSheetOpenChange}>
@@ -1,84 +0,0 @@
<script lang="ts">
import { goto } from '$app/navigation';
import { page } from '$app/state';
import { ActionIcon } from '$lib/components/app';
import {
ICON_STRIP_TRANSITION_DURATION,
ICON_STRIP_TRANSITION_DELAY_MULTIPLIER,
SIDEBAR_ACTIONS_ITEMS
} from '$lib/constants';
import { TooltipSide } from '$lib/enums';
import { fade } from 'svelte/transition';
import { circIn } from 'svelte/easing';
import { onMount } from 'svelte';
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
interface Props {
sidebarOpen: boolean;
onSearchClick: () => void;
}
let { sidebarOpen = false, onSearchClick }: Props = $props();
const { handleKeydown } = useKeyboardShortcuts({ activateSearchMode: () => onSearchClick() });
let initialized = $state(false);
let showIcons = $derived(!sidebarOpen);
showIcons = false;
onMount(() => {
showIcons = !sidebarOpen;
setTimeout(() => {
initialized = true;
}, ICON_STRIP_TRANSITION_DELAY_MULTIPLIER * SIDEBAR_ACTIONS_ITEMS.length);
});
</script>
<svelte:window onkeydown={handleKeydown} />
<div
class="hidden shrink-0 transition-[width] duration-200 ease-linear md:block {sidebarOpen
? 'w-0'
: 'w-[calc(var(--sidebar-width-icon)+1.5rem)]'}"
></div>
<aside
class="fixed top-0 bottom-0 left-0 z-10 hidden w-[calc(var(--sidebar-width-icon)+1.5rem)] flex-col items-center justify-between py-3 transition-opacity duration-200 ease-linear md:flex {sidebarOpen
? 'pointer-events-none opacity-0'
: 'opacity-100'}"
>
<div class="mt-12 flex flex-col items-center gap-1">
{#each SIDEBAR_ACTIONS_ITEMS as item, i (item.tooltip)}
{@const onclick = item.route ? () => goto(item.route!) : onSearchClick}
{@const isActive = item.activeRouteId
? page.route.id === item.activeRouteId
: item.activeRoutePrefix
? !!page.route.id?.startsWith(item.activeRoutePrefix)
: false}
{#if showIcons}
<div
in:fade={{
duration: ICON_STRIP_TRANSITION_DURATION,
delay: !initialized
? ICON_STRIP_TRANSITION_DELAY_MULTIPLIER + i * ICON_STRIP_TRANSITION_DELAY_MULTIPLIER
: 0,
easing: circIn
}}
>
<ActionIcon
icon={item.icon}
tooltip={item.tooltip}
tooltipSide={TooltipSide.RIGHT}
size="lg"
iconSize="h-4 w-4"
class="h-9 w-9 rounded-full hover:bg-accent! {isActive
? 'bg-accent text-accent-foreground'
: ''}"
{onclick}
/>
</div>
{/if}
{/each}
</div>
</aside>
@@ -1,40 +1,67 @@
<script lang="ts">
import { goto } from '$app/navigation';
import { page } from '$app/state';
import { Trash2, Pencil, Pin, X } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import { DialogConfirmation } from '$lib/components/app';
import SidebarNavigationActions from './SidebarNavigationActions.svelte';
import SidebarNavigationConversationItem from './SidebarNavigationConversationItem.svelte';
import { Checkbox } from '$lib/components/ui/checkbox';
import Label from '$lib/components/ui/label/label.svelte';
import ScrollArea from '$lib/components/ui/scroll-area/scroll-area.svelte';
import * as Sidebar from '$lib/components/ui/sidebar';
import Input from '$lib/components/ui/input/input.svelte';
import { ROUTES } from '$lib/constants/routes';
import { RouterService } from '$lib/services/router.service';
import { PanelLeftClose, PanelLeftOpen, X } from '@lucide/svelte';
import {
conversationsStore,
conversations,
buildConversationTree
} from '$lib/stores/conversations.svelte';
import { chatStore } from '$lib/stores/chat.svelte';
import { getPreviewText } from '$lib/utils';
import { APP_NAME } from '$lib/constants';
ActionIcon,
Logo,
SidebarNavigationConversationList,
SidebarNavigationActions
} from '$lib/components/app';
import { ROUTES } from '$lib/constants';
import { fade } from 'svelte/transition';
const sidebar = Sidebar.useSidebar();
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
import { chatStore } from '$lib/stores/chat.svelte';
import { RouterService } from '$lib/services/router.service';
import { isMobile } from '$lib/stores/viewport.svelte';
import { TooltipSide } from '$lib/enums';
import { device } from '$lib/stores/device.svelte';
import { circIn } from 'svelte/easing';
interface Props {
onSearchClick?: () => void;
}
let { onSearchClick = () => {} }: Props = $props();
const { handleKeydown } = useKeyboardShortcuts({ activateSearchMode: () => onSearchClick() });
let isExpandedMode = $state(false);
let hoveredTooltip = $state<string | null>(null);
let logoHovered = $state(false);
const isStripExpanded = $derived(isExpandedMode || hoveredTooltip !== null);
const isOnMobile = $derived(isMobile.current);
function toggleExpandedMode() {
isExpandedMode = !isExpandedMode;
if (!isExpandedMode) {
hoveredTooltip = null;
}
}
$effect(() => {
if (!isExpandedMode) {
isSearchModeActive = false;
searchQuery = '';
cancelMobileCollapse();
}
});
// On mobile the dedicated /search route hides the sidebar (see the aside
// render guard below). Collapse it as we enter /search so it doesn't
// reappear expanded when the user navigates back via the back button.
$effect(() => {
if (isMobile.current && page.url.hash.includes(ROUTES.SEARCH)) {
isExpandedMode = false;
}
});
let currentChatId = $derived(page.params.id);
let isSearchModeActive = $state(false);
let searchQuery = $state('');
let showDeleteDialog = $state(false);
let deleteWithForks = $state(false);
let showEditDialog = $state(false);
let selectedConversation = $state<DatabaseConversation | null>(null);
let editedName = $state('');
let selectedConversationNamePreview = $derived.by(() =>
selectedConversation ? getPreviewText(selectedConversation.name) : ''
);
let filteredConversations = $derived.by(() => {
if (isSearchModeActive) {
@@ -50,294 +77,206 @@
return conversations();
});
let conversationTree = $derived(buildConversationTree(filteredConversations));
let pinnedConversations = $derived.by(() => {
return conversationTree.filter(({ conversation }) => conversation.pinned);
});
let unpinnedConversations = $derived.by(() => {
return conversationTree.filter(({ conversation }) => !conversation.pinned);
});
let selectedConversationHasDescendants = $derived.by(() => {
if (!selectedConversation) return false;
const allConvs = conversations();
const queue = [selectedConversation.id];
while (queue.length > 0) {
const parentId = queue.pop()!;
for (const c of allConvs) {
if (c.forkedFromConversationId === parentId) return true;
}
}
return false;
});
async function handleDeleteConversation(id: string) {
const conversation = conversations().find((conv) => conv.id === id);
if (conversation) {
selectedConversation = conversation;
deleteWithForks = false;
showDeleteDialog = true;
async function selectConversation(id: string) {
if (isMobile.current) {
scheduleMobileCollapse();
}
await goto(RouterService.chat(id));
}
async function handleEditConversation(id: string) {
const conversation = conversations().find((conv) => conv.id === id);
if (conversation) {
selectedConversation = conversation;
editedName = conversation.name;
showEditDialog = true;
if (!conversation) return;
const newName = window.prompt('Rename conversation', conversation.name);
if (newName && newName.trim()) {
await conversationsStore.updateConversationName(id, newName.trim());
}
}
function handleConfirmDelete() {
if (selectedConversation) {
const convId = selectedConversation.id;
const withForks = deleteWithForks;
showDeleteDialog = false;
async function handleDeleteConversation(id: string) {
const conversation = conversations().find((conv) => conv.id === id);
if (!conversation) return;
setTimeout(() => {
conversationsStore.deleteConversation(convId, {
deleteWithForks: withForks
});
}, 100); // Wait for animation to finish
}
}
const confirmed = window.confirm(
`Delete "${conversation.name}"? This action cannot be undone.`
);
if (!confirmed) return;
function handleConfirmEdit() {
if (!editedName.trim() || !selectedConversation) return;
showEditDialog = false;
conversationsStore.updateConversationName(selectedConversation.id, editedName);
selectedConversation = null;
}
export function handleMobileSidebarItemClick() {
if (sidebar.isMobile) {
sidebar.toggle();
}
}
let chatSidebarActions: { activateSearch?: () => void } | undefined = $state();
let openedForSearch = $state(false);
export function activateSearchMode() {
if (!sidebar.open) {
openedForSearch = true;
}
chatSidebarActions?.activateSearch?.();
}
function handleSearchDeactivated() {
if (openedForSearch) {
openedForSearch = false;
sidebar.toggle();
}
}
$effect(() => {
if (!sidebar.open) {
isSearchModeActive = false;
searchQuery = '';
openedForSearch = false;
}
});
export function editActiveConversation() {
if (currentChatId) {
const activeConversation = filteredConversations.find((conv) => conv.id === currentChatId);
if (activeConversation) {
const event = new CustomEvent('edit-active-conversation', {
detail: { conversationId: currentChatId }
});
document.dispatchEvent(event);
}
}
}
async function selectConversation(id: string) {
if (isSearchModeActive) {
isSearchModeActive = false;
searchQuery = '';
}
handleMobileSidebarItemClick();
await goto(RouterService.chat(id));
await conversationsStore.deleteConversation(id, { deleteWithForks: false });
}
function handleStopGeneration(id: string) {
chatStore.stopGenerationForChat(id);
}
let innerWidth = $state(0);
let pendingCollapse = $state<ReturnType<typeof setTimeout> | null>(null);
function scheduleMobileCollapse() {
if (pendingCollapse) {
clearTimeout(pendingCollapse);
}
pendingCollapse = setTimeout(() => {
isExpandedMode = false;
pendingCollapse = null;
}, 100);
}
function cancelMobileCollapse() {
if (pendingCollapse) {
clearTimeout(pendingCollapse);
pendingCollapse = null;
}
}
</script>
<div class="flex h-full flex-col">
<ScrollArea class="h-full flex-1">
<Sidebar.Header class="gap-4 bg-sidebar/50 p-3 backdrop-blur-lg md:pt-4 md:pb-2">
<div class="flex items-center justify-between">
<a href={ROUTES.START} onclick={handleMobileSidebarItemClick}>
<h1 class="inline-flex items-center gap-1 px-2 text-xl font-semibold">
{APP_NAME}
</h1>
</a>
<svelte:window onkeydown={handleKeydown} bind:innerWidth />
<Button
class="rounded-full md:hidden"
variant="ghost"
size="icon"
onclick={() => sidebar.toggle()}
>
<X class="h-4 w-4" />
<span class="sr-only">Close sidebar</span>
</Button>
{#if innerWidth > 768 || (!page.url.hash.includes(ROUTES.SETTINGS) && !page.url.hash.includes(ROUTES.MCP_SERVERS) && !page.url.hash.includes(ROUTES.SEARCH))}
<aside
class={[
// Layout & positioning
'fixed md:sticky top-2 left-2 md:left-0 md:ml-2 md:mt-2 pt-2 z-10 w-[calc(100dvw-1rem)]',
// Dimensions & overflow
'md:h-[calc(100dvh-1.125rem)]',
isExpandedMode &&
(device.isStandalone
? 'h-[calc(100dvh-2rem)]'
: device.isIOSDevice
? 'h-[calc(100dvh-0.5rem)]'
: 'h-[calc(100dvh-1rem)]'),
// Shape & depth
'rounded-3xl md:rounded-2xl',
// Flex layout
'flex flex-col justify-between',
// Transition
'md:transition-[width,padding] duration-200 ease-out',
// Expanded state: width, surface, depth
isStripExpanded && 'md:w-72 md:bg-muted/60 md:backdrop-blur-xl border-border shadow-md',
// Collapsed state
!isStripExpanded && 'md:w-12',
// Expanded mode flag (for mobile ::before overlay)
isExpandedMode && 'is-expanded'
]}
>
<div class="px-2 flex items-center justify-between">
<div
role="button"
tabindex="0"
class="relative"
onmouseenter={() => (logoHovered = true)}
onmouseleave={() => (logoHovered = false)}
>
<ActionIcon
icon={!isExpandedMode && logoHovered && innerWidth > 768 ? PanelLeftOpen : Logo}
size="lg"
iconSize="h-4.5 w-4.5 md:h-4 md:w-4"
class="{isExpandedMode
? 'bg-muted! md:bg-foreground/5!'
: 'bg-transparent!'} md:h-9 md:w-9 h-10 w-10 rounded-full md:hover:bg-foreground/10! pointer-events-auto"
href={isExpandedMode ? ROUTES.START : undefined}
onclick={isExpandedMode ? undefined : toggleExpandedMode}
tooltip={isExpandedMode ? undefined : 'Open Sidebar'}
tooltipSide={TooltipSide.RIGHT}
ariaLabel={isExpandedMode ? 'Go to start' : 'Expand navigation'}
/>
</div>
<SidebarNavigationActions
bind:this={chatSidebarActions}
{handleMobileSidebarItemClick}
bind:isSearchModeActive
bind:searchQuery
onSearchDeactivated={handleSearchDeactivated}
/>
</Sidebar.Header>
{#if !isSearchModeActive && pinnedConversations.length > 0}
<Sidebar.Group class="p-0 px-4">
<Sidebar.GroupLabel>
<div class="flex items-center gap-1">
<Pin class="h-3.5 w-3.5" />
<span>Pinned</span>
</div>
</Sidebar.GroupLabel>
<Sidebar.GroupContent>
<Sidebar.Menu>
{#each pinnedConversations as { conversation, depth } (conversation.id)}
<Sidebar.MenuItem class="mb-1 p-0">
<SidebarNavigationConversationItem
conversation={{
id: conversation.id,
name: conversation.name,
lastModified: conversation.lastModified,
currNode: conversation.currNode,
forkedFromConversationId: conversation.forkedFromConversationId,
pinned: conversation.pinned
}}
{depth}
isActive={currentChatId === conversation.id}
onSelect={selectConversation}
onEdit={handleEditConversation}
onDelete={handleDeleteConversation}
onStop={handleStopGeneration}
/>
</Sidebar.MenuItem>
{/each}
</Sidebar.Menu>
</Sidebar.GroupContent>
</Sidebar.Group>
{/if}
<Sidebar.Group class="mt-2 h-[calc(100vh-21rem)] space-y-2 p-0 px-3">
{#if (filteredConversations.length > 0 && isSearchModeActive) || !isSearchModeActive}
<Sidebar.GroupLabel>
{isSearchModeActive ? 'Search results' : 'Recent conversations'}
</Sidebar.GroupLabel>
{#if isExpandedMode || isOnMobile}
<div
class="flex items-center transition-all duration-150 ease-out {isMobile.current &&
!isExpandedMode
? 'opacity-0 h-0!'
: ''}"
in:fade={{ duration: 150, easing: circIn, delay: 50 }}
out:fade={{ duration: 100 }}
>
<ActionIcon
icon={isMobile.current ? X : PanelLeftClose}
size="lg"
iconSize="h-4.5 w-4.5 md:h-4 md:w-4"
class="backdrop-blur-none md:h-9 md:w-9 h-10 w-10 rounded-full mr-1 hover:bg-accent!"
onclick={toggleExpandedMode}
tooltip="Close Sidebar"
tooltipSide={TooltipSide.LEFT}
ariaLabel="Collapse navigation"
/>
</div>
{/if}
<Sidebar.GroupContent>
<Sidebar.Menu>
{#each isSearchModeActive ? conversationTree : unpinnedConversations as { conversation, depth } (conversation.id)}
<Sidebar.MenuItem class="mb-1 p-0">
<SidebarNavigationConversationItem
conversation={{
id: conversation.id,
name: conversation.name,
lastModified: conversation.lastModified,
currNode: conversation.currNode,
forkedFromConversationId: conversation.forkedFromConversationId,
pinned: conversation.pinned
}}
{depth}
isActive={currentChatId === conversation.id}
onSelect={selectConversation}
onEdit={handleEditConversation}
onDelete={handleDeleteConversation}
onStop={handleStopGeneration}
/>
</Sidebar.MenuItem>
{/each}
{#if (isSearchModeActive ? conversationTree : unpinnedConversations).length === 0}
<div class="px-2 py-4 text-center">
<p class="mb-4 p-4 text-sm text-muted-foreground">
{searchQuery.length > 0
? 'No results found'
: isSearchModeActive
? 'Start typing to see results'
: 'No conversations yet'}
</p>
</div>
{/if}
</Sidebar.Menu>
</Sidebar.GroupContent>
</Sidebar.Group>
</ScrollArea>
</div>
<DialogConfirmation
bind:open={showDeleteDialog}
title="Delete Conversation"
description={selectedConversation
? `Are you sure you want to delete "${selectedConversationNamePreview}"? This action cannot be undone and will permanently remove all messages in this conversation.`
: ''}
confirmText="Delete"
cancelText="Cancel"
variant="destructive"
icon={Trash2}
onConfirm={handleConfirmDelete}
onCancel={() => {
showDeleteDialog = false;
selectedConversation = null;
}}
>
{#if selectedConversationHasDescendants}
<div class="flex items-center gap-2 py-2">
<Checkbox id="delete-with-forks" bind:checked={deleteWithForks} />
<Label for="delete-with-forks" class="text-sm">Also delete all forked conversations</Label>
</div>
{/if}
</DialogConfirmation>
<DialogConfirmation
bind:open={showEditDialog}
title="Edit Conversation Name"
description=""
confirmText="Save"
cancelText="Cancel"
icon={Pencil}
onConfirm={handleConfirmEdit}
onCancel={() => {
showEditDialog = false;
selectedConversation = null;
}}
onKeydown={(event) => {
if (event.key === 'Enter') {
event.preventDefault();
event.stopImmediatePropagation();
handleConfirmEdit();
<div class="mt-2 flex min-h-0 flex-1 flex-col gap-4 md:gap-1 overflow-y-auto">
<div
class="flex min-h-0 flex-1 flex-col gap-4 md:gap-1 {isMobile.current
? 'transition-[opacity,height] duration-200 ease-out'
: ''} {isMobile.current && !isExpandedMode ? 'opacity-0 !h-0' : ''}"
in:fade={{ duration: 200 }}
out:fade={{ duration: 200 }}
>
<SidebarNavigationActions
isExpandedMode={innerWidth > 768 ? isExpandedMode : true}
class="px-2"
bind:isSearchModeActive
bind:searchQuery
onSearchDeactivated={() => {
isSearchModeActive = false;
searchQuery = '';
}}
onSearchClick={() => {
isExpandedMode = true;
isSearchModeActive = true;
}}
onNewChat={() => {
if (isMobile.current) {
scheduleMobileCollapse();
}
}}
/>
{#if isExpandedMode || isOnMobile}
<SidebarNavigationConversationList
class="px-2"
{filteredConversations}
{currentChatId}
{isSearchModeActive}
{searchQuery}
onSelect={selectConversation}
onEdit={handleEditConversation}
onDelete={handleDeleteConversation}
onStop={handleStopGeneration}
/>
{/if}
</div>
</div>
</aside>
{/if}
<style>
aside {
@media (max-width: 768px) {
--size: 1.125rem;
}
}}
>
<Input
class="text-foreground"
placeholder="Enter a new name"
type="text"
bind:value={editedName}
/>
</DialogConfirmation>
}
@media (max-width: 768px) {
aside {
&:not(.is-expanded) {
pointer-events: none;
}
}
aside.is-expanded::before {
content: '';
position: fixed;
top: -0.5rem;
bottom: -0.25rem;
left: -0.5rem;
right: -0.5rem;
z-index: -1;
background: var(--background);
backdrop-filter: blur(1rem);
pointer-events: none;
}
}
</style>
@@ -1,39 +1,86 @@
<script lang="ts">
import { KeyboardShortcutInfo } from '$lib/components/app';
import { Button } from '$lib/components/ui/button';
import type { Component } from 'svelte';
import { SearchInput } from '$lib/components/app';
import { goto } from '$app/navigation';
import { page } from '$app/state';
import { SIDEBAR_ACTIONS_ITEMS } from '$lib/constants/ui';
import { Search } from '@lucide/svelte';
import { ActionIcon, KeyboardShortcutInfo, SearchInput } from '$lib/components/app';
import { Button } from '$lib/components/ui/button';
import {
ICON_STRIP_TRANSITION_DURATION,
ICON_STRIP_TRANSITION_DELAY_MULTIPLIER,
ROUTES,
SIDEBAR_ACTIONS_ITEMS
} from '$lib/constants';
import { isMobile } from '$lib/stores/viewport.svelte';
import { TooltipSide } from '$lib/enums';
import { fade } from 'svelte/transition';
import { circIn } from 'svelte/easing';
import { onMount } from 'svelte';
import type { Component } from 'svelte';
interface Props {
handleMobileSidebarItemClick: () => void;
class: string;
isExpandedMode: boolean;
isSearchModeActive: boolean;
searchQuery: string;
isCancelAlwaysVisible?: boolean;
onSearchDeactivated?: () => void;
onSearchClick?: () => void;
onNewChat?: () => void;
}
let {
handleMobileSidebarItemClick,
isSearchModeActive = $bindable(),
searchQuery = $bindable(),
isCancelAlwaysVisible = false,
onSearchDeactivated
class: className,
isExpandedMode = false,
isSearchModeActive = $bindable(false),
searchQuery = $bindable(''),
onSearchDeactivated,
onSearchClick,
onNewChat
}: Props = $props();
let initialized = $state(false);
let showIcons = $state(false);
let searchInputRef = $state<HTMLInputElement | null>(null);
const isOnMobile = $derived(isMobile.current);
$effect(() => {
if (isSearchModeActive && searchInputRef) {
searchInputRef.focus();
}
});
onMount(() => {
showIcons = true;
setTimeout(() => {
initialized = true;
}, ICON_STRIP_TRANSITION_DELAY_MULTIPLIER * SIDEBAR_ACTIONS_ITEMS.length);
});
function handleSearchModeDeactivate() {
isSearchModeActive = false;
searchQuery = '';
onSearchDeactivated?.();
}
export function activateSearch() {
isSearchModeActive = true;
// Focus after Svelte renders the input
queueMicrotask(() => searchInputRef?.focus());
function isItemActive(item: {
activeRouteId?: string;
activeRoutePrefix?: string;
activeUrlIncludes?: string;
}): boolean {
if (item.activeRouteId) {
return page.route.id === item.activeRouteId;
}
if (item.activeRoutePrefix) {
return !!page.route.id?.startsWith(item.activeRoutePrefix);
}
if (item.activeUrlIncludes) {
return page.url?.hash?.includes(item.activeUrlIncludes) ?? false;
}
return false;
}
</script>
@@ -41,56 +88,109 @@
<IconComponent class="h-4 w-4" />
{/snippet}
<div class="my-1 space-y-1">
{#if isSearchModeActive}
{#if isSearchModeActive}
<div class="px-4 my-2">
<SearchInput
bind:value={searchQuery}
bind:ref={searchInputRef}
onClose={handleSearchModeDeactivate}
onKeyDown={(e) => e.key === 'Escape' && handleSearchModeDeactivate()}
placeholder="Search conversations..."
{isCancelAlwaysVisible}
/>
{:else}
{#each SIDEBAR_ACTIONS_ITEMS as item (item.route)}
{#if !item.route}
<Button
class="w-full justify-between px-2 backdrop-blur-none! hover:[&>kbd]:opacity-100"
onclick={activateSearch}
variant="ghost"
>
<div class="flex items-center gap-2">
{@render itemIcon(item.icon)}
</div>
{:else if isExpandedMode || isOnMobile}
<div
class="{className} flex flex-col gap-5 md:gap-1 mt-2 md:mt-0 {!isExpandedMode && isOnMobile
? 'hidden pointer-events-none'
: ''}"
>
{#each SIDEBAR_ACTIONS_ITEMS as item, i (item.tooltip)}
{@const isActive = isItemActive(item)}
{@const isSearchOnMobile = item.icon === Search && isMobile.current}
{@const itemHref = isSearchOnMobile ? ROUTES.SEARCH : item.route}
{@const itemOnClick = item.route
? () => {
onNewChat?.();
goto(item.route!);
}
: isSearchOnMobile
? undefined
: onSearchClick}
{@const itemTransition = {
duration: ICON_STRIP_TRANSITION_DURATION,
delay: !initialized
? ICON_STRIP_TRANSITION_DELAY_MULTIPLIER + i * ICON_STRIP_TRANSITION_DELAY_MULTIPLIER
: 0,
easing: circIn
}}
{item.tooltip}
</div>
{#if showIcons}
<div transition:fade={itemTransition}>
<Button
class="w-full min-w-9 justify-between px-2 backdrop-blur-none! hover:[&>kbd]:opacity-100 {isActive
? 'bg-accent text-accent-foreground'
: ''}"
href={itemHref}
onclick={itemOnClick}
variant="ghost"
size="default"
>
<span class="flex min-w-0 items-center px-0.5 gap-2">
{@render itemIcon(item.icon)}
{#if item.keys}
<KeyboardShortcutInfo keys={item.keys} />
{/if}
</Button>
{:else}
<Button
class="w-full justify-between px-2 backdrop-blur-none! hover:[&>kbd]:opacity-100 {(item.activeRouteId &&
page.route.id === item.activeRouteId) ||
(item.activeRoutePrefix && page.route.id?.startsWith(item.activeRoutePrefix))
? 'bg-accent text-accent-foreground'
: ''}"
href={item.route}
onclick={handleMobileSidebarItemClick}
variant="ghost"
>
<div class="flex items-center gap-2">
{@render itemIcon(item.icon)}
{#if showIcons}
<span
in:fade={{ duration: 150, easing: circIn, delay: 50 }}
out:fade={{ duration: 100 }}
class="min-w-0 truncate">{item.tooltip}</span
>
{/if}
</span>
{item.tooltip}
</div>
{#if item.keys}
<KeyboardShortcutInfo keys={item.keys} />
{/if}
</Button>
{#if item.keys}
<KeyboardShortcutInfo keys={item.keys} />
{/if}
</Button>
</div>
{/if}
{/each}
{/if}
</div>
</div>
{:else}
<div class="{className} flex-col gap-1 hidden md:flex">
{#each SIDEBAR_ACTIONS_ITEMS as item, i (item.tooltip)}
{@const isActive = isItemActive(item)}
{@const isSearchOnMobile = item.icon === Search && isMobile.current}
{@const itemOnClick = item.route
? () => {
onNewChat?.();
goto(item.route!);
}
: isSearchOnMobile
? undefined
: onSearchClick}
{@const itemTransition = {
duration: ICON_STRIP_TRANSITION_DURATION,
delay: !initialized
? ICON_STRIP_TRANSITION_DELAY_MULTIPLIER + i * ICON_STRIP_TRANSITION_DELAY_MULTIPLIER
: 0,
easing: circIn
}}
{#if showIcons}
<div transition:fade={itemTransition}>
<ActionIcon
icon={item.icon}
tooltip={item.tooltip}
tooltipSide={TooltipSide.RIGHT}
size="lg"
iconSize="h-4 w-4"
class="h-9 w-9 rounded-full hover:bg-accent! {isActive
? 'bg-accent text-accent-foreground'
: ''}"
onclick={itemOnClick}
/>
</div>
{/if}
{/each}
</div>
{/if}
@@ -0,0 +1,135 @@
<script lang="ts">
import { Pin } from '@lucide/svelte';
import { buildConversationTree } from '$lib/stores/conversations.svelte';
import SidebarNavigationConversationItem from './SidebarNavigationConversationItem.svelte';
import SidebarNavigationSearchResults from './SidebarNavigationSearchResults.svelte';
interface Props {
class: string;
filteredConversations: DatabaseConversation[];
currentChatId: string | undefined;
isSearchModeActive: boolean;
searchQuery: string;
onSelect: (id: string) => void;
onEdit: (id: string) => void;
onDelete: (id: string) => void;
onStop: (id: string) => void;
}
let {
class: className,
filteredConversations,
currentChatId,
isSearchModeActive,
searchQuery,
onSelect,
onEdit,
onDelete,
onStop
}: Props = $props();
let conversationTree = $derived(buildConversationTree(filteredConversations));
let pinnedConversations = $derived(
conversationTree.filter(({ conversation }) => conversation.pinned)
);
let unpinnedConversations = $derived(
conversationTree.filter(({ conversation }) => !conversation.pinned)
);
const recentEmptyMessage = $derived(
searchQuery.length > 0 ? 'No results found' : 'No conversations yet'
);
</script>
{#if isSearchModeActive}
<SidebarNavigationSearchResults
class={className}
{searchQuery}
{filteredConversations}
{currentChatId}
{onSelect}
{onEdit}
{onDelete}
{onStop}
/>
{:else}
{#if pinnedConversations.length > 0}
<div class="py-2 flex whitespace-nowrap {className}">
<div
class="text-muted-foreground inline-flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium gap-1"
>
<Pin class="h-3.5 w-3.5" />
<span>Pinned</span>
</div>
</div>
<ul class="flex w-full min-w-0 flex-col gap-4 md:gap-1 {className}">
{#each pinnedConversations as { conversation, depth } (conversation.id)}
<li class="group/item relative mb-1 p-0">
<SidebarNavigationConversationItem
conversation={{
id: conversation.id,
name: conversation.name,
lastModified: conversation.lastModified,
currNode: conversation.currNode,
forkedFromConversationId: conversation.forkedFromConversationId,
pinned: conversation.pinned
}}
{depth}
isActive={currentChatId === conversation.id}
{onSelect}
{onEdit}
{onDelete}
{onStop}
/>
</li>
{/each}
</ul>
{/if}
<div class="mt-2 flex min-h-0 flex-1 flex-col gap-4 md:gap-2 whitespace-nowrap {className}">
{#if filteredConversations.length > 0}
<div
class="text-muted-foreground flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium"
>
Recent conversations
</div>
{/if}
<div class="min-h-0 flex-1 md:overflow-y-auto">
<ul class="flex w-full min-w-0 flex-col gap-4 md:gap-1">
{#each unpinnedConversations as { conversation, depth } (conversation.id)}
<li class="group/item relative mb-1 p-0">
<SidebarNavigationConversationItem
conversation={{
id: conversation.id,
name: conversation.name,
lastModified: conversation.lastModified,
currNode: conversation.currNode,
forkedFromConversationId: conversation.forkedFromConversationId,
pinned: conversation.pinned
}}
{depth}
isActive={currentChatId === conversation.id}
{onSelect}
{onEdit}
{onDelete}
{onStop}
/>
</li>
{/each}
{#if unpinnedConversations.length === 0}
<li class="px-2 py-4 text-center">
<p class="mb-4 p-4 text-sm text-muted-foreground">
{recentEmptyMessage}
</p>
</li>
{/if}
</ul>
</div>
</div>
{/if}
@@ -16,4 +16,6 @@
}: Props = $props();
</script>
<SearchInput bind:value {placeholder} {onInput} class="mb-4 {className}" />
<div class="mb-4 px-2 {className}">
<SearchInput bind:value {placeholder} {onInput} />
</div>
@@ -0,0 +1,76 @@
<script lang="ts">
import { buildConversationTree } from '$lib/stores/conversations.svelte';
import SidebarNavigationConversationItem from './SidebarNavigationConversationItem.svelte';
interface Props {
class?: string;
searchQuery: string;
filteredConversations: DatabaseConversation[];
currentChatId: string | undefined;
onSelect: (id: string) => void;
onEdit: (id: string) => void;
onDelete: (id: string) => void;
onStop: (id: string) => void;
}
let {
class: className = '',
searchQuery,
filteredConversations,
currentChatId,
onSelect,
onEdit,
onDelete,
onStop
}: Props = $props();
let tree = $derived(buildConversationTree(filteredConversations));
const hasQuery = $derived(searchQuery.trim().length > 0);
const showHeader = $derived(hasQuery && filteredConversations.length > 0);
const emptyMessage = $derived(hasQuery ? 'No results found' : 'Start typing to see results');
</script>
<div class="flex min-h-0 flex-1 flex-col gap-2 whitespace-nowrap {className}">
{#if showHeader}
<div
class="text-muted-foreground flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium"
>
Search results
</div>
{/if}
<div class="min-h-0 flex-1 overflow-y-auto">
<ul class="flex w-full min-w-0 flex-col gap-1">
{#each tree as { conversation, depth } (conversation.id)}
<li class="group/item relative mb-1 p-0">
<SidebarNavigationConversationItem
conversation={{
id: conversation.id,
name: conversation.name,
lastModified: conversation.lastModified,
currNode: conversation.currNode,
forkedFromConversationId: conversation.forkedFromConversationId,
pinned: conversation.pinned
}}
{depth}
isActive={currentChatId === conversation.id}
{onSelect}
{onEdit}
{onDelete}
{onStop}
/>
</li>
{/each}
{#if tree.length === 0}
<li class="px-2 py-4 text-center">
<p class="mb-4 p-4 text-sm text-muted-foreground">
{emptyMessage}
</p>
</li>
{/if}
</ul>
</div>
</div>
@@ -63,15 +63,6 @@ export { default as DropdownMenuSearchable } from './DropdownMenuSearchable.svel
* ```
*/
export { default as DropdownMenuActions } from './DropdownMenuActions.svelte';
/**
* **DesktopIconStrip** - Fixed icon strip for desktop sidebar
*
* Vertical icon strip shown on desktop when the sidebar is collapsed.
* Contains navigation shortcuts for new chat, search, MCP, import/export, and settings.
*/
export { default as DesktopIconStrip } from './DesktopIconStrip.svelte';
/**
* **SidebarNavigation** - Sidebar with actions menu and conversation list
*
@@ -115,13 +106,6 @@ export { default as DesktopIconStrip } from './DesktopIconStrip.svelte';
*/
export { default as SidebarNavigation } from './SidebarNavigation/SidebarNavigation.svelte';
/**
* Action buttons for sidebar header. Contains new chat button, settings button,
* and delete all conversations button. Manages dialog states for settings and
* delete confirmation.
*/
export { default as SidebarNavigationActions } from './SidebarNavigation/SidebarNavigationActions.svelte';
/**
* Single conversation item in sidebar. Displays conversation title (truncated),
* last message preview, and timestamp. Shows context menu on right-click with
@@ -130,6 +114,58 @@ export { default as SidebarNavigationActions } from './SidebarNavigation/Sidebar
*/
export { default as SidebarNavigationConversationItem } from './SidebarNavigation/SidebarNavigationConversationItem.svelte';
/**
* **SidebarNavigationConversationList** - Grouped conversation list
*
* Pure-presentational list of conversations. Splits items into a Pinned
* section (when not in search mode) and a Recent Conversations / Search
* Results section with the unpinned items. Item selection, edit, delete,
* and stop-generation are delegated to the caller via callbacks.
*
* @example
* ```svelte
* <SidebarNavigationConversationList
* {filteredConversations}
* {currentChatId}
* {isSearchModeActive}
* {searchQuery}
* onSelect={...}
* onEdit={...}
* onDelete={...}
* onStop={...}
* />
* ```
*/
export { default as SidebarNavigationConversationList } from './SidebarNavigation/SidebarNavigationConversationList.svelte';
export { default as SidebarNavigationActions } from './SidebarNavigation/SidebarNavigationActions.svelte';
/**
* **SidebarNavigationSearchResults** - Filtered conversation list for search.
*
* Pure-presentational rendering of the search-mode subtree: "Search results"
* header, the matching items rendered through {@link SidebarNavigationConversationItem},
* and contextual empty-state messages. Used both inline inside
* {@link SidebarNavigationConversationList} (when search mode is active in the
* sidebar) and as the body of the mobile `/search` route.
*
* The caller is expected to provide an already-filtered list via
* `filteredConversations` and a `searchQuery` for the empty-state messages.
*
* @example
* ```svelte
* <SidebarNavigationSearchResults
* {searchQuery}
* {filteredConversations}
* {currentChatId}
* onSelect={...}
* onEdit={...}
* onDelete={...}
* onStop={...}
* />
* ```
*/
export { default as SidebarNavigationSearchResults } from './SidebarNavigation/SidebarNavigationSearchResults.svelte';
/**
* Search input for filtering conversations in sidebar. Filters conversation
* list by title as user types. Shows clear button when query is not empty.
@@ -126,10 +126,7 @@
});
</script>
<div
class="mx-auto flex h-full max-h-[100dvh] w-full flex-col overflow-y-auto md:pl-8"
in:fade={{ duration: 150 }}
>
<div class="mx-auto flex h-full w-full flex-col md:pl-8" in:fade={{ duration: 150 }}>
<div class="flex flex-1 flex-col gap-4 md:flex-row">
<SettingsChatDesktopSidebar
sections={SETTINGS_CHAT_SECTIONS}
@@ -12,11 +12,13 @@
let { sections, isActive, getHref, onSectionChange }: Props = $props();
</script>
<div class="sticky top-0 hidden w-64 flex-col self-start bg-background pt-10 pb-4 md:flex">
<div class="flex items-center gap-2 pb-10">
<Settings class="h-6 w-6" />
<h1 class="text-2xl font-semibold">Settings</h1>
<div class="sticky top-2 hidden w-64 flex-col self-start bg-background py-4 md:flex gap-6">
<div class="flex items-center gap-2 py-2">
<Settings class="h-5 w-5 md:h-6 md:w-6" />
<h1 class="text-xl font-semibold md:text-2xl">Settings</h1>
</div>
<nav class="space-y-1">
{#each sections as section (section.title)}
{#if getHref}
@@ -1,17 +1,19 @@
<script lang="ts">
import { Plus } from '@lucide/svelte';
import { X, Plus } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { toolsStore } from '$lib/stores/tools.svelte';
import { McpServerCard, McpServerCardSkeleton } from '$lib/components/app/mcp';
import { ActionIcon, McpServerCard, McpServerCardSkeleton } from '$lib/components/app';
import { DialogMcpServerAddNew } from '$lib/components/app/dialogs';
import { HealthCheckStatus } from '$lib/enums';
import { ROUTES } from '$lib/constants';
import { fade } from 'svelte/transition';
import { onMount } from 'svelte';
import McpLogo from '../mcp/McpLogo.svelte';
import { browser } from '$app/environment';
import { page } from '$app/state';
import { replaceState } from '$app/navigation';
import { goto, replaceState } from '$app/navigation';
interface Props {
class?: string;
@@ -24,6 +26,24 @@
let initialLoadComplete = $state(false);
let isAddingServer = $state(false);
let previousRouteId = $state<string | null>(null);
$effect(() => {
const currentId = page.route.id;
return () => {
previousRouteId = currentId;
};
});
function handleClose() {
const prevIsMcpServers = previousRouteId === '/mcp-servers';
if (browser && window.history.length > 1 && !prevIsMcpServers) {
history.back();
} else {
goto(ROUTES.START);
}
}
onMount(() => {
if (page.url.searchParams.has('add')) {
isAddingServer = true;
@@ -54,15 +74,26 @@
});
</script>
<div in:fade={{ duration: 150 }} class="h-full max-h-[100dvh] overflow-y-auto">
<div class="flex items-center gap-2 p-4 md:absolute md:top-8 md:left-8 md:px-0 md:py-2">
<McpLogo class="h-5 w-5 md:h-6 md:w-6" />
<h1 class="text-xl font-semibold md:text-2xl">MCP Servers</h1>
<div in:fade={{ duration: 150 }}>
<div class="fixed top-4.5 right-4 z-50 md:hidden">
<ActionIcon icon={X} tooltip="Close" onclick={handleClose} />
</div>
<div class="sticky top-0 z-10 mt-4 flex items-start gap-4 p-4 md:justify-end md:px-8">
<Button variant="outline" size="sm" class="shrink-0" onclick={() => (isAddingServer = true)}>
<div
class="sticky top-0 z-10 mt-4 mb-2 flex items-start gap-4 md:p-4 p-0 px-4 md:justify-between md:px-8"
>
<div class="flex items-center gap-2">
<McpLogo class="h-5 w-5 md:h-6 md:w-6" />
<h1 class="text-lg font-semibold md:text-2xl">MCP Servers</h1>
</div>
<Button
variant="outline"
size="lg"
class="shrink-0 fixed md:static bottom-6 right-6"
onclick={() => (isAddingServer = true)}
>
<Plus class="h-4 w-4" />
Add New Server
@@ -20,7 +20,7 @@
size: {
default: 'h-9 px-4 py-2 has-[>svg]:px-3',
sm: 'h-8 gap-1.5 rounded-md px-3 has-[>svg]:px-2.5',
lg: 'h-10 rounded-md px-6 has-[>svg]:px-4',
lg: 'h-10 rounded-lg px-6 has-[>svg]:px-4',
'icon-lg': 'size-10',
icon: 'size-9',
'icon-sm': 'size-5 rounded-sm'
@@ -1,7 +0,0 @@
export const SIDEBAR_COOKIE_NAME = 'sidebar:state';
export const SIDEBAR_COOKIE_MAX_AGE = 60 * 60 * 24 * 7;
export const SIDEBAR_MIN_WIDTH = '18rem';
export const SIDEBAR_MAX_WIDTH = '32rem';
export const SIDEBAR_WIDTH_MOBILE = '18rem';
export const SIDEBAR_WIDTH_ICON = '3rem';
export const SIDEBAR_KEYBOARD_SHORTCUT = 'b';
@@ -1,79 +0,0 @@
import { isMobile } from '$lib/stores/viewport.svelte.js';
import { getContext, setContext } from 'svelte';
import { SIDEBAR_KEYBOARD_SHORTCUT, SIDEBAR_MIN_WIDTH } from './constants.js';
type Getter<T> = () => T;
export type SidebarStateProps = {
/**
* A getter function that returns the current open state of the sidebar.
* We use a getter function here to support `bind:open` on the `Sidebar.Provider`
* component.
*/
open: Getter<boolean>;
/**
* A function that sets the open state of the sidebar. To support `bind:open`, we need
* a source of truth for changing the open state to ensure it will be synced throughout
* the sub-components and any `bind:` references.
*/
setOpen: (open: boolean) => void;
};
class SidebarState {
readonly props: SidebarStateProps;
open = $derived.by(() => this.props.open());
openMobile = $state(false);
sidebarWidth = $state(SIDEBAR_MIN_WIDTH);
isResizing = $state(false);
setOpen: SidebarStateProps['setOpen'];
state = $derived.by(() => (this.open ? 'expanded' : 'collapsed'));
constructor(props: SidebarStateProps) {
this.setOpen = props.setOpen;
this.props = props;
}
// Convenience getter for checking if the sidebar is mobile
// without this, we would need to use `sidebar.isMobile.current` everywhere
get isMobile() {
return isMobile.current;
}
// Event handler to apply to the `<svelte:window>`
handleShortcutKeydown = (e: KeyboardEvent) => {
if (e.key === SIDEBAR_KEYBOARD_SHORTCUT && (e.metaKey || e.ctrlKey)) {
e.preventDefault();
this.toggle();
}
};
setOpenMobile = (value: boolean) => {
this.openMobile = value;
};
toggle = () => {
this.setOpen(!this.open);
};
}
const SYMBOL_KEY = 'scn-sidebar';
/**
* Instantiates a new `SidebarState` instance and sets it in the context.
*
* @param props The constructor props for the `SidebarState` class.
* @returns The `SidebarState` instance.
*/
export function setSidebar(props: SidebarStateProps): SidebarState {
return setContext(Symbol.for(SYMBOL_KEY), new SidebarState(props));
}
/**
* Retrieves the `SidebarState` instance from the context. This is a class instance,
* so you cannot destructure it.
* @returns The `SidebarState` instance.
*/
export function useSidebar(): SidebarState {
return getContext(Symbol.for(SYMBOL_KEY));
}
@@ -1,75 +0,0 @@
import { useSidebar } from './context.svelte.js';
import Content from './sidebar-content.svelte';
import Footer from './sidebar-footer.svelte';
import GroupAction from './sidebar-group-action.svelte';
import GroupContent from './sidebar-group-content.svelte';
import GroupLabel from './sidebar-group-label.svelte';
import Group from './sidebar-group.svelte';
import Header from './sidebar-header.svelte';
import Input from './sidebar-input.svelte';
import Inset from './sidebar-inset.svelte';
import MenuAction from './sidebar-menu-action.svelte';
import MenuBadge from './sidebar-menu-badge.svelte';
import MenuButton from './sidebar-menu-button.svelte';
import MenuItem from './sidebar-menu-item.svelte';
import MenuSkeleton from './sidebar-menu-skeleton.svelte';
import MenuSubButton from './sidebar-menu-sub-button.svelte';
import MenuSubItem from './sidebar-menu-sub-item.svelte';
import MenuSub from './sidebar-menu-sub.svelte';
import Menu from './sidebar-menu.svelte';
import Provider from './sidebar-provider.svelte';
import Rail from './sidebar-rail.svelte';
import Separator from './sidebar-separator.svelte';
import Trigger from './sidebar-trigger.svelte';
import Root from './sidebar.svelte';
export {
Content,
Footer,
Group,
GroupAction,
GroupContent,
GroupLabel,
Header,
Input,
Inset,
Menu,
MenuAction,
MenuBadge,
MenuButton,
MenuItem,
MenuSkeleton,
MenuSub,
MenuSubButton,
MenuSubItem,
Provider,
Rail,
Root,
Separator,
//
Root as Sidebar,
Content as SidebarContent,
Footer as SidebarFooter,
Group as SidebarGroup,
GroupAction as SidebarGroupAction,
GroupContent as SidebarGroupContent,
GroupLabel as SidebarGroupLabel,
Header as SidebarHeader,
Input as SidebarInput,
Inset as SidebarInset,
Menu as SidebarMenu,
MenuAction as SidebarMenuAction,
MenuBadge as SidebarMenuBadge,
MenuButton as SidebarMenuButton,
MenuItem as SidebarMenuItem,
MenuSkeleton as SidebarMenuSkeleton,
MenuSub as SidebarMenuSub,
MenuSubButton as SidebarMenuSubButton,
MenuSubItem as SidebarMenuSubItem,
Provider as SidebarProvider,
Rail as SidebarRail,
Separator as SidebarSeparator,
Trigger as SidebarTrigger,
Trigger,
useSidebar
};
@@ -1,24 +0,0 @@
<script lang="ts">
import type { HTMLAttributes } from 'svelte/elements';
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
let {
ref = $bindable(null),
class: className,
children,
...restProps
}: WithElementRef<HTMLAttributes<HTMLElement>> = $props();
</script>
<div
bind:this={ref}
data-slot="sidebar-content"
data-sidebar="content"
class={cn(
'flex min-h-0 flex-1 flex-col gap-2 overflow-auto group-data-[collapsible=icon]:overflow-hidden',
className
)}
{...restProps}
>
{@render children?.()}
</div>
@@ -1,21 +0,0 @@
<script lang="ts">
import type { HTMLAttributes } from 'svelte/elements';
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
let {
ref = $bindable(null),
class: className,
children,
...restProps
}: WithElementRef<HTMLAttributes<HTMLElement>> = $props();
</script>
<div
bind:this={ref}
data-slot="sidebar-footer"
data-sidebar="footer"
class={cn('flex flex-col gap-2 p-3', className)}
{...restProps}
>
{@render children?.()}
</div>
@@ -1,36 +0,0 @@
<script lang="ts">
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
import type { Snippet } from 'svelte';
import type { HTMLButtonAttributes } from 'svelte/elements';
let {
ref = $bindable(null),
class: className,
children,
child,
...restProps
}: WithElementRef<HTMLButtonAttributes> & {
child?: Snippet<[{ props: Record<string, unknown> }]>;
} = $props();
const mergedProps = $derived({
class: cn(
'text-sidebar-foreground ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground outline-hidden absolute right-3 top-3.5 flex aspect-square w-5 items-center justify-center rounded-md p-0 transition-transform focus-visible:ring-2 [&>svg]:size-4 [&>svg]:shrink-0',
// Increases the hit area of the button on mobile.
'after:absolute after:-inset-2 md:after:hidden',
'group-data-[collapsible=icon]:hidden',
className
),
'data-slot': 'sidebar-group-action',
'data-sidebar': 'group-action',
...restProps
});
</script>
{#if child}
{@render child({ props: mergedProps })}
{:else}
<button bind:this={ref} {...mergedProps}>
{@render children?.()}
</button>
{/if}
@@ -1,21 +0,0 @@
<script lang="ts">
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
import type { HTMLAttributes } from 'svelte/elements';
let {
ref = $bindable(null),
class: className,
children,
...restProps
}: WithElementRef<HTMLAttributes<HTMLDivElement>> = $props();
</script>
<div
bind:this={ref}
data-slot="sidebar-group-content"
data-sidebar="group-content"
class={cn('w-full text-sm', className)}
{...restProps}
>
{@render children?.()}
</div>
@@ -1,34 +0,0 @@
<script lang="ts">
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
import type { Snippet } from 'svelte';
import type { HTMLAttributes } from 'svelte/elements';
let {
ref = $bindable(null),
children,
child,
class: className,
...restProps
}: WithElementRef<HTMLAttributes<HTMLElement>> & {
child?: Snippet<[{ props: Record<string, unknown> }]>;
} = $props();
const mergedProps = $derived({
class: cn(
'text-sidebar-foreground/70 ring-sidebar-ring outline-hidden flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium transition-[margin,opacity] duration-200 ease-linear focus-visible:ring-2 [&>svg]:size-4 [&>svg]:shrink-0',
'group-data-[collapsible=icon]:-mt-8 group-data-[collapsible=icon]:opacity-0',
className
),
'data-slot': 'sidebar-group-label',
'data-sidebar': 'group-label',
...restProps
});
</script>
{#if child}
{@render child({ props: mergedProps })}
{:else}
<div bind:this={ref} {...mergedProps}>
{@render children?.()}
</div>
{/if}
@@ -1,21 +0,0 @@
<script lang="ts">
import type { HTMLAttributes } from 'svelte/elements';
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
let {
ref = $bindable(null),
class: className,
children,
...restProps
}: WithElementRef<HTMLAttributes<HTMLElement>> = $props();
</script>
<div
bind:this={ref}
data-slot="sidebar-group"
data-sidebar="group"
class={cn('relative flex w-full min-w-0 flex-col p-2', className)}
{...restProps}
>
{@render children?.()}
</div>
@@ -1,21 +0,0 @@
<script lang="ts">
import type { HTMLAttributes } from 'svelte/elements';
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
let {
ref = $bindable(null),
class: className,
children,
...restProps
}: WithElementRef<HTMLAttributes<HTMLElement>> = $props();
</script>
<div
bind:this={ref}
data-slot="sidebar-header"
data-sidebar="header"
class={cn('flex flex-col gap-2 p-2', className)}
{...restProps}
>
{@render children?.()}
</div>
@@ -1,21 +0,0 @@
<script lang="ts">
import type { ComponentProps } from 'svelte';
import { Input } from '$lib/components/ui/input/index.js';
import { cn } from '$lib/components/ui/utils.js';
let {
ref = $bindable(null),
value = $bindable(''),
class: className,
...restProps
}: ComponentProps<typeof Input> = $props();
</script>
<Input
bind:ref
bind:value
data-slot="sidebar-input"
data-sidebar="input"
class={cn('h-8 w-full bg-background shadow-none', className)}
{...restProps}
/>
@@ -1,24 +0,0 @@
<script lang="ts">
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
import type { HTMLAttributes } from 'svelte/elements';
let {
ref = $bindable(null),
class: className,
children,
...restProps
}: WithElementRef<HTMLAttributes<HTMLElement>> = $props();
</script>
<main
bind:this={ref}
data-slot="sidebar-inset"
class={cn(
'relative flex w-full flex-1 flex-col',
'md:peer-data-[variant=inset]:m-2 md:peer-data-[variant=inset]:ml-0 md:peer-data-[variant=inset]:rounded-xl md:peer-data-[variant=inset]:shadow-sm md:peer-data-[variant=inset]:peer-data-[state=collapsed]:ml-2',
className
)}
{...restProps}
>
{@render children?.()}
</main>
@@ -1,43 +0,0 @@
<script lang="ts">
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
import type { Snippet } from 'svelte';
import type { HTMLButtonAttributes } from 'svelte/elements';
let {
ref = $bindable(null),
class: className,
showOnHover = false,
children,
child,
...restProps
}: WithElementRef<HTMLButtonAttributes> & {
child?: Snippet<[{ props: Record<string, unknown> }]>;
showOnHover?: boolean;
} = $props();
const mergedProps = $derived({
class: cn(
'text-sidebar-foreground ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground peer-hover/menu-button:text-sidebar-accent-foreground outline-hidden absolute right-1 top-1.5 flex aspect-square w-5 items-center justify-center rounded-md p-0 transition-transform focus-visible:ring-2 [&>svg]:size-4 [&>svg]:shrink-0',
// Increases the hit area of the button on mobile.
'after:absolute after:-inset-2 md:after:hidden',
'peer-data-[size=sm]/menu-button:top-1',
'peer-data-[size=default]/menu-button:top-1.5',
'peer-data-[size=lg]/menu-button:top-2.5',
'group-data-[collapsible=icon]:hidden',
showOnHover &&
'peer-data-[active=true]/menu-button:text-sidebar-accent-foreground group-focus-within/menu-item:opacity-100 group-hover/menu-item:opacity-100 data-[state=open]:opacity-100 md:opacity-0',
className
),
'data-slot': 'sidebar-menu-action',
'data-sidebar': 'menu-action',
...restProps
});
</script>
{#if child}
{@render child({ props: mergedProps })}
{:else}
<button bind:this={ref} {...mergedProps}>
{@render children?.()}
</button>
{/if}
@@ -1,29 +0,0 @@
<script lang="ts">
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
import type { HTMLAttributes } from 'svelte/elements';
let {
ref = $bindable(null),
class: className,
children,
...restProps
}: WithElementRef<HTMLAttributes<HTMLElement>> = $props();
</script>
<div
bind:this={ref}
data-slot="sidebar-menu-badge"
data-sidebar="menu-badge"
class={cn(
'pointer-events-none absolute right-1 flex h-5 min-w-5 items-center justify-center rounded-md px-1 text-xs font-medium text-sidebar-foreground tabular-nums select-none',
'peer-hover/menu-button:text-sidebar-accent-foreground peer-data-[active=true]/menu-button:text-sidebar-accent-foreground',
'peer-data-[size=sm]/menu-button:top-1',
'peer-data-[size=default]/menu-button:top-1.5',
'peer-data-[size=lg]/menu-button:top-2.5',
'group-data-[collapsible=icon]:hidden',
className
)}
{...restProps}
>
{@render children?.()}
</div>
@@ -1,106 +0,0 @@
<script lang="ts" module>
import { tv, type VariantProps } from 'tailwind-variants';
export const sidebarMenuButtonVariants = tv({
base: 'peer/menu-button outline-hidden ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground active:bg-sidebar-accent active:text-sidebar-accent-foreground group-has-data-[sidebar=menu-action]/menu-item:pr-8 data-[active=true]:bg-sidebar-accent data-[active=true]:text-sidebar-accent-foreground data-[state=open]:hover:bg-sidebar-accent data-[state=open]:hover:text-sidebar-accent-foreground group-data-[collapsible=icon]:size-8! group-data-[collapsible=icon]:p-2! flex w-full items-center gap-2 overflow-hidden rounded-md py-2 px-1 text-left text-sm transition-[width,height,padding] focus-visible:ring-2 disabled:pointer-events-none disabled:opacity-50 aria-disabled:pointer-events-none aria-disabled:opacity-50 data-[active=true]:font-medium [&>span:last-child]:truncate [&>svg]:size-4 [&>svg]:shrink-0',
variants: {
variant: {
default: 'hover:bg-sidebar-accent hover:text-sidebar-accent-foreground',
outline:
'bg-background hover:bg-sidebar-accent hover:text-sidebar-accent-foreground shadow-[0_0_0_1px_var(--sidebar-border)] hover:shadow-[0_0_0_1px_var(--sidebar-accent)]'
},
size: {
default: 'h-8 text-sm',
sm: 'h-7 text-xs',
lg: 'group-data-[collapsible=icon]:p-0! h-12 text-sm'
}
},
defaultVariants: {
variant: 'default',
size: 'default'
}
});
export type SidebarMenuButtonVariant = VariantProps<typeof sidebarMenuButtonVariants>['variant'];
export type SidebarMenuButtonSize = VariantProps<typeof sidebarMenuButtonVariants>['size'];
</script>
<script lang="ts">
import * as Tooltip from '$lib/components/ui/tooltip/index.js';
import {
cn,
type WithElementRef,
type WithoutChildrenOrChild
} from '$lib/components/ui/utils.js';
import { mergeProps } from 'bits-ui';
import type { ComponentProps, Snippet } from 'svelte';
import type { HTMLAttributes } from 'svelte/elements';
import { useSidebar } from './context.svelte.js';
let {
ref = $bindable(null),
class: className,
children,
child,
variant = 'default',
size = 'default',
isActive = false,
tooltipContent,
tooltipContentProps,
...restProps
}: WithElementRef<HTMLAttributes<HTMLButtonElement>, HTMLButtonElement> & {
isActive?: boolean;
variant?: SidebarMenuButtonVariant;
size?: SidebarMenuButtonSize;
tooltipContent?: Snippet | string;
tooltipContentProps?: WithoutChildrenOrChild<ComponentProps<typeof Tooltip.Content>>;
child?: Snippet<[{ props: Record<string, unknown> }]>;
} = $props();
const sidebar = useSidebar();
const buttonProps = $derived({
class: cn(sidebarMenuButtonVariants({ variant, size }), className),
'data-slot': 'sidebar-menu-button',
'data-sidebar': 'menu-button',
'data-size': size,
'data-active': isActive,
...restProps
});
</script>
{#snippet Button({ props }: { props?: Record<string, unknown> })}
{@const mergedProps = mergeProps(buttonProps, props)}
{#if child}
{@render child({ props: mergedProps })}
{:else}
<button bind:this={ref} {...mergedProps}>
{@render children?.()}
</button>
{/if}
{/snippet}
{#if !tooltipContent}
{@render Button({})}
{:else}
<Tooltip.Root>
<Tooltip.Trigger>
{#snippet child({ props })}
{@render Button({ props })}
{/snippet}
</Tooltip.Trigger>
<Tooltip.Content
side="right"
align="center"
hidden={sidebar.state !== 'collapsed' || sidebar.isMobile}
{...tooltipContentProps}
>
{#if typeof tooltipContent === 'string'}
{tooltipContent}
{:else if tooltipContent}
{@render tooltipContent()}
{/if}
</Tooltip.Content>
</Tooltip.Root>
{/if}
@@ -1,21 +0,0 @@
<script lang="ts">
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
import type { HTMLAttributes } from 'svelte/elements';
let {
ref = $bindable(null),
class: className,
children,
...restProps
}: WithElementRef<HTMLAttributes<HTMLLIElement>, HTMLLIElement> = $props();
</script>
<li
bind:this={ref}
data-slot="sidebar-menu-item"
data-sidebar="menu-item"
class={cn('group/menu-item relative', className)}
{...restProps}
>
{@render children?.()}
</li>
@@ -1,36 +0,0 @@
<script lang="ts">
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
import { Skeleton } from '$lib/components/ui/skeleton/index.js';
import type { HTMLAttributes } from 'svelte/elements';
let {
ref = $bindable(null),
class: className,
showIcon = false,
children,
...restProps
}: WithElementRef<HTMLAttributes<HTMLElement>> & {
showIcon?: boolean;
} = $props();
// Random width between 50% and 90%
const width = `${Math.floor(Math.random() * 40) + 50}%`;
</script>
<div
bind:this={ref}
data-slot="sidebar-menu-skeleton"
data-sidebar="menu-skeleton"
class={cn('flex h-8 items-center gap-2 rounded-md px-2', className)}
{...restProps}
>
{#if showIcon}
<Skeleton class="size-4 rounded-md" data-sidebar="menu-skeleton-icon" />
{/if}
<Skeleton
class="h-4 max-w-(--skeleton-width) flex-1"
data-sidebar="menu-skeleton-text"
style="--skeleton-width: {width};"
/>
{@render children?.()}
</div>
@@ -1,43 +0,0 @@
<script lang="ts">
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
import type { Snippet } from 'svelte';
import type { HTMLAnchorAttributes } from 'svelte/elements';
let {
ref = $bindable(null),
children,
child,
class: className,
size = 'md',
isActive = false,
...restProps
}: WithElementRef<HTMLAnchorAttributes> & {
child?: Snippet<[{ props: Record<string, unknown> }]>;
size?: 'sm' | 'md';
isActive?: boolean;
} = $props();
const mergedProps = $derived({
class: cn(
'text-sidebar-foreground ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground active:bg-sidebar-accent active:text-sidebar-accent-foreground [&>svg]:text-sidebar-accent-foreground outline-hidden flex h-7 min-w-0 -translate-x-px items-center gap-2 overflow-hidden rounded-md px-2 focus-visible:ring-2 disabled:pointer-events-none disabled:opacity-50 aria-disabled:pointer-events-none aria-disabled:opacity-50 [&>span:last-child]:truncate [&>svg]:size-4 [&>svg]:shrink-0',
'data-[active=true]:bg-sidebar-accent data-[active=true]:text-sidebar-accent-foreground',
size === 'sm' && 'text-xs',
size === 'md' && 'text-sm',
'group-data-[collapsible=icon]:hidden',
className
),
'data-slot': 'sidebar-menu-sub-button',
'data-sidebar': 'menu-sub-button',
'data-size': size,
'data-active': isActive,
...restProps
});
</script>
{#if child}
{@render child({ props: mergedProps })}
{:else}
<a bind:this={ref} {...mergedProps}>
{@render children?.()}
</a>
{/if}
@@ -1,21 +0,0 @@
<script lang="ts">
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
import type { HTMLAttributes } from 'svelte/elements';
let {
ref = $bindable(null),
children,
class: className,
...restProps
}: WithElementRef<HTMLAttributes<HTMLLIElement>> = $props();
</script>
<li
bind:this={ref}
data-slot="sidebar-menu-sub-item"
data-sidebar="menu-sub-item"
class={cn('group/menu-sub-item relative', className)}
{...restProps}
>
{@render children?.()}
</li>

Some files were not shown because too many files have changed in this diff Show More