mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-24 14:47:39 +02:00
Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1191758c5d | |||
| 00139b660b | |||
| ef9c13d4c2 | |||
| 88636e178f | |||
| ac4105d68b | |||
| be4a6a63eb | |||
| 72a9269172 | |||
| 92e854ab83 | |||
| c5606364b2 | |||
| 0eb874d374 | |||
| 75ad0b23ed | |||
| c926ad0985 | |||
| a3900a6694 | |||
| 7c908502ea | |||
| 035cd8f9a6 | |||
| 73618f27a8 | |||
| 23ee8797e1 |
+1
-1
@@ -10,7 +10,7 @@
|
||||
# ggml-org/ggml-rpc : rgerganov
|
||||
# ggml-org/ggml-sycl : arthw
|
||||
# ggml-org/ggml-vulkan : 0cc4m, jeffbolznv
|
||||
# ggml-org/ggml-webgpu : reeselevine
|
||||
# ggml-org/ggml-webgpu : reeselevine, yomaytk
|
||||
# ggml-org/ggml-zdnn : taronaeo
|
||||
# ggml-org/llama-common : ggerganov, aldehir, angt, danbev, ngxson, pwilkin
|
||||
# ggml-org/llama-mtmd : ngxson
|
||||
|
||||
@@ -142,7 +142,9 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
|
||||
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
|
||||
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
|
||||
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
|
||||
- [x] [Liquid LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2)
|
||||
- [x] [Liquid LFM2.5 models](https://huggingface.co/collections/LiquidAI/lfm25)
|
||||
- [x] [Liquid Nanos](https://huggingface.co/collections/LiquidAI/liquid-nanos)
|
||||
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
|
||||
- [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86)
|
||||
- [x] [Mellum models](https://huggingface.co/JetBrains/models?search=mellum)
|
||||
|
||||
+7
-4
@@ -301,6 +301,8 @@ static handle_model_result common_params_handle_model(struct common_params_model
|
||||
const common_download_opts & opts) {
|
||||
handle_model_result result;
|
||||
|
||||
// TODO @ngxson : refactor this into a new common_model_download_context
|
||||
|
||||
if (!model.docker_repo.empty()) {
|
||||
model.path = common_docker_resolve_model(model.docker_repo);
|
||||
} else if (!model.hf_repo.empty()) {
|
||||
@@ -396,7 +398,7 @@ static bool parse_bool_value(const std::string & value) {
|
||||
// CLI argument parsing functions
|
||||
//
|
||||
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex, common_download_callback * callback) {
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex, const common_params_handle_models_params & handle_params) {
|
||||
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
|
||||
params.speculative.types.end(),
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
|
||||
@@ -407,9 +409,10 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex,
|
||||
opts.skip_download = params.skip_download;
|
||||
opts.download_mtp = spec_type_draft_mtp;
|
||||
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
|
||||
opts.preset_only = handle_params.preset_only;
|
||||
|
||||
if (callback) {
|
||||
opts.callback = callback;
|
||||
if (handle_params.callback) {
|
||||
opts.callback = handle_params.callback;
|
||||
}
|
||||
|
||||
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
|
||||
@@ -596,7 +599,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
|
||||
if (!skip_model_download) {
|
||||
// handle model and download
|
||||
common_params_handle_models(params, ctx_arg.ex);
|
||||
common_params_handle_models(params, ctx_arg.ex, {});
|
||||
|
||||
// model is required (except for server)
|
||||
// TODO @ngxson : maybe show a list of available models in CLI in this case
|
||||
|
||||
+6
-1
@@ -130,6 +130,11 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
|
||||
// see: https://github.com/ggml-org/llama.cpp/issues/18163
|
||||
void common_params_add_preset_options(std::vector<common_arg> & args);
|
||||
|
||||
struct common_params_handle_models_params {
|
||||
common_download_callback * callback = nullptr;
|
||||
bool preset_only = false; // if true, only check & download remote preset (for router mode)
|
||||
};
|
||||
|
||||
// populate model paths (main model, mmproj, etc) from -hf if necessary
|
||||
// return true if the model is ready to use
|
||||
// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc)
|
||||
@@ -137,7 +142,7 @@ void common_params_add_preset_options(std::vector<common_arg> & args);
|
||||
bool common_params_handle_models(
|
||||
common_params & params,
|
||||
llama_example curr_ex,
|
||||
common_download_callback * callback = nullptr);
|
||||
const common_params_handle_models_params & handle_params);
|
||||
|
||||
// initialize argument parser context - used by test-arg-parser and preset
|
||||
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
||||
|
||||
+103
-53
@@ -90,41 +90,93 @@ std::string common_chat_msg::render_content(const std::string & delimiter) const
|
||||
return text;
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims) {
|
||||
if (delims.empty() || prompt.empty()) {
|
||||
return {};
|
||||
common_chat_role common_chat_role_from_string(const std::string & role) {
|
||||
if (role == "system") { return COMMON_CHAT_ROLE_SYSTEM; }
|
||||
if (role == "assistant") { return COMMON_CHAT_ROLE_ASSISTANT; }
|
||||
if (role == "user") { return COMMON_CHAT_ROLE_USER; }
|
||||
if (role == "tool") { return COMMON_CHAT_ROLE_TOOL; }
|
||||
return COMMON_CHAT_ROLE_UNKNOWN;
|
||||
}
|
||||
|
||||
const char * common_chat_role_to_string(common_chat_role role) {
|
||||
switch (role) {
|
||||
case COMMON_CHAT_ROLE_SYSTEM: return "system";
|
||||
case COMMON_CHAT_ROLE_ASSISTANT: return "assistant";
|
||||
case COMMON_CHAT_ROLE_USER: return "user";
|
||||
case COMMON_CHAT_ROLE_TOOL: return "tool";
|
||||
case COMMON_CHAT_ROLE_UNKNOWN: return "";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
json common_chat_msg_delimiters::to_json() const {
|
||||
json result = json::array();
|
||||
for (const auto & d : delimiters) {
|
||||
result.push_back({
|
||||
{ "role", common_chat_role_to_string(d.role) },
|
||||
{ "delimiter", d.delimiter },
|
||||
});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const json & delimiters) {
|
||||
common_chat_msg_delimiters result;
|
||||
|
||||
if (!delimiters.is_array()) {
|
||||
return result;
|
||||
}
|
||||
|
||||
auto parser = build_peg_parser([&](common_peg_parser_builder & p) {
|
||||
std::vector<std::string> all_delims;
|
||||
std::vector<common_peg_parser> tagged_messages;
|
||||
|
||||
all_delims.reserve(delims.size());
|
||||
tagged_messages.reserve(delims.size());
|
||||
for (const auto & d : delims) {
|
||||
all_delims.push_back(d.delimiter);
|
||||
result.delimiters.reserve(delimiters.size());
|
||||
for (const auto & d : delimiters) {
|
||||
if (!d.is_object()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto any_delim = p.until_one_of(all_delims);
|
||||
for (const auto & d : delims) {
|
||||
tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim));
|
||||
}
|
||||
|
||||
return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end();
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx(prompt);
|
||||
const auto result = parser.parse(ctx);
|
||||
if (!result.success()) {
|
||||
return {};
|
||||
result.delimiters.push_back({
|
||||
common_chat_role_from_string(d.value("role", std::string())),
|
||||
d.value("delimiter", std::string()),
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_span> spans;
|
||||
ctx.ast.visit(result, [&](const common_peg_ast_node & node) {
|
||||
if (!node.tag.empty()) {
|
||||
spans.push_back({ node.tag, node.start, node.end - node.start });
|
||||
return result;
|
||||
}
|
||||
|
||||
void common_chat_msg_delimiters::tokenize(const llama_vocab * vocab) {
|
||||
for (auto & d : delimiters) {
|
||||
d.tokens = common_tokenize(vocab, d.delimiter, false, true);
|
||||
}
|
||||
}
|
||||
|
||||
common_chat_msg_spans common_chat_msg_delimiters::split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips) const {
|
||||
std::vector<std::pair<common_chat_role, size_t>> matches;
|
||||
|
||||
auto skip = skips.begin();
|
||||
for (size_t i = 0; i < tokens.size();) {
|
||||
if (skip != skips.end() && i == skip->first) {
|
||||
i += skip->second;
|
||||
++skip;
|
||||
continue;
|
||||
}
|
||||
});
|
||||
for (const auto & d : delimiters) {
|
||||
if (i + d.tokens.size() > tokens.size()) {
|
||||
continue;
|
||||
}
|
||||
if (std::equal(d.tokens.begin(), d.tokens.end(), tokens.begin() + i)) {
|
||||
matches.emplace_back(d.role, i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
matches.emplace_back(COMMON_CHAT_ROLE_UNKNOWN, tokens.size());
|
||||
|
||||
common_chat_msg_spans spans;
|
||||
for (size_t i = 0; i + 1 < matches.size(); i++) {
|
||||
const auto & curr = matches[i];
|
||||
const auto & next = matches[i + 1];
|
||||
spans.add(curr.first, curr.second, next.second - curr.second);
|
||||
}
|
||||
|
||||
return spans;
|
||||
}
|
||||
@@ -1081,13 +1133,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
|
||||
data.prompt = prompt;
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
data.message_spans = common_chat_split_by_role(prompt, {
|
||||
{ "assistant", "<|start|>assistant" },
|
||||
{ "user", "<|start|>user" },
|
||||
{ "system", "<|start|>developer" },
|
||||
{ "system", "<|start|>system" },
|
||||
{ "tool", "<|start|>functions" },
|
||||
});
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "<|start|>assistant" },
|
||||
{ COMMON_CHAT_ROLE_USER, "<|start|>user" },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>developer" },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>system" },
|
||||
{ COMMON_CHAT_ROLE_TOOL, "<|start|>functions" },
|
||||
};
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
@@ -1228,10 +1280,10 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
data.message_spans = common_chat_split_by_role(data.prompt, {
|
||||
{ "user", "<|turn>user\n" },
|
||||
{ "assistant", "<|turn>model\n" },
|
||||
});
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_USER, "<|turn>user" },
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "<|turn>model" },
|
||||
};
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
|
||||
data.supports_thinking = true;
|
||||
@@ -2030,15 +2082,15 @@ static common_chat_params common_chat_params_init_cohere2moe(const common_chat_t
|
||||
RESULT_START, RESULT_END,
|
||||
};
|
||||
|
||||
// Split the rendered prompt into per-role message spans. Tool results are rendered with the
|
||||
// Declare per-role message delimiters. Tool results are rendered with the
|
||||
// system token followed by <|START_TOOL_RESULT|>, so the "tool" delimiter must be listed before
|
||||
// the plain "system" one (it is a strict superset, and the role split tries delimiters in order).
|
||||
data.message_spans = common_chat_split_by_role(data.prompt, {
|
||||
{ "assistant", GEN_PREFIX },
|
||||
{ "user", TURN_START + USER },
|
||||
{ "tool", TURN_START + SYSTEM + RESULT_START },
|
||||
{ "system", TURN_START + SYSTEM },
|
||||
});
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, GEN_PREFIX },
|
||||
{ COMMON_CHAT_ROLE_USER, TURN_START + USER },
|
||||
{ COMMON_CHAT_ROLE_TOOL, TURN_START + SYSTEM + RESULT_START },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, TURN_START + SYSTEM },
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
@@ -2526,17 +2578,15 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
autoparser.analyze_template(tmpl);
|
||||
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
|
||||
|
||||
std::vector<common_chat_msg_delimiter> delimiters;
|
||||
common_chat_msg_delimiters delimiters;
|
||||
if (!autoparser.assistant_start.empty()) {
|
||||
delimiters.push_back({ "assistant", autoparser.assistant_start });
|
||||
delimiters.add(COMMON_CHAT_ROLE_ASSISTANT, autoparser.assistant_start);
|
||||
}
|
||||
if (!autoparser.user_start.empty()) {
|
||||
delimiters.push_back({ "user", autoparser.user_start });
|
||||
delimiters.add(COMMON_CHAT_ROLE_USER, autoparser.user_start);
|
||||
}
|
||||
|
||||
if (!delimiters.empty()) {
|
||||
auto_params.message_spans = common_chat_split_by_role(auto_params.prompt, delimiters);
|
||||
}
|
||||
auto_params.message_delimiters = std::move(delimiters);
|
||||
|
||||
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
|
||||
if (auto_params.supports_thinking) {
|
||||
|
||||
+65
-6
@@ -143,15 +143,75 @@ struct common_chat_msg_diff {
|
||||
}
|
||||
};
|
||||
|
||||
enum common_chat_role {
|
||||
COMMON_CHAT_ROLE_UNKNOWN,
|
||||
COMMON_CHAT_ROLE_SYSTEM,
|
||||
COMMON_CHAT_ROLE_ASSISTANT,
|
||||
COMMON_CHAT_ROLE_USER,
|
||||
COMMON_CHAT_ROLE_TOOL
|
||||
};
|
||||
|
||||
common_chat_role common_chat_role_from_string(const std::string & role);
|
||||
const char * common_chat_role_to_string(common_chat_role role);
|
||||
|
||||
struct common_chat_msg_span {
|
||||
std::string role;
|
||||
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
|
||||
std::size_t pos = 0;
|
||||
std::size_t len = 0;
|
||||
|
||||
bool valid() const {
|
||||
return role != COMMON_CHAT_ROLE_UNKNOWN;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_spans {
|
||||
std::vector<common_chat_msg_span> spans;
|
||||
|
||||
void add(common_chat_role role, size_t pos, size_t len) {
|
||||
spans.push_back({ role, pos, len });
|
||||
}
|
||||
|
||||
bool is_user_start(int32_t pos) const {
|
||||
for (auto it = spans.begin(); it != spans.end(); ++it) {
|
||||
if (it->role == COMMON_CHAT_ROLE_USER && pos == (int32_t) it->pos) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t last_user_message_pos() const {
|
||||
for (auto it = spans.rbegin(); it != spans.rend(); ++it) {
|
||||
if (it->role == COMMON_CHAT_ROLE_USER) {
|
||||
return (int32_t) it->pos;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_delimiter {
|
||||
std::string role;
|
||||
std::string delimiter;
|
||||
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
|
||||
std::string delimiter;
|
||||
llama_tokens tokens = {};
|
||||
};
|
||||
|
||||
struct common_chat_msg_delimiters {
|
||||
std::vector<common_chat_msg_delimiter> delimiters;
|
||||
|
||||
common_chat_msg_delimiters() = default;
|
||||
common_chat_msg_delimiters(std::initializer_list<common_chat_msg_delimiter> delims) : delimiters(delims) {}
|
||||
|
||||
void add(common_chat_role role, const std::string & delimiter) {
|
||||
delimiters.push_back({ role, delimiter });
|
||||
}
|
||||
|
||||
void tokenize(const llama_vocab * vocab);
|
||||
|
||||
// split tokens into message spans. skips maps a start index to a length of a region to jump over without matching
|
||||
common_chat_msg_spans split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips = {}) const;
|
||||
|
||||
nlohmann::ordered_json to_json() const;
|
||||
};
|
||||
|
||||
struct common_chat_tool {
|
||||
@@ -219,7 +279,7 @@ struct common_chat_params {
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
std::string parser;
|
||||
std::vector<common_chat_msg_span> message_spans;
|
||||
common_chat_msg_delimiters message_delimiters;
|
||||
};
|
||||
|
||||
// per-message parsing syntax
|
||||
@@ -325,5 +385,4 @@ struct common_chat_prompt_preset {
|
||||
|
||||
common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates);
|
||||
|
||||
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims);
|
||||
|
||||
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const nlohmann::ordered_json & delimiters);
|
||||
|
||||
+1
-1
@@ -609,7 +609,7 @@ struct common_params {
|
||||
bool cache_prompt = true; // whether to enable prompt caching
|
||||
bool cache_idle_slots = true; // save and clear idle slots upon starting a new task
|
||||
int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
|
||||
int32_t checkpoint_min_step = 256; // minimum spacing between context checkpoints
|
||||
int32_t checkpoint_min_step = 8192; // minimum spacing between context checkpoints
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
|
||||
std::string hostname = "127.0.0.1";
|
||||
|
||||
+3
-1
@@ -799,6 +799,7 @@ common_download_model_result common_download_model(const common_params_model &
|
||||
|
||||
bool download_mmproj = opts.download_mmproj;
|
||||
bool download_mtp = opts.download_mtp;
|
||||
bool preset_only = opts.preset_only;
|
||||
bool is_hf = !model.hf_repo.empty();
|
||||
|
||||
if (is_hf) {
|
||||
@@ -806,7 +807,8 @@ common_download_model_result common_download_model(const common_params_model &
|
||||
if (!hf.preset.path.empty()) {
|
||||
// if preset.ini exists, only download that file alone
|
||||
tasks.push_back({hf.preset.url, hf.preset.local_path});
|
||||
} else {
|
||||
} else if (!preset_only) {
|
||||
// only add other files if we're NOT in preset-only mode (normal run, non-router)
|
||||
for (const auto & f : hf.model_files) {
|
||||
tasks.push_back({f.url, f.local_path});
|
||||
}
|
||||
|
||||
@@ -55,6 +55,7 @@ struct common_download_opts {
|
||||
bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid
|
||||
bool download_mmproj = false;
|
||||
bool download_mtp = false;
|
||||
bool preset_only = false; // if true, only check & download remote preset (for router mode)
|
||||
common_download_callback * callback = nullptr;
|
||||
};
|
||||
|
||||
|
||||
@@ -96,6 +96,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"GraniteMoeHybridForCausalLM": "granite",
|
||||
"GraniteMoeSharedForCausalLM": "granite",
|
||||
"GraniteSpeechForConditionalGeneration": "granite",
|
||||
"GraniteSpeechPlusForConditionalGeneration": "granite",
|
||||
"Grok1ForCausalLM": "grok",
|
||||
"GrokForCausalLM": "grok",
|
||||
"GroveMoeForCausalLM": "grovemoe",
|
||||
@@ -123,6 +124,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"LLaDAModelLM": "llada",
|
||||
"LLaMAForCausalLM": "llama",
|
||||
"Lfm25AudioTokenizer": "lfm2",
|
||||
"Lfm2BidirectionalModel": "lfm2",
|
||||
"Lfm2ForCausalLM": "lfm2",
|
||||
"Lfm2Model": "lfm2",
|
||||
"Lfm2MoeForCausalLM": "lfm2",
|
||||
@@ -261,6 +263,7 @@ MMPROJ_MODEL_MAP: dict[str, str] = {
|
||||
"GlmasrModel": "ultravox",
|
||||
"Granite4VisionForConditionalGeneration": "granite",
|
||||
"GraniteSpeechForConditionalGeneration": "granite",
|
||||
"GraniteSpeechPlusForConditionalGeneration": "granite",
|
||||
"HunYuanVLForConditionalGeneration": "hunyuan",
|
||||
"Idefics3ForConditionalGeneration": "smolvlm",
|
||||
"InternVisionModel": "internvl",
|
||||
|
||||
@@ -348,6 +348,34 @@ class GraniteSpeechMmprojModel(MmprojModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("GraniteSpeechPlusForConditionalGeneration")
|
||||
class GraniteSpeechPlusMmprojModel(GraniteSpeechMmprojModel):
|
||||
"""Conversion for GraniteSpeechPlus - extends GraniteSpeech with feature layer concatenation"""
|
||||
has_vision_encoder = False
|
||||
has_audio_encoder = True
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
assert self.hparams_audio is not None
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# Add feature_layer if present in encoder config
|
||||
if feature_layers := self.hparams_audio.get("cat_hidden_layers"):
|
||||
self.gguf_writer.add_audio_feature_layers(feature_layers)
|
||||
logger.info(f"gguf: audio feature_layers = {feature_layers}")
|
||||
|
||||
# Validate projector dimension matches concatenated encoder output
|
||||
hidden_dim = self.hparams_audio["hidden_dim"]
|
||||
expected_dim = hidden_dim * (len(feature_layers) + 1)
|
||||
projector_dim = self.global_config["projector_config"]["encoder_hidden_size"]
|
||||
|
||||
if projector_dim != expected_dim:
|
||||
raise ValueError(
|
||||
f"Projector encoder_hidden_size ({projector_dim}) does not match "
|
||||
f"expected concatenated dimension ({expected_dim}). "
|
||||
f"Expected: hidden_dim ({hidden_dim}) * (len(feature_layers) + 1) = {expected_dim}"
|
||||
)
|
||||
|
||||
|
||||
@ModelBase.register("Granite4VisionForConditionalGeneration")
|
||||
class Granite4VisionMmprojModel(MmprojModel):
|
||||
has_vision_encoder = True
|
||||
|
||||
+10
-3
@@ -64,11 +64,17 @@ class LFM2Model(TextModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Lfm2Model")
|
||||
@ModelBase.register("Lfm2Model", "Lfm2BidirectionalModel")
|
||||
class LFM2ColBertModel(LFM2Model):
|
||||
model_arch = gguf.MODEL_ARCH.LFM2
|
||||
dense_tensor_name = "dense_2"
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if self.hf_arch == "Lfm2BidirectionalModel":
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
self._try_set_pooling_type()
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if not name.startswith(self.dense_tensor_name):
|
||||
name = "model." + name
|
||||
@@ -76,10 +82,11 @@ class LFM2ColBertModel(LFM2Model):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
# dense tensor is stored in a separate safetensors file
|
||||
# optional dense tensor is stored in a separate safetensors file
|
||||
from safetensors.torch import load_file
|
||||
tensors_file = self.dir_model / "1_Dense" / "model.safetensors"
|
||||
assert tensors_file.is_file()
|
||||
if not tensors_file.is_file():
|
||||
return
|
||||
tensor = load_file(tensors_file)["linear.weight"]
|
||||
self.gguf_writer.add_embedding_length_out(tensor.shape[0])
|
||||
yield f"{self.dense_tensor_name}.weight", tensor.clone()
|
||||
|
||||
+50
-23
@@ -3688,8 +3688,6 @@ static void ggml_compute_forward_norm_f32(
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -3703,25 +3701,49 @@ static void ggml_compute_forward_norm_f32(
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
||||
char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
float sum = 0.0;
|
||||
ggml_vec_sum_f32(ne00, &sum, x);
|
||||
float mean = sum/ne00;
|
||||
if (nb00 == sizeof(float) && nb0 == sizeof(float)) {
|
||||
const float * xf = (const float *) x;
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
float variance = 0;
|
||||
float sum = 0.0;
|
||||
ggml_vec_sum_f32(ne00, &sum, xf);
|
||||
float mean = sum/ne00;
|
||||
|
||||
float * yf = (float *) y;
|
||||
float variance = 0;
|
||||
|
||||
#ifdef GGML_USE_ACCELERATE
|
||||
mean = -mean;
|
||||
vDSP_vsadd(x, 1, &mean, y, 1, ne00);
|
||||
vDSP_measqv(y, 1, &variance, ne00);
|
||||
mean = -mean;
|
||||
vDSP_vsadd(xf, 1, &mean, yf, 1, ne00);
|
||||
vDSP_measqv(yf, 1, &variance, ne00);
|
||||
#else
|
||||
variance = ggml_vec_cvar_f32(ne00, y, x, mean);
|
||||
variance = ggml_vec_cvar_f32(ne00, yf, xf, mean);
|
||||
#endif //GGML_USE_ACCELERATE
|
||||
|
||||
const float scale = 1.0f/sqrtf(variance + eps);
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
const float scale = 1.0f/sqrtf(variance + eps);
|
||||
ggml_vec_scale_f32(ne00, yf, scale);
|
||||
} else {
|
||||
float sum = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
sum += *(const float *) (x + i00*nb00);
|
||||
}
|
||||
const float mean = sum/ne00;
|
||||
|
||||
float variance = 0.0f;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
const float v = *(const float *) (x + i00*nb00) - mean;
|
||||
*(float *) (y + i00*nb0) = v;
|
||||
variance += v * v;
|
||||
}
|
||||
variance /= ne00;
|
||||
|
||||
const float scale = 1.0f/sqrtf(variance + eps);
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
*(float *) (y + i00*nb0) *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4142,8 +4164,6 @@ static void ggml_compute_forward_l2_norm_f32(
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -4158,20 +4178,27 @@ static void ggml_compute_forward_l2_norm_f32(
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
||||
|
||||
ggml_float sum = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
sum += (ggml_float)(x[i00] * x[i00]);
|
||||
const float xi = *(const float *) (x + i00*nb00);
|
||||
sum += (ggml_float)(xi * xi);
|
||||
}
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
memcpy(y, x, ne00 * sizeof(float));
|
||||
|
||||
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
|
||||
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
if (nb00 == sizeof(float) && nb0 == sizeof(float)) {
|
||||
memcpy(y, x, ne00 * sizeof(float));
|
||||
ggml_vec_scale_f32(ne00, (float *) y, scale);
|
||||
} else {
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
const float xi = *(const float *) (x + i00*nb00);
|
||||
*(float *) (y + i00*nb0) = xi * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5334,7 +5334,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return true;
|
||||
return ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
break;
|
||||
|
||||
@@ -174,7 +174,7 @@ __kernel void kernel_gemv_noshuffle_q8_0_f32(
|
||||
regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
|
||||
regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
|
||||
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB);
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, convert_float(regS), regB);
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #wave=4
|
||||
|
||||
@@ -108,6 +108,9 @@ if (Vulkan_FOUND)
|
||||
|
||||
if (GGML_VULKAN_CHECK_RESULTS)
|
||||
add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)
|
||||
# the result-checking path computes a CPU reference graph via
|
||||
# ggml_graph_compute_with_ctx(), which is defined in ggml-cpu
|
||||
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
|
||||
endif()
|
||||
|
||||
if (GGML_VULKAN_DEBUG)
|
||||
@@ -129,6 +132,8 @@ if (Vulkan_FOUND)
|
||||
|
||||
if (GGML_VULKAN_RUN_TESTS)
|
||||
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
|
||||
# the test path also calls ggml_graph_compute_with_ctx() (ggml-cpu)
|
||||
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
|
||||
endif()
|
||||
|
||||
# Set up toolchain for host compilation whether cross-compiling or not
|
||||
|
||||
@@ -493,6 +493,20 @@ struct vk_conv2d_pipeline_state {
|
||||
}
|
||||
};
|
||||
|
||||
struct vk_conv3d_pipeline_state {
|
||||
vk_conv3d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t s2, uint32_t p0, uint32_t p1, uint32_t p2,
|
||||
uint32_t d0, uint32_t d1, uint32_t d2, uint32_t KW, uint32_t KH, uint32_t KD, uint32_t aligned)
|
||||
: s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), KW(KW), KH(KH), KD(KD), aligned(aligned) {}
|
||||
|
||||
uint32_t s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD;
|
||||
uint32_t aligned;
|
||||
|
||||
bool operator<(const vk_conv3d_pipeline_state &b) const {
|
||||
return std::tie(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned) <
|
||||
std::tie(b.s0, b.s1, b.s2, b.p0, b.p1, b.p2, b.d0, b.d1, b.d2, b.KW, b.KH, b.KD, b.aligned);
|
||||
}
|
||||
};
|
||||
|
||||
struct vk_solve_tri_pipeline_state {
|
||||
vk_solve_tri_pipeline_state(uint32_t N, uint32_t K)
|
||||
: N(N), K(K) {}
|
||||
@@ -777,6 +791,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
|
||||
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_get_rows_back_f32;
|
||||
vk_pipeline pipeline_acc_f32;
|
||||
vk_pipeline pipeline_set_f32;
|
||||
|
||||
@@ -801,14 +816,10 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_concat_i8, pipeline_concat_i16, pipeline_concat_i32, pipeline_concat_i64;
|
||||
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32;
|
||||
vk_pipeline pipeline_scale_f32;
|
||||
vk_pipeline pipeline_sqr_f32;
|
||||
vk_pipeline pipeline_sqrt_f32;
|
||||
vk_pipeline pipeline_sin_f32;
|
||||
vk_pipeline pipeline_cos_f32;
|
||||
vk_pipeline pipeline_log[2];
|
||||
vk_pipeline pipeline_tri[2];
|
||||
vk_pipeline pipeline_diag[2];
|
||||
vk_pipeline pipeline_clamp_f32;
|
||||
vk_pipeline pipeline_clamp[2];
|
||||
vk_pipeline pipeline_pad_f32;
|
||||
vk_pipeline pipeline_roll_f32;
|
||||
vk_pipeline pipeline_repeat_i32, pipeline_repeat_back_f32;
|
||||
@@ -840,6 +851,10 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_gelu_quick[2];
|
||||
vk_pipeline pipeline_silu[2];
|
||||
vk_pipeline pipeline_relu[2];
|
||||
vk_pipeline pipeline_sqr[2];
|
||||
vk_pipeline pipeline_sqrt[2];
|
||||
vk_pipeline pipeline_sin[2];
|
||||
vk_pipeline pipeline_cos[2];
|
||||
vk_pipeline pipeline_xielu[2];
|
||||
vk_pipeline pipeline_neg[2];
|
||||
vk_pipeline pipeline_tanh[2];
|
||||
@@ -871,7 +886,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_geglu_erf[2];
|
||||
vk_pipeline pipeline_geglu_quick[2];
|
||||
|
||||
vk_pipeline pipeline_leaky_relu_f32;
|
||||
vk_pipeline pipeline_leaky_relu[2];
|
||||
vk_pipeline pipeline_silu_back_f32;
|
||||
vk_pipeline pipeline_diag_mask_inf_f32;
|
||||
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
|
||||
@@ -924,6 +939,8 @@ struct vk_device_struct {
|
||||
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
|
||||
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
|
||||
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
|
||||
std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f32[CONV_SHAPE_COUNT];
|
||||
std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f16_f32[CONV_SHAPE_COUNT];
|
||||
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
|
||||
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
|
||||
|
||||
@@ -1669,6 +1686,41 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
|
||||
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
|
||||
}
|
||||
|
||||
struct vk_op_conv3d_push_constants {
|
||||
uint32_t OC;
|
||||
uint32_t IC;
|
||||
uint32_t N;
|
||||
|
||||
uint32_t IW;
|
||||
uint32_t IH;
|
||||
uint32_t ID;
|
||||
uint32_t OW;
|
||||
uint32_t OH;
|
||||
uint32_t OD;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
|
||||
uint32_t nb1;
|
||||
uint32_t nb2;
|
||||
uint32_t nb3;
|
||||
|
||||
uint32_t OWmp; uint32_t OWL;
|
||||
uint32_t OWOHmp; uint32_t OWOHL;
|
||||
uint32_t OWOHODmp; uint32_t OWOHODL;
|
||||
};
|
||||
|
||||
template <> void init_pushconst_fastdiv(vk_op_conv3d_push_constants &p) {
|
||||
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
|
||||
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
|
||||
init_fastdiv_values(p.OW*p.OH*p.OD, p.OWOHODmp, p.OWOHODL);
|
||||
}
|
||||
|
||||
struct vk_op_conv2d_dw_push_constants {
|
||||
uint32_t ne;
|
||||
uint32_t batches;
|
||||
@@ -4074,19 +4126,35 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
}
|
||||
#endif
|
||||
|
||||
auto const &ggml_vk_mul_mm_spec = [](std::vector<uint32_t> spec, bool aligned) {
|
||||
spec.push_back(aligned ? 1u : 0u);
|
||||
return spec;
|
||||
};
|
||||
|
||||
const int mul_mat_id_param_count = 5;
|
||||
|
||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (device->coopmat2) {
|
||||
auto const &ggml_vk_mul_mm_cm2_spec = [](std::vector<uint32_t> spec, bool aligned, bool mul_mat_id) {
|
||||
if (mul_mat_id && spec.size() > 5) {
|
||||
spec.insert(spec.begin() + 5, aligned ? 1u : 0u);
|
||||
} else {
|
||||
spec.push_back(aligned ? 1u : 0u);
|
||||
}
|
||||
if (mul_mat_id && spec.size() == 6) {
|
||||
spec.push_back(32);
|
||||
}
|
||||
return spec;
|
||||
};
|
||||
|
||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), l_align, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), m_align, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), s_align, true); \
|
||||
|
||||
// Create 2 variants, {f16,f32} accumulator
|
||||
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
||||
@@ -4161,17 +4229,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, true); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, true); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, true); \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, true); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, true); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, true); \
|
||||
|
||||
// Create 2 variants, {f16,f32} accumulator
|
||||
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
@@ -4284,32 +4352,32 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
// Selects dot2 SPIR-V variant at runtime when device->dot2_f16 is true
|
||||
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
|
||||
// bf16 scalar path promotes to f32, no dot2 variant
|
||||
#define CREATE_MM_NODOT2(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
|
||||
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
||||
if (device->mul_mat ## ID ## _l_int[TYPE]) { \
|
||||
@@ -4474,17 +4542,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
|
||||
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
if (device->mul_mat ## ID ## _l_int[TYPE]) \
|
||||
@@ -4879,6 +4947,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_back_f32, "get_rows_back_f32", get_rows_back_f32_len, get_rows_back_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {256, 1, 1}, {}, 1, true);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
||||
@@ -4903,7 +4972,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
}
|
||||
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_nc_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
|
||||
@@ -5023,11 +5092,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
@@ -5037,8 +5101,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
@@ -5058,6 +5120,12 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
CREATE_UNARY(gelu_quick)
|
||||
CREATE_UNARY(silu)
|
||||
CREATE_UNARY(relu)
|
||||
CREATE_UNARY(sqr)
|
||||
CREATE_UNARY(sqrt)
|
||||
CREATE_UNARY(sin)
|
||||
CREATE_UNARY(cos)
|
||||
CREATE_UNARY(clamp)
|
||||
CREATE_UNARY(leaky_relu)
|
||||
CREATE_UNARY(xielu)
|
||||
CREATE_UNARY(neg)
|
||||
CREATE_UNARY(tanh)
|
||||
@@ -5097,7 +5165,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
CREATE_GLU(geglu_quick)
|
||||
#undef CREATE_GLU
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
|
||||
@@ -5314,7 +5381,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
// conv2d, conv_transpose_2d
|
||||
// conv2d, conv_transpose_2d, conv3d
|
||||
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
|
||||
// smaller WG for the small-tile fallback gives more concurrent WGs per SM
|
||||
uint32_t conv2d_WG_SIZE = (s == CONV_SHAPE_64x32) ? 128 : 256;
|
||||
@@ -5377,8 +5444,8 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
return (conv2d_BS.K * (conv2d_BS.CRS + pad) + conv2d_BS.CRS * (conv2d_BS.NPQ + pad) + csh_elems) * elem_size;
|
||||
};
|
||||
|
||||
// coopmat1 needs to store the output through shared memory, so check up front
|
||||
// whether it'll fit and disable it before applying coopmat1 parameters.
|
||||
// 2D, transpose-2D, and 3D conv use the same KxCRS @ CRSxNPQ shmem
|
||||
// layout. cm1 needs Csh for output, so check before applying cm1 params.
|
||||
if (conv2d_use_cm1 && device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_cm1_shmem_pad, true, true)) {
|
||||
conv2d_use_cm1 = false;
|
||||
}
|
||||
@@ -5470,6 +5537,53 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
}
|
||||
#undef CREATE_CONV
|
||||
#undef CREATE_CONVS
|
||||
|
||||
std::vector<uint32_t> conv3d_spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, conv2d_SHMEM_PAD };
|
||||
#define CREATE_CONV3D(type_suffix, spv_suffix) \
|
||||
for (auto &c : device->pipeline_conv3d##type_suffix[s]) { \
|
||||
const vk_conv3d_pipeline_state &state = c.first; \
|
||||
std::vector<uint32_t> spec_constants_cpy = conv3d_spec_constants; \
|
||||
spec_constants_cpy.push_back(state.s0); \
|
||||
spec_constants_cpy.push_back(state.s1); \
|
||||
spec_constants_cpy.push_back(state.s2); \
|
||||
spec_constants_cpy.push_back(state.p0); \
|
||||
spec_constants_cpy.push_back(state.p1); \
|
||||
spec_constants_cpy.push_back(state.p2); \
|
||||
spec_constants_cpy.push_back(state.d0); \
|
||||
spec_constants_cpy.push_back(state.d1); \
|
||||
spec_constants_cpy.push_back(state.d2); \
|
||||
spec_constants_cpy.push_back(state.KW); \
|
||||
spec_constants_cpy.push_back(state.KH); \
|
||||
spec_constants_cpy.push_back(state.KD); \
|
||||
spec_constants_cpy.push_back(state.aligned); \
|
||||
spec_constants_cpy.push_back(conv2d_csh_store); \
|
||||
spec_constants_cpy.push_back(conv2d_WM); \
|
||||
spec_constants_cpy.push_back(conv2d_WN); \
|
||||
ggml_vk_create_pipeline( \
|
||||
device, c.second, "conv3d" #type_suffix, \
|
||||
conv3d##type_suffix##spv_suffix##_len, conv3d##type_suffix##spv_suffix##_data, "main", 3, \
|
||||
sizeof(vk_op_conv3d_push_constants), wg_denoms, spec_constants_cpy, 1, true, conv2d_required_subgroup_size != 0, conv2d_required_subgroup_size); \
|
||||
}
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (device->coopmat2) {
|
||||
CREATE_CONV3D(_f32, _cm2)
|
||||
CREATE_CONV3D(_f16_f32, _cm2)
|
||||
} else
|
||||
#endif
|
||||
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (conv2d_use_cm1) {
|
||||
CREATE_CONV3D(_f32, _cm1)
|
||||
CREATE_CONV3D(_f16_f32, _cm1)
|
||||
} else
|
||||
#endif
|
||||
if (conv2d_UNROLL) {
|
||||
CREATE_CONV3D(_f32, _unroll)
|
||||
CREATE_CONV3D(_f16_f32, _unroll)
|
||||
} else {
|
||||
CREATE_CONV3D(_f32, )
|
||||
CREATE_CONV3D(_f16_f32, )
|
||||
}
|
||||
#undef CREATE_CONV3D
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||
@@ -10294,6 +10408,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_get_rows_f32[src0->type];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_get_rows_back_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_ACC:
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_acc_f32;
|
||||
@@ -10400,23 +10519,27 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SQR:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_sqr_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_sqr[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SQRT:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_sqrt_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_sqrt[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SIN:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_sin_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_sin[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_COS:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_cos_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_cos[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_LOG:
|
||||
@@ -10438,8 +10561,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_CLAMP:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_clamp_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_clamp[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_PAD:
|
||||
@@ -10807,8 +10931,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_leaky_relu_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_leaky_relu[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_CONV_2D:
|
||||
@@ -10885,6 +11010,61 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_CONV_3D:
|
||||
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
const uint32_t OC = (uint32_t)ggml_get_op_params_i32(dst, 11);
|
||||
const uint32_t IC = (uint32_t)ggml_get_op_params_i32(dst, 9);
|
||||
const uint32_t N = (uint32_t)ggml_get_op_params_i32(dst, 10);
|
||||
const uint32_t NPQ = N * dst->ne[2] * dst->ne[1] * dst->ne[0];
|
||||
const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, OC, NPQ);
|
||||
|
||||
const uint32_t KW = (uint32_t)src0->ne[0];
|
||||
const uint32_t KH = (uint32_t)src0->ne[1];
|
||||
const uint32_t KD = (uint32_t)src0->ne[2];
|
||||
const uint32_t s0 = (uint32_t)ggml_get_op_params_i32(dst, 0);
|
||||
const uint32_t s1 = (uint32_t)ggml_get_op_params_i32(dst, 1);
|
||||
const uint32_t s2 = (uint32_t)ggml_get_op_params_i32(dst, 2);
|
||||
const uint32_t p0 = (uint32_t)ggml_get_op_params_i32(dst, 3);
|
||||
const uint32_t p1 = (uint32_t)ggml_get_op_params_i32(dst, 4);
|
||||
const uint32_t p2 = (uint32_t)ggml_get_op_params_i32(dst, 5);
|
||||
const uint32_t d0 = (uint32_t)ggml_get_op_params_i32(dst, 6);
|
||||
const uint32_t d1 = (uint32_t)ggml_get_op_params_i32(dst, 7);
|
||||
const uint32_t d2 = (uint32_t)ggml_get_op_params_i32(dst, 8);
|
||||
|
||||
const uint32_t CRS = IC * KW * KH * KD;
|
||||
const uint32_t BS_K = vk_conv_block_sizes[shape].K;
|
||||
const uint32_t BS_CRS = vk_conv_block_sizes[shape].CRS;
|
||||
const uint32_t BS_NPQ = vk_conv_block_sizes[shape].NPQ;
|
||||
const uint32_t aligned = ((OC % BS_K == 0) &&
|
||||
(CRS % BS_CRS == 0) &&
|
||||
(NPQ % BS_NPQ == 0)) ? 1u : 0u;
|
||||
|
||||
vk_conv3d_pipeline_state conv3d_pipeline_state(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned);
|
||||
|
||||
std::map<vk_conv3d_pipeline_state, vk_pipeline> *pipelines = nullptr;
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
pipelines = &ctx->device->pipeline_conv3d_f32[shape];
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
pipelines = &ctx->device->pipeline_conv3d_f16_f32[shape];
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
vk_pipeline pipeline = nullptr;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
|
||||
auto it = pipelines->find(conv3d_pipeline_state);
|
||||
if (it != pipelines->end()) {
|
||||
pipeline = it->second;
|
||||
} else {
|
||||
(*pipelines)[conv3d_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
|
||||
}
|
||||
}
|
||||
|
||||
return pipeline;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_ADD1:
|
||||
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||
return ctx->device->pipeline_add1_f16_f16;
|
||||
@@ -11135,6 +11315,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
||||
break;
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
elements = { (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], 1 };
|
||||
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
GGML_ASSERT(0);
|
||||
break;
|
||||
@@ -11220,6 +11404,21 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
GGML_ABORT("invalid push constant type for CONV_2D");
|
||||
}
|
||||
break;
|
||||
case GGML_OP_CONV_3D:
|
||||
if constexpr (std::is_same_v<PC, vk_op_conv3d_push_constants>) {
|
||||
const uint32_t NPQ = pc.N * pc.OD * pc.OH * pc.OW;
|
||||
const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, pc.OC, NPQ);
|
||||
const uint32_t NPQ_blocks = CEIL_DIV(NPQ, vk_conv_block_sizes[shape].NPQ);
|
||||
|
||||
elements = { pc.OC, NPQ_blocks, 1 };
|
||||
if (elements[1] > 512) {
|
||||
elements[2] = CEIL_DIV(elements[1], 512);
|
||||
elements[1] = 512;
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("invalid push constant type for CONV_3D");
|
||||
}
|
||||
break;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_DIV:
|
||||
@@ -11236,6 +11435,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
case GGML_OP_TRI:
|
||||
case GGML_OP_DIAG:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ROLL:
|
||||
case GGML_OP_REPEAT:
|
||||
@@ -11380,6 +11580,21 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_vk_get_rows_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GET_ROWS_BACK, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2], (uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f, 0,
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
@@ -12087,8 +12302,10 @@ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
|
||||
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
float * op_params = (float *)dst->op_params;
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
||||
p.param1 = op_params[0];
|
||||
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, std::move(p));
|
||||
}
|
||||
|
||||
static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
@@ -13118,6 +13335,51 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
|
||||
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, std::move(p));
|
||||
}
|
||||
|
||||
static void ggml_vk_conv_3d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
|
||||
const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
|
||||
vk_op_conv3d_push_constants p{};
|
||||
p.IC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 9));
|
||||
p.N = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 10));
|
||||
p.OC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 11));
|
||||
GGML_ASSERT(src0->ne[3] == (int64_t)p.IC * p.OC);
|
||||
GGML_ASSERT(src1->ne[3] == (int64_t)p.IC * p.N);
|
||||
GGML_ASSERT(dst->ne[3] == (int64_t)p.OC * p.N);
|
||||
|
||||
p.IW = static_cast<uint32_t>(ne10);
|
||||
p.IH = static_cast<uint32_t>(ne11);
|
||||
p.ID = static_cast<uint32_t>(ne12);
|
||||
p.OW = static_cast<uint32_t>(ne0);
|
||||
p.OH = static_cast<uint32_t>(ne1);
|
||||
p.OD = static_cast<uint32_t>(ne2);
|
||||
|
||||
// the shader clamps src addresses to p.IC * p.N * p.IW * p.IH * p.ID - 1 in uint32, so the
|
||||
// total input element count must fit in a uint32.
|
||||
GGML_ASSERT((uint64_t)p.IC * p.N * p.IW * p.IH * p.ID <= 0xFFFFFFFFull);
|
||||
|
||||
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
|
||||
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
|
||||
p.nb03 = static_cast<uint32_t>(nb03 / nb00);
|
||||
|
||||
p.nb11 = static_cast<uint32_t>(nb11 / nb10);
|
||||
p.nb12 = static_cast<uint32_t>(nb12 / nb10);
|
||||
p.nb13 = static_cast<uint32_t>(nb13 / nb10);
|
||||
|
||||
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
|
||||
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
|
||||
p.nb3 = static_cast<uint32_t>(nb3 / nb0);
|
||||
|
||||
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_3D, std::move(p));
|
||||
}
|
||||
|
||||
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
vk_op_conv2d_dw_push_constants p{};
|
||||
p.ne = ggml_nelements(dst);
|
||||
@@ -13144,7 +13406,10 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
|
||||
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
const float * op_params = (const float *)dst->op_params;
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f });
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
||||
p.param1 = op_params[0];
|
||||
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, std::move(p));
|
||||
}
|
||||
|
||||
#ifdef GGML_VULKAN_RUN_TESTS
|
||||
@@ -14247,6 +14512,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_OP_GET_ROWS:
|
||||
ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
ggml_vk_get_rows_back(ctx, compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_ADD:
|
||||
if (ctx->num_additional_fused_ops) {
|
||||
@@ -14515,6 +14784,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_CONV_3D:
|
||||
ggml_vk_conv_3d(ctx, compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node);
|
||||
@@ -16964,6 +17237,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
return false;
|
||||
}
|
||||
}
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
switch (op->type) {
|
||||
@@ -17060,12 +17335,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_RMS_NORM:
|
||||
return true;
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return ggml_is_contiguous_rows(op->src[0]) &&
|
||||
op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
@@ -17084,8 +17358,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||
op->type == op->src[0]->type;
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
@@ -17285,6 +17560,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
ggml_is_contiguous(op->src[1]) &&
|
||||
ggml_is_contiguous(op));
|
||||
}
|
||||
case GGML_OP_CONV_3D:
|
||||
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
op->type == GGML_TYPE_F32 &&
|
||||
ggml_is_contiguous(op->src[0]) &&
|
||||
ggml_is_contiguous(op->src[1]) &&
|
||||
ggml_is_contiguous(op);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -18128,6 +18410,20 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||
const int32_t d0 = tensor->op_params[4];
|
||||
const int32_t d1 = tensor->op_params[5];
|
||||
tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
|
||||
} else if (tensor->op == GGML_OP_CONV_3D) {
|
||||
const int32_t s0 = tensor->op_params[0];
|
||||
const int32_t s1 = tensor->op_params[1];
|
||||
const int32_t s2 = tensor->op_params[2];
|
||||
const int32_t p0 = tensor->op_params[3];
|
||||
const int32_t p1 = tensor->op_params[4];
|
||||
const int32_t p2 = tensor->op_params[5];
|
||||
const int32_t d0 = tensor->op_params[6];
|
||||
const int32_t d1 = tensor->op_params[7];
|
||||
const int32_t d2 = tensor->op_params[8];
|
||||
const int32_t IC = tensor->op_params[9];
|
||||
const int32_t N = tensor->op_params[10];
|
||||
const int32_t OC = tensor->op_params[11];
|
||||
tensor_clone = ggml_conv_3d_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, s2, p0, p1, p2, d0, d1, d2, IC, N, OC);
|
||||
} else if (tensor->op == GGML_OP_CONV_2D_DW) {
|
||||
const int32_t s0 = tensor->op_params[0];
|
||||
const int32_t s1 = tensor->op_params[1];
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
|
||||
}
|
||||
@@ -0,0 +1,431 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#ifdef COOPMAT2
|
||||
#extension GL_NV_cooperative_matrix2 : enable
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#endif
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
|
||||
layout(binding = 0) readonly buffer A {
|
||||
A_TYPE knl_data[];
|
||||
}; // src0 - kernel: [KW, KH, KD, IC*OC]
|
||||
|
||||
layout(binding = 1) readonly buffer B {
|
||||
B_TYPE src_data[];
|
||||
}; // src1 - input: [IW, IH, ID, IC*N] -- channel_first format
|
||||
|
||||
layout(binding = 2) writeonly buffer D {
|
||||
D_TYPE dst_data[];
|
||||
}; // dst - result: [OW, OH, OD, OC*N]
|
||||
|
||||
layout(push_constant) uniform parameter {
|
||||
// I/O channels, batch size
|
||||
uint32_t OC;
|
||||
uint32_t IC;
|
||||
uint32_t N;
|
||||
|
||||
// Tensor spatial sizes: input, output
|
||||
uint32_t IW;
|
||||
uint32_t IH;
|
||||
uint32_t ID;
|
||||
uint32_t OW;
|
||||
uint32_t OH;
|
||||
uint32_t OD;
|
||||
|
||||
// Strides in elements
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
|
||||
uint32_t nb1;
|
||||
uint32_t nb2;
|
||||
uint32_t nb3;
|
||||
|
||||
// fastdiv helper values
|
||||
uint32_t OWmp; uint32_t OWL;
|
||||
uint32_t OWOHmp; uint32_t OWOHL;
|
||||
uint32_t OWOHODmp; uint32_t OWOHODL;
|
||||
}
|
||||
|
||||
p;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
// Blocktile sizes
|
||||
layout(constant_id = 1) const uint BS_K = 128;
|
||||
layout(constant_id = 2) const uint BS_CRS = 16;
|
||||
layout(constant_id = 3) const uint BS_NPQ = 128;
|
||||
// Thread-tile sizes
|
||||
layout(constant_id = 4) const uint TS_K = 8;
|
||||
layout(constant_id = 5) const uint SHMEM_PAD = 4;
|
||||
// Stride, padding, dilation
|
||||
layout(constant_id = 6) const uint s0 = 1;
|
||||
layout(constant_id = 7) const uint s1 = 1;
|
||||
layout(constant_id = 8) const uint s2 = 1;
|
||||
layout(constant_id = 9) const uint p0 = 0;
|
||||
layout(constant_id = 10) const uint p1 = 0;
|
||||
layout(constant_id = 11) const uint p2 = 0;
|
||||
layout(constant_id = 12) const uint d0 = 1;
|
||||
layout(constant_id = 13) const uint d1 = 1;
|
||||
layout(constant_id = 14) const uint d2 = 1;
|
||||
// Kernel spatial sizes
|
||||
layout(constant_id = 15) const uint KW = 1;
|
||||
layout(constant_id = 16) const uint KH = 1;
|
||||
layout(constant_id = 17) const uint KD = 1;
|
||||
// when set, skip bounds checks and address clamps (K/CRS/NPQ are tile-aligned)
|
||||
layout(constant_id = 18) const uint aligned = 0;
|
||||
// stage cm2 result through shmem (Csh) for coalesced stores. cm1 always does this.
|
||||
layout(constant_id = 19) const uint csh_store = 0;
|
||||
|
||||
#ifdef COOPMAT
|
||||
// cm1 subgroup tile: each subgroup computes a WM x WN region as a grid of
|
||||
// TM x TN x TK fragments. Requires WM%TM == WN%TN == BS_K%WM == BS_NPQ%WN ==
|
||||
// BS_CRS%TK == 0, and WG_SIZE == (BS_K/WM) * (BS_NPQ/WN) * subgroup_size.
|
||||
layout(constant_id = 20) const uint WM = 32;
|
||||
layout(constant_id = 21) const uint WN = 32;
|
||||
const uint TM = 16;
|
||||
const uint TN = 16;
|
||||
const uint TK = 16;
|
||||
const uint cms_per_row = WM / TM;
|
||||
const uint cms_per_col = WN / TN;
|
||||
const uint warps_M = BS_K / WM;
|
||||
const uint warps_N = BS_NPQ / WN;
|
||||
#endif
|
||||
|
||||
// without padding, ID_idx/IH_idx/IW_idx are in bounds by construction
|
||||
const bool dhw_in_bounds = (p0 == 0) && (p1 == 0) && (p2 == 0);
|
||||
|
||||
uint32_t tid = gl_LocalInvocationID.x;
|
||||
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
|
||||
|
||||
uint splitWork(uint work_size, uint block_size) {
|
||||
return (block_size + work_size - 1) / block_size;
|
||||
}
|
||||
|
||||
uint32_t K = p.OC;
|
||||
uint32_t CRS = p.IC * KD * KH * KW;
|
||||
uint32_t NPQ = p.N * p.OD * p.OH * p.OW;
|
||||
|
||||
// Number of blocktiles per input
|
||||
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
|
||||
|
||||
#if defined(COOPMAT2) || defined(COOPMAT)
|
||||
#define SHMEM_TYPE float16_t
|
||||
#else
|
||||
#define SHMEM_TYPE float
|
||||
#endif
|
||||
|
||||
const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
|
||||
const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
|
||||
|
||||
const uint32_t Ash_len = BS_K * Ash_stride;
|
||||
const uint32_t Bsh_len = BS_CRS * Bsh_stride;
|
||||
|
||||
shared SHMEM_TYPE Ash[Ash_len]; // K x CRS
|
||||
shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ
|
||||
|
||||
#if defined(COOPMAT2) || defined(COOPMAT)
|
||||
// stage matC through shmem so global stores are row-major (NPQ-contiguous)
|
||||
const uint32_t Csh_stride = BS_NPQ;
|
||||
#ifdef COOPMAT
|
||||
const uint32_t Csh_len = BS_K * Csh_stride;
|
||||
#else
|
||||
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1;
|
||||
#endif
|
||||
shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ
|
||||
#endif
|
||||
|
||||
// Threadtile sizes
|
||||
const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
|
||||
|
||||
// Number of threadtiles per blocktile
|
||||
const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
|
||||
|
||||
/*
|
||||
Compute
|
||||
KxCRS @ CRSxNPQ = K x NPQ
|
||||
K=OC
|
||||
C=IC
|
||||
D,R,S=KD,KH,KW
|
||||
Z,P,Q=OD,OH,OW
|
||||
*/
|
||||
|
||||
uint32_t B_idx_K = gl_WorkGroupID.x;
|
||||
uint32_t B_idx_NPQ = gl_WorkGroupID.y + gl_WorkGroupID.z * 512;
|
||||
|
||||
uint32_t T_y = tid / NT_NPQ;
|
||||
uint32_t T_x = tid % NT_NPQ;
|
||||
|
||||
uint32_t Ar = tid / BS_CRS;
|
||||
uint32_t Ac = tid % BS_CRS;
|
||||
const uint32_t ArpWg = WG_SIZE / BS_CRS;
|
||||
|
||||
uint32_t Br = tid / BS_NPQ;
|
||||
uint32_t Bc = tid % BS_NPQ;
|
||||
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
|
||||
|
||||
// see init_fastdiv_values in ggml-vulkan.cpp
|
||||
uint fastdiv(uint n, uint mp, uint L) {
|
||||
uint msbs, lsbs;
|
||||
// msbs = mulhi(n, mp)
|
||||
umulExtended(n, mp, msbs, lsbs);
|
||||
return (msbs + n) >> L;
|
||||
}
|
||||
|
||||
void split_crs(uint32_t crs_idx, out uint32_t ic, out uint32_t kd, out uint32_t kh, out uint32_t kw) {
|
||||
const uint32_t KHKW = KH * KW;
|
||||
const uint32_t KDKHKW = KD * KHKW;
|
||||
ic = crs_idx / KDKHKW;
|
||||
uint32_t rem = crs_idx - ic * KDKHKW;
|
||||
kd = rem / KHKW;
|
||||
rem = rem - kd * KHKW;
|
||||
kh = rem / KW;
|
||||
kw = rem - kh * KW;
|
||||
}
|
||||
|
||||
void split_npq(uint32_t npq_idx, out uint32_t n, out uint32_t od, out uint32_t oh, out uint32_t ow) {
|
||||
const uint32_t OWOH = p.OW * p.OH;
|
||||
n = fastdiv(npq_idx, p.OWOHODmp, p.OWOHODL);
|
||||
uint32_t rem = npq_idx - n * p.OD * OWOH;
|
||||
od = fastdiv(rem, p.OWOHmp, p.OWOHL);
|
||||
rem = rem - od * OWOH;
|
||||
oh = fastdiv(rem, p.OWmp, p.OWL);
|
||||
ow = rem - oh * p.OW;
|
||||
}
|
||||
|
||||
#ifdef COOPMAT2
|
||||
#define ACC_TYPE float16_t
|
||||
|
||||
ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem)
|
||||
{
|
||||
uint32_t K_idx = B_idx_K * BS_K + r;
|
||||
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c;
|
||||
uint32_t N_idx;
|
||||
uint32_t OD_idx;
|
||||
uint32_t OH_idx;
|
||||
uint32_t OW_idx;
|
||||
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
|
||||
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
|
||||
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
|
||||
dst_data[dst_idx] = D_TYPE(elem);
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
if (B_idx_NPQ * BS_NPQ >= NPQ) {
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef COOPMAT2
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
|
||||
matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);
|
||||
#elif defined(COOPMAT)
|
||||
coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
|
||||
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
||||
sums[i] = coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0);
|
||||
}
|
||||
const uint warp_r = gl_SubgroupID / warps_N;
|
||||
const uint warp_c = gl_SubgroupID % warps_N;
|
||||
#else
|
||||
float regC[TS_K][TS_NPQ];
|
||||
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
||||
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
||||
regC[T_ly][T_lx] = 0.0;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
/* Advance block in CRS dim */
|
||||
[[dont_unroll]] for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
|
||||
uint32_t CRS_idx_a = B_idx_CRS * BS_CRS + Ac;
|
||||
uint32_t IC_idx_a;
|
||||
uint32_t KD_idx_a;
|
||||
uint32_t KH_idx_a;
|
||||
uint32_t KW_idx_a;
|
||||
split_crs(CRS_idx_a, IC_idx_a, KD_idx_a, KH_idx_a, KW_idx_a);
|
||||
|
||||
/* Load kernel to A_block: (BS_K x BS_CRS)*/
|
||||
UNROLL for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
|
||||
uint32_t B_ly = r_offset + Ar;
|
||||
uint32_t B_lx = Ac;
|
||||
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
|
||||
uint32_t knl_idx = KW_idx_a + KH_idx_a * p.nb01 + KD_idx_a * p.nb02 + (K_idx * p.IC + IC_idx_a) * p.nb03;
|
||||
if (aligned == 0) {
|
||||
knl_idx = min(knl_idx, K * CRS - 1);
|
||||
}
|
||||
float val = knl_data[knl_idx];
|
||||
if (aligned == 0 && (K_idx >= K || CRS_idx_a >= CRS)) {
|
||||
val = 0.0;
|
||||
}
|
||||
Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val);
|
||||
}
|
||||
/* Load input to B_block: (BS_CRS x BS_NPQ) */
|
||||
UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
|
||||
uint32_t B_ly = r_offset + Br; /* Row index of B block */
|
||||
uint32_t B_lx = Bc;
|
||||
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
|
||||
uint32_t N_idx;
|
||||
uint32_t OD_idx;
|
||||
uint32_t OH_idx;
|
||||
uint32_t OW_idx;
|
||||
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
|
||||
|
||||
uint32_t CRS_idx_b = B_idx_CRS * BS_CRS + B_ly;
|
||||
uint32_t IC_idx_b;
|
||||
uint32_t KD_idx_b;
|
||||
uint32_t KH_idx_b;
|
||||
uint32_t KW_idx_b;
|
||||
split_crs(CRS_idx_b, IC_idx_b, KD_idx_b, KH_idx_b, KW_idx_b);
|
||||
|
||||
uint32_t ID_idx = OD_idx * s2 + KD_idx_b * d2 - p2;
|
||||
uint32_t IH_idx = OH_idx * s1 + KH_idx_b * d1 - p1;
|
||||
uint32_t IW_idx = OW_idx * s0 + KW_idx_b * d0 - p0;
|
||||
|
||||
uint32_t src_idx = IW_idx + IH_idx * p.nb11 + ID_idx * p.nb12 + (N_idx * p.IC + IC_idx_b) * p.nb13;
|
||||
// skip clamp when address can't go OOB
|
||||
if (aligned == 0 || !dhw_in_bounds) {
|
||||
src_idx = min(src_idx, p.IC * p.N * p.IW * p.IH * p.ID - 1);
|
||||
}
|
||||
float val = src_data[src_idx];
|
||||
bool oob = false;
|
||||
if (aligned == 0 && (CRS_idx_b >= CRS || NPQ_idx >= NPQ)) {
|
||||
oob = true;
|
||||
}
|
||||
// also catches lower-bound underflow (idx wraps to 0x80000000+)
|
||||
if (!dhw_in_bounds && (ID_idx >= p.ID || IH_idx >= p.IH || IW_idx >= p.IW)) {
|
||||
oob = true;
|
||||
}
|
||||
if (oob) {
|
||||
val = 0.0;
|
||||
}
|
||||
Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
|
||||
}
|
||||
barrier();
|
||||
#ifdef COOPMAT2
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, BS_K, BS_CRS, gl_MatrixUseA> matA;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, BS_CRS, BS_NPQ, gl_MatrixUseB> matB;
|
||||
|
||||
coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
matC = coopMatMulAdd(matA, matB, matC);
|
||||
#elif defined(COOPMAT)
|
||||
// each subgroup multiplies its grid of fragments per TK-sized CRS chunk
|
||||
[[unroll]] for (uint k_step = 0; k_step < BS_CRS / TK; k_step++) {
|
||||
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a[cms_per_row];
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
const uint a_off = (warp_r * WM + cm_row * TM) * Ash_stride + k_step * TK;
|
||||
coopMatLoad(cache_a[cm_row], Ash, a_off, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
|
||||
const uint b_off = k_step * TK * Bsh_stride + warp_c * WN + cm_col * TN;
|
||||
coopMatLoad(cache_b, Bsh, b_off, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a[cm_row], cache_b, sums[cm_col * cms_per_row + cm_row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (T_y * TS_K < K) {
|
||||
UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
|
||||
float regA[TS_K];
|
||||
float regB[TS_NPQ];
|
||||
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
||||
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
|
||||
}
|
||||
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
||||
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
|
||||
}
|
||||
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
||||
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
||||
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
barrier();
|
||||
}
|
||||
/* Save C* */
|
||||
#if defined(COOPMAT2) || defined(COOPMAT)
|
||||
// stage matC into Csh, then write to dst with coalesced NPQ-contiguous stores
|
||||
#ifdef COOPMAT
|
||||
const bool use_staged_store = true;
|
||||
#else
|
||||
const bool use_staged_store = (csh_store != 0);
|
||||
#endif
|
||||
if (use_staged_store) {
|
||||
#ifdef COOPMAT
|
||||
// cm1: each subgroup stores its fragment grid into its Csh slot
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
const uint csh_off = (warp_r * WM + cm_row * TM) * Csh_stride + warp_c * WN + cm_col * TN;
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], Csh, csh_off, Csh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
}
|
||||
#else
|
||||
coopMatStore(matC, Csh, 0, Csh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
#endif
|
||||
barrier();
|
||||
|
||||
// cooperative shmem->global: WG threads spread across BS_NPQ (the
|
||||
// contiguous direction of dst), each iter covers store_rows_per_iter K-rows
|
||||
const uint32_t store_rows_per_iter = WG_SIZE / BS_NPQ;
|
||||
const uint32_t store_iters = BS_K / store_rows_per_iter;
|
||||
const uint32_t k_thread_offset = tid / BS_NPQ;
|
||||
const uint32_t npq_thread = tid % BS_NPQ;
|
||||
[[unroll]] for (uint32_t i = 0; i < store_iters; i++) {
|
||||
uint32_t k_local = i * store_rows_per_iter + k_thread_offset;
|
||||
uint32_t K_idx = B_idx_K * BS_K + k_local;
|
||||
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + npq_thread;
|
||||
uint32_t N_idx;
|
||||
uint32_t OD_idx;
|
||||
uint32_t OH_idx;
|
||||
uint32_t OW_idx;
|
||||
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
|
||||
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
|
||||
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
|
||||
dst_data[dst_idx] = D_TYPE(Csh[k_local * Csh_stride + npq_thread]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef COOPMAT2
|
||||
else {
|
||||
coopMatPerElementNV(matC, matC, perElemOpStore);
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
if (T_y * TS_K < K) {
|
||||
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
||||
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
||||
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
|
||||
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
|
||||
uint32_t N_idx;
|
||||
uint32_t OD_idx;
|
||||
uint32_t OH_idx;
|
||||
uint32_t OW_idx;
|
||||
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
|
||||
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
|
||||
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
|
||||
dst_data[dst_idx] = D_TYPE(regC[T_ly][T_lx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val));
|
||||
}
|
||||
@@ -463,6 +463,7 @@ void main() {
|
||||
}
|
||||
rowmaxf = max(rowmaxf, float(Sf[r][c]));
|
||||
}
|
||||
rowmaxf += FATTN_KQ_MAX_OFFSET;
|
||||
float Moldf = Mf[r];
|
||||
|
||||
// M = max(rowmax, Mold)
|
||||
|
||||
@@ -352,6 +352,7 @@ void main() {
|
||||
}
|
||||
rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp]));
|
||||
}
|
||||
rowmaxf += FATTN_KQ_MAX_OFFSET;
|
||||
float Moldf = Mf[r];
|
||||
|
||||
// Compute max across the row
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_binary_head.glsl"
|
||||
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint col = gl_GlobalInvocationID.x;
|
||||
|
||||
if (col >= p.ne20) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (uint row = gl_GlobalInvocationID.y; row < p.ne21; row += gl_WorkGroupSize.y * gl_NumWorkGroups.y) {
|
||||
float sum = 0.0f;
|
||||
for (uint i = 0; i < p.ne10; ++i) {
|
||||
if (data_b[get_boffset() + i*p.nb10] == int(row)) {
|
||||
sum += data_a[get_aoffset() + i*p.nb01 + col*p.nb00];
|
||||
}
|
||||
}
|
||||
|
||||
data_d[get_doffset() + row*p.nb21 + col*p.nb20] = sum;
|
||||
}
|
||||
}
|
||||
@@ -14,16 +14,13 @@ void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
const uint i3 = row / (p.ne11 * p.ne12);
|
||||
const uint i3_offset = i3 * p.ne12 * p.ne11;
|
||||
const uint i2 = (row - i3_offset) / p.ne11;
|
||||
const uint i2_offset = i2 * p.ne11;
|
||||
const uint i1 = row - i3_offset - i2_offset;
|
||||
const uint a_base = get_aoffset() + src0_idx(row * p.ne00);
|
||||
const uint d_base = get_doffset() + dst_idx(row * p.ne10);
|
||||
|
||||
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
|
||||
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]);
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_base + i0*p.nb00]);
|
||||
sum[tid] += xi * xi;
|
||||
}
|
||||
|
||||
@@ -39,6 +36,6 @@ void main() {
|
||||
const FLOAT_TYPE scale = 1.0f / max(sqrt(sum[0]), FLOAT_TYPE(p.param1));
|
||||
|
||||
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
|
||||
data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]));
|
||||
data_d[d_base + i0*p.nb10] = D_TYPE(scale * FLOAT_TYPE(data_a[a_base + i0*p.nb00]));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float val = float(data_a[i]);
|
||||
data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1);
|
||||
}
|
||||
@@ -38,17 +38,7 @@
|
||||
#define LOAD_VEC_B 1
|
||||
#endif
|
||||
|
||||
// Load 2 values at once without affecting index calculations through LOAD_VEC
|
||||
#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED)
|
||||
#define LOAD_VEC_BATCH_A 2
|
||||
#else
|
||||
#define LOAD_VEC_BATCH_A 1
|
||||
#endif
|
||||
#if !defined(ALIGNED)
|
||||
#define LOAD_VEC_BATCH_B 2
|
||||
#else
|
||||
#define LOAD_VEC_BATCH_B 1
|
||||
#endif
|
||||
layout (constant_id = 11) const uint ALIGNED = 0;
|
||||
|
||||
#if !defined(TO_FLOAT_TYPE)
|
||||
#define TO_FLOAT_TYPE FLOAT_TYPE
|
||||
@@ -57,6 +47,13 @@
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
#if defined(DATA_A_F32)
|
||||
layout (binding = 0) readonly buffer A_SCALAR {float data_a_scalar[];};
|
||||
#elif defined(DATA_A_F16)
|
||||
layout (binding = 0) readonly buffer A_SCALAR {float16_t data_a_scalar[];};
|
||||
#elif defined(DATA_A_BF16)
|
||||
layout (binding = 0) readonly buffer A_SCALAR {uint16_t data_a_scalar[];};
|
||||
#endif
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
|
||||
#endif
|
||||
@@ -65,6 +62,7 @@ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32
|
||||
#endif
|
||||
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
layout (binding = 1) readonly buffer B_SCALAR {B_TYPE_SCALAR data_b_scalar[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
@@ -194,13 +192,23 @@ void main() {
|
||||
const uint warp_r = warp_i % (BM / WM);
|
||||
const uint warp_c = warp_i / (BM / WM);
|
||||
|
||||
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
|
||||
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
|
||||
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
|
||||
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
|
||||
const uint LOAD_VEC_A_EFF = (ALIGNED != 0) ? LOAD_VEC_A : 1;
|
||||
const uint LOAD_VEC_BATCH_A = (ALIGNED != 0) ? 1 : 2;
|
||||
#else
|
||||
const uint LOAD_VEC_A_EFF = LOAD_VEC_A;
|
||||
const uint LOAD_VEC_BATCH_A = 1;
|
||||
#endif
|
||||
const uint LOAD_VEC_B_EFF = (ALIGNED != 0) ? LOAD_VEC_B : 1;
|
||||
const uint LOAD_VEC_BATCH_B = (ALIGNED != 0) ? 1 : 2;
|
||||
|
||||
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK;
|
||||
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK;
|
||||
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A);
|
||||
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A);
|
||||
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B);
|
||||
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B);
|
||||
|
||||
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A_EFF * LOAD_VEC_BATCH_A / BK;
|
||||
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B_EFF * LOAD_VEC_BATCH_B / BK;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||
@@ -239,15 +247,15 @@ void main() {
|
||||
|
||||
uint pos_a =
|
||||
#ifdef MUL_MAT_ID
|
||||
expert_idx * (p.batch_stride_a / LOAD_VEC_A) +
|
||||
expert_idx * (p.batch_stride_a / LOAD_VEC_A_EFF) +
|
||||
#else
|
||||
batch_idx_a * (p.batch_stride_a / LOAD_VEC_A) +
|
||||
batch_idx_a * (p.batch_stride_a / LOAD_VEC_A_EFF) +
|
||||
#endif
|
||||
(ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
|
||||
(ir * BM * p.stride_a + start_k) / LOAD_VEC_A_EFF;
|
||||
#ifdef MUL_MAT_ID
|
||||
uint pos_b = 0;
|
||||
#else
|
||||
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
|
||||
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B_EFF;
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
@@ -287,8 +295,8 @@ void main() {
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a += BK / LOAD_VEC_A;
|
||||
pos_b += BK / LOAD_VEC_B;
|
||||
pos_a += BK / LOAD_VEC_A_EFF;
|
||||
pos_b += BK / LOAD_VEC_B_EFF;
|
||||
|
||||
#ifdef COOPMAT
|
||||
[[unroll]] for (uint i = 0; i < BK; i += TK) {
|
||||
|
||||
@@ -36,6 +36,7 @@ layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working wit
|
||||
layout (constant_id = 4) const bool enable_smaller_matrices = false;
|
||||
const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
|
||||
const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
|
||||
layout (constant_id = 5) const uint ALIGNED = 0;
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
@@ -111,7 +112,7 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
|
||||
};
|
||||
|
||||
uint _ne1;
|
||||
layout (constant_id = 5) const uint subgroup_size = 32;
|
||||
layout (constant_id = 6) const uint subgroup_size = 32;
|
||||
shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
|
||||
|
||||
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
@@ -297,12 +298,12 @@ void main() {
|
||||
|
||||
// Hint to the compiler that values are aligned (want 16B alignment).
|
||||
// Quants are always block-aligned, no alignment needed.
|
||||
#if ALIGNED
|
||||
if (ALIGNED != 0) {
|
||||
#if QUANT_K == 1
|
||||
stride_a &= ~7;
|
||||
#endif
|
||||
stride_b &= ~7;
|
||||
stride_a &= ~7;
|
||||
#endif
|
||||
stride_b &= ~7;
|
||||
}
|
||||
|
||||
// Create layouts for both clamped and unclamped accesses
|
||||
tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2);
|
||||
|
||||
@@ -1,50 +1,57 @@
|
||||
void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) {
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
||||
#if LOAD_VEC_A == 8
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa[0].xy;
|
||||
buf_a[buf_idx + 1] = aa[0].zw;
|
||||
buf_a[buf_idx + 2] = aa[1].xy;
|
||||
buf_a[buf_idx + 3] = aa[1].zw;
|
||||
if (ALIGNED != 0) {
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa[0].xy;
|
||||
buf_a[buf_idx + 1] = aa[0].zw;
|
||||
buf_a[buf_idx + 2] = aa[1].xy;
|
||||
buf_a[buf_idx + 3] = aa[1].zw;
|
||||
return;
|
||||
}
|
||||
#elif LOAD_VEC_A == 4
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
#else // LOAD_VEC_BATCH_A == 2
|
||||
if (ALIGNED != 0) {
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
const uint idx = pos_a + col * p.stride_a + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx],
|
||||
data_a[idx + 1]);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx],
|
||||
data_a_scalar[idx + 1]);
|
||||
} else if (idx_m < p.M && block + row * 2 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], 0.0f);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx], 0.0f);
|
||||
} else {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
#elif defined(DATA_A_BF16)
|
||||
#if LOAD_VEC_A == 4
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx]));
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
#else // LOAD_VEC_BATCH_A == 2
|
||||
if (ALIGNED != 0) {
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx]));
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
const uint idx = pos_a + col * p.stride_a + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]),
|
||||
TO_FLOAT_TYPE(data_a[idx + 1]));
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]),
|
||||
TO_FLOAT_TYPE(data_a_scalar[idx + 1]));
|
||||
} else if (idx_m < p.M && block + row * 2 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), 0.0f);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]), 0.0f);
|
||||
} else {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
#elif defined(DATA_A_Q4_0)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
|
||||
@@ -526,75 +533,85 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
#if !defined(MUL_MAT_ID)
|
||||
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) {
|
||||
#if LOAD_VEC_B == 8
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
buf_b[buf_idx + 3] = bb[1].zw;
|
||||
if (ALIGNED != 0) {
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
buf_b[buf_idx + 3] = bb[1].zw;
|
||||
return;
|
||||
}
|
||||
#elif LOAD_VEC_B == 4
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
if (ALIGNED != 0) {
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
#if defined(DATA_B_BF16)
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
#else
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
#else // LOAD_VEC_BATCH_B == 2
|
||||
const uint idx = pos_b + col * p.stride_b + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_n < p.N && block + row * 2 + 1 < end_k) {
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
|
||||
TO_FLOAT_TYPE(data_b[idx + 1]));
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]),
|
||||
TO_FLOAT_TYPE(data_b_scalar[idx + 1]));
|
||||
} else if (idx_n < p.N && block + row * 2 < end_k) {
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f);
|
||||
} else {
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) {
|
||||
#if LOAD_VEC_B == 8
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
buf_b[buf_idx + 3] = bb[1].zw;
|
||||
if (ALIGNED != 0) {
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
buf_b[buf_idx + 3] = bb[1].zw;
|
||||
return;
|
||||
}
|
||||
#elif LOAD_VEC_B == 4
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
if (ALIGNED != 0) {
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
#if defined(DATA_B_BF16)
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
#else
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
#else // LOAD_VEC_BATCH_B == 2
|
||||
const uint row_i = ic * BN + col;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
|
||||
TO_FLOAT_TYPE(data_b[idx + 1]));
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]),
|
||||
TO_FLOAT_TYPE(data_b_scalar[idx + 1]));
|
||||
} else if (row_i < _ne1 && block + row * 2 < end_k) {
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f);
|
||||
} else {
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#define BLOCK_SIZE 512
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
shared vec2 sum[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
const uint a_base = get_aoffset() + src0_idx(row * p.ne00);
|
||||
const uint d_base = get_doffset() + dst_idx(row * p.ne10);
|
||||
|
||||
sum[tid] = vec2(0.0f, 0.0f);
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
const float xi = float(data_a[row*p.KX + col]);
|
||||
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
|
||||
const float xi = float(data_a[a_base + i0*p.nb00]);
|
||||
sum[tid].x += xi;
|
||||
sum[tid].y += xi * xi;
|
||||
}
|
||||
@@ -34,11 +34,11 @@ void main() {
|
||||
barrier();
|
||||
}
|
||||
|
||||
const float mean = sum[0].x / p.KX;
|
||||
const float var = sum[0].y / p.KX - mean * mean;
|
||||
const float mean = sum[0].x / p.ne00;
|
||||
const float var = sum[0].y / p.ne00 - mean * mean;
|
||||
const float inv_std = inversesqrt(var + p.param1);
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std);
|
||||
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
|
||||
data_d[d_base + i0*p.nb10] = D_TYPE((float(data_a[a_base + i0*p.nb00]) - mean) * inv_std);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val));
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sqrt(val));
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val);
|
||||
}
|
||||
@@ -17,6 +17,30 @@ float op_neg(float x) {
|
||||
return -x;
|
||||
}
|
||||
|
||||
float op_sqr(float x) {
|
||||
return x * x;
|
||||
}
|
||||
|
||||
float op_sqrt(float x) {
|
||||
return sqrt(x);
|
||||
}
|
||||
|
||||
float op_sin(float x) {
|
||||
return sin(x);
|
||||
}
|
||||
|
||||
float op_cos(float x) {
|
||||
return cos(x);
|
||||
}
|
||||
|
||||
float op_clamp(float x) {
|
||||
return clamp(x, p.param1, p.param2);
|
||||
}
|
||||
|
||||
float op_leaky_relu(float x) {
|
||||
return max(x, 0.0f) + min(x, 0.0f) * p.param1;
|
||||
}
|
||||
|
||||
float op_step(float x) {
|
||||
return x >= 0.0f ? 1.0f : 0.0f;
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <future>
|
||||
#include <queue>
|
||||
#include <condition_variable>
|
||||
#include <atomic>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <cstdlib>
|
||||
@@ -34,6 +35,9 @@
|
||||
|
||||
std::mutex lock;
|
||||
std::vector<std::pair<std::string, std::string>> shader_fnames;
|
||||
// Set when any shader subprocess fails (non-zero exit / stderr / launch failure) so the
|
||||
// build is stopped instead of silently producing a broken libggml-vulkan. (issue #24393)
|
||||
static std::atomic<bool> compile_failed{false};
|
||||
std::locale c_locale("C");
|
||||
|
||||
std::string GLSLC = "glslc";
|
||||
@@ -78,7 +82,7 @@ enum MatMulIdType {
|
||||
|
||||
namespace {
|
||||
|
||||
void execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
|
||||
int execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
|
||||
#ifdef _WIN32
|
||||
HANDLE stdout_read, stdout_write;
|
||||
HANDLE stderr_read, stderr_write;
|
||||
@@ -127,8 +131,11 @@ void execute_command(std::vector<std::string>& command, std::string& stdout_str,
|
||||
CloseHandle(stdout_read);
|
||||
CloseHandle(stderr_read);
|
||||
WaitForSingleObject(pi.hProcess, INFINITE);
|
||||
DWORD exit_code = 1;
|
||||
GetExitCodeProcess(pi.hProcess, &exit_code);
|
||||
CloseHandle(pi.hProcess);
|
||||
CloseHandle(pi.hThread);
|
||||
return (int)exit_code;
|
||||
#else
|
||||
int stdout_pipe[2];
|
||||
int stderr_pipe[2];
|
||||
@@ -175,7 +182,9 @@ void execute_command(std::vector<std::string>& command, std::string& stdout_str,
|
||||
|
||||
close(stdout_pipe[0]);
|
||||
close(stderr_pipe[0]);
|
||||
waitpid(pid, nullptr, 0);
|
||||
int status = 0;
|
||||
waitpid(pid, &status, 0);
|
||||
return WIFEXITED(status) ? WEXITSTATUS(status) : -1;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -372,13 +381,14 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
|
||||
// }
|
||||
// std::cout << std::endl;
|
||||
|
||||
execute_command(cmd, stdout_str, stderr_str);
|
||||
if (!stderr_str.empty()) {
|
||||
std::cerr << "cannot compile " << name << "\n\n";
|
||||
int exit_code = execute_command(cmd, stdout_str, stderr_str);
|
||||
if (exit_code != 0 || !stderr_str.empty()) {
|
||||
std::cerr << "cannot compile " << name << " (exit code " << exit_code << ")\n\n";
|
||||
for (const auto& part : cmd) {
|
||||
std::cerr << part << " ";
|
||||
}
|
||||
std::cerr << "\n\n" << stderr_str << std::endl;
|
||||
compile_failed = true;
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -398,6 +408,7 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
|
||||
shader_fnames.push_back(std::make_pair(name, out_path));
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
|
||||
compile_failed = true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -539,11 +550,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
};
|
||||
|
||||
// Shaders with f16 B_TYPE
|
||||
string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f32_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
|
||||
string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
|
||||
// bf16
|
||||
{
|
||||
@@ -565,8 +574,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
#endif
|
||||
{
|
||||
if (!dot2) {
|
||||
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPE_SCALAR", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -583,8 +591,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
}
|
||||
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
// For unaligned, load one at a time for f32/f16, or two at a time for quants
|
||||
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
|
||||
// For aligned matmul loads
|
||||
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
|
||||
|
||||
@@ -597,13 +603,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
|
||||
// don't generate f32 variants for coopmat2
|
||||
if (!coopmat2) {
|
||||
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPE_SCALAR", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
|
||||
if (tname != "f16" && tname != "f32") {
|
||||
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
@@ -850,21 +854,12 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("repeat_i32", "repeat.comp", {{"A_TYPE", "int32_t"}, {"D_TYPE", "int32_t"}});
|
||||
string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("get_rows_back_f32", "get_rows_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("repeat_i16", "repeat.comp", {{"A_TYPE", "int16_t"}, {"D_TYPE", "int16_t"}});
|
||||
|
||||
string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("sqrt_f32", "sqrt.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("concat_i8", "concat.comp", {{"A_TYPE", "uint8_t"}, {"B_TYPE", "uint8_t"}, {"D_TYPE", "uint8_t"}});
|
||||
@@ -891,6 +886,18 @@ void process_shaders() {
|
||||
string_to_spv("silu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_silu"}});
|
||||
string_to_spv("relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_relu"}});
|
||||
string_to_spv("relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_relu"}});
|
||||
string_to_spv("sqr_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqr"}});
|
||||
string_to_spv("sqr_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqr"}});
|
||||
string_to_spv("sqrt_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqrt"}});
|
||||
string_to_spv("sqrt_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqrt"}});
|
||||
string_to_spv("sin_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sin"}});
|
||||
string_to_spv("sin_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sin"}});
|
||||
string_to_spv("cos_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_cos"}});
|
||||
string_to_spv("cos_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_cos"}});
|
||||
string_to_spv("clamp_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_clamp"}});
|
||||
string_to_spv("clamp_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_clamp"}});
|
||||
string_to_spv("leaky_relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_leaky_relu"}});
|
||||
string_to_spv("leaky_relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_leaky_relu"}});
|
||||
string_to_spv("neg_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_neg"}});
|
||||
string_to_spv("neg_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_neg"}});
|
||||
string_to_spv("tanh_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_tanh"}});
|
||||
@@ -948,7 +955,6 @@ void process_shaders() {
|
||||
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
@@ -1060,6 +1066,31 @@ void process_shaders() {
|
||||
}
|
||||
}
|
||||
|
||||
for (auto unroll : {false, true}) {
|
||||
for (auto a_f16 : {false, true}) {
|
||||
std::map<std::string, std::string> defines = {
|
||||
{"A_TYPE", a_f16 ? "float16_t" : "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"},
|
||||
{"UNROLL", unroll ? "[[unroll]]" : ""},
|
||||
};
|
||||
std::string name = std::string("conv3d") + (a_f16 ? "_f16" : "") + "_f32";
|
||||
string_to_spv(name + (unroll ? "_unroll" : ""), "conv3d_mm.comp", defines);
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (unroll) {
|
||||
auto cm2_defines = defines;
|
||||
cm2_defines["COOPMAT2"] = "1";
|
||||
string_to_spv(name, "conv3d_mm.comp", cm2_defines, true, false, true);
|
||||
}
|
||||
#endif
|
||||
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (unroll) {
|
||||
auto cm1_defines = defines;
|
||||
cm1_defines["COOPMAT"] = "1";
|
||||
string_to_spv(name, "conv3d_mm.comp", cm1_defines, true, true, false);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
||||
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
|
||||
string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
||||
@@ -1251,6 +1282,11 @@ int main(int argc, char** argv) {
|
||||
|
||||
process_shaders();
|
||||
|
||||
if (compile_failed) {
|
||||
std::cerr << "vulkan-shaders-gen: one or more shaders failed to compile" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
write_output_files();
|
||||
|
||||
return EXIT_SUCCESS;
|
||||
|
||||
@@ -905,11 +905,12 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
ggml_type src1_type;
|
||||
int vectorized;
|
||||
uint32_t num_cols;
|
||||
bool use_mmvq;
|
||||
|
||||
bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
|
||||
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
|
||||
use_mmvq == other.use_mmvq;
|
||||
num_cols == other.num_cols && use_mmvq == other.use_mmvq;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -919,6 +920,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
|
||||
ggml_webgpu_hash_combine(seed, key.src0_type);
|
||||
ggml_webgpu_hash_combine(seed, key.src1_type);
|
||||
ggml_webgpu_hash_combine(seed, key.vectorized);
|
||||
ggml_webgpu_hash_combine(seed, key.num_cols);
|
||||
ggml_webgpu_hash_combine(seed, key.use_mmvq);
|
||||
return seed;
|
||||
}
|
||||
@@ -993,11 +995,12 @@ struct ggml_webgpu_mul_mat_id_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
ggml_type src1_type;
|
||||
uint32_t n_experts;
|
||||
uint32_t num_cols;
|
||||
int vectorized;
|
||||
|
||||
bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const {
|
||||
return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts &&
|
||||
vectorized == other.vectorized;
|
||||
num_cols == other.num_cols && vectorized == other.vectorized;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1007,6 +1010,7 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash {
|
||||
ggml_webgpu_hash_combine(seed, key.src0_type);
|
||||
ggml_webgpu_hash_combine(seed, key.src1_type);
|
||||
ggml_webgpu_hash_combine(seed, key.n_experts);
|
||||
ggml_webgpu_hash_combine(seed, key.num_cols);
|
||||
ggml_webgpu_hash_combine(seed, key.vectorized);
|
||||
return seed;
|
||||
}
|
||||
@@ -1107,7 +1111,7 @@ inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0,
|
||||
const ggml_tensor * src1,
|
||||
bool supports_dot_product,
|
||||
const std::string & vendor) {
|
||||
if (src1->ne[1] == 1) {
|
||||
if (src1->ne[1] <= 4) {
|
||||
bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia";
|
||||
if (supports_dp4a && supports_dot_product) {
|
||||
switch (src1->type) {
|
||||
@@ -1889,6 +1893,7 @@ class ggml_webgpu_shader_lib {
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
key.num_cols = context.dst->ne[1];
|
||||
key.use_mmvq =
|
||||
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
|
||||
|
||||
@@ -2004,6 +2009,7 @@ class ggml_webgpu_shader_lib {
|
||||
if (key.vectorized) {
|
||||
variant += "_vectorized";
|
||||
}
|
||||
defines.push_back(std::string("NUM_COLS=") + std::to_string(key.num_cols));
|
||||
|
||||
auto processed = preprocessor.preprocess(shader_src, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
|
||||
@@ -2421,6 +2427,7 @@ class ggml_webgpu_shader_lib {
|
||||
if (key.vectorized) {
|
||||
variant += "_vectorized";
|
||||
}
|
||||
defines.push_back(std::string("NUM_COLS=1"));
|
||||
|
||||
defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts));
|
||||
|
||||
|
||||
@@ -1418,15 +1418,17 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
|
||||
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
|
||||
const size_t q8_src1_align_offset = ROUNDUP_POW2(
|
||||
dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
const size_t q8_src1_binding_size =
|
||||
ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
const size_t q8_src1_binding_size = ROUNDUP_POW2(
|
||||
src1->ne[3] * src1->ne[2] * src1->ne[1] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
|
||||
std::vector<uint32_t> q8_params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
(uint32_t) src1->ne[0],
|
||||
(uint32_t) src1->ne[1],
|
||||
(uint32_t) src1->ne[2],
|
||||
(uint32_t) src1->ne[3],
|
||||
};
|
||||
@@ -1442,7 +1444,7 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
|
||||
uint32_t q8_wg_x = 1;
|
||||
uint32_t q8_wg_y = 1;
|
||||
const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size;
|
||||
const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec;
|
||||
const uint32_t q8_total_wg = src1->ne[1] * src1->ne[2] * src1->ne[3] * wg_per_vec;
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y);
|
||||
|
||||
@@ -1456,7 +1458,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
// Determine if this is a mat-vec operation
|
||||
bool is_vec = (dst->ne[1] == 1);
|
||||
bool use_mat_vec = (dst->ne[1] <= 4);
|
||||
|
||||
// use MMVQ path for mat-vec
|
||||
bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product,
|
||||
@@ -1482,7 +1484,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
webgpu_pipeline pipeline;
|
||||
std::vector<webgpu_dispatch_desc> dispatches;
|
||||
|
||||
if (is_vec) {
|
||||
if (use_mat_vec) {
|
||||
if (use_mmvq) {
|
||||
ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches);
|
||||
}
|
||||
@@ -1529,7 +1531,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
uint32_t wg_y = 1;
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
|
||||
if (is_vec) {
|
||||
if (use_mat_vec) {
|
||||
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t batches = dst->ne[2] * dst->ne[3];
|
||||
@@ -3691,8 +3693,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
||||
ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product,
|
||||
ctx->webgpu_global_ctx->vendor);
|
||||
if (use_mmvq) {
|
||||
const size_t q8_src1_size =
|
||||
src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
|
||||
const size_t q8_src1_size = src1->ne[3] * src1->ne[2] * src1->ne[1] *
|
||||
(36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
|
||||
res = ROUNDUP_POW2(res + q8_src1_size +
|
||||
ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
@@ -4268,7 +4270,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||
supports_op = (op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32) && ggml_is_contiguous_rows(src0);
|
||||
break;
|
||||
case GGML_OP_ROPE:
|
||||
supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
|
||||
|
||||
@@ -103,7 +103,7 @@ fn main(
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let subgroup_total = subgroupAdd(acc[row]);
|
||||
let subgroup_total = subgroupAdd(acc[0][row]);
|
||||
if (subgroup_invocation_id == 0u) {
|
||||
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
|
||||
}
|
||||
@@ -126,7 +126,7 @@ fn main(
|
||||
|
||||
#ifdef USE_WORKGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] = acc[row];
|
||||
partial_sums[partial_index(row, thread_id)] = acc[0][row];
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
@@ -91,61 +91,67 @@ fn main(
|
||||
let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base;
|
||||
|
||||
#ifdef MMVQ
|
||||
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u);
|
||||
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * params.n * (params.k / 32u);
|
||||
let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base);
|
||||
#else
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||
let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base);
|
||||
#endif
|
||||
|
||||
for (var col = 0u;col < NUM_COLS;col += 1) {
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let subgroup_total = subgroupAdd(acc[row]);
|
||||
if (subgroup_invocation_id == 0u) {
|
||||
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
|
||||
}
|
||||
}
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let subgroup_total = subgroupAdd(acc[col][row]);
|
||||
if (subgroup_invocation_id == 0u) {
|
||||
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
workgroupBarrier();
|
||||
|
||||
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
|
||||
let output_row = row_base + row;
|
||||
var row_acc = 0.0f;
|
||||
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
|
||||
row_acc += partial_sums[partial_index(row, k)];
|
||||
}
|
||||
let row_total = subgroupAdd(row_acc);
|
||||
if (subgroup_invocation_id == 0) {
|
||||
dst[dst_idx_base + row] = row_total;
|
||||
}
|
||||
}
|
||||
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
|
||||
let output_row = row_base + row;
|
||||
var row_acc = 0.0f;
|
||||
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
|
||||
row_acc += partial_sums[partial_index(row, k)];
|
||||
}
|
||||
let row_total = subgroupAdd(row_acc);
|
||||
if (subgroup_invocation_id == 0) {
|
||||
dst[dst_idx_base + col * params.m + row] = row_total;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_WORKGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] = acc[row];
|
||||
}
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] = acc[col][row];
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
var stride = WG_SIZE / 2u;
|
||||
|
||||
while (stride > 0) {
|
||||
if (thread_id < stride) {
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
stride = stride / 2;
|
||||
}
|
||||
|
||||
if (thread_id < OUTPUTS_PER_WG) {
|
||||
let output_row = row_base + thread_id;
|
||||
if (output_row < params.m) {
|
||||
dst[dst_idx_base + col * params.m + thread_id] = partial_sums[partial_index(thread_id, 0)];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
var stride = WG_SIZE / 2u;
|
||||
|
||||
while (stride > 0) {
|
||||
if (thread_id < stride) {
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
stride = stride / 2;
|
||||
}
|
||||
|
||||
if (thread_id < OUTPUTS_PER_WG) {
|
||||
let output_row = row_base + thread_id;
|
||||
if (output_row < params.m) {
|
||||
dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -51,10 +51,7 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE {
|
||||
fn get_dm(block_byte_base: u32) -> f32 {
|
||||
return f32(load_f16_at_src0(block_byte_base));
|
||||
}
|
||||
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
|
||||
return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
|
||||
}
|
||||
#endif
|
||||
#endif // MUL_ACC_Q4_0
|
||||
|
||||
#ifdef MUL_ACC_Q4_1
|
||||
#define BLOCK_SIZE_BYTES 20
|
||||
@@ -85,10 +82,7 @@ fn get_dm(block_byte_base: u32) -> vec2<f32> {
|
||||
f32(load_f16_at_src0(block_byte_base + 2u))
|
||||
);
|
||||
}
|
||||
fn mul_q8_1(row_sum: i32, dma: vec2<f32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK;
|
||||
}
|
||||
#endif
|
||||
#endif // MUL_ACC_Q4_1
|
||||
|
||||
#ifdef MUL_ACC_Q8_0
|
||||
#define BLOCK_SIZE_BYTES 34
|
||||
@@ -111,46 +105,48 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE {
|
||||
fn get_dm(block_byte_base: u32) -> f32 {
|
||||
return f32(load_f16_at_src0(block_byte_base));
|
||||
}
|
||||
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
|
||||
return f32(row_sum) * (da * b_ds);
|
||||
}
|
||||
#endif
|
||||
#endif // MUL_ACC_Q8_0
|
||||
|
||||
#ifdef LEGACY_QUANTS
|
||||
fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2<u32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
var row_sum = 0;
|
||||
let a_repacked = repack_a(a_byte_base, b_inner_id);
|
||||
|
||||
row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]);
|
||||
row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]);
|
||||
|
||||
return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds);
|
||||
}
|
||||
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
|
||||
var acc: array<f32, OUTPUTS_PER_WG>;
|
||||
#if defined(LEGACY_QUANTS)
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
|
||||
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
|
||||
let b_inner_id = thread_id % THREADS_PER_BLOCK;
|
||||
let b_block_idx = src1q_idx_base + block;
|
||||
|
||||
let b_repacked = repack_b_qs(b_block_idx, b_inner_id);
|
||||
let b_ds = repack_b_dm(b_block_idx);
|
||||
|
||||
let inner_id = thread_id % THREADS_PER_BLOCK;
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds);
|
||||
let a_repacked = repack_a(block_byte_base, inner_id);
|
||||
let da = get_dm(block_byte_base);
|
||||
for (var col = 0u;col < NUM_COLS;col += 1) {
|
||||
let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + block;
|
||||
let b_repacked = repack_b_qs(src1q_idx, inner_id);
|
||||
let b_ds = repack_b_dm(src1q_idx);
|
||||
|
||||
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]);
|
||||
|
||||
#if defined(MUL_ACC_Q4_0)
|
||||
acc[col][row] += f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
|
||||
#endif // MUL_ACC_Q4_0
|
||||
|
||||
#if defined(MUL_ACC_Q4_1)
|
||||
acc[col][row] += f32(row_sum) * (da.x * b_ds.x) + da.y * b_ds.y / THREADS_PER_BLOCK;
|
||||
#endif // MUL_ACC_Q4_1
|
||||
|
||||
#if defined(MUL_ACC_Q8_0)
|
||||
acc[col][row] += f32(row_sum) * (da * b_ds);
|
||||
#endif // MUL_ACC_Q8_0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
#endif
|
||||
#endif // LEGACY_QUANTS
|
||||
|
||||
#ifdef MUL_ACC_Q2_K
|
||||
#define BLOCK_SIZE_BYTES 84
|
||||
@@ -191,22 +187,7 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
|
||||
let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u);
|
||||
return vec2<f32>(f32(scale & 0xFu), f32(scale >> 4u));
|
||||
}
|
||||
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
let a_repacked = repack_a(a_byte_base, tid);
|
||||
let dm = get_dm(a_byte_base);
|
||||
let scale_min = get_scale_min(a_byte_base, tid);
|
||||
|
||||
let scale_q = i32(scale_min.x);
|
||||
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
|
||||
|
||||
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
|
||||
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
|
||||
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
|
||||
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
|
||||
|
||||
return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
|
||||
}
|
||||
#endif
|
||||
#endif // MUL_ACC_Q2_K
|
||||
|
||||
#ifdef MUL_ACC_Q4_K
|
||||
#define BLOCK_SIZE_BYTES 144
|
||||
@@ -265,39 +246,52 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
|
||||
|
||||
return vec2<f32>(scale, min_val);
|
||||
}
|
||||
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
let a_repacked = repack_a(a_byte_base, tid);
|
||||
let dm = get_dm(a_byte_base);
|
||||
let scale_min = get_scale_min(a_byte_base, tid);
|
||||
|
||||
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
|
||||
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
|
||||
|
||||
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
|
||||
return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
|
||||
}
|
||||
#endif
|
||||
#endif // MUL_ACC_Q4_K
|
||||
|
||||
#ifdef K_QUANTS
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
|
||||
var acc: array<f32, OUTPUTS_PER_WG>;
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
|
||||
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
|
||||
for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) {
|
||||
let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
|
||||
let b_repacked = repack_b_qs(src1q_idx, tid);
|
||||
let b_ds = repack_b_dm(src1q_idx);
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds);
|
||||
let a_repacked = repack_a(block_byte_base, tid);
|
||||
let dm = get_dm(block_byte_base);
|
||||
let scale_min = get_scale_min(block_byte_base, tid);
|
||||
for (var col = 0u;col < NUM_COLS;col += 1) {
|
||||
let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
|
||||
let b_repacked = repack_b_qs(src1q_idx, tid);
|
||||
let b_ds = repack_b_dm(src1q_idx);
|
||||
|
||||
#if defined(MUL_ACC_Q2_K)
|
||||
let scale_q = i32(scale_min.x);
|
||||
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
|
||||
|
||||
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
|
||||
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
|
||||
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
|
||||
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
|
||||
|
||||
acc[col][row] += b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
|
||||
#endif // MUL_ACC_Q2_K
|
||||
|
||||
#if defined(MUL_ACC_Q4_K)
|
||||
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
|
||||
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
|
||||
|
||||
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
|
||||
acc[col][row] += b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
|
||||
#endif // MUL_ACC_Q4_K
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
#endif
|
||||
#endif // K_QUANTS
|
||||
|
||||
@@ -9,9 +9,11 @@ requires packed_4x8_integer_dot_product;
|
||||
|
||||
struct Params {
|
||||
offset_src1: u32,
|
||||
stride_11: u32,
|
||||
stride_12: u32,
|
||||
stride_13: u32,
|
||||
ne0: u32,
|
||||
ne1: u32,
|
||||
ne2: u32,
|
||||
ne3: u32,
|
||||
};
|
||||
@@ -57,25 +59,28 @@ fn main(
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>
|
||||
) {
|
||||
let thread_id = local_id.x;
|
||||
let num_vec4 = params.ne0 / 4u;
|
||||
let ne0_vec4 = params.ne0 / 4u;
|
||||
|
||||
let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
|
||||
let total_batches = wg_per_vec * params.ne2 * params.ne3;
|
||||
let wg_per_vec = (ne0_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
|
||||
let total_batches = wg_per_vec * params.ne1 * params.ne2 * params.ne3;
|
||||
|
||||
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
|
||||
if (wg_linear >= total_batches) {
|
||||
return;
|
||||
}
|
||||
|
||||
let src13_idx = wg_linear / (params.ne2 * wg_per_vec);
|
||||
let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec;
|
||||
let src11_wg_idx = wg_linear % wg_per_vec;
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||
let vec_idx = wg_linear / wg_per_vec;
|
||||
let src13_idx = vec_idx / (params.ne2 * params.ne1);
|
||||
let vec_ne12_num = vec_idx % (params.ne2 * params.ne1);
|
||||
let src12_idx = vec_ne12_num / params.ne1;
|
||||
let src11_idx = vec_ne12_num % params.ne1;
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + src11_idx * params.stride_11;
|
||||
let src1_idx_vec4_base = src1_idx_base / 4u;
|
||||
|
||||
let blocks_per_row = params.ne0 / 32u;
|
||||
let blocks_per_wg = (WG_SIZE * 4u) / 32u;
|
||||
let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row;
|
||||
let src1q_idx_base = ((src13_idx * params.ne2 + src12_idx) * params.ne1 + src11_idx) * blocks_per_row;
|
||||
let src11_wg_idx = wg_linear % wg_per_vec;
|
||||
let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u;
|
||||
let qs_idx = thread_id % 8u;
|
||||
|
||||
@@ -85,7 +90,7 @@ fn main(
|
||||
var thread_amax = 0.0;
|
||||
|
||||
let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id;
|
||||
let is_valid = src11_vec4_idx < num_vec4;
|
||||
let is_valid = src11_vec4_idx < ne0_vec4;
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
|
||||
|
||||
@@ -359,6 +359,7 @@ class Keys:
|
||||
CHUNK_SIZE = "clip.audio.chunk_size"
|
||||
CONV_KERNEL_SIZE = "clip.audio.conv_kernel_size"
|
||||
MAX_POS_EMB = "clip.audio.max_pos_emb"
|
||||
FEATURE_LAYERS = "clip.audio.feature_layer" # Granite Speech Plus
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "clip.audio.attention.head_count"
|
||||
|
||||
@@ -1310,6 +1310,9 @@ class GGUFWriter:
|
||||
def add_audio_max_pos_emb(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.MAX_POS_EMB, value)
|
||||
|
||||
def add_audio_feature_layers(self, layers: Sequence[int]) -> None:
|
||||
self.add_array(Keys.ClipAudio.FEATURE_LAYERS, layers)
|
||||
|
||||
def add_audio_projector_window_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.Projector.WINDOW_SIZE, value)
|
||||
|
||||
|
||||
+14
-4
@@ -190,7 +190,15 @@ llama_model_lfm2::graph<iswa>::graph(const llama_model & model, const llm_graph_
|
||||
auto * conv_rs = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs);
|
||||
auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);
|
||||
|
||||
bx = ggml_concat(ctx0, conv, bx, 0);
|
||||
// causal prepends the state, non-causal pads symmetrically for a centered window
|
||||
if (hparams.causal_attn) {
|
||||
bx = ggml_concat(ctx0, conv, bx, 0);
|
||||
} else {
|
||||
const int64_t pad = (hparams.n_shortconv_l_cache - 1) / 2;
|
||||
auto * left = ggml_cont(ctx0,
|
||||
ggml_view_3d(ctx0, conv, pad, hparams.n_embd, n_seqs, conv->nb[1], conv->nb[2], (d_conv - pad) * conv->nb[0]));
|
||||
bx = ggml_pad_ext(ctx0, ggml_concat(ctx0, left, bx, 0), 0, pad, 0, 0, 0, 0, 0, 0);
|
||||
}
|
||||
GGML_ASSERT(bx->ne[0] > conv->ne[0]);
|
||||
|
||||
// last d_conv columns is a new conv state
|
||||
@@ -266,10 +274,12 @@ llama_model_lfm2::graph<iswa>::graph(const llama_model & model, const llm_graph_
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
cur = build_lora_mm(model.output, cur, model.output_s);
|
||||
cb(cur, "result_output", -1);
|
||||
if (!cparams.embeddings) {
|
||||
cur = build_lora_mm(model.output, cur, model.output_s);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
res->t_logits = cur;
|
||||
res->t_logits = cur;
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
@@ -3298,21 +3298,29 @@ struct test_norm : public test_case {
|
||||
const std::array<int64_t, 4> ne;
|
||||
const bool v; // whether a is a non-contiguous view
|
||||
const float eps;
|
||||
const bool noncontig_rows;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, ne, v, eps);
|
||||
return VARS_TO_STR5(type, ne, v, eps, noncontig_rows);
|
||||
}
|
||||
|
||||
test_norm(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {64, 5, 4, 3},
|
||||
bool v = false,
|
||||
float eps = 1e-6f)
|
||||
: type(type), ne(ne), v(v), eps(eps) {}
|
||||
float eps = 1e-6f,
|
||||
bool noncontig_rows = false)
|
||||
: type(type), ne(ne), v(v), eps(eps), noncontig_rows(noncontig_rows) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
const std::array<int64_t, 4> ne_a = noncontig_rows ?
|
||||
std::array<int64_t, 4>{ ne[1], ne[0], ne[2], ne[3] } : ne;
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
if (noncontig_rows) {
|
||||
a = ggml_permute(ctx, a, 1, 0, 2, 3);
|
||||
ggml_set_name(a, "permuted a");
|
||||
}
|
||||
if (v) {
|
||||
a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
|
||||
ggml_set_name(a, "view of a");
|
||||
@@ -6193,21 +6201,29 @@ struct test_l2_norm : public test_case {
|
||||
const std::array<int64_t, 4> ne;
|
||||
const float eps;
|
||||
bool v;
|
||||
bool noncontig_rows;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, ne, eps, v);
|
||||
return VARS_TO_STR5(type, ne, eps, v, noncontig_rows);
|
||||
}
|
||||
|
||||
test_l2_norm(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {64, 64, 320, 1},
|
||||
float eps = 1e-12f,
|
||||
bool v = false)
|
||||
: type(type), ne(ne), eps(eps), v(v) {}
|
||||
bool v = false,
|
||||
bool noncontig_rows = false)
|
||||
: type(type), ne(ne), eps(eps), v(v), noncontig_rows(noncontig_rows) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
const std::array<int64_t, 4> ne_a = noncontig_rows ?
|
||||
std::array<int64_t, 4>{ ne[1], ne[0], ne[2], ne[3] } : ne;
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
if (noncontig_rows) {
|
||||
a = ggml_permute(ctx, a, 1, 0, 2, 3);
|
||||
ggml_set_name(a, "permuted a");
|
||||
}
|
||||
if (v) {
|
||||
a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
|
||||
ggml_set_name(a, "view of a");
|
||||
@@ -8282,9 +8298,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
|
||||
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
|
||||
}
|
||||
test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, false, eps, true));
|
||||
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
|
||||
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
|
||||
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
|
||||
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false, true));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8433,6 +8451,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {1, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {3, 2}, {2, 2}));
|
||||
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {2, 1}));
|
||||
@@ -8449,6 +8468,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
|
||||
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
|
||||
@@ -9270,6 +9290,34 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||
}
|
||||
}
|
||||
|
||||
struct conv3d_perf_case {
|
||||
int N, IC, ID, IH, IW, OC, KD, KH, KW, s0, s1, s2, p0, p1, p2, d0, d1, d2;
|
||||
};
|
||||
|
||||
const std::vector<conv3d_perf_case> conv3d_cases = {
|
||||
{1, 320, 8, 38, 26, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{1, 1280, 8, 38, 26, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{1, 320, 8, 76, 52, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{1, 1280, 8, 76, 52, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{1, 320, 8, 152, 104, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
#if 0
|
||||
// too slow on some devices
|
||||
{1, 1280, 8, 152, 104, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{1, 320, 4, 304, 208, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{1, 640, 4, 304, 208, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
#endif
|
||||
};
|
||||
|
||||
for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
for (const conv3d_perf_case & c : conv3d_cases) {
|
||||
test_cases.emplace_back(new test_conv_3d(
|
||||
c.N, c.IC, c.ID, c.IH, c.IW,
|
||||
c.OC, c.KD, c.KH, c.KW,
|
||||
c.s0, c.s1, c.s2, c.p0, c.p1, c.p2, c.d0, c.d1, c.d2,
|
||||
kernel_type));
|
||||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));
|
||||
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
|
||||
|
||||
|
||||
+99
-24
@@ -1562,37 +1562,112 @@ static void test_msgs_oaicompat_json_conversion() {
|
||||
}
|
||||
}
|
||||
|
||||
static void test_split_by_role() {
|
||||
static void test_msg_token_delimiters_split() {
|
||||
LOG_DBG("%s\n", __func__);
|
||||
|
||||
// Delimiters that share a leading token, distinguished by the second token,
|
||||
// to exercise the per-position token matching.
|
||||
const common_chat_msg_delimiters delims = {
|
||||
{ { COMMON_CHAT_ROLE_USER, "", { 10, 11 } },
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "", { 10, 12 } } }
|
||||
};
|
||||
|
||||
// Empty inputs
|
||||
assert_equals<size_t>(0, common_chat_split_by_role("", {}).size());
|
||||
assert_equals<size_t>(0, common_chat_split_by_role("hello", {}).size());
|
||||
assert_equals<size_t>(0, common_chat_split_by_role("", { { "user", "<|user|>" } }).size());
|
||||
assert_equals<size_t>(0, common_chat_msg_delimiters{}.split({}).spans.size());
|
||||
assert_equals<size_t>(0, common_chat_msg_delimiters{}.split({ 10, 11 }).spans.size());
|
||||
assert_equals<size_t>(0, delims.split({}).spans.size());
|
||||
|
||||
// Multi-role conversation, no leading/trailing content
|
||||
// No delimiters match -> no spans
|
||||
assert_equals<size_t>(0, delims.split({ 100, 101, 102 }).spans.size());
|
||||
|
||||
// Multi-role conversation: <user>Hi<assistant>Hello<user>Bye
|
||||
{
|
||||
const std::string prompt = "<|user|>Hi<|assistant|>Hello<|user|>Bye";
|
||||
const auto splits = common_chat_split_by_role(prompt, {
|
||||
{ "user", "<|user|>" },
|
||||
{ "assistant", "<|assistant|>" },
|
||||
});
|
||||
assert_equals<size_t>(3, splits.size());
|
||||
const llama_tokens tokens = {
|
||||
10, 11, // <user>
|
||||
100, 101, // Hi
|
||||
10, 12, // <assistant>
|
||||
200, 201, 202, // Hello
|
||||
10, 11, // <user>
|
||||
300, 301, // Bye
|
||||
};
|
||||
|
||||
assert_equals<std::string>("user", splits[0].role);
|
||||
assert_equals<size_t>(0, splits[0].pos);
|
||||
assert_equals<size_t>(10, splits[0].len);
|
||||
assert_equals<std::string>("<|user|>Hi", prompt.substr(splits[0].pos, splits[0].len));
|
||||
const auto result = delims.split(tokens);
|
||||
const auto & spans = result.spans;
|
||||
assert_equals<size_t>(3, spans.size());
|
||||
|
||||
assert_equals<std::string>("assistant", splits[1].role);
|
||||
assert_equals<size_t>(10, splits[1].pos);
|
||||
assert_equals<size_t>(18, splits[1].len);
|
||||
assert_equals<std::string>("<|assistant|>Hello", prompt.substr(splits[1].pos, splits[1].len));
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
|
||||
assert_equals<size_t>(0, spans[0].pos);
|
||||
assert_equals<size_t>(4, spans[0].len);
|
||||
|
||||
assert_equals<std::string>("user", splits[2].role);
|
||||
assert_equals<size_t>(28, splits[2].pos);
|
||||
assert_equals<size_t>(11, splits[2].len);
|
||||
assert_equals<std::string>("<|user|>Bye", prompt.substr(splits[2].pos, splits[2].len));
|
||||
assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role);
|
||||
assert_equals<size_t>(4, spans[1].pos);
|
||||
assert_equals<size_t>(5, spans[1].len);
|
||||
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[2].role);
|
||||
assert_equals<size_t>(9, spans[2].pos);
|
||||
assert_equals<size_t>(4, spans[2].len);
|
||||
|
||||
// is_user_start() is true at the token position where a user span begins
|
||||
assert_equals(true, result.is_user_start(0));
|
||||
assert_equals(false, result.is_user_start(4)); // assistant span
|
||||
assert_equals(true, result.is_user_start(9));
|
||||
}
|
||||
|
||||
// Content before the first delimiter is not captured as a span
|
||||
{
|
||||
const llama_tokens tokens = {
|
||||
500, 501, // leading content (dropped)
|
||||
10, 11, // <user>
|
||||
100, // Hi
|
||||
};
|
||||
|
||||
const auto spans = delims.split(tokens).spans;
|
||||
assert_equals<size_t>(1, spans.size());
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
|
||||
assert_equals<size_t>(2, spans[0].pos);
|
||||
assert_equals<size_t>(3, spans[0].len);
|
||||
}
|
||||
|
||||
// Skipped regions (media chunks) are jumped over but still count as span content
|
||||
{
|
||||
const llama_tokens tokens = {
|
||||
10, 11, // <user>
|
||||
LLAMA_TOKEN_NULL, // media chunk (3 tokens)
|
||||
LLAMA_TOKEN_NULL,
|
||||
LLAMA_TOKEN_NULL,
|
||||
100, // Hi
|
||||
10, 12, // <assistant>
|
||||
};
|
||||
|
||||
const std::map<size_t, size_t> skips = { { 2, 3 } };
|
||||
|
||||
const auto spans = delims.split(tokens, skips).spans;
|
||||
assert_equals<size_t>(2, spans.size());
|
||||
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
|
||||
assert_equals<size_t>(0, spans[0].pos);
|
||||
assert_equals<size_t>(6, spans[0].len);
|
||||
|
||||
assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role);
|
||||
assert_equals<size_t>(6, spans[1].pos);
|
||||
assert_equals<size_t>(2, spans[1].len);
|
||||
}
|
||||
|
||||
// A delimiter sequence inside a skipped region is not matched
|
||||
{
|
||||
const llama_tokens tokens = {
|
||||
10, 11, // <user>
|
||||
10, 12, // skipped region that happens to contain delimiter tokens
|
||||
100, // Hi
|
||||
};
|
||||
|
||||
const std::map<size_t, size_t> skips = { { 2, 2 } };
|
||||
|
||||
const auto spans = delims.split(tokens, skips).spans;
|
||||
assert_equals<size_t>(1, spans.size());
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
|
||||
assert_equals<size_t>(0, spans[0].pos);
|
||||
assert_equals<size_t>(5, spans[0].len);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5857,7 +5932,7 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
test_msg_diffs_compute();
|
||||
test_msgs_oaicompat_json_conversion();
|
||||
test_split_by_role();
|
||||
test_msg_token_delimiters_split();
|
||||
test_tools_oaicompat_json_conversion();
|
||||
test_convert_responses_to_chatcmpl();
|
||||
test_developer_role_to_system_workaround();
|
||||
|
||||
@@ -42,6 +42,7 @@
|
||||
#define KEY_N_HEAD "clip.%s.attention.head_count"
|
||||
#define KEY_N_HEAD_KV "clip.%s.attention.head_count_kv"
|
||||
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
|
||||
#define KEY_FEATURE_LAYERS "clip.%s.feature_layer"
|
||||
|
||||
// vision-specific
|
||||
#define KEY_VISION_PROJ_TYPE "clip.vision.projector_type" // for models with mixed modalities
|
||||
@@ -54,7 +55,6 @@
|
||||
#define KEY_PATCH_SIZE "clip.vision.patch_size"
|
||||
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
|
||||
#define KEY_IMAGE_STD "clip.vision.image_std"
|
||||
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
|
||||
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
|
||||
#define KEY_PROJ_SAMPLE_QUERY_SIDE "clip.vision.projector.query_side"
|
||||
#define KEY_PROJ_SAMPLE_WINDOW_SIDE "clip.vision.projector.window_side"
|
||||
|
||||
@@ -91,7 +91,7 @@ struct clip_hparams {
|
||||
|
||||
float eps = 1e-6;
|
||||
float rope_theta = 0.0;
|
||||
std::vector<int32_t> vision_feature_layer;
|
||||
std::vector<int32_t> feature_layers;
|
||||
int32_t attn_window_size = 0;
|
||||
int32_t n_wa_pattern = 0;
|
||||
std::unordered_set<int32_t> wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL)
|
||||
@@ -165,8 +165,8 @@ struct clip_hparams {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_vision_feature_layer(int32_t layer) const {
|
||||
return std::find(vision_feature_layer.begin(), vision_feature_layer.end(), layer) != vision_feature_layer.end();
|
||||
bool is_feature_layer(int32_t layer) const {
|
||||
return std::find(feature_layers.begin(), feature_layers.end(), layer) != feature_layers.end();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
+9
-10
@@ -1264,12 +1264,10 @@ struct clip_model_loader {
|
||||
}
|
||||
}
|
||||
|
||||
// Load the vision feature layer indices if they are explicitly provided;
|
||||
// if multiple vision feature layers are present, the values will be concatenated
|
||||
// to form the final visual features.
|
||||
// Load the vision/audio feature layer indices if they are explicitly provided
|
||||
// NOTE: gguf conversions should standardize the values of the vision feature layer to
|
||||
// be non-negative, since we use -1 to mark values as unset here.
|
||||
get_arr_int(KEY_FEATURE_LAYER, hparams.vision_feature_layer, false);
|
||||
get_arr_int(string_format(KEY_FEATURE_LAYERS, prefix), hparams.feature_layers, false);
|
||||
|
||||
// model-specific params
|
||||
switch (model.proj_type) {
|
||||
@@ -1651,6 +1649,7 @@ struct clip_model_loader {
|
||||
get_u32(KEY_A_PROJ_WINDOW_SIZE, hparams.audio_proj_window_size);
|
||||
get_u32(KEY_A_PROJ_DOWNSAMPLE_RATE, hparams.audio_proj_downsample_rate);
|
||||
get_u32(KEY_A_PROJ_HEAD_COUNT, hparams.audio_proj_head_count);
|
||||
// NOTE: feature layers loaded above in common path
|
||||
} break;
|
||||
case PROJECTOR_TYPE_JANUS_PRO:
|
||||
{
|
||||
@@ -1663,11 +1662,11 @@ struct clip_model_loader {
|
||||
hparams.image_resize_algo = RESIZE_ALGO_BICUBIC_PILLOW;
|
||||
hparams.image_resize_pad = PAD_CEIL;
|
||||
|
||||
get_arr_int(KEY_FEATURE_LAYER, hparams.vision_feature_layer);
|
||||
// NOTE: feature_layers loaded in common path as optional
|
||||
get_arr_int(KEY_PROJ_SPATIAL_OFFSETS, hparams.proj_spatial_offsets);
|
||||
if (hparams.vision_feature_layer.size() != hparams.proj_spatial_offsets.size()) {
|
||||
throw std::runtime_error(string_format("%s: vision_feature_layer.size() %d != proj_spatial_offsets.size() %d",
|
||||
hparams.vision_feature_layer.size(), hparams.proj_spatial_offsets.size()));
|
||||
if (hparams.feature_layers.size() != hparams.proj_spatial_offsets.size()) {
|
||||
throw std::runtime_error(string_format("%s: feature_layers.size() %d != proj_spatial_offsets.size() %d",
|
||||
hparams.feature_layers.size(), hparams.proj_spatial_offsets.size()));
|
||||
}
|
||||
|
||||
get_u32(KEY_PROJ_SAMPLE_QUERY_SIDE, hparams.downsample_query_side);
|
||||
@@ -2740,7 +2739,7 @@ struct clip_model_loader {
|
||||
model.image_newline = get_tensor(TN_IMAGE_NEWLINE);
|
||||
|
||||
// Load separate layerwise and spatial projector tensors
|
||||
const auto projector_count = hparams.vision_feature_layer.size();
|
||||
const auto projector_count = hparams.feature_layers.size();
|
||||
model.qf_proj_blocks.resize(projector_count);
|
||||
for (size_t bid = 0; bid < projector_count; ++bid) {
|
||||
auto & b = model.qf_proj_blocks[bid];
|
||||
@@ -4388,7 +4387,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, int n_threads, const clip_image_f32
|
||||
|
||||
// Stage 1b only uses block 0's permutations; future stages
|
||||
// will upload all blocks.
|
||||
for (size_t bid = 0; bid < hparams.vision_feature_layer.size(); ++bid) {
|
||||
for (size_t bid = 0; bid < hparams.feature_layers.size(); ++bid) {
|
||||
const std::string prefix = "g4v_blk" + std::to_string(bid) + "_";
|
||||
upload(prefix + "win_idx", make_win_idx(image_side, window_side));
|
||||
upload(prefix + "qwin_idx", make_win_idx(new_side, query_side));
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#include "models.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
ggml_cgraph * clip_graph_granite_speech::build() {
|
||||
const int n_frames = img.nx();
|
||||
const int context_size = hparams.audio_chunk_size;
|
||||
@@ -11,6 +13,10 @@ ggml_cgraph * clip_graph_granite_speech::build() {
|
||||
const int padded_len = num_blocks * context_size;
|
||||
const int remainder = n_frames % context_size;
|
||||
|
||||
// Calculate projector input dimension based on feature layers
|
||||
const int proj_input_dim = n_embd * (hparams.feature_layers.size() + 1);
|
||||
const bool use_feature_concat = !hparams.feature_layers.empty();
|
||||
|
||||
ggml_tensor * attn_dists = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, context_size * context_size);
|
||||
ggml_set_name(attn_dists, "attn_dists");
|
||||
ggml_set_input(attn_dists);
|
||||
@@ -31,6 +37,15 @@ ggml_cgraph * clip_graph_granite_speech::build() {
|
||||
cur = ggml_add(ctx0, cur, model.inp_proj_b);
|
||||
cb(cur, "inp_linear", -1);
|
||||
|
||||
// Capture layer 0 if requested (after input_linear)
|
||||
ggml_tensor * concat_result = nullptr;
|
||||
if (use_feature_concat) {
|
||||
if (std::find(hparams.feature_layers.begin(), hparams.feature_layers.end(), 0) != hparams.feature_layers.end()) {
|
||||
concat_result = cur;
|
||||
cb(concat_result, "feature_layer_0", -1);
|
||||
}
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; il++) {
|
||||
const auto & layer = model.layers[il];
|
||||
auto * residual = cur;
|
||||
@@ -168,6 +183,18 @@ ggml_cgraph * clip_graph_granite_speech::build() {
|
||||
NORM_TYPE_NORMAL, eps, il);
|
||||
cb(cur, "layer_out", il);
|
||||
|
||||
// Capture intermediate layer (il + 1) if requested
|
||||
if (use_feature_concat) {
|
||||
if (hparams.is_feature_layer(il + 1)) {
|
||||
if (concat_result == nullptr) {
|
||||
concat_result = cur;
|
||||
} else {
|
||||
concat_result = ggml_concat(ctx0, concat_result, cur, 0);
|
||||
}
|
||||
cb(concat_result, string_format("feature_layer_%d", il + 1).c_str(), il);
|
||||
}
|
||||
}
|
||||
|
||||
// CTC branch
|
||||
if (il + 1 == ctc_layer) {
|
||||
auto * mid = build_mm(model.ctc_out_w, cur);
|
||||
@@ -180,6 +207,13 @@ ggml_cgraph * clip_graph_granite_speech::build() {
|
||||
}
|
||||
}
|
||||
|
||||
// Append final output to concatenated features if using feature concatenation
|
||||
if (use_feature_concat && concat_result != nullptr) {
|
||||
concat_result = ggml_concat(ctx0, concat_result, cur, 0);
|
||||
cb(concat_result, "concat_final", -1);
|
||||
cur = concat_result;
|
||||
}
|
||||
|
||||
cb(cur, "encoder_out", -1);
|
||||
|
||||
// QFormer projector
|
||||
@@ -197,7 +231,7 @@ ggml_cgraph * clip_graph_granite_speech::build() {
|
||||
cur = ggml_pad(ctx0, cur, 0, padded_proj - n_frames, 0, 0);
|
||||
}
|
||||
|
||||
ggml_tensor * enc_windows = ggml_reshape_3d(ctx0, cur, n_embd, window_size, nblocks_proj);
|
||||
ggml_tensor * enc_windows = ggml_reshape_3d(ctx0, cur, proj_input_dim, window_size, nblocks_proj);
|
||||
|
||||
ggml_tensor * queries = build_norm(model.qf_proj_blocks[0].qf_proj_query,
|
||||
model.qf_proj_blocks[0].qf_proj_norm_w, model.qf_proj_blocks[0].qf_proj_norm_b,
|
||||
|
||||
@@ -304,14 +304,14 @@ ggml_cgraph * clip_graph_granite4_vision::build() {
|
||||
}
|
||||
|
||||
// --- Stage 1b/1c: WindowQFormer blocks ---
|
||||
const int projector_count = hparams.vision_feature_layer.size();
|
||||
const int projector_count = hparams.feature_layers.size();
|
||||
const float qformer_eps = 1e-12f;
|
||||
|
||||
ggml_tensor * mmproj = nullptr;
|
||||
for (int bid = 0; bid < projector_count; ++bid) {
|
||||
const auto & blk = model.qf_proj_blocks[bid];
|
||||
|
||||
int vlayer = hparams.vision_feature_layer[bid];
|
||||
int vlayer = hparams.feature_layers[bid];
|
||||
GGML_ASSERT(vlayer >= 0 && vlayer < n_layer);
|
||||
ggml_tensor * h = layer_outs[vlayer];
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ ggml_cgraph * clip_graph_llava::build() {
|
||||
|
||||
// If we set explicit vision feature layers, only go up to the deepest one
|
||||
// NOTE: only used by granite-vision models for now
|
||||
for (const auto & feature_layer : hparams.vision_feature_layer) {
|
||||
for (const auto & feature_layer : hparams.feature_layers) {
|
||||
if (feature_layer > deepest_feature_layer) {
|
||||
deepest_feature_layer = feature_layer;
|
||||
}
|
||||
@@ -59,7 +59,7 @@ ggml_cgraph * clip_graph_llava::build() {
|
||||
|
||||
// If this is an embedding feature layer, save the output.
|
||||
// NOTE: 0 index here refers to the input to the encoder.
|
||||
if (hparams.is_vision_feature_layer(il)) {
|
||||
if (hparams.is_feature_layer(il)) {
|
||||
embedding_stack.push_back(cur);
|
||||
}
|
||||
|
||||
@@ -134,7 +134,7 @@ ggml_cgraph * clip_graph_llava::build() {
|
||||
// process vision feature layers (used by granite)
|
||||
{
|
||||
// final layer is a vision feature layer
|
||||
if (hparams.is_vision_feature_layer(max_feature_layer)) {
|
||||
if (hparams.is_feature_layer(max_feature_layer)) {
|
||||
embedding_stack.push_back(inpL);
|
||||
}
|
||||
|
||||
|
||||
@@ -518,6 +518,14 @@ size_t server_tokens::get_common_prefix(const server_tokens & b) const {
|
||||
return max_idx; // all tokens are equal
|
||||
}
|
||||
|
||||
common_chat_msg_spans server_tokens::find_message_spans(const common_chat_msg_delimiters & delims) const {
|
||||
std::map<size_t, size_t> skips;
|
||||
for (const auto & it : map_idx_to_media) {
|
||||
skips[it.first] = mtmd_input_chunk_get_n_tokens(it.second.get());
|
||||
}
|
||||
return delims.split(tokens, skips);
|
||||
}
|
||||
|
||||
bool server_tokens::validate(const struct llama_context * ctx) const {
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
@@ -1104,15 +1112,7 @@ json oaicompat_chat_params_parse(
|
||||
llama_params["chat_parser"] = chat_params.parser;
|
||||
}
|
||||
|
||||
llama_params["message_spans"] = json::array();
|
||||
|
||||
for (const auto & span : chat_params.message_spans) {
|
||||
llama_params["message_spans"].push_back({
|
||||
{ "role", span.role },
|
||||
{ "pos", span.pos },
|
||||
{ "len", span.len },
|
||||
});
|
||||
}
|
||||
llama_params["message_delimiters"] = chat_params.message_delimiters.to_json();
|
||||
|
||||
// Reasoning budget: pass parameters through to sampling layer
|
||||
{
|
||||
|
||||
@@ -218,6 +218,9 @@ public:
|
||||
|
||||
size_t get_common_prefix(const server_tokens & b) const;
|
||||
|
||||
// split the tokens into message spans, skipping over media chunks
|
||||
common_chat_msg_spans find_message_spans(const common_chat_msg_delimiters & delims) const;
|
||||
|
||||
// make sure all text tokens are within the vocab range
|
||||
bool validate(const struct llama_context * ctx) const;
|
||||
|
||||
|
||||
@@ -89,7 +89,9 @@ struct server_batch {
|
||||
}
|
||||
|
||||
~server_batch() {
|
||||
llama_batch_free(batch);
|
||||
if (batch.token != nullptr) {
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
}
|
||||
|
||||
void init(int32_t n_tokens_alloc) {
|
||||
@@ -1215,6 +1217,10 @@ private:
|
||||
cparams.ctx_other = ctx_tgt;
|
||||
|
||||
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
|
||||
if (ctx_dft == nullptr) {
|
||||
SRV_ERR("%s", "failed to create draft context\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
params_base.speculative.draft.ctx_tgt = ctx_tgt;
|
||||
params_base.speculative.draft.ctx_dft = ctx_dft.get();
|
||||
@@ -3436,8 +3442,8 @@ private:
|
||||
has_mtmd = true;
|
||||
}
|
||||
|
||||
const int32_t n_before_user = slot.task->params.n_before_user;
|
||||
const bool n_before_user_known = n_before_user > 0;
|
||||
const auto & spans = slot.task->params.message_spans;
|
||||
const auto last_user_pos = spans.last_user_message_pos();
|
||||
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.size() < n_batch) {
|
||||
@@ -3466,10 +3472,8 @@ private:
|
||||
|
||||
slot.n_prompt_tokens_processed++;
|
||||
|
||||
// stop the prompt batch exactly before the latest user input, so a checkpoint
|
||||
// can be created after the previous messages
|
||||
if (n_before_user_known &&
|
||||
slot.prompt.n_tokens() == n_before_user) {
|
||||
// stop the prompt batch exactly before a user message
|
||||
if (spans.is_user_start(slot.prompt.n_tokens())) {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -3498,8 +3502,13 @@ private:
|
||||
// the number of tokens added to the batch for the current slot
|
||||
const auto n_tokens_cur = batch.size() - n_tokens_prev;
|
||||
|
||||
const auto n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur;
|
||||
|
||||
const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch;
|
||||
|
||||
const bool is_user_start = spans.is_user_start(n_tokens_start);
|
||||
const bool is_last_user_message = n_tokens_start == last_user_pos;
|
||||
|
||||
// entire prompt has been processed
|
||||
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
|
||||
slot.state = SLOT_STATE_DONE_PROMPT;
|
||||
@@ -3514,8 +3523,9 @@ private:
|
||||
|
||||
slot.init_sampler();
|
||||
} else {
|
||||
// skip ordinary mid-prompt checkpoints
|
||||
if (!n_before_user_known && !near_prompt_end) {
|
||||
// skip ordinary mid-prompt checkpoints, unless the batch starts a user
|
||||
// message or we are near the end of the prompt
|
||||
if (!is_user_start && !near_prompt_end) {
|
||||
do_checkpoint = false;
|
||||
}
|
||||
}
|
||||
@@ -3523,29 +3533,6 @@ private:
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
|
||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id);
|
||||
|
||||
// checkpoints are created before the current batch is decoded, so
|
||||
// their token position is the batch start rather than the prompt end
|
||||
const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur;
|
||||
|
||||
{
|
||||
const bool is_on_user =
|
||||
n_before_user_known &&
|
||||
n_tokens_start == n_before_user;
|
||||
|
||||
const bool is_after_user =
|
||||
n_before_user_known &&
|
||||
n_tokens_start > n_before_user;
|
||||
|
||||
const bool is_allowed =
|
||||
!n_before_user_known ||
|
||||
is_on_user ||
|
||||
(is_after_user && near_prompt_end);
|
||||
|
||||
if (do_checkpoint && !is_allowed) {
|
||||
do_checkpoint = false;
|
||||
}
|
||||
}
|
||||
|
||||
// nothing to checkpoint yet
|
||||
// TODO: is this check needed?
|
||||
if (do_checkpoint && pos_min < 0) {
|
||||
@@ -3555,8 +3542,8 @@ private:
|
||||
// do not checkpoint after mtmd chunks
|
||||
do_checkpoint = do_checkpoint && !has_mtmd;
|
||||
|
||||
// no need to create checkpoints that are too close together
|
||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step);
|
||||
// no need to create checkpoints that are too close together, unless it's the last user message
|
||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || is_last_user_message || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step);
|
||||
SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max);
|
||||
|
||||
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
|
||||
@@ -4055,54 +4042,6 @@ void server_context::set_state_callback(server_state_callback_t callback) {
|
||||
});
|
||||
}
|
||||
|
||||
// compute the number of tokens before the last user message in the prompt
|
||||
static int32_t prompt_get_n_before_user(
|
||||
const json & message_spans,
|
||||
const std::string & prompt,
|
||||
const std::vector<raw_buffer> & files,
|
||||
const llama_vocab * vocab,
|
||||
mtmd_context * mctx) {
|
||||
int32_t result = -1;
|
||||
int32_t byte_pos = -1;
|
||||
|
||||
for (const auto & span : message_spans) {
|
||||
const std::string role = json_value(span, "role", std::string());
|
||||
|
||||
if (role == "user") {
|
||||
byte_pos = json_value(span, "pos", -1);
|
||||
}
|
||||
}
|
||||
|
||||
if (byte_pos >= 0) {
|
||||
GGML_ASSERT((size_t) byte_pos <= prompt.size());
|
||||
|
||||
const std::string prefix = prompt.substr(0, (size_t) byte_pos);
|
||||
|
||||
const std::string marker = get_media_marker();
|
||||
size_t n_prefix_media = 0;
|
||||
for (size_t pos = 0; (pos = prefix.find(marker, pos)) != std::string::npos; pos += marker.size()) {
|
||||
n_prefix_media++;
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_prefix_media <= files.size());
|
||||
|
||||
if (mctx != nullptr && n_prefix_media > 0) {
|
||||
// TODO: this makes a copy - avoid it
|
||||
std::vector<raw_buffer> prefix_files(files.begin(), files.begin() + n_prefix_media);
|
||||
|
||||
result = (int32_t) process_mtmd_prompt(mctx, prefix, prefix_files).size();
|
||||
} else {
|
||||
result = (int32_t) tokenize_input_prompts(vocab, nullptr, prefix, true, true)[0].size();
|
||||
}
|
||||
|
||||
SRV_TRC("message_spans: last user message: byte_pos=%d, media=%zu, n_before_user=%d\n",
|
||||
byte_pos, n_prefix_media, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// server_routes
|
||||
//
|
||||
@@ -4150,6 +4089,10 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
|
||||
// tasks.reserve(inputs.size()); // TODO: this is inaccurate due to child tasks
|
||||
|
||||
// message delimiters for checkpointing
|
||||
auto delimiters = common_chat_msg_delimiters_parse(json_value(data, "message_delimiters", json::array()));
|
||||
delimiters.tokenize(ctx_server.vocab);
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
server_task task = server_task(type);
|
||||
|
||||
@@ -4163,16 +4106,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
meta->logit_bias_eog,
|
||||
data);
|
||||
|
||||
const auto message_spans = json_value(data, "message_spans", json::array());
|
||||
if (prompt.is_string() && message_spans.is_array()) {
|
||||
task.params.n_before_user =
|
||||
prompt_get_n_before_user(
|
||||
message_spans,
|
||||
prompt.get<std::string>(),
|
||||
files,
|
||||
ctx_server.vocab,
|
||||
ctx_server.mctx);
|
||||
}
|
||||
task.params.message_spans = task.tokens.find_message_spans(delimiters);
|
||||
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
|
||||
@@ -224,7 +224,7 @@ void server_model_meta::update_caps() {
|
||||
});
|
||||
params.offline = true;
|
||||
// params.skip_download = true; // TODO: ideally, we should validate the model here, but it takes too much time
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {});
|
||||
if (params.mmproj.path.empty()) {
|
||||
multimodal = { false, false };
|
||||
} else {
|
||||
@@ -1393,7 +1393,9 @@ struct server_download_state : public common_download_callback {
|
||||
|
||||
bool run(common_params & params) {
|
||||
try {
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, this);
|
||||
common_params_handle_models_params p;
|
||||
p.callback = this;
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, p);
|
||||
is_ok = true;
|
||||
} catch (const std::exception & e) {
|
||||
auto model_name = params.model.get_name();
|
||||
|
||||
@@ -62,9 +62,6 @@ struct task_params {
|
||||
|
||||
int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled)
|
||||
|
||||
// number of prompt tokens before the latest user message
|
||||
int32_t n_before_user = -1;
|
||||
|
||||
int64_t t_max_prompt_ms = -1; // TODO: implement
|
||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||
|
||||
@@ -92,6 +89,9 @@ struct task_params {
|
||||
// per-request parameters for chat parsing
|
||||
common_chat_parser_params chat_parser_params;
|
||||
|
||||
// message spans for checkpointing
|
||||
common_chat_msg_spans message_spans;
|
||||
|
||||
// Embeddings
|
||||
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
|
||||
|
||||
|
||||
+12
-1
@@ -89,6 +89,17 @@ int llama_server(int argc, char ** argv) {
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
// note: router mode also accepts -hf remote-preset, so we need to check that first
|
||||
if (!params.model.hf_repo.empty()) {
|
||||
try {
|
||||
common_params_handle_models_params handle_params;
|
||||
handle_params.preset_only = true;
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, handle_params);
|
||||
} catch (const std::exception & e) {
|
||||
// ignored for now
|
||||
}
|
||||
}
|
||||
|
||||
// router server never loads a model and must not touch the GPU
|
||||
const bool is_router_server = params.model.path.empty()
|
||||
&& params.model.hf_repo.empty();
|
||||
@@ -263,7 +274,7 @@ int llama_server(int argc, char ** argv) {
|
||||
return child.run_download(params);
|
||||
} else if (!is_router_server) {
|
||||
// single-model mode (NOT spawned by router)
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {});
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@@ -256,6 +256,25 @@ def test_router_reload_models():
|
||||
os.remove(preset_path)
|
||||
|
||||
|
||||
def test_router_remote_preset():
|
||||
global server
|
||||
server.model_hf_repo = "ggml-org/test-preset-ci"
|
||||
server.model_hf_file = None
|
||||
server.offline = False
|
||||
server.start()
|
||||
|
||||
# Should see preset models in GET /models
|
||||
res = server.make_request("GET", "/models")
|
||||
assert res.status_code == 200
|
||||
ids = {item["id"] for item in res.body.get("data", [])}
|
||||
assert "tinygemma3-preset" in ids
|
||||
assert "stories260K-test" in ids
|
||||
|
||||
# Should be able to load a preset model
|
||||
model_id = "tinygemma3-preset"
|
||||
_load_model_and_wait(model_id)
|
||||
|
||||
|
||||
MODEL_DOWNLOAD_ID = "ggml-org/test-model-router-download:F16"
|
||||
MODEL_DOWNLOAD_TIMEOUT = 30
|
||||
|
||||
|
||||
+1
-2
@@ -28,10 +28,9 @@ vite.config.ts.timestamp-*
|
||||
# PWA Artifacts
|
||||
apple-splash-*.png
|
||||
apple-touch-icon-*.png
|
||||
favicon.ico
|
||||
favicon-dark.ico
|
||||
maskable-icon-*.png
|
||||
pwa-*.png
|
||||
static/favicon*
|
||||
|
||||
# Storybook
|
||||
*storybook.log
|
||||
|
||||
Generated
+7
-7
@@ -35,7 +35,7 @@
|
||||
"bits-ui": "2.18.1",
|
||||
"clsx": "2.1.1",
|
||||
"dexie": "4.4.3",
|
||||
"dompurify": "3.4.5",
|
||||
"dompurify": "3.4.11",
|
||||
"eslint": "9.39.4",
|
||||
"eslint-config-prettier": "10.1.8",
|
||||
"eslint-plugin-storybook": "10.4.2",
|
||||
@@ -8653,9 +8653,9 @@
|
||||
"peer": true
|
||||
},
|
||||
"node_modules/dompurify": {
|
||||
"version": "3.4.5",
|
||||
"resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.4.5.tgz",
|
||||
"integrity": "sha512-OrwIBKsdNSVEeubdJ1HBv/wNENRM9ytAVCv7YXt//A3vPdVMNuACRqK9mXCGCBW2ln7BT/A4X0jXHo2Gu89miA==",
|
||||
"version": "3.4.11",
|
||||
"resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.4.11.tgz",
|
||||
"integrity": "sha512-zhlUV12GsaRzMsf9q5M254YhA4+VuF0fG+QFqu6aYpoGlKtz+w8//jBcGVYBgQkR5GHjUomejY84AV+/uPbWdw==",
|
||||
"dev": true,
|
||||
"license": "(MPL-2.0 OR Apache-2.0)",
|
||||
"optionalDependencies": {
|
||||
@@ -10226,9 +10226,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/hono": {
|
||||
"version": "4.12.23",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.23.tgz",
|
||||
"integrity": "sha512-eIaZ9qDgu7XV0pxOCrg7/WhnQ6Ivm22UcxhXx/A3dcbqbbYgBEkc6e/J/s7j2tS96zoB0S9VBdLwQNCWwUo4LA==",
|
||||
"version": "4.12.26",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.26.tgz",
|
||||
"integrity": "sha512-uyZtpnYxM9CmQ7QsQknM4zN8EftNqhON1qYeIKM0Se67CCEe2c44xyGURwB0axX2fBDu1dqHrHAc1hmNT8ITkw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
|
||||
@@ -54,7 +54,7 @@
|
||||
"bits-ui": "2.18.1",
|
||||
"clsx": "2.1.1",
|
||||
"dexie": "4.4.3",
|
||||
"dompurify": "3.4.5",
|
||||
"dompurify": "3.4.11",
|
||||
"eslint": "9.39.4",
|
||||
"eslint-config-prettier": "10.1.8",
|
||||
"eslint-plugin-storybook": "10.4.2",
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
import { defineConfig } from '@vite-pwa/assets-generator/config';
|
||||
import { FAVICON_COLORS, PWA_ASSET_GENERATOR } from './src/lib/constants/pwa';
|
||||
import { writeThemeFavicons } from './scripts/favicon-colorize';
|
||||
|
||||
writeThemeFavicons(FAVICON_COLORS.LIGHT, FAVICON_COLORS.DARK, {
|
||||
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
|
||||
});
|
||||
|
||||
export default defineConfig({
|
||||
headLinkOptions: {
|
||||
@@ -7,7 +13,8 @@ export default defineConfig({
|
||||
preset: {
|
||||
transparent: {
|
||||
sizes: [],
|
||||
favicons: [[48, 'favicon-dark.ico']]
|
||||
favicons: [[48, 'favicon-dark.ico']],
|
||||
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
|
||||
},
|
||||
maskable: {
|
||||
sizes: []
|
||||
|
||||
@@ -5,15 +5,32 @@ import {
|
||||
} from '@vite-pwa/assets-generator/config';
|
||||
import { readFileSync } from 'node:fs';
|
||||
import { resolve } from 'node:path';
|
||||
import { THEME_COLORS, PWA_GENERATOR_DEVICES, PWA_ASSET_GENERATOR } from './src/lib/constants/pwa';
|
||||
import {
|
||||
THEME_COLORS,
|
||||
PWA_GENERATOR_DEVICES,
|
||||
PWA_ASSET_GENERATOR,
|
||||
FAVICON_COLORS
|
||||
} from './src/lib/constants/pwa';
|
||||
import { SplashOrientation } from './src/lib/enums/splash.enums';
|
||||
import { writeThemeFavicons } from './scripts/favicon-colorize';
|
||||
|
||||
writeThemeFavicons(FAVICON_COLORS.LIGHT, FAVICON_COLORS.DARK, {
|
||||
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
|
||||
});
|
||||
|
||||
export default defineConfig({
|
||||
headLinkOptions: {
|
||||
preset: PWA_ASSET_GENERATOR.LINK_PRESET
|
||||
},
|
||||
preset: combinePresetAndAppleSplashScreens(
|
||||
minimal2023Preset,
|
||||
{
|
||||
...minimal2023Preset,
|
||||
// tiny margin so favicon.ico / pwa-*.png breathe inside the canvas
|
||||
transparent: {
|
||||
...minimal2023Preset.transparent,
|
||||
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
|
||||
}
|
||||
},
|
||||
{
|
||||
padding: PWA_ASSET_GENERATOR.SPLASH_PADDING,
|
||||
resizeOptions: {
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
import { mkdirSync, readFileSync, writeFileSync } from 'node:fs';
|
||||
import { dirname, resolve } from 'node:path';
|
||||
import { fileURLToPath } from 'node:url';
|
||||
|
||||
const HERE = dirname(fileURLToPath(import.meta.url));
|
||||
const PROJECT_ROOT = resolve(HERE, '..');
|
||||
|
||||
const DEFAULT_LOGO = resolve(PROJECT_ROOT, 'src/lib/assets/logo.svg');
|
||||
const DEFAULT_OUT_DIR = resolve(PROJECT_ROOT, 'static');
|
||||
const DEFAULT_OUT_LIGHT = resolve(DEFAULT_OUT_DIR, 'favicon.svg');
|
||||
const DEFAULT_OUT_DARK = resolve(DEFAULT_OUT_DIR, 'favicon-dark.svg');
|
||||
|
||||
const CURRENT_COLOR = 'currentColor';
|
||||
|
||||
export interface ColorizedFavicon {
|
||||
light: string;
|
||||
dark: string;
|
||||
}
|
||||
|
||||
export interface WriteThemeFaviconsOptions {
|
||||
sourcePath?: string;
|
||||
lightOutPath?: string;
|
||||
darkOutPath?: string;
|
||||
/**
|
||||
* Fraction of the icon (0..1) to leave as an even margin on each side.
|
||||
* Applied by wrapping the inner content in a `<g transform="...">` so the
|
||||
* source `src/lib/assets/logo.svg` is not modified. Pass 0 to disable.
|
||||
*/
|
||||
padding?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Replace every `currentColor` occurrence in the SVG with the given color.
|
||||
* Pure: no filesystem access, so it is straightforward to unit-test.
|
||||
*/
|
||||
export function colorizeFaviconSvg(
|
||||
svg: string,
|
||||
lightColor: string,
|
||||
darkColor: string
|
||||
): ColorizedFavicon {
|
||||
return {
|
||||
light: svg.replaceAll(CURRENT_COLOR, lightColor),
|
||||
dark: svg.replaceAll(CURRENT_COLOR, darkColor)
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Shrink the inner SVG content uniformly and re-center it so `padding` (a
|
||||
* 0..1 fraction) is reserved as equal margin on each side. Returns the input
|
||||
* unchanged for non-positive padding, missing/invalid `viewBox`, or unexpected
|
||||
* markup so the caller always gets a renderable SVG.
|
||||
*/
|
||||
export function padFaviconSvg(svg: string, padding: number): string {
|
||||
if (!(padding > 0) || padding >= 1) return svg;
|
||||
|
||||
const viewBoxMatch = svg.match(/viewBox\s*=\s*["']([^"']+)["']/i);
|
||||
if (!viewBoxMatch) return svg;
|
||||
|
||||
const parts = viewBoxMatch[1]
|
||||
.trim()
|
||||
.split(/[\s,]+/)
|
||||
.map(Number);
|
||||
if (parts.length !== 4 || parts.some((n) => !Number.isFinite(n))) return svg;
|
||||
|
||||
const [, , width, height] = parts;
|
||||
if (width <= 0 || height <= 0) return svg;
|
||||
|
||||
const scale = 1 - padding;
|
||||
const translateX = (padding * width) / 2;
|
||||
const translateY = (padding * height) / 2;
|
||||
|
||||
const openTagStart = svg.search(/<svg\b/i);
|
||||
if (openTagStart === -1) return svg;
|
||||
const openTagEnd = svg.indexOf('>', openTagStart);
|
||||
if (openTagEnd === -1) return svg;
|
||||
const closeStart = svg.lastIndexOf('</svg');
|
||||
if (closeStart === -1 || closeStart <= openTagEnd) return svg;
|
||||
|
||||
const openTag = svg.slice(0, openTagEnd + 1);
|
||||
const inner = svg.slice(openTagEnd + 1, closeStart);
|
||||
const closeTag = svg.slice(closeStart);
|
||||
|
||||
const group = `<g transform="translate(${translateX} ${translateY}) scale(${scale})">`;
|
||||
return `${openTag}${group}${inner}</g>${closeTag}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Read `src/lib/assets/logo.svg`, colorize it for both themes, and write
|
||||
* the results to the static directory so the PWA asset generator can consume
|
||||
* them. Paths can be overridden for tests.
|
||||
*/
|
||||
export function writeThemeFavicons(
|
||||
lightColor: string,
|
||||
darkColor: string,
|
||||
{
|
||||
sourcePath = DEFAULT_LOGO,
|
||||
lightOutPath = DEFAULT_OUT_LIGHT,
|
||||
darkOutPath = DEFAULT_OUT_DARK,
|
||||
padding = 0
|
||||
}: WriteThemeFaviconsOptions = {}
|
||||
): void {
|
||||
const source = readFileSync(sourcePath, 'utf-8');
|
||||
const { light, dark } = colorizeFaviconSvg(source, lightColor, darkColor);
|
||||
mkdirSync(dirname(lightOutPath), { recursive: true });
|
||||
writeFileSync(lightOutPath, padFaviconSvg(light, padding));
|
||||
writeFileSync(darkOutPath, padFaviconSvg(dark, padding));
|
||||
}
|
||||
@@ -48,6 +48,7 @@
|
||||
|
||||
--chat-form-area-height: 8rem;
|
||||
--chat-form-area-offset: 2rem;
|
||||
--chat-form-padding-top: 6rem;
|
||||
--max-message-height: max(24rem, min(80dvh, calc(100dvh - var(--chat-form-area-height) - 12rem)));
|
||||
}
|
||||
|
||||
@@ -55,6 +56,7 @@
|
||||
:root {
|
||||
--chat-form-area-height: 24rem;
|
||||
--chat-form-area-offset: 12rem;
|
||||
--chat-form-padding-top: 6rem;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,7 +143,6 @@
|
||||
@apply bg-background text-foreground;
|
||||
scrollbar-width: thin;
|
||||
scrollbar-gutter: stable;
|
||||
overflow: hidden; /* Added due to Mermaid rendering somehow causing the double scrollbar */
|
||||
}
|
||||
|
||||
/* Global scrollbar styling - visible only on hover */
|
||||
@@ -193,3 +194,7 @@
|
||||
scrollbar-width: none;
|
||||
}
|
||||
}
|
||||
|
||||
.mermaidTooltip {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
@@ -10,9 +10,9 @@ import { isElementInViewport } from '$lib/utils/viewport';
|
||||
*/
|
||||
export function fadeInView(
|
||||
node: HTMLElement,
|
||||
options: { duration?: number; y?: number; skipIfVisible?: boolean } = {}
|
||||
options: { duration?: number; y?: number; delay?: number; skipIfVisible?: boolean } = {}
|
||||
) {
|
||||
const { duration = 300, y = 0, skipIfVisible = false } = options;
|
||||
const { duration = 300, y = 0, delay = 0, skipIfVisible = false } = options;
|
||||
|
||||
if (skipIfVisible && isElementInViewport(node)) {
|
||||
return;
|
||||
@@ -27,10 +27,12 @@ export function fadeInView(
|
||||
(entries) => {
|
||||
for (const entry of entries) {
|
||||
if (entry.isIntersecting) {
|
||||
requestAnimationFrame(() => {
|
||||
node.style.opacity = '1';
|
||||
node.style.transform = 'translateY(0)';
|
||||
});
|
||||
setTimeout(() => {
|
||||
requestAnimationFrame(() => {
|
||||
node.style.opacity = '1';
|
||||
node.style.transform = 'translateY(0)';
|
||||
});
|
||||
}, delay);
|
||||
observer.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
<svg width="512" height="512" viewBox="0 0 512 512" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M244.95 8C215.233 8 187.774 23.8591 172.923 49.5999L95.6009 183.625C60.2162 244.959 104.481 321.6 175.29 321.6H208L316.977 132.708C348.959 77.2719 308.95 8 244.95 8ZM208 321.6H351.947C415.982 321.6 456.013 390.91 424.013 446.377C409.155 472.132 381.681 488 351.947 488H271.29C200.481 488 156.216 411.359 191.601 350.026L208 321.6Z" fill="currentColor"/>
|
||||
<path d="M208 321.6H16L106.462 164.8L208 321.6Z" fill="currentColor"/>
|
||||
<path d="M388.923 8L208 321.6L253.6 8H388.923Z" fill="currentColor"/>
|
||||
<path d="M304 488H112L202.462 331.2L304 488Z" fill="currentColor"/>
|
||||
<path d="M496 321.6H208L419.399 454.4L496 321.6Z" fill="currentColor"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 771 B |
@@ -8,12 +8,13 @@
|
||||
ariaLabel?: string;
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
href?: string;
|
||||
icon: Component;
|
||||
iconSize?: string;
|
||||
onclick: (e?: MouseEvent) => void;
|
||||
onclick?: (e?: MouseEvent) => void;
|
||||
size?: ButtonSize;
|
||||
stopPropagationOnClick?: boolean;
|
||||
tooltip: string;
|
||||
tooltip?: string;
|
||||
variant?: ButtonVariant;
|
||||
tooltipSide?: TooltipSide;
|
||||
}
|
||||
@@ -22,6 +23,7 @@
|
||||
icon,
|
||||
tooltip,
|
||||
variant = 'ghost',
|
||||
href = '',
|
||||
size = 'sm',
|
||||
class: className = '',
|
||||
disabled = false,
|
||||
@@ -31,34 +33,49 @@
|
||||
onclick,
|
||||
ariaLabel
|
||||
}: Props = $props();
|
||||
|
||||
let innerWidth = $state(0);
|
||||
const showTooltip = $derived(!!tooltip && innerWidth > 768);
|
||||
</script>
|
||||
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<!-- prevent another nested button element -->
|
||||
{#snippet child({ props })}
|
||||
<Button
|
||||
{...props}
|
||||
{variant}
|
||||
{size}
|
||||
{disabled}
|
||||
onclick={(e: MouseEvent) => {
|
||||
if (stopPropagationOnClick) e.stopPropagation();
|
||||
{#snippet button(props = {})}
|
||||
<Button
|
||||
{...props}
|
||||
{href}
|
||||
{variant}
|
||||
{size}
|
||||
{disabled}
|
||||
onclick={(e: MouseEvent) => {
|
||||
if (stopPropagationOnClick) e.stopPropagation();
|
||||
|
||||
onclick?.(e);
|
||||
}}
|
||||
class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!"
|
||||
aria-label={ariaLabel || tooltip}
|
||||
>
|
||||
{#if icon}
|
||||
{@const IconComponent = icon}
|
||||
<IconComponent class={iconSize} />
|
||||
{/if}
|
||||
</Button>
|
||||
{/snippet}
|
||||
</Tooltip.Trigger>
|
||||
onclick?.(e);
|
||||
}}
|
||||
class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!"
|
||||
aria-label={ariaLabel || tooltip}
|
||||
>
|
||||
{#if icon}
|
||||
{@const IconComponent = icon}
|
||||
|
||||
<Tooltip.Content side={tooltipSide}>
|
||||
<p>{tooltip}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
<IconComponent class={iconSize} />
|
||||
{/if}
|
||||
</Button>
|
||||
{/snippet}
|
||||
|
||||
{#if showTooltip}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<!-- prevent another nested button element -->
|
||||
{#snippet child({ props })}
|
||||
{@render button(props)}
|
||||
{/snippet}
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content side={tooltipSide}>
|
||||
<p>{tooltip}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{:else}
|
||||
{@render button({ href })}
|
||||
{/if}
|
||||
|
||||
<svelte:window bind:innerWidth />
|
||||
|
||||
@@ -494,7 +494,7 @@
|
||||
/>
|
||||
|
||||
<div
|
||||
class="{INPUT_CLASSES} overflow-hidden rounded-3xl backdrop-blur-md {disabled
|
||||
class="{INPUT_CLASSES} overflow-hidden rounded-4xl md:rounded-3xl backdrop-blur-md {disabled
|
||||
? 'cursor-not-allowed opacity-60'
|
||||
: ''}"
|
||||
data-slot="input-area"
|
||||
@@ -510,7 +510,7 @@
|
||||
/>
|
||||
|
||||
<div
|
||||
class="flex-column relative min-h-[48px] items-center rounded-3xl py-2 pb-2.25 shadow-sm transition-all focus-within:shadow-md md:!py-3"
|
||||
class="flex-column relative min-h-12 items-center rounded-4xl md:rounded-3xl py-2 pb-2.25 shadow-sm transition-all focus-within:shadow-md md:py-3!"
|
||||
onpaste={handlePaste}
|
||||
>
|
||||
<ChatFormTextarea
|
||||
|
||||
+1
-1
@@ -15,7 +15,7 @@
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<Button
|
||||
class="file-upload-button h-8 w-8 rounded-full p-0"
|
||||
class="file-upload-button md:h-8 md:w-8 h-9 w-9 rounded-full p-0"
|
||||
{disabled}
|
||||
{onclick}
|
||||
variant="secondary"
|
||||
|
||||
+16
-3
@@ -15,6 +15,7 @@
|
||||
import { McpLogo } from '$lib/components/app';
|
||||
import { PencilRuler, ChevronDown, ChevronRight } from '@lucide/svelte';
|
||||
import { HealthCheckStatus } from '$lib/enums';
|
||||
import { AttachmentAction } from '$lib/enums/attachment.enums';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -270,14 +271,22 @@
|
||||
</Collapsible.Root>
|
||||
{/if}
|
||||
|
||||
<button type="button" class={sheetItemClass} onclick={onSystemPromptClick}>
|
||||
<button
|
||||
type="button"
|
||||
class={sheetItemClass}
|
||||
onclick={() => attachmentMenu.callbacks[AttachmentAction.SYSTEM_PROMPT_CLICK]()}
|
||||
>
|
||||
<MessageSquare class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>System Message</span>
|
||||
</button>
|
||||
|
||||
{#if hasMcpPromptsSupport}
|
||||
<button type="button" class={sheetItemClass} onclick={onMcpPromptClick}>
|
||||
<button
|
||||
type="button"
|
||||
class={sheetItemClass}
|
||||
onclick={() => attachmentMenu.callbacks[AttachmentAction.MCP_PROMPT_CLICK]()}
|
||||
>
|
||||
<Zap class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>MCP Prompt</span>
|
||||
@@ -285,7 +294,11 @@
|
||||
{/if}
|
||||
|
||||
{#if hasMcpResourcesSupport}
|
||||
<button type="button" class={sheetItemClass} onclick={onMcpResourcesClick}>
|
||||
<button
|
||||
type="button"
|
||||
class={sheetItemClass}
|
||||
onclick={() => attachmentMenu.callbacks[AttachmentAction.MCP_RESOURCES_CLICK]()}
|
||||
>
|
||||
<FolderOpen class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>MCP Resources</span>
|
||||
|
||||
+1
@@ -42,6 +42,7 @@
|
||||
{hasMcpPromptsSupport}
|
||||
{hasMcpResourcesSupport}
|
||||
{onFileUpload}
|
||||
{onSystemPromptClick}
|
||||
{onMcpPromptClick}
|
||||
{onMcpResourcesClick}
|
||||
>
|
||||
|
||||
+1
-1
@@ -20,7 +20,7 @@
|
||||
type="submit"
|
||||
disabled={isDisabled}
|
||||
class={[
|
||||
'h-8 w-8 rounded-full p-0',
|
||||
'md:h-8 md:w-8 h-9 w-9 rounded-full p-0',
|
||||
showErrorState &&
|
||||
'bg-red-400/10 text-red-400 hover:bg-red-400/20 hover:text-red-400 disabled:opacity-100'
|
||||
]}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
<script lang="ts">
|
||||
import { isMobile } from '$lib/stores/viewport.svelte';
|
||||
import { autoResizeTextarea } from '$lib/utils';
|
||||
import { onMount } from 'svelte';
|
||||
|
||||
@@ -37,7 +38,9 @@
|
||||
}
|
||||
|
||||
export function focus() {
|
||||
textareaElement?.focus();
|
||||
if (isMobile.current) return;
|
||||
|
||||
textareaElement?.focus({ preventScroll: true });
|
||||
}
|
||||
|
||||
export function resetHeight() {
|
||||
|
||||
@@ -231,7 +231,7 @@
|
||||
editedContent = message.content;
|
||||
}
|
||||
|
||||
textareaElement?.focus();
|
||||
textareaElement?.focus({ preventScroll: true });
|
||||
editedExtras = message.extra ? [...message.extra] : [];
|
||||
editedUploadedFiles = [];
|
||||
|
||||
@@ -324,7 +324,7 @@
|
||||
}
|
||||
</script>
|
||||
|
||||
<div use:fadeInView>
|
||||
<div use:fadeInView class="chat-message">
|
||||
{#if message.role === MessageRole.SYSTEM}
|
||||
<ChatMessageSystem
|
||||
bind:textareaElement
|
||||
|
||||
+72
-5
@@ -180,6 +180,9 @@
|
||||
|
||||
let displayedModel = $derived(message.model ?? null);
|
||||
|
||||
// model being switched to while it loads, so the selector bar tracks it
|
||||
let pendingModel = $state<string | null>(null);
|
||||
|
||||
let isCurrentlyLoading = $derived(isLoading());
|
||||
let isStreaming = $derived(isChatStreaming());
|
||||
let hasNoContent = $derived(!message?.content?.trim());
|
||||
@@ -207,6 +210,42 @@
|
||||
isLastAssistantMessage
|
||||
);
|
||||
|
||||
let assistantEl: HTMLDivElement | undefined = $state();
|
||||
let lastUserMessageHeight = $state(0);
|
||||
let assistantMarginTop = $state(0);
|
||||
|
||||
$effect(() => {
|
||||
if (!assistantEl) return;
|
||||
|
||||
assistantMarginTop = Math.round(parseFloat(getComputedStyle(assistantEl).marginTop));
|
||||
|
||||
const chatMessageEl = assistantEl.closest('.chat-message');
|
||||
const previousChatMessage = chatMessageEl?.previousElementSibling;
|
||||
const userMessageEl = previousChatMessage?.querySelector(
|
||||
'.chat-message-user'
|
||||
) as HTMLElement | null;
|
||||
|
||||
if (!userMessageEl) {
|
||||
lastUserMessageHeight = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
const updateHeight = () => {
|
||||
const rect = userMessageEl.getBoundingClientRect();
|
||||
const marginTop = Math.round(parseFloat(getComputedStyle(userMessageEl).marginTop));
|
||||
lastUserMessageHeight = Math.round(rect.height + marginTop);
|
||||
};
|
||||
|
||||
updateHeight();
|
||||
|
||||
const resizeObserver = new ResizeObserver(updateHeight);
|
||||
resizeObserver.observe(userMessageEl);
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
};
|
||||
});
|
||||
|
||||
function handleCopyModel() {
|
||||
void copyToClipboard(displayedModel ?? '');
|
||||
}
|
||||
@@ -219,12 +258,17 @@
|
||||
</script>
|
||||
|
||||
<div
|
||||
class="text-md group w-full leading-7.5 {className}"
|
||||
bind:this={assistantEl}
|
||||
class="chat-message-assistant text-md group w-full leading-7.5 {className}"
|
||||
style:--last-user-message-height={lastUserMessageHeight > 0
|
||||
? `${lastUserMessageHeight}px`
|
||||
: undefined}
|
||||
style:--assistant-margin-top={assistantMarginTop > 0 ? `${assistantMarginTop}px` : undefined}
|
||||
role="group"
|
||||
aria-label="Assistant message with actions"
|
||||
>
|
||||
{#if showProcessingInfoTop}
|
||||
<div class="mt-6 w-full max-w-[48rem]" in:fade>
|
||||
<div class="mt-6 w-full max-w-3xl" in:fade>
|
||||
<div class="processing-container">
|
||||
<span class="processing-text">
|
||||
{modelLoadingText ??
|
||||
@@ -257,7 +301,7 @@
|
||||
{/if}
|
||||
|
||||
{#if showProcessingInfoBottom}
|
||||
<div class="mt-4 w-full max-w-[48rem]" in:fade>
|
||||
<div class="mt-4 w-full max-w-3xl" in:fade>
|
||||
<div class="processing-container">
|
||||
<span class="processing-text">
|
||||
{modelLoadingText ??
|
||||
@@ -277,13 +321,19 @@
|
||||
>
|
||||
{#if isRouter}
|
||||
<ModelsSelectorDropdown
|
||||
currentModel={displayedModel}
|
||||
currentModel={pendingModel ?? displayedModel}
|
||||
disabled={isLoading()}
|
||||
onModelChange={async (modelId: string, modelName: string) => {
|
||||
const status = modelsStore.getModelStatus(modelId);
|
||||
|
||||
if (status !== ServerModelStatus.LOADED) {
|
||||
await modelsStore.loadModel(modelId);
|
||||
pendingModel = modelId;
|
||||
|
||||
try {
|
||||
await modelsStore.loadModel(modelId);
|
||||
} finally {
|
||||
pendingModel = null;
|
||||
}
|
||||
}
|
||||
|
||||
onRegenerate(modelName);
|
||||
@@ -351,6 +401,23 @@
|
||||
</div>
|
||||
|
||||
<style>
|
||||
:global(.chat-message):last-child .chat-message-assistant {
|
||||
--assistant-min-height-offset: calc(
|
||||
var(--last-user-message-height, 19rem) + var(--chat-form-height, 6rem) +
|
||||
var(--chat-form-bottom-position, 0.5rem) + var(--chat-form-padding-top, 6rem) +
|
||||
var(--assistant-margin-top, 3rem)
|
||||
);
|
||||
min-height: calc(100dvh - var(--assistant-min-height-offset));
|
||||
|
||||
@media (width > 768px) {
|
||||
--assistant-min-height-offset: calc(
|
||||
var(--last-user-message-height, 18rem) + var(--chat-form-height, 6rem) +
|
||||
var(--chat-form-bottom-position, 1rem) + var(--chat-form-padding-top, 6rem) +
|
||||
var(--assistant-margin-top, 3rem)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
.processing-container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
|
||||
+1
-1
@@ -48,7 +48,7 @@
|
||||
|
||||
<div
|
||||
aria-label="User message with actions"
|
||||
class="group flex flex-col items-end gap-3 md:gap-2 {className}"
|
||||
class="chat-message-user group flex flex-col items-end gap-3 md:gap-2 {className}"
|
||||
role="group"
|
||||
>
|
||||
{#if editCtx.isEditing}
|
||||
|
||||
+2
-2
@@ -19,7 +19,7 @@
|
||||
renderMarkdown = false,
|
||||
textColorClass = 'text-foreground',
|
||||
cardBgClass = 'dark:bg-primary/15',
|
||||
maxHeightStyle = 'max-height: var(--max-message-height);'
|
||||
maxHeightStyle = ''
|
||||
}: Props = $props();
|
||||
|
||||
let isMultiline = $state(false);
|
||||
@@ -59,7 +59,7 @@
|
||||
|
||||
{#if content.trim()}
|
||||
<Card
|
||||
class="max-w-[80%] overflow-y-auto rounded-[1.125rem] border-none bg-primary/5 px-3.75 py-1.5 {textColorClass} backdrop-blur-md data-[multiline]:py-2.5 {cardBgClass}"
|
||||
class="chat-message-user-bubble max-w-[80%] overflow-y-auto rounded-[1.125rem] border-none bg-primary/5 px-3.75 py-1.5 {textColorClass} backdrop-blur-md data-multiline:py-2.5 {cardBgClass}"
|
||||
data-multiline={isMultiline ? '' : undefined}
|
||||
style="{maxHeightStyle} overflow-wrap: anywhere; word-break: break-word;"
|
||||
>
|
||||
|
||||
@@ -37,6 +37,7 @@
|
||||
let allConversationMessages = $state<DatabaseMessage[]>([]);
|
||||
let isVisible = $state(false);
|
||||
let previousConversationId = $state<string | null>(null);
|
||||
let previousRouteId = $state<string | null>(null);
|
||||
|
||||
const currentConfig = config();
|
||||
|
||||
@@ -157,8 +158,9 @@
|
||||
});
|
||||
});
|
||||
|
||||
beforeNavigate(() => {
|
||||
beforeNavigate((navigation) => {
|
||||
isVisible = false;
|
||||
previousRouteId = navigation.from?.route.id ?? null;
|
||||
});
|
||||
|
||||
afterNavigate(() => {
|
||||
@@ -249,12 +251,13 @@
|
||||
</script>
|
||||
|
||||
<div
|
||||
class="transition-opacity delay-300 duration-500 ease-out
|
||||
{isVisible ? 'opacity-100' : 'opacity-0'}"
|
||||
class="transition-opacity duration-500 ease-out
|
||||
{isVisible ? 'opacity-100' : 'opacity-0'}
|
||||
{previousRouteId === '/(chat)/chat/[id]' ? '' : 'delay-300'}"
|
||||
>
|
||||
{#each displayMessages as { message, toolMessages, isLastAssistantMessage, siblingInfo } (message.id)}
|
||||
<ChatMessage
|
||||
class="mx-auto mt-12 w-full max-w-[48rem]"
|
||||
class="mx-auto mt-12 w-full max-w-3xl"
|
||||
{message}
|
||||
{toolMessages}
|
||||
{isLastAssistantMessage}
|
||||
|
||||
@@ -1,31 +1,28 @@
|
||||
<script lang="ts">
|
||||
import { Trash2 } from '@lucide/svelte';
|
||||
import { afterNavigate } from '$app/navigation';
|
||||
import { page } from '$app/state';
|
||||
import {
|
||||
ChatScreenForm,
|
||||
ChatMessages,
|
||||
ChatScreenDragOverlay,
|
||||
ChatScreenProcessingInfo,
|
||||
ChatScreenActionScrollDown,
|
||||
DialogEmptyFileAlert,
|
||||
DialogFileUploadError,
|
||||
DialogChatError,
|
||||
ServerLoadingSplash,
|
||||
DialogConfirmation,
|
||||
ChatScreenServerError
|
||||
} from '$lib/components/app';
|
||||
import { setProcessingInfoContext } from '$lib/contexts';
|
||||
import { ErrorDialogType } from '$lib/enums';
|
||||
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
|
||||
import { useChatScreenActiveModel } from '$lib/hooks/use-chat-screen-active-model.svelte';
|
||||
import { useChatScreenDragAndDrop } from '$lib/hooks/use-chat-screen-drag-and-drop.svelte';
|
||||
import { useChatScreenFileUpload } from '$lib/hooks/use-chat-screen-file-upload.svelte';
|
||||
import { useChatScreenScroll } from '$lib/hooks/use-chat-screen-scroll.svelte';
|
||||
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
|
||||
import { device } from '$lib/stores/device.svelte';
|
||||
import { isMobile } from '$lib/stores/viewport.svelte';
|
||||
import {
|
||||
chatStore,
|
||||
errorDialog,
|
||||
isLoading,
|
||||
isChatStreaming,
|
||||
isEditing,
|
||||
getAddFilesHandler,
|
||||
activeProcessingState
|
||||
} from '$lib/stores/chat.svelte';
|
||||
import {
|
||||
@@ -34,138 +31,81 @@
|
||||
activeConversation
|
||||
} from '$lib/stores/conversations.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { serverLoading, serverError, isRouterMode } from '$lib/stores/server.svelte';
|
||||
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
|
||||
import { isFileTypeSupported, filterFilesByModalities } from '$lib/utils';
|
||||
import { parseFilesToMessageExtras, processFilesToChatUploaded } from '$lib/utils/browser-only';
|
||||
import { onMount } from 'svelte';
|
||||
import { serverLoading, serverError } from '$lib/stores/server.svelte';
|
||||
import { parseFilesToMessageExtras } from '$lib/utils/browser-only';
|
||||
import { onDestroy, onMount } from 'svelte';
|
||||
import ChatScreenGreeting from './ChatScreenGreeting.svelte';
|
||||
import ChatScreenActionScrollDown from './ChatScreenActionScrollDown.svelte';
|
||||
import ChatScreenDialogsAndAlerts from './ChatScreenDialogsAndAlerts.svelte';
|
||||
import { ROUTES } from '$lib/constants';
|
||||
|
||||
let { showCenteredEmpty = false } = $props();
|
||||
|
||||
const autoScroll = createAutoScrollController();
|
||||
|
||||
let disableAutoScroll = $derived(Boolean(config().disableAutoScroll));
|
||||
let chatScrollContainer: HTMLDivElement | undefined = $state();
|
||||
let dragCounter = $state(0);
|
||||
let isDragOver = $state(false);
|
||||
let showFileErrorDialog = $state(false);
|
||||
let uploadedFiles = $state<ChatUploadedFile[]>([]);
|
||||
|
||||
let fileErrorData = $state<{
|
||||
generallyUnsupported: File[];
|
||||
modalityUnsupported: File[];
|
||||
modalityReasons: Record<string, string>;
|
||||
supportedTypes: string[];
|
||||
}>({
|
||||
generallyUnsupported: [],
|
||||
modalityUnsupported: [],
|
||||
modalityReasons: {},
|
||||
supportedTypes: []
|
||||
});
|
||||
|
||||
let showDeleteDialog = $state(false);
|
||||
|
||||
let showEmptyFileDialog = $state(false);
|
||||
|
||||
let processingInfoVisible = $state(false);
|
||||
|
||||
let emptyFileNames = $state<string[]>([]);
|
||||
|
||||
let initialMessage = $state('');
|
||||
|
||||
let isEmpty = $derived(
|
||||
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
|
||||
);
|
||||
|
||||
let activeErrorDialog = $derived(errorDialog());
|
||||
let isServerLoading = $derived(serverLoading());
|
||||
let hasPropsError = $derived(!!serverError());
|
||||
|
||||
let isCurrentConversationLoading = $derived(isLoading() || isChatStreaming());
|
||||
|
||||
let showProcessingInfo = $derived(
|
||||
isCurrentConversationLoading ||
|
||||
(config().keepStatsVisible && !!page.params.id) ||
|
||||
activeProcessingState() !== null
|
||||
);
|
||||
|
||||
let isRouter = $derived(isRouterMode());
|
||||
|
||||
let conversationModel = $derived(
|
||||
chatStore.getConversationModel(activeMessages() as DatabaseMessage[])
|
||||
);
|
||||
|
||||
let activeModelId = $derived.by(() => {
|
||||
const options = modelOptions();
|
||||
|
||||
if (!isRouter) {
|
||||
return options.length > 0 ? options[0].model : null;
|
||||
}
|
||||
|
||||
const selectedId = selectedModelId();
|
||||
if (selectedId) {
|
||||
const model = options.find((m) => m.id === selectedId);
|
||||
if (model) return model.model;
|
||||
}
|
||||
|
||||
if (conversationModel) {
|
||||
const model = options.find((m) => m.model === conversationModel);
|
||||
if (model) return model.model;
|
||||
}
|
||||
|
||||
return null;
|
||||
});
|
||||
|
||||
let modelPropsVersion = $state(0);
|
||||
|
||||
setProcessingInfoContext({
|
||||
get showProcessingInfo() {
|
||||
return showProcessingInfo;
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (activeModelId) {
|
||||
const cached = modelsStore.getModelProps(activeModelId);
|
||||
let disableAutoScroll = $derived(Boolean(config().disableAutoScroll) || isMobile.current);
|
||||
let isMobileUserScrolledUp = $state(false);
|
||||
let mobileScrollDownHint = $state(false);
|
||||
let mobileScrollDownHintLockedUntil = $state(0);
|
||||
let emptyFileNames = $state<string[]>([]);
|
||||
let initialMessage = $state('');
|
||||
let showDeleteDialog = $state(false);
|
||||
let showEmptyFileDialog = $state(false);
|
||||
let isEmpty = $derived(
|
||||
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
|
||||
);
|
||||
let activeErrorDialog = $derived(errorDialog());
|
||||
let isServerLoading = $derived(serverLoading());
|
||||
let hasPropsError = $derived(!!serverError());
|
||||
let isCurrentConversationLoading = $derived(isLoading() || isChatStreaming());
|
||||
let showProcessingInfo = $derived(
|
||||
isCurrentConversationLoading ||
|
||||
(config().keepStatsVisible && !!page.params.id) ||
|
||||
activeProcessingState() !== null
|
||||
);
|
||||
let chatFormBottomPosition = $derived.by(() => {
|
||||
if (!isMobile.current) return '1rem';
|
||||
if (device.isStandalone) return '1.5rem';
|
||||
if (device.isIOSSafari) return '0.25rem';
|
||||
return '0.5rem';
|
||||
});
|
||||
|
||||
if (!cached) {
|
||||
modelsStore.fetchModelProps(activeModelId).then(() => {
|
||||
modelPropsVersion++;
|
||||
});
|
||||
const autoScroll = createAutoScrollController();
|
||||
const scroll = useChatScreenScroll(autoScroll);
|
||||
const activeModel = useChatScreenActiveModel();
|
||||
const fileUpload = useChatScreenFileUpload({
|
||||
capabilities: () => ({
|
||||
hasVision: activeModel.hasVisionModality,
|
||||
hasAudio: activeModel.hasAudioModality,
|
||||
hasVideo: activeModel.hasVideoModality
|
||||
}),
|
||||
activeModelId: () => activeModel.activeModelId
|
||||
});
|
||||
const dragAndDrop = useChatScreenDragAndDrop({
|
||||
onDrop: fileUpload.handleFileUpload
|
||||
});
|
||||
const { handleKeydown } = useKeyboardShortcuts({
|
||||
deleteActiveConversation: () => {
|
||||
if (activeConversation()) {
|
||||
showDeleteDialog = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let hasAudioModality = $derived.by(() => {
|
||||
if (activeModelId) {
|
||||
void modelPropsVersion;
|
||||
function handleMobileScroll() {
|
||||
if (!isMobile.current) return;
|
||||
|
||||
return modelsStore.modelSupportsAudio(activeModelId);
|
||||
}
|
||||
const container = scroll.chatScrollContainer;
|
||||
if (!container) return;
|
||||
|
||||
return false;
|
||||
});
|
||||
|
||||
let hasVideoModality = $derived.by(() => {
|
||||
if (activeModelId) {
|
||||
void modelPropsVersion;
|
||||
|
||||
return modelsStore.modelSupportsVideo(activeModelId);
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
|
||||
let hasVisionModality = $derived.by(() => {
|
||||
if (activeModelId) {
|
||||
void modelPropsVersion;
|
||||
|
||||
return modelsStore.modelSupportsVision(activeModelId);
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
const distanceFromBottom =
|
||||
container.scrollHeight - container.clientHeight - container.scrollTop;
|
||||
isMobileUserScrolledUp = distanceFromBottom > 300;
|
||||
}
|
||||
|
||||
async function handleDeleteConfirm() {
|
||||
const conversation = activeConversation();
|
||||
@@ -177,27 +117,69 @@
|
||||
showDeleteDialog = false;
|
||||
}
|
||||
|
||||
function handleProcessingInfoVisibility(visible: boolean) {
|
||||
processingInfoVisible = visible;
|
||||
}
|
||||
async function handleSendMessage(message: string, files?: ChatUploadedFile[]): Promise<boolean> {
|
||||
const plainFiles = files ? $state.snapshot(files) : undefined;
|
||||
const result = plainFiles
|
||||
? await parseFilesToMessageExtras(plainFiles, activeModel.activeModelId ?? undefined)
|
||||
: undefined;
|
||||
|
||||
function handleDragEnter(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
|
||||
dragCounter++;
|
||||
|
||||
if (event.dataTransfer?.types.includes('Files')) {
|
||||
isDragOver = true;
|
||||
if (result?.emptyFiles && result.emptyFiles.length > 0) {
|
||||
emptyFileNames = result.emptyFiles;
|
||||
showEmptyFileDialog = true;
|
||||
if (files) {
|
||||
const emptyFileNamesSet = new Set(result.emptyFiles);
|
||||
fileUpload.uploadedFiles = fileUpload.uploadedFiles.filter(
|
||||
(file) => !emptyFileNamesSet.has(file.name)
|
||||
);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
handleSendLikeScroll();
|
||||
|
||||
await chatStore.sendMessage(message, result?.extras);
|
||||
return true;
|
||||
}
|
||||
|
||||
function handleDragLeave(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
function handleSendLikeScroll() {
|
||||
if (!isMobile.current) {
|
||||
autoScroll.enable();
|
||||
}
|
||||
|
||||
dragCounter--;
|
||||
setTimeout(() => {
|
||||
const container = scroll.chatScrollContainer;
|
||||
if (!container) return;
|
||||
|
||||
if (dragCounter === 0) {
|
||||
isDragOver = false;
|
||||
const lastUserBubble = container.querySelector(
|
||||
'.chat-message:nth-last-child(2) .chat-message-user .chat-message-user-bubble'
|
||||
) as HTMLElement | null;
|
||||
|
||||
if (isMobile.current) {
|
||||
// Keep the last user message bubble just above the input on mobile
|
||||
const bubbleHeight = lastUserBubble?.scrollHeight ?? 0;
|
||||
const baseHeight = container.scrollHeight - innerHeight;
|
||||
|
||||
container.scrollTo({
|
||||
top: bubbleHeight > 0 ? baseHeight - bubbleHeight : baseHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
} else if (lastUserBubble) {
|
||||
// On desktop, place the last user message near the top of the viewport
|
||||
const topPadding = 24;
|
||||
const bubbleRect = lastUserBubble.getBoundingClientRect();
|
||||
container.scrollTo({
|
||||
top: Math.max(0, container.scrollTop + bubbleRect.top - topPadding),
|
||||
behavior: 'smooth'
|
||||
});
|
||||
} else {
|
||||
autoScroll.scrollToBottom();
|
||||
}
|
||||
}, 100);
|
||||
|
||||
if (isMobile.current) {
|
||||
autoScroll.setDisabled(disableAutoScroll);
|
||||
mobileScrollDownHint = true;
|
||||
mobileScrollDownHintLockedUntil = Date.now() + 500;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -207,273 +189,138 @@
|
||||
}
|
||||
}
|
||||
|
||||
function handleDragOver(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
}
|
||||
|
||||
function handleDrop(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
|
||||
isDragOver = false;
|
||||
dragCounter = 0;
|
||||
|
||||
if (event.dataTransfer?.files) {
|
||||
const files = Array.from(event.dataTransfer.files);
|
||||
|
||||
if (isEditing()) {
|
||||
const handler = getAddFilesHandler();
|
||||
|
||||
if (handler) {
|
||||
handler(files);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
processFiles(files);
|
||||
}
|
||||
}
|
||||
|
||||
function handleFileRemove(fileId: string) {
|
||||
uploadedFiles = uploadedFiles.filter((f) => f.id !== fileId);
|
||||
}
|
||||
|
||||
function handleFileUpload(files: File[]) {
|
||||
processFiles(files);
|
||||
}
|
||||
|
||||
const { handleKeydown } = useKeyboardShortcuts({
|
||||
deleteActiveConversation: () => {
|
||||
if (activeConversation()) {
|
||||
showDeleteDialog = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
async function handleSystemPromptAdd(draft: { message: string; files: ChatUploadedFile[] }) {
|
||||
if (draft.message || draft.files.length > 0) {
|
||||
chatStore.savePendingDraft(draft.message, draft.files);
|
||||
}
|
||||
|
||||
await chatStore.addSystemPrompt();
|
||||
}
|
||||
|
||||
function handleScroll() {
|
||||
autoScroll.handleScroll();
|
||||
}
|
||||
|
||||
async function handleSendMessage(message: string, files?: ChatUploadedFile[]): Promise<boolean> {
|
||||
const plainFiles = files ? $state.snapshot(files) : undefined;
|
||||
const result = plainFiles
|
||||
? await parseFilesToMessageExtras(plainFiles, activeModelId ?? undefined)
|
||||
: undefined;
|
||||
|
||||
if (result?.emptyFiles && result.emptyFiles.length > 0) {
|
||||
emptyFileNames = result.emptyFiles;
|
||||
showEmptyFileDialog = true;
|
||||
|
||||
if (files) {
|
||||
const emptyFileNamesSet = new Set(result.emptyFiles);
|
||||
uploadedFiles = uploadedFiles.filter((file) => !emptyFileNamesSet.has(file.name));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const extras = result?.extras;
|
||||
|
||||
// Enable autoscroll for user-initiated message sending
|
||||
autoScroll.enable();
|
||||
await chatStore.sendMessage(message, extras);
|
||||
autoScroll.scrollToBottom();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
async function processFiles(files: File[]) {
|
||||
const generallySupported: File[] = [];
|
||||
const generallyUnsupported: File[] = [];
|
||||
|
||||
for (const file of files) {
|
||||
if (isFileTypeSupported(file.name, file.type)) {
|
||||
generallySupported.push(file);
|
||||
} else {
|
||||
generallyUnsupported.push(file);
|
||||
}
|
||||
}
|
||||
|
||||
// Use model-specific capabilities for file validation
|
||||
const capabilities = {
|
||||
hasVision: hasVisionModality,
|
||||
hasAudio: hasAudioModality,
|
||||
hasVideo: hasVideoModality
|
||||
};
|
||||
const { supportedFiles, unsupportedFiles, modalityReasons } = filterFilesByModalities(
|
||||
generallySupported,
|
||||
capabilities
|
||||
);
|
||||
|
||||
const allUnsupportedFiles = [...generallyUnsupported, ...unsupportedFiles];
|
||||
|
||||
if (allUnsupportedFiles.length > 0) {
|
||||
const supportedTypes: string[] = ['text files', 'PDFs'];
|
||||
|
||||
if (hasVisionModality) supportedTypes.push('images');
|
||||
if (hasAudioModality) supportedTypes.push('audio files');
|
||||
if (hasVideoModality) supportedTypes.push('video files');
|
||||
|
||||
fileErrorData = {
|
||||
generallyUnsupported,
|
||||
modalityUnsupported: unsupportedFiles,
|
||||
modalityReasons,
|
||||
supportedTypes
|
||||
};
|
||||
showFileErrorDialog = true;
|
||||
}
|
||||
|
||||
if (supportedFiles.length > 0) {
|
||||
const processed = await processFilesToChatUploaded(
|
||||
supportedFiles,
|
||||
activeModelId ?? undefined
|
||||
);
|
||||
uploadedFiles = [...uploadedFiles, ...processed];
|
||||
}
|
||||
}
|
||||
|
||||
afterNavigate(() => {
|
||||
if (!disableAutoScroll) {
|
||||
$effect(() => {
|
||||
const shouldDisableAutoScroll =
|
||||
config().disableAutoScroll || (isMobile.current && isCurrentConversationLoading);
|
||||
autoScroll.setDisabled(shouldDisableAutoScroll);
|
||||
if (!shouldDisableAutoScroll) {
|
||||
autoScroll.enable();
|
||||
}
|
||||
});
|
||||
|
||||
function handleMessagesReady() {
|
||||
if (disableAutoScroll) return;
|
||||
|
||||
if (!autoScroll.userScrolledUp) {
|
||||
requestAnimationFrame(() => {
|
||||
autoScroll.scrollToBottom('instant');
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
onMount(() => {
|
||||
const pendingDraft = chatStore.consumePendingDraft();
|
||||
if (pendingDraft) {
|
||||
initialMessage = pendingDraft.message;
|
||||
fileUpload.uploadedFiles = pendingDraft.files;
|
||||
}
|
||||
|
||||
autoScroll.startObserving();
|
||||
|
||||
if (!disableAutoScroll) {
|
||||
autoScroll.enable();
|
||||
}
|
||||
|
||||
const pendingDraft = chatStore.consumePendingDraft();
|
||||
if (pendingDraft) {
|
||||
initialMessage = pendingDraft.message;
|
||||
uploadedFiles = pendingDraft.files;
|
||||
if (isMobile.current && isCurrentConversationLoading) {
|
||||
mobileScrollDownHint = true;
|
||||
mobileScrollDownHintLockedUntil = Date.now() + 500;
|
||||
}
|
||||
|
||||
handleMobileScroll();
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
autoScroll.setContainer(chatScrollContainer);
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
autoScroll.setDisabled(disableAutoScroll);
|
||||
});
|
||||
onDestroy(() => autoScroll.destroy());
|
||||
</script>
|
||||
|
||||
{#if isDragOver}
|
||||
{#if dragAndDrop.isDragOver}
|
||||
<ChatScreenDragOverlay />
|
||||
{/if}
|
||||
|
||||
<svelte:window onkeydown={handleKeydown} />
|
||||
<svelte:window
|
||||
onkeydown={handleKeydown}
|
||||
onscroll={(e) => {
|
||||
scroll.handleScroll(e);
|
||||
handleMobileScroll();
|
||||
if (e.isTrusted && Date.now() > mobileScrollDownHintLockedUntil) {
|
||||
mobileScrollDownHint = false;
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
{#if isServerLoading}
|
||||
<ServerLoadingSplash />
|
||||
{:else}
|
||||
<div
|
||||
bind:this={chatScrollContainer}
|
||||
aria-label="Chat interface with file drop zone"
|
||||
class="flex h-full flex-col overflow-y-auto px-4 md:px-6"
|
||||
ondragenter={handleDragEnter}
|
||||
ondragleave={handleDragLeave}
|
||||
ondragover={handleDragOver}
|
||||
ondrop={handleDrop}
|
||||
onscroll={handleScroll}
|
||||
class="chat-screen flex grow flex-col min-h-[calc(100dvh-1rem)] md:min-h-full px-4 md:py-0 pt-12 pb-48 md:pb-4"
|
||||
style:--chat-form-bottom-position={chatFormBottomPosition}
|
||||
ondragenter={dragAndDrop.dragHandlers.dragenter}
|
||||
ondragleave={dragAndDrop.dragHandlers.dragleave}
|
||||
ondragover={dragAndDrop.dragHandlers.dragover}
|
||||
ondrop={dragAndDrop.dragHandlers.drop}
|
||||
role="main"
|
||||
>
|
||||
<div class="flex grow flex-col pt-14">
|
||||
{#if !isEmpty}
|
||||
<ChatMessages
|
||||
messages={activeMessages()}
|
||||
onMessagesReady={handleMessagesReady}
|
||||
onUserAction={() => {
|
||||
autoScroll.enable();
|
||||
if (!autoScroll.userScrolledUp) {
|
||||
autoScroll.scrollToBottom();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
{/if}
|
||||
{#if !isEmpty}
|
||||
<ChatMessages
|
||||
messages={activeMessages()}
|
||||
onUserAction={() => {
|
||||
handleSendLikeScroll();
|
||||
}}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
<div
|
||||
class={[
|
||||
'pointer-events-none sticky right-4 left-4 mt-auto transition-all duration-200',
|
||||
isEmpty ? 'bottom-[calc(50dvh-7rem)]' : 'bottom-4 pt-24 md:pt-32'
|
||||
]}
|
||||
>
|
||||
<ChatScreenGreeting {isEmpty} />
|
||||
<div
|
||||
class={[
|
||||
'pointer-events-none md:sticky fixed mt-auto transition-all duration-200',
|
||||
device.isStandalone
|
||||
? 'bottom-6 right-4 left-4'
|
||||
: device.isIOSSafari
|
||||
? 'bottom-1 left-2 right-2'
|
||||
: 'bottom-2 right-2 left-2',
|
||||
isEmpty ? 'md:bottom-[calc(50dvh-7rem)] 2xl:bottom-[calc(50dvh-4rem)]' : 'md:bottom-4'
|
||||
]}
|
||||
style:padding-top={!isEmpty ? 'var(--chat-form-padding-top)' : undefined}
|
||||
>
|
||||
<ChatScreenGreeting {isEmpty} />
|
||||
|
||||
<ChatScreenActionScrollDown
|
||||
container={chatScrollContainer}
|
||||
hasProcessingInfoVisible={processingInfoVisible}
|
||||
/>
|
||||
<ChatScreenServerError />
|
||||
|
||||
<ChatScreenProcessingInfo onVisibilityChange={handleProcessingInfoVisibility} />
|
||||
|
||||
<ChatScreenServerError />
|
||||
|
||||
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl">
|
||||
<ChatScreenForm
|
||||
disabled={hasPropsError || isEditing()}
|
||||
{initialMessage}
|
||||
isLoading={isCurrentConversationLoading}
|
||||
onFileRemove={handleFileRemove}
|
||||
onFileUpload={handleFileUpload}
|
||||
onSend={handleSendMessage}
|
||||
onStop={() => chatStore.stopGeneration()}
|
||||
onSystemPromptAdd={handleSystemPromptAdd}
|
||||
bind:uploadedFiles
|
||||
<div class="pointer-events-none flex flex-col gap-6 items-center w-full">
|
||||
{#if (isMobile.current ? mobileScrollDownHint || isMobileUserScrolledUp : autoScroll.userScrolledUp) && page.url.hash.includes(ROUTES.CHAT) && page.params.id}
|
||||
<ChatScreenActionScrollDown
|
||||
onclick={() => {
|
||||
mobileScrollDownHint = false;
|
||||
scroll.chatScrollContainer?.scrollTo({
|
||||
top: scroll.chatScrollContainer.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if showProcessingInfo}
|
||||
<ChatScreenProcessingInfo />
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<ChatScreenForm
|
||||
class="pointer-events-auto conversation-chat-form"
|
||||
disabled={hasPropsError || isEditing()}
|
||||
{initialMessage}
|
||||
isLoading={isCurrentConversationLoading}
|
||||
onFileRemove={fileUpload.handleFileRemove}
|
||||
onFileUpload={fileUpload.handleFileUpload}
|
||||
onSend={handleSendMessage}
|
||||
onStop={() => chatStore.stopGeneration()}
|
||||
onSystemPromptAdd={handleSystemPromptAdd}
|
||||
bind:uploadedFiles={fileUpload.uploadedFiles}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<DialogFileUploadError bind:open={showFileErrorDialog} {fileErrorData} />
|
||||
|
||||
<DialogConfirmation
|
||||
bind:open={showDeleteDialog}
|
||||
title="Delete Conversation"
|
||||
description="Are you sure you want to delete this conversation? This action cannot be undone and will permanently remove all messages in this conversation."
|
||||
confirmText="Delete"
|
||||
cancelText="Cancel"
|
||||
variant="destructive"
|
||||
icon={Trash2}
|
||||
onConfirm={handleDeleteConfirm}
|
||||
onCancel={() => (showDeleteDialog = false)}
|
||||
/>
|
||||
|
||||
<DialogEmptyFileAlert
|
||||
bind:open={showEmptyFileDialog}
|
||||
emptyFiles={emptyFileNames}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) {
|
||||
emptyFileNames = [];
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
<DialogChatError
|
||||
message={activeErrorDialog?.message ?? ''}
|
||||
contextInfo={activeErrorDialog?.contextInfo}
|
||||
onOpenChange={handleErrorDialogOpenChange}
|
||||
open={Boolean(activeErrorDialog)}
|
||||
type={activeErrorDialog?.type ?? ErrorDialogType.SERVER}
|
||||
<ChatScreenDialogsAndAlerts
|
||||
{showDeleteDialog}
|
||||
{handleDeleteConfirm}
|
||||
{showEmptyFileDialog}
|
||||
{emptyFileNames}
|
||||
{activeErrorDialog}
|
||||
{handleErrorDialogOpenChange}
|
||||
{fileUpload}
|
||||
/>
|
||||
|
||||
@@ -1,58 +1,18 @@
|
||||
<script lang="ts">
|
||||
import { ArrowDown } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import ActionIcon from '$lib/components/app/actions/ActionIcon.svelte';
|
||||
|
||||
interface Props {
|
||||
container: HTMLDivElement | undefined;
|
||||
hasProcessingInfoVisible: boolean;
|
||||
}
|
||||
|
||||
let { container, hasProcessingInfoVisible }: Props = $props();
|
||||
|
||||
let show = $state(false);
|
||||
|
||||
let buttonBottom = $derived(hasProcessingInfoVisible ? '2rem' : '0');
|
||||
|
||||
function checkVisibility() {
|
||||
if (!container) return;
|
||||
const { scrollTop, scrollHeight, clientHeight } = container;
|
||||
const distanceFromBottom = scrollHeight - clientHeight - scrollTop;
|
||||
show = distanceFromBottom > clientHeight * 0.5;
|
||||
}
|
||||
|
||||
function scrollToBottom() {
|
||||
if (container) {
|
||||
container.scrollTo({
|
||||
top: container.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
const c = container;
|
||||
if (c) {
|
||||
c.addEventListener('scroll', checkVisibility);
|
||||
checkVisibility();
|
||||
return () => {
|
||||
c.removeEventListener('scroll', checkVisibility);
|
||||
};
|
||||
}
|
||||
});
|
||||
let { onclick }: { onclick: (e?: MouseEvent) => void } = $props();
|
||||
</script>
|
||||
|
||||
<div class="relative z-50 mx-auto mb-4 flex max-w-[48rem] justify-center">
|
||||
<Button
|
||||
onclick={scrollToBottom}
|
||||
variant="secondary"
|
||||
size="icon"
|
||||
disabled={!show}
|
||||
class="pointer-events-auto absolute h-10 w-10 rounded-full bg-background/80 shadow-lg backdrop-blur-sm transition-all duration-200 hover:bg-muted/80"
|
||||
style="bottom: {buttonBottom}; transform: translateY({show ? '0' : '2rem'}); opacity: {show
|
||||
? 1
|
||||
: 0};"
|
||||
aria-label="Scroll to bottom"
|
||||
>
|
||||
<ArrowDown class="h-4 w-4" />
|
||||
</Button>
|
||||
<div class="pointer-events-auto flex justify-center relative h-0">
|
||||
<ActionIcon
|
||||
icon={ArrowDown}
|
||||
{onclick}
|
||||
ariaLabel="Scroll to bottom"
|
||||
tooltip="Scroll to bottom"
|
||||
size="lg"
|
||||
iconSize="h-4 w-4"
|
||||
class="h-9 w-9 rounded-full bg-accent text-accent-foreground absolute bottom-4 shadow-md"
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
<script lang="ts">
|
||||
import { Trash2 } from '@lucide/svelte';
|
||||
import { ErrorDialogType } from '$lib/enums';
|
||||
import {
|
||||
DialogChatError,
|
||||
DialogConfirmation,
|
||||
DialogEmptyFileAlert,
|
||||
DialogFileUploadError
|
||||
} from '$lib/components/app';
|
||||
|
||||
let {
|
||||
showDeleteDialog,
|
||||
handleDeleteConfirm,
|
||||
showEmptyFileDialog,
|
||||
emptyFileNames,
|
||||
activeErrorDialog,
|
||||
handleErrorDialogOpenChange,
|
||||
fileUpload
|
||||
} = $props();
|
||||
</script>
|
||||
|
||||
<DialogFileUploadError
|
||||
bind:open={fileUpload.showFileErrorDialog}
|
||||
fileErrorData={fileUpload.fileErrorData}
|
||||
/>
|
||||
|
||||
<DialogConfirmation
|
||||
bind:open={showDeleteDialog}
|
||||
title="Delete Conversation"
|
||||
description="Are you sure you want to delete this conversation? This action cannot be undone and will permanently remove all messages in this conversation."
|
||||
confirmText="Delete"
|
||||
cancelText="Cancel"
|
||||
variant="destructive"
|
||||
icon={Trash2}
|
||||
onConfirm={handleDeleteConfirm}
|
||||
onCancel={() => (showDeleteDialog = false)}
|
||||
/>
|
||||
|
||||
<DialogEmptyFileAlert
|
||||
bind:open={showEmptyFileDialog}
|
||||
emptyFiles={emptyFileNames}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) {
|
||||
emptyFileNames = [];
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
<DialogChatError
|
||||
message={activeErrorDialog?.message ?? ''}
|
||||
contextInfo={activeErrorDialog?.contextInfo}
|
||||
onOpenChange={handleErrorDialogOpenChange}
|
||||
open={Boolean(activeErrorDialog)}
|
||||
type={activeErrorDialog?.type ?? ErrorDialogType.SERVER}
|
||||
/>
|
||||
@@ -2,6 +2,7 @@
|
||||
import { afterNavigate } from '$app/navigation';
|
||||
import { page } from '$app/state';
|
||||
import { ChatForm } from '$lib/components/app';
|
||||
import { isMobile } from '$lib/stores/viewport.svelte';
|
||||
import { onMount } from 'svelte';
|
||||
import { useDraftMessages } from '$lib/hooks/use-draft-messages.svelte';
|
||||
|
||||
@@ -32,7 +33,30 @@
|
||||
}: Props = $props();
|
||||
|
||||
let chatFormRef: ChatForm | undefined = $state(undefined);
|
||||
let formWrapperEl: HTMLDivElement | undefined = $state();
|
||||
let chatId = $derived(page.params.id as string | undefined);
|
||||
|
||||
$effect(() => {
|
||||
if (!formWrapperEl) return;
|
||||
|
||||
const formEl = formWrapperEl.querySelector('form') as HTMLElement | null;
|
||||
if (!formEl) return;
|
||||
|
||||
const updateHeight = () => {
|
||||
const height = Math.round(formEl.getBoundingClientRect().height);
|
||||
document.documentElement.style.setProperty('--chat-form-height', `${height}px`);
|
||||
};
|
||||
|
||||
updateHeight();
|
||||
|
||||
const resizeObserver = new ResizeObserver(updateHeight);
|
||||
resizeObserver.observe(formEl);
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
document.documentElement.style.removeProperty('--chat-form-height');
|
||||
};
|
||||
});
|
||||
let hasLoadingAttachments = $derived(uploadedFiles.some((f) => f.isLoading));
|
||||
let message = $derived(initialMessage);
|
||||
let previousIsLoading = $derived(isLoading);
|
||||
@@ -83,12 +107,14 @@
|
||||
}
|
||||
|
||||
onMount(() => {
|
||||
setTimeout(() => chatFormRef?.focus(), 10);
|
||||
if (!isMobile.current) {
|
||||
setTimeout(() => chatFormRef?.focus(), 100);
|
||||
}
|
||||
});
|
||||
|
||||
afterNavigate((navigation) => {
|
||||
if (navigation?.from != null) {
|
||||
setTimeout(() => chatFormRef?.focus(), 10);
|
||||
if (navigation?.from != null && !isMobile.current) {
|
||||
setTimeout(() => chatFormRef?.focus(), 100);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -108,12 +134,12 @@
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="relative mx-auto max-w-[48rem]">
|
||||
<div class="chat-screen-form-wrapper" bind:this={formWrapperEl}>
|
||||
<ChatForm
|
||||
class="mx-auto max-w-3xl {className}"
|
||||
bind:this={chatFormRef}
|
||||
bind:value={message}
|
||||
bind:uploadedFiles
|
||||
class={className}
|
||||
{disabled}
|
||||
{isLoading}
|
||||
showMcpPromptButton
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
<script lang="ts">
|
||||
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
|
||||
import { serverStore } from '$lib/stores/server.svelte';
|
||||
|
||||
interface Props {
|
||||
@@ -11,10 +10,9 @@
|
||||
|
||||
<div
|
||||
class={[
|
||||
'pointer-events-none mb-4 hidden px-4 text-center',
|
||||
isEmpty && 'pointer-events-auto block!'
|
||||
'pointer-events-none mb-4 hidden px-4 text-center text-balance',
|
||||
isEmpty && 'mb-[calc(50dvh-8rem)] md:mb-6 pointer-events-auto block!'
|
||||
]}
|
||||
use:fadeInView={{ duration: 300 }}
|
||||
>
|
||||
<h1 class="mb-2 text-2xl font-semibold tracking-tight md:text-3xl">Hello there</h1>
|
||||
|
||||
|
||||
@@ -5,13 +5,8 @@
|
||||
import { chatStore, isLoading, isChatStreaming } from '$lib/stores/chat.svelte';
|
||||
import { activeMessages, activeConversation } from '$lib/stores/conversations.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { getProcessingInfoContext } from '$lib/contexts';
|
||||
import { page } from '$app/state';
|
||||
|
||||
const processingState = useProcessingState();
|
||||
const processingInfoCtx = getProcessingInfoContext();
|
||||
|
||||
let showProcessingInfo = $derived(processingInfoCtx.showProcessingInfo);
|
||||
|
||||
let isCurrentConversationLoading = $derived(isLoading());
|
||||
let isStreaming = $derived(isChatStreaming());
|
||||
@@ -70,8 +65,8 @@
|
||||
|
||||
<div
|
||||
class={[
|
||||
'chat-processing-info-container pointer-events-none relative',
|
||||
page.params.id && showProcessingInfo && 'visible'
|
||||
'chat-processing-info-container pointer-events-none relative w-full hidden md:block',
|
||||
processingVisible && 'visible'
|
||||
]}
|
||||
>
|
||||
<div class="chat-processing-info-content absolute bottom-4 left-1/2 -translate-x-1/2">
|
||||
|
||||
@@ -677,13 +677,6 @@ export { default as ChatScreenForm } from './ChatScreen/ChatScreenForm.svelte';
|
||||
*/
|
||||
export { default as ChatScreenProcessingInfo } from './ChatScreen/ChatScreenProcessingInfo.svelte';
|
||||
|
||||
/**
|
||||
* Scroll-to-bottom action button. Displays a floating button when the user
|
||||
* has scrolled up more than half a viewport height from the bottom.
|
||||
* Takes the chat container element as a prop to manage scroll state internally.
|
||||
*/
|
||||
export { default as ChatScreenActionScrollDown } from './ChatScreen/ChatScreenActionScrollDown.svelte';
|
||||
|
||||
/**
|
||||
* Server error alert displayed when the server is unreachable.
|
||||
* Shows the error message with a retry button.
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import { Search, X } from '@lucide/svelte';
|
||||
|
||||
interface Props {
|
||||
autofocus?: boolean;
|
||||
value?: string;
|
||||
placeholder?: string;
|
||||
onInput?: (value: string) => void;
|
||||
@@ -15,6 +16,7 @@
|
||||
}
|
||||
|
||||
let {
|
||||
autofocus,
|
||||
value = $bindable(''),
|
||||
placeholder = 'Search...',
|
||||
onInput,
|
||||
@@ -39,7 +41,7 @@
|
||||
if (value) {
|
||||
value = '';
|
||||
onInput?.('');
|
||||
ref?.focus();
|
||||
ref?.focus({ preventScroll: true });
|
||||
} else {
|
||||
onClose?.();
|
||||
}
|
||||
@@ -52,6 +54,7 @@
|
||||
/>
|
||||
|
||||
<Input
|
||||
{autofocus}
|
||||
{id}
|
||||
bind:value
|
||||
bind:ref
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
<script>
|
||||
import logoMark from '$lib/assets/logo.svg?raw';
|
||||
let { class: className = '', style = '' } = $props();
|
||||
</script>
|
||||
|
||||
<div class={className} {style}>
|
||||
{@html logoMark}
|
||||
</div>
|
||||
|
||||
<style>
|
||||
div :global(svg) {
|
||||
width: var(--size, 1rem);
|
||||
height: var(--size, 1rem);
|
||||
}
|
||||
</style>
|
||||
@@ -51,3 +51,11 @@ export { default as KeyboardShortcutInfo } from './KeyboardShortcutInfo.svelte';
|
||||
* Preview button is shown only for HTML code blocks.
|
||||
*/
|
||||
export { default as CodeBlockActions } from './CodeBlockActions.svelte';
|
||||
|
||||
/**
|
||||
* **Logo** - Application brand mark
|
||||
*
|
||||
* Inline SVG of the application logo. Accepts styling via the standard
|
||||
* `class` and `style` props and inherits color via `currentColor`.
|
||||
*/
|
||||
export { default as Logo } from './Logo.svelte';
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
<script lang="ts">
|
||||
let { percent }: { percent: number } = $props();
|
||||
</script>
|
||||
|
||||
<!-- thin determinate load bar pinned to the bottom edge, pulsing while it fills -->
|
||||
<div class="pointer-events-none absolute inset-x-0 bottom-0 h-0.5 overflow-hidden rounded-b-sm">
|
||||
<div
|
||||
class="h-full animate-pulse bg-primary transition-[width] duration-200 ease-out"
|
||||
style="width: {percent}%"
|
||||
></div>
|
||||
</div>
|
||||
@@ -2,8 +2,10 @@
|
||||
import { ChevronDown, Loader2, Package } from '@lucide/svelte';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { KeyboardKey } from '$lib/enums';
|
||||
import { KeyboardKey, ServerModelStatus } from '$lib/enums';
|
||||
import { useModelsSelector } from '$lib/hooks/use-models-selector.svelte';
|
||||
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
|
||||
import { modelLoadFraction } from '$lib/utils';
|
||||
import {
|
||||
DialogModelInformation,
|
||||
DropdownMenuSearchable,
|
||||
@@ -11,6 +13,7 @@
|
||||
ModelsSelectorList,
|
||||
ModelsSelectorOption
|
||||
} from '$lib/components/app';
|
||||
import ModelLoadHighlight from './ModelLoadHighlight.svelte';
|
||||
import type { ModelItem } from './utils';
|
||||
|
||||
interface Props {
|
||||
@@ -113,6 +116,17 @@
|
||||
{/if}
|
||||
{:else}
|
||||
{@const selectedOption = ms.getDisplayOption()}
|
||||
{@const triggerModel = selectedOption?.model}
|
||||
{@const triggerStatus = triggerModel
|
||||
? routerModels().find((m) => m.id === triggerModel)?.status?.value
|
||||
: undefined}
|
||||
{@const triggerLoading =
|
||||
!!triggerModel &&
|
||||
(triggerStatus === ServerModelStatus.LOADING ||
|
||||
modelsStore.isModelOperationInProgress(triggerModel))}
|
||||
{@const triggerLoadPercent = triggerLoading
|
||||
? Math.round(modelLoadFraction(modelsStore.getLoadProgress(triggerModel)) * 100)
|
||||
: 0}
|
||||
|
||||
{#if ms.isRouter}
|
||||
<DropdownMenu.Root bind:open={isOpen} onOpenChange={ms.handleOpenChange}>
|
||||
@@ -123,7 +137,7 @@
|
||||
<DropdownMenu.Trigger
|
||||
{...props}
|
||||
class={[
|
||||
`inline-grid cursor-pointer grid-cols-[1fr_auto_1fr] items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
|
||||
`relative inline-grid cursor-pointer grid-cols-[1fr_auto_1fr] items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
|
||||
!ms.isCurrentModelInCache
|
||||
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
|
||||
: forceForegroundText
|
||||
@@ -154,6 +168,10 @@
|
||||
{:else}
|
||||
<ChevronDown class="h-3 w-3.5 shrink-0" />
|
||||
{/if}
|
||||
|
||||
{#if triggerLoading}
|
||||
<ModelLoadHighlight percent={triggerLoadPercent} />
|
||||
{/if}
|
||||
</DropdownMenu.Trigger>
|
||||
{/snippet}
|
||||
</Tooltip.Trigger>
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
RotateCw
|
||||
} from '@lucide/svelte';
|
||||
import { ActionIcon, ModelId } from '$lib/components/app';
|
||||
import ModelLoadHighlight from './ModelLoadHighlight.svelte';
|
||||
import type { ModelOption } from '$lib/types/models';
|
||||
import { ServerModelStatus } from '$lib/enums';
|
||||
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
|
||||
@@ -119,11 +120,11 @@
|
||||
</div>
|
||||
|
||||
{#if isLoading}
|
||||
<div class="flex w-4 [@media(pointer:coarse)]:w-5 items-center justify-center">
|
||||
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-5">
|
||||
<Loader2 class="h-4 w-4 animate-spin text-muted-foreground" />
|
||||
</div>
|
||||
{:else if isFailed}
|
||||
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
|
||||
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
|
||||
<CircleAlert
|
||||
class="h-3.5 w-3.5 text-red-500 group-hover:hidden [@media(pointer:coarse)]:hidden"
|
||||
/>
|
||||
@@ -140,7 +141,7 @@
|
||||
</div>
|
||||
</div>
|
||||
{:else if isSleeping}
|
||||
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
|
||||
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
|
||||
<span
|
||||
class="h-2 w-2 rounded-full bg-orange-400 group-hover:hidden [@media(pointer:coarse)]:hidden"
|
||||
></span>
|
||||
@@ -159,7 +160,7 @@
|
||||
</div>
|
||||
</div>
|
||||
{:else if isLoaded}
|
||||
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
|
||||
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
|
||||
<span
|
||||
class="h-2 w-2 rounded-full bg-green-500 group-hover:hidden [@media(pointer:coarse)]:hidden"
|
||||
></span>
|
||||
@@ -176,7 +177,7 @@
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
|
||||
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
|
||||
<span
|
||||
class="h-2 w-2 rounded-full bg-muted-foreground/50 group-hover:hidden [@media(pointer:coarse)]:hidden"
|
||||
></span>
|
||||
@@ -196,13 +197,6 @@
|
||||
</div>
|
||||
|
||||
{#if isLoading}
|
||||
<div
|
||||
class="pointer-events-none absolute inset-x-0 bottom-0 h-0.5 overflow-hidden rounded-b-sm bg-muted"
|
||||
>
|
||||
<div
|
||||
class="h-full bg-primary transition-[width] duration-200 ease-out"
|
||||
style="width: {loadPercent}%"
|
||||
></div>
|
||||
</div>
|
||||
<ModelLoadHighlight percent={loadPercent} />
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -8,6 +8,10 @@
|
||||
ModelsSelectorList,
|
||||
SearchInput
|
||||
} from '$lib/components/app';
|
||||
import ModelLoadHighlight from './ModelLoadHighlight.svelte';
|
||||
import { ServerModelStatus } from '$lib/enums';
|
||||
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
|
||||
import { modelLoadFraction } from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -61,12 +65,23 @@
|
||||
<p class="text-xs text-muted-foreground">No models available.</p>
|
||||
{:else}
|
||||
{@const selectedOption = ms.getDisplayOption()}
|
||||
{@const triggerModel = selectedOption?.model}
|
||||
{@const triggerStatus = triggerModel
|
||||
? routerModels().find((m) => m.id === triggerModel)?.status?.value
|
||||
: undefined}
|
||||
{@const triggerLoading =
|
||||
!!triggerModel &&
|
||||
(triggerStatus === ServerModelStatus.LOADING ||
|
||||
modelsStore.isModelOperationInProgress(triggerModel))}
|
||||
{@const triggerLoadPercent = triggerLoading
|
||||
? Math.round(modelLoadFraction(modelsStore.getLoadProgress(triggerModel)) * 100)
|
||||
: 0}
|
||||
|
||||
{#if ms.isRouter}
|
||||
<button
|
||||
type="button"
|
||||
class={[
|
||||
`inline-flex cursor-pointer items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 max-sm:px-3 max-sm:py-2 text-xs max-sm:text-sm shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
|
||||
`relative inline-flex cursor-pointer items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 max-sm:px-3 max-sm:py-2 max-sm:text-sm dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
|
||||
!ms.isCurrentModelInCache
|
||||
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
|
||||
: forceForegroundText
|
||||
@@ -99,6 +114,10 @@
|
||||
{:else}
|
||||
<ChevronDown class="h-3 w-3.5 shrink-0" />
|
||||
{/if}
|
||||
|
||||
{#if triggerLoading}
|
||||
<ModelLoadHighlight percent={triggerLoadPercent} />
|
||||
{/if}
|
||||
</button>
|
||||
|
||||
<Sheet.Root bind:open={sheetOpen} onOpenChange={handleSheetOpenChange}>
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { goto } from '$app/navigation';
|
||||
import { page } from '$app/state';
|
||||
import { ActionIcon } from '$lib/components/app';
|
||||
import {
|
||||
ICON_STRIP_TRANSITION_DURATION,
|
||||
ICON_STRIP_TRANSITION_DELAY_MULTIPLIER,
|
||||
SIDEBAR_ACTIONS_ITEMS
|
||||
} from '$lib/constants';
|
||||
import { TooltipSide } from '$lib/enums';
|
||||
import { fade } from 'svelte/transition';
|
||||
import { circIn } from 'svelte/easing';
|
||||
import { onMount } from 'svelte';
|
||||
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
|
||||
|
||||
interface Props {
|
||||
sidebarOpen: boolean;
|
||||
onSearchClick: () => void;
|
||||
}
|
||||
|
||||
let { sidebarOpen = false, onSearchClick }: Props = $props();
|
||||
|
||||
const { handleKeydown } = useKeyboardShortcuts({ activateSearchMode: () => onSearchClick() });
|
||||
|
||||
let initialized = $state(false);
|
||||
let showIcons = $derived(!sidebarOpen);
|
||||
|
||||
showIcons = false;
|
||||
|
||||
onMount(() => {
|
||||
showIcons = !sidebarOpen;
|
||||
|
||||
setTimeout(() => {
|
||||
initialized = true;
|
||||
}, ICON_STRIP_TRANSITION_DELAY_MULTIPLIER * SIDEBAR_ACTIONS_ITEMS.length);
|
||||
});
|
||||
</script>
|
||||
|
||||
<svelte:window onkeydown={handleKeydown} />
|
||||
|
||||
<div
|
||||
class="hidden shrink-0 transition-[width] duration-200 ease-linear md:block {sidebarOpen
|
||||
? 'w-0'
|
||||
: 'w-[calc(var(--sidebar-width-icon)+1.5rem)]'}"
|
||||
></div>
|
||||
<aside
|
||||
class="fixed top-0 bottom-0 left-0 z-10 hidden w-[calc(var(--sidebar-width-icon)+1.5rem)] flex-col items-center justify-between py-3 transition-opacity duration-200 ease-linear md:flex {sidebarOpen
|
||||
? 'pointer-events-none opacity-0'
|
||||
: 'opacity-100'}"
|
||||
>
|
||||
<div class="mt-12 flex flex-col items-center gap-1">
|
||||
{#each SIDEBAR_ACTIONS_ITEMS as item, i (item.tooltip)}
|
||||
{@const onclick = item.route ? () => goto(item.route!) : onSearchClick}
|
||||
{@const isActive = item.activeRouteId
|
||||
? page.route.id === item.activeRouteId
|
||||
: item.activeRoutePrefix
|
||||
? !!page.route.id?.startsWith(item.activeRoutePrefix)
|
||||
: false}
|
||||
{#if showIcons}
|
||||
<div
|
||||
in:fade={{
|
||||
duration: ICON_STRIP_TRANSITION_DURATION,
|
||||
delay: !initialized
|
||||
? ICON_STRIP_TRANSITION_DELAY_MULTIPLIER + i * ICON_STRIP_TRANSITION_DELAY_MULTIPLIER
|
||||
: 0,
|
||||
easing: circIn
|
||||
}}
|
||||
>
|
||||
<ActionIcon
|
||||
icon={item.icon}
|
||||
tooltip={item.tooltip}
|
||||
tooltipSide={TooltipSide.RIGHT}
|
||||
size="lg"
|
||||
iconSize="h-4 w-4"
|
||||
class="h-9 w-9 rounded-full hover:bg-accent! {isActive
|
||||
? 'bg-accent text-accent-foreground'
|
||||
: ''}"
|
||||
{onclick}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
||||
</aside>
|
||||
+234
-295
@@ -1,40 +1,67 @@
|
||||
<script lang="ts">
|
||||
import { goto } from '$app/navigation';
|
||||
import { page } from '$app/state';
|
||||
import { Trash2, Pencil, Pin, X } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { DialogConfirmation } from '$lib/components/app';
|
||||
import SidebarNavigationActions from './SidebarNavigationActions.svelte';
|
||||
import SidebarNavigationConversationItem from './SidebarNavigationConversationItem.svelte';
|
||||
import { Checkbox } from '$lib/components/ui/checkbox';
|
||||
import Label from '$lib/components/ui/label/label.svelte';
|
||||
import ScrollArea from '$lib/components/ui/scroll-area/scroll-area.svelte';
|
||||
import * as Sidebar from '$lib/components/ui/sidebar';
|
||||
import Input from '$lib/components/ui/input/input.svelte';
|
||||
import { ROUTES } from '$lib/constants/routes';
|
||||
import { RouterService } from '$lib/services/router.service';
|
||||
import { PanelLeftClose, PanelLeftOpen, X } from '@lucide/svelte';
|
||||
import {
|
||||
conversationsStore,
|
||||
conversations,
|
||||
buildConversationTree
|
||||
} from '$lib/stores/conversations.svelte';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { getPreviewText } from '$lib/utils';
|
||||
import { APP_NAME } from '$lib/constants';
|
||||
ActionIcon,
|
||||
Logo,
|
||||
SidebarNavigationConversationList,
|
||||
SidebarNavigationActions
|
||||
} from '$lib/components/app';
|
||||
import { ROUTES } from '$lib/constants';
|
||||
import { fade } from 'svelte/transition';
|
||||
|
||||
const sidebar = Sidebar.useSidebar();
|
||||
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
|
||||
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { RouterService } from '$lib/services/router.service';
|
||||
import { isMobile } from '$lib/stores/viewport.svelte';
|
||||
import { TooltipSide } from '$lib/enums';
|
||||
import { device } from '$lib/stores/device.svelte';
|
||||
import { circIn } from 'svelte/easing';
|
||||
|
||||
interface Props {
|
||||
onSearchClick?: () => void;
|
||||
}
|
||||
|
||||
let { onSearchClick = () => {} }: Props = $props();
|
||||
|
||||
const { handleKeydown } = useKeyboardShortcuts({ activateSearchMode: () => onSearchClick() });
|
||||
|
||||
let isExpandedMode = $state(false);
|
||||
let hoveredTooltip = $state<string | null>(null);
|
||||
let logoHovered = $state(false);
|
||||
|
||||
const isStripExpanded = $derived(isExpandedMode || hoveredTooltip !== null);
|
||||
const isOnMobile = $derived(isMobile.current);
|
||||
|
||||
function toggleExpandedMode() {
|
||||
isExpandedMode = !isExpandedMode;
|
||||
if (!isExpandedMode) {
|
||||
hoveredTooltip = null;
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (!isExpandedMode) {
|
||||
isSearchModeActive = false;
|
||||
searchQuery = '';
|
||||
cancelMobileCollapse();
|
||||
}
|
||||
});
|
||||
|
||||
// On mobile the dedicated /search route hides the sidebar (see the aside
|
||||
// render guard below). Collapse it as we enter /search so it doesn't
|
||||
// reappear expanded when the user navigates back via the back button.
|
||||
$effect(() => {
|
||||
if (isMobile.current && page.url.hash.includes(ROUTES.SEARCH)) {
|
||||
isExpandedMode = false;
|
||||
}
|
||||
});
|
||||
|
||||
let currentChatId = $derived(page.params.id);
|
||||
let isSearchModeActive = $state(false);
|
||||
let searchQuery = $state('');
|
||||
let showDeleteDialog = $state(false);
|
||||
let deleteWithForks = $state(false);
|
||||
let showEditDialog = $state(false);
|
||||
let selectedConversation = $state<DatabaseConversation | null>(null);
|
||||
let editedName = $state('');
|
||||
let selectedConversationNamePreview = $derived.by(() =>
|
||||
selectedConversation ? getPreviewText(selectedConversation.name) : ''
|
||||
);
|
||||
|
||||
let filteredConversations = $derived.by(() => {
|
||||
if (isSearchModeActive) {
|
||||
@@ -50,294 +77,206 @@
|
||||
return conversations();
|
||||
});
|
||||
|
||||
let conversationTree = $derived(buildConversationTree(filteredConversations));
|
||||
|
||||
let pinnedConversations = $derived.by(() => {
|
||||
return conversationTree.filter(({ conversation }) => conversation.pinned);
|
||||
});
|
||||
|
||||
let unpinnedConversations = $derived.by(() => {
|
||||
return conversationTree.filter(({ conversation }) => !conversation.pinned);
|
||||
});
|
||||
|
||||
let selectedConversationHasDescendants = $derived.by(() => {
|
||||
if (!selectedConversation) return false;
|
||||
|
||||
const allConvs = conversations();
|
||||
const queue = [selectedConversation.id];
|
||||
|
||||
while (queue.length > 0) {
|
||||
const parentId = queue.pop()!;
|
||||
|
||||
for (const c of allConvs) {
|
||||
if (c.forkedFromConversationId === parentId) return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
|
||||
async function handleDeleteConversation(id: string) {
|
||||
const conversation = conversations().find((conv) => conv.id === id);
|
||||
if (conversation) {
|
||||
selectedConversation = conversation;
|
||||
deleteWithForks = false;
|
||||
showDeleteDialog = true;
|
||||
async function selectConversation(id: string) {
|
||||
if (isMobile.current) {
|
||||
scheduleMobileCollapse();
|
||||
}
|
||||
await goto(RouterService.chat(id));
|
||||
}
|
||||
|
||||
async function handleEditConversation(id: string) {
|
||||
const conversation = conversations().find((conv) => conv.id === id);
|
||||
if (conversation) {
|
||||
selectedConversation = conversation;
|
||||
editedName = conversation.name;
|
||||
showEditDialog = true;
|
||||
if (!conversation) return;
|
||||
|
||||
const newName = window.prompt('Rename conversation', conversation.name);
|
||||
if (newName && newName.trim()) {
|
||||
await conversationsStore.updateConversationName(id, newName.trim());
|
||||
}
|
||||
}
|
||||
|
||||
function handleConfirmDelete() {
|
||||
if (selectedConversation) {
|
||||
const convId = selectedConversation.id;
|
||||
const withForks = deleteWithForks;
|
||||
showDeleteDialog = false;
|
||||
async function handleDeleteConversation(id: string) {
|
||||
const conversation = conversations().find((conv) => conv.id === id);
|
||||
if (!conversation) return;
|
||||
|
||||
setTimeout(() => {
|
||||
conversationsStore.deleteConversation(convId, {
|
||||
deleteWithForks: withForks
|
||||
});
|
||||
}, 100); // Wait for animation to finish
|
||||
}
|
||||
}
|
||||
const confirmed = window.confirm(
|
||||
`Delete "${conversation.name}"? This action cannot be undone.`
|
||||
);
|
||||
if (!confirmed) return;
|
||||
|
||||
function handleConfirmEdit() {
|
||||
if (!editedName.trim() || !selectedConversation) return;
|
||||
|
||||
showEditDialog = false;
|
||||
|
||||
conversationsStore.updateConversationName(selectedConversation.id, editedName);
|
||||
selectedConversation = null;
|
||||
}
|
||||
|
||||
export function handleMobileSidebarItemClick() {
|
||||
if (sidebar.isMobile) {
|
||||
sidebar.toggle();
|
||||
}
|
||||
}
|
||||
|
||||
let chatSidebarActions: { activateSearch?: () => void } | undefined = $state();
|
||||
let openedForSearch = $state(false);
|
||||
|
||||
export function activateSearchMode() {
|
||||
if (!sidebar.open) {
|
||||
openedForSearch = true;
|
||||
}
|
||||
chatSidebarActions?.activateSearch?.();
|
||||
}
|
||||
|
||||
function handleSearchDeactivated() {
|
||||
if (openedForSearch) {
|
||||
openedForSearch = false;
|
||||
sidebar.toggle();
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (!sidebar.open) {
|
||||
isSearchModeActive = false;
|
||||
searchQuery = '';
|
||||
openedForSearch = false;
|
||||
}
|
||||
});
|
||||
|
||||
export function editActiveConversation() {
|
||||
if (currentChatId) {
|
||||
const activeConversation = filteredConversations.find((conv) => conv.id === currentChatId);
|
||||
|
||||
if (activeConversation) {
|
||||
const event = new CustomEvent('edit-active-conversation', {
|
||||
detail: { conversationId: currentChatId }
|
||||
});
|
||||
document.dispatchEvent(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function selectConversation(id: string) {
|
||||
if (isSearchModeActive) {
|
||||
isSearchModeActive = false;
|
||||
searchQuery = '';
|
||||
}
|
||||
|
||||
handleMobileSidebarItemClick();
|
||||
await goto(RouterService.chat(id));
|
||||
await conversationsStore.deleteConversation(id, { deleteWithForks: false });
|
||||
}
|
||||
|
||||
function handleStopGeneration(id: string) {
|
||||
chatStore.stopGenerationForChat(id);
|
||||
}
|
||||
|
||||
let innerWidth = $state(0);
|
||||
let pendingCollapse = $state<ReturnType<typeof setTimeout> | null>(null);
|
||||
|
||||
function scheduleMobileCollapse() {
|
||||
if (pendingCollapse) {
|
||||
clearTimeout(pendingCollapse);
|
||||
}
|
||||
pendingCollapse = setTimeout(() => {
|
||||
isExpandedMode = false;
|
||||
pendingCollapse = null;
|
||||
}, 100);
|
||||
}
|
||||
|
||||
function cancelMobileCollapse() {
|
||||
if (pendingCollapse) {
|
||||
clearTimeout(pendingCollapse);
|
||||
pendingCollapse = null;
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex h-full flex-col">
|
||||
<ScrollArea class="h-full flex-1">
|
||||
<Sidebar.Header class="gap-4 bg-sidebar/50 p-3 backdrop-blur-lg md:pt-4 md:pb-2">
|
||||
<div class="flex items-center justify-between">
|
||||
<a href={ROUTES.START} onclick={handleMobileSidebarItemClick}>
|
||||
<h1 class="inline-flex items-center gap-1 px-2 text-xl font-semibold">
|
||||
{APP_NAME}
|
||||
</h1>
|
||||
</a>
|
||||
<svelte:window onkeydown={handleKeydown} bind:innerWidth />
|
||||
|
||||
<Button
|
||||
class="rounded-full md:hidden"
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onclick={() => sidebar.toggle()}
|
||||
>
|
||||
<X class="h-4 w-4" />
|
||||
<span class="sr-only">Close sidebar</span>
|
||||
</Button>
|
||||
{#if innerWidth > 768 || (!page.url.hash.includes(ROUTES.SETTINGS) && !page.url.hash.includes(ROUTES.MCP_SERVERS) && !page.url.hash.includes(ROUTES.SEARCH))}
|
||||
<aside
|
||||
class={[
|
||||
// Layout & positioning
|
||||
'fixed md:sticky top-2 left-2 md:left-0 md:ml-2 md:mt-2 pt-2 z-10 w-[calc(100dvw-1rem)]',
|
||||
// Dimensions & overflow
|
||||
'md:h-[calc(100dvh-1.125rem)]',
|
||||
isExpandedMode &&
|
||||
(device.isStandalone
|
||||
? 'h-[calc(100dvh-2rem)]'
|
||||
: device.isIOSDevice
|
||||
? 'h-[calc(100dvh-0.5rem)]'
|
||||
: 'h-[calc(100dvh-1rem)]'),
|
||||
// Shape & depth
|
||||
'rounded-3xl md:rounded-2xl',
|
||||
// Flex layout
|
||||
'flex flex-col justify-between',
|
||||
// Transition
|
||||
'md:transition-[width,padding] duration-200 ease-out',
|
||||
// Expanded state: width, surface, depth
|
||||
isStripExpanded && 'md:w-72 md:bg-muted/60 md:backdrop-blur-xl border-border shadow-md',
|
||||
// Collapsed state
|
||||
!isStripExpanded && 'md:w-12',
|
||||
// Expanded mode flag (for mobile ::before overlay)
|
||||
isExpandedMode && 'is-expanded'
|
||||
]}
|
||||
>
|
||||
<div class="px-2 flex items-center justify-between">
|
||||
<div
|
||||
role="button"
|
||||
tabindex="0"
|
||||
class="relative"
|
||||
onmouseenter={() => (logoHovered = true)}
|
||||
onmouseleave={() => (logoHovered = false)}
|
||||
>
|
||||
<ActionIcon
|
||||
icon={!isExpandedMode && logoHovered && innerWidth > 768 ? PanelLeftOpen : Logo}
|
||||
size="lg"
|
||||
iconSize="h-4.5 w-4.5 md:h-4 md:w-4"
|
||||
class="{isExpandedMode
|
||||
? 'bg-muted! md:bg-foreground/5!'
|
||||
: 'bg-transparent!'} md:h-9 md:w-9 h-10 w-10 rounded-full md:hover:bg-foreground/10! pointer-events-auto"
|
||||
href={isExpandedMode ? ROUTES.START : undefined}
|
||||
onclick={isExpandedMode ? undefined : toggleExpandedMode}
|
||||
tooltip={isExpandedMode ? undefined : 'Open Sidebar'}
|
||||
tooltipSide={TooltipSide.RIGHT}
|
||||
ariaLabel={isExpandedMode ? 'Go to start' : 'Expand navigation'}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<SidebarNavigationActions
|
||||
bind:this={chatSidebarActions}
|
||||
{handleMobileSidebarItemClick}
|
||||
bind:isSearchModeActive
|
||||
bind:searchQuery
|
||||
onSearchDeactivated={handleSearchDeactivated}
|
||||
/>
|
||||
</Sidebar.Header>
|
||||
|
||||
{#if !isSearchModeActive && pinnedConversations.length > 0}
|
||||
<Sidebar.Group class="p-0 px-4">
|
||||
<Sidebar.GroupLabel>
|
||||
<div class="flex items-center gap-1">
|
||||
<Pin class="h-3.5 w-3.5" />
|
||||
<span>Pinned</span>
|
||||
</div>
|
||||
</Sidebar.GroupLabel>
|
||||
<Sidebar.GroupContent>
|
||||
<Sidebar.Menu>
|
||||
{#each pinnedConversations as { conversation, depth } (conversation.id)}
|
||||
<Sidebar.MenuItem class="mb-1 p-0">
|
||||
<SidebarNavigationConversationItem
|
||||
conversation={{
|
||||
id: conversation.id,
|
||||
name: conversation.name,
|
||||
lastModified: conversation.lastModified,
|
||||
currNode: conversation.currNode,
|
||||
forkedFromConversationId: conversation.forkedFromConversationId,
|
||||
pinned: conversation.pinned
|
||||
}}
|
||||
{depth}
|
||||
isActive={currentChatId === conversation.id}
|
||||
onSelect={selectConversation}
|
||||
onEdit={handleEditConversation}
|
||||
onDelete={handleDeleteConversation}
|
||||
onStop={handleStopGeneration}
|
||||
/>
|
||||
</Sidebar.MenuItem>
|
||||
{/each}
|
||||
</Sidebar.Menu>
|
||||
</Sidebar.GroupContent>
|
||||
</Sidebar.Group>
|
||||
{/if}
|
||||
|
||||
<Sidebar.Group class="mt-2 h-[calc(100vh-21rem)] space-y-2 p-0 px-3">
|
||||
{#if (filteredConversations.length > 0 && isSearchModeActive) || !isSearchModeActive}
|
||||
<Sidebar.GroupLabel>
|
||||
{isSearchModeActive ? 'Search results' : 'Recent conversations'}
|
||||
</Sidebar.GroupLabel>
|
||||
{#if isExpandedMode || isOnMobile}
|
||||
<div
|
||||
class="flex items-center transition-all duration-150 ease-out {isMobile.current &&
|
||||
!isExpandedMode
|
||||
? 'opacity-0 h-0!'
|
||||
: ''}"
|
||||
in:fade={{ duration: 150, easing: circIn, delay: 50 }}
|
||||
out:fade={{ duration: 100 }}
|
||||
>
|
||||
<ActionIcon
|
||||
icon={isMobile.current ? X : PanelLeftClose}
|
||||
size="lg"
|
||||
iconSize="h-4.5 w-4.5 md:h-4 md:w-4"
|
||||
class="backdrop-blur-none md:h-9 md:w-9 h-10 w-10 rounded-full mr-1 hover:bg-accent!"
|
||||
onclick={toggleExpandedMode}
|
||||
tooltip="Close Sidebar"
|
||||
tooltipSide={TooltipSide.LEFT}
|
||||
ariaLabel="Collapse navigation"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<Sidebar.GroupContent>
|
||||
<Sidebar.Menu>
|
||||
{#each isSearchModeActive ? conversationTree : unpinnedConversations as { conversation, depth } (conversation.id)}
|
||||
<Sidebar.MenuItem class="mb-1 p-0">
|
||||
<SidebarNavigationConversationItem
|
||||
conversation={{
|
||||
id: conversation.id,
|
||||
name: conversation.name,
|
||||
lastModified: conversation.lastModified,
|
||||
currNode: conversation.currNode,
|
||||
forkedFromConversationId: conversation.forkedFromConversationId,
|
||||
pinned: conversation.pinned
|
||||
}}
|
||||
{depth}
|
||||
isActive={currentChatId === conversation.id}
|
||||
onSelect={selectConversation}
|
||||
onEdit={handleEditConversation}
|
||||
onDelete={handleDeleteConversation}
|
||||
onStop={handleStopGeneration}
|
||||
/>
|
||||
</Sidebar.MenuItem>
|
||||
{/each}
|
||||
|
||||
{#if (isSearchModeActive ? conversationTree : unpinnedConversations).length === 0}
|
||||
<div class="px-2 py-4 text-center">
|
||||
<p class="mb-4 p-4 text-sm text-muted-foreground">
|
||||
{searchQuery.length > 0
|
||||
? 'No results found'
|
||||
: isSearchModeActive
|
||||
? 'Start typing to see results'
|
||||
: 'No conversations yet'}
|
||||
</p>
|
||||
</div>
|
||||
{/if}
|
||||
</Sidebar.Menu>
|
||||
</Sidebar.GroupContent>
|
||||
</Sidebar.Group>
|
||||
</ScrollArea>
|
||||
</div>
|
||||
|
||||
<DialogConfirmation
|
||||
bind:open={showDeleteDialog}
|
||||
title="Delete Conversation"
|
||||
description={selectedConversation
|
||||
? `Are you sure you want to delete "${selectedConversationNamePreview}"? This action cannot be undone and will permanently remove all messages in this conversation.`
|
||||
: ''}
|
||||
confirmText="Delete"
|
||||
cancelText="Cancel"
|
||||
variant="destructive"
|
||||
icon={Trash2}
|
||||
onConfirm={handleConfirmDelete}
|
||||
onCancel={() => {
|
||||
showDeleteDialog = false;
|
||||
selectedConversation = null;
|
||||
}}
|
||||
>
|
||||
{#if selectedConversationHasDescendants}
|
||||
<div class="flex items-center gap-2 py-2">
|
||||
<Checkbox id="delete-with-forks" bind:checked={deleteWithForks} />
|
||||
|
||||
<Label for="delete-with-forks" class="text-sm">Also delete all forked conversations</Label>
|
||||
</div>
|
||||
{/if}
|
||||
</DialogConfirmation>
|
||||
|
||||
<DialogConfirmation
|
||||
bind:open={showEditDialog}
|
||||
title="Edit Conversation Name"
|
||||
description=""
|
||||
confirmText="Save"
|
||||
cancelText="Cancel"
|
||||
icon={Pencil}
|
||||
onConfirm={handleConfirmEdit}
|
||||
onCancel={() => {
|
||||
showEditDialog = false;
|
||||
selectedConversation = null;
|
||||
}}
|
||||
onKeydown={(event) => {
|
||||
if (event.key === 'Enter') {
|
||||
event.preventDefault();
|
||||
event.stopImmediatePropagation();
|
||||
handleConfirmEdit();
|
||||
<div class="mt-2 flex min-h-0 flex-1 flex-col gap-4 md:gap-1 overflow-y-auto">
|
||||
<div
|
||||
class="flex min-h-0 flex-1 flex-col gap-4 md:gap-1 {isMobile.current
|
||||
? 'transition-[opacity,height] duration-200 ease-out'
|
||||
: ''} {isMobile.current && !isExpandedMode ? 'opacity-0 !h-0' : ''}"
|
||||
in:fade={{ duration: 200 }}
|
||||
out:fade={{ duration: 200 }}
|
||||
>
|
||||
<SidebarNavigationActions
|
||||
isExpandedMode={innerWidth > 768 ? isExpandedMode : true}
|
||||
class="px-2"
|
||||
bind:isSearchModeActive
|
||||
bind:searchQuery
|
||||
onSearchDeactivated={() => {
|
||||
isSearchModeActive = false;
|
||||
searchQuery = '';
|
||||
}}
|
||||
onSearchClick={() => {
|
||||
isExpandedMode = true;
|
||||
isSearchModeActive = true;
|
||||
}}
|
||||
onNewChat={() => {
|
||||
if (isMobile.current) {
|
||||
scheduleMobileCollapse();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
{#if isExpandedMode || isOnMobile}
|
||||
<SidebarNavigationConversationList
|
||||
class="px-2"
|
||||
{filteredConversations}
|
||||
{currentChatId}
|
||||
{isSearchModeActive}
|
||||
{searchQuery}
|
||||
onSelect={selectConversation}
|
||||
onEdit={handleEditConversation}
|
||||
onDelete={handleDeleteConversation}
|
||||
onStop={handleStopGeneration}
|
||||
/>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</aside>
|
||||
{/if}
|
||||
|
||||
<style>
|
||||
aside {
|
||||
@media (max-width: 768px) {
|
||||
--size: 1.125rem;
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Input
|
||||
class="text-foreground"
|
||||
placeholder="Enter a new name"
|
||||
type="text"
|
||||
bind:value={editedName}
|
||||
/>
|
||||
</DialogConfirmation>
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
aside {
|
||||
&:not(.is-expanded) {
|
||||
pointer-events: none;
|
||||
}
|
||||
}
|
||||
|
||||
aside.is-expanded::before {
|
||||
content: '';
|
||||
position: fixed;
|
||||
top: -0.5rem;
|
||||
bottom: -0.25rem;
|
||||
left: -0.5rem;
|
||||
right: -0.5rem;
|
||||
z-index: -1;
|
||||
background: var(--background);
|
||||
backdrop-filter: blur(1rem);
|
||||
pointer-events: none;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
||||
+157
-57
@@ -1,39 +1,86 @@
|
||||
<script lang="ts">
|
||||
import { KeyboardShortcutInfo } from '$lib/components/app';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import type { Component } from 'svelte';
|
||||
import { SearchInput } from '$lib/components/app';
|
||||
import { goto } from '$app/navigation';
|
||||
import { page } from '$app/state';
|
||||
import { SIDEBAR_ACTIONS_ITEMS } from '$lib/constants/ui';
|
||||
import { Search } from '@lucide/svelte';
|
||||
import { ActionIcon, KeyboardShortcutInfo, SearchInput } from '$lib/components/app';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import {
|
||||
ICON_STRIP_TRANSITION_DURATION,
|
||||
ICON_STRIP_TRANSITION_DELAY_MULTIPLIER,
|
||||
ROUTES,
|
||||
SIDEBAR_ACTIONS_ITEMS
|
||||
} from '$lib/constants';
|
||||
import { isMobile } from '$lib/stores/viewport.svelte';
|
||||
import { TooltipSide } from '$lib/enums';
|
||||
import { fade } from 'svelte/transition';
|
||||
import { circIn } from 'svelte/easing';
|
||||
import { onMount } from 'svelte';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
handleMobileSidebarItemClick: () => void;
|
||||
class: string;
|
||||
isExpandedMode: boolean;
|
||||
isSearchModeActive: boolean;
|
||||
searchQuery: string;
|
||||
isCancelAlwaysVisible?: boolean;
|
||||
onSearchDeactivated?: () => void;
|
||||
onSearchClick?: () => void;
|
||||
onNewChat?: () => void;
|
||||
}
|
||||
|
||||
let {
|
||||
handleMobileSidebarItemClick,
|
||||
isSearchModeActive = $bindable(),
|
||||
searchQuery = $bindable(),
|
||||
isCancelAlwaysVisible = false,
|
||||
onSearchDeactivated
|
||||
class: className,
|
||||
isExpandedMode = false,
|
||||
isSearchModeActive = $bindable(false),
|
||||
searchQuery = $bindable(''),
|
||||
onSearchDeactivated,
|
||||
onSearchClick,
|
||||
onNewChat
|
||||
}: Props = $props();
|
||||
|
||||
let initialized = $state(false);
|
||||
let showIcons = $state(false);
|
||||
let searchInputRef = $state<HTMLInputElement | null>(null);
|
||||
|
||||
const isOnMobile = $derived(isMobile.current);
|
||||
|
||||
$effect(() => {
|
||||
if (isSearchModeActive && searchInputRef) {
|
||||
searchInputRef.focus();
|
||||
}
|
||||
});
|
||||
|
||||
onMount(() => {
|
||||
showIcons = true;
|
||||
|
||||
setTimeout(() => {
|
||||
initialized = true;
|
||||
}, ICON_STRIP_TRANSITION_DELAY_MULTIPLIER * SIDEBAR_ACTIONS_ITEMS.length);
|
||||
});
|
||||
|
||||
function handleSearchModeDeactivate() {
|
||||
isSearchModeActive = false;
|
||||
searchQuery = '';
|
||||
onSearchDeactivated?.();
|
||||
}
|
||||
|
||||
export function activateSearch() {
|
||||
isSearchModeActive = true;
|
||||
// Focus after Svelte renders the input
|
||||
queueMicrotask(() => searchInputRef?.focus());
|
||||
function isItemActive(item: {
|
||||
activeRouteId?: string;
|
||||
activeRoutePrefix?: string;
|
||||
activeUrlIncludes?: string;
|
||||
}): boolean {
|
||||
if (item.activeRouteId) {
|
||||
return page.route.id === item.activeRouteId;
|
||||
}
|
||||
|
||||
if (item.activeRoutePrefix) {
|
||||
return !!page.route.id?.startsWith(item.activeRoutePrefix);
|
||||
}
|
||||
|
||||
if (item.activeUrlIncludes) {
|
||||
return page.url?.hash?.includes(item.activeUrlIncludes) ?? false;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -41,56 +88,109 @@
|
||||
<IconComponent class="h-4 w-4" />
|
||||
{/snippet}
|
||||
|
||||
<div class="my-1 space-y-1">
|
||||
{#if isSearchModeActive}
|
||||
{#if isSearchModeActive}
|
||||
<div class="px-4 my-2">
|
||||
<SearchInput
|
||||
bind:value={searchQuery}
|
||||
bind:ref={searchInputRef}
|
||||
onClose={handleSearchModeDeactivate}
|
||||
onKeyDown={(e) => e.key === 'Escape' && handleSearchModeDeactivate()}
|
||||
placeholder="Search conversations..."
|
||||
{isCancelAlwaysVisible}
|
||||
/>
|
||||
{:else}
|
||||
{#each SIDEBAR_ACTIONS_ITEMS as item (item.route)}
|
||||
{#if !item.route}
|
||||
<Button
|
||||
class="w-full justify-between px-2 backdrop-blur-none! hover:[&>kbd]:opacity-100"
|
||||
onclick={activateSearch}
|
||||
variant="ghost"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
{@render itemIcon(item.icon)}
|
||||
</div>
|
||||
{:else if isExpandedMode || isOnMobile}
|
||||
<div
|
||||
class="{className} flex flex-col gap-5 md:gap-1 mt-2 md:mt-0 {!isExpandedMode && isOnMobile
|
||||
? 'hidden pointer-events-none'
|
||||
: ''}"
|
||||
>
|
||||
{#each SIDEBAR_ACTIONS_ITEMS as item, i (item.tooltip)}
|
||||
{@const isActive = isItemActive(item)}
|
||||
{@const isSearchOnMobile = item.icon === Search && isMobile.current}
|
||||
{@const itemHref = isSearchOnMobile ? ROUTES.SEARCH : item.route}
|
||||
{@const itemOnClick = item.route
|
||||
? () => {
|
||||
onNewChat?.();
|
||||
goto(item.route!);
|
||||
}
|
||||
: isSearchOnMobile
|
||||
? undefined
|
||||
: onSearchClick}
|
||||
{@const itemTransition = {
|
||||
duration: ICON_STRIP_TRANSITION_DURATION,
|
||||
delay: !initialized
|
||||
? ICON_STRIP_TRANSITION_DELAY_MULTIPLIER + i * ICON_STRIP_TRANSITION_DELAY_MULTIPLIER
|
||||
: 0,
|
||||
easing: circIn
|
||||
}}
|
||||
|
||||
{item.tooltip}
|
||||
</div>
|
||||
{#if showIcons}
|
||||
<div transition:fade={itemTransition}>
|
||||
<Button
|
||||
class="w-full min-w-9 justify-between px-2 backdrop-blur-none! hover:[&>kbd]:opacity-100 {isActive
|
||||
? 'bg-accent text-accent-foreground'
|
||||
: ''}"
|
||||
href={itemHref}
|
||||
onclick={itemOnClick}
|
||||
variant="ghost"
|
||||
size="default"
|
||||
>
|
||||
<span class="flex min-w-0 items-center px-0.5 gap-2">
|
||||
{@render itemIcon(item.icon)}
|
||||
|
||||
{#if item.keys}
|
||||
<KeyboardShortcutInfo keys={item.keys} />
|
||||
{/if}
|
||||
</Button>
|
||||
{:else}
|
||||
<Button
|
||||
class="w-full justify-between px-2 backdrop-blur-none! hover:[&>kbd]:opacity-100 {(item.activeRouteId &&
|
||||
page.route.id === item.activeRouteId) ||
|
||||
(item.activeRoutePrefix && page.route.id?.startsWith(item.activeRoutePrefix))
|
||||
? 'bg-accent text-accent-foreground'
|
||||
: ''}"
|
||||
href={item.route}
|
||||
onclick={handleMobileSidebarItemClick}
|
||||
variant="ghost"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
{@render itemIcon(item.icon)}
|
||||
{#if showIcons}
|
||||
<span
|
||||
in:fade={{ duration: 150, easing: circIn, delay: 50 }}
|
||||
out:fade={{ duration: 100 }}
|
||||
class="min-w-0 truncate">{item.tooltip}</span
|
||||
>
|
||||
{/if}
|
||||
</span>
|
||||
|
||||
{item.tooltip}
|
||||
</div>
|
||||
|
||||
{#if item.keys}
|
||||
<KeyboardShortcutInfo keys={item.keys} />
|
||||
{/if}
|
||||
</Button>
|
||||
{#if item.keys}
|
||||
<KeyboardShortcutInfo keys={item.keys} />
|
||||
{/if}
|
||||
</Button>
|
||||
</div>
|
||||
{/if}
|
||||
{/each}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="{className} flex-col gap-1 hidden md:flex">
|
||||
{#each SIDEBAR_ACTIONS_ITEMS as item, i (item.tooltip)}
|
||||
{@const isActive = isItemActive(item)}
|
||||
{@const isSearchOnMobile = item.icon === Search && isMobile.current}
|
||||
{@const itemOnClick = item.route
|
||||
? () => {
|
||||
onNewChat?.();
|
||||
goto(item.route!);
|
||||
}
|
||||
: isSearchOnMobile
|
||||
? undefined
|
||||
: onSearchClick}
|
||||
{@const itemTransition = {
|
||||
duration: ICON_STRIP_TRANSITION_DURATION,
|
||||
delay: !initialized
|
||||
? ICON_STRIP_TRANSITION_DELAY_MULTIPLIER + i * ICON_STRIP_TRANSITION_DELAY_MULTIPLIER
|
||||
: 0,
|
||||
easing: circIn
|
||||
}}
|
||||
|
||||
{#if showIcons}
|
||||
<div transition:fade={itemTransition}>
|
||||
<ActionIcon
|
||||
icon={item.icon}
|
||||
tooltip={item.tooltip}
|
||||
tooltipSide={TooltipSide.RIGHT}
|
||||
size="lg"
|
||||
iconSize="h-4 w-4"
|
||||
class="h-9 w-9 rounded-full hover:bg-accent! {isActive
|
||||
? 'bg-accent text-accent-foreground'
|
||||
: ''}"
|
||||
onclick={itemOnClick}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
+135
@@ -0,0 +1,135 @@
|
||||
<script lang="ts">
|
||||
import { Pin } from '@lucide/svelte';
|
||||
import { buildConversationTree } from '$lib/stores/conversations.svelte';
|
||||
import SidebarNavigationConversationItem from './SidebarNavigationConversationItem.svelte';
|
||||
import SidebarNavigationSearchResults from './SidebarNavigationSearchResults.svelte';
|
||||
|
||||
interface Props {
|
||||
class: string;
|
||||
filteredConversations: DatabaseConversation[];
|
||||
currentChatId: string | undefined;
|
||||
isSearchModeActive: boolean;
|
||||
searchQuery: string;
|
||||
onSelect: (id: string) => void;
|
||||
onEdit: (id: string) => void;
|
||||
onDelete: (id: string) => void;
|
||||
onStop: (id: string) => void;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className,
|
||||
filteredConversations,
|
||||
currentChatId,
|
||||
isSearchModeActive,
|
||||
searchQuery,
|
||||
onSelect,
|
||||
onEdit,
|
||||
onDelete,
|
||||
onStop
|
||||
}: Props = $props();
|
||||
|
||||
let conversationTree = $derived(buildConversationTree(filteredConversations));
|
||||
|
||||
let pinnedConversations = $derived(
|
||||
conversationTree.filter(({ conversation }) => conversation.pinned)
|
||||
);
|
||||
|
||||
let unpinnedConversations = $derived(
|
||||
conversationTree.filter(({ conversation }) => !conversation.pinned)
|
||||
);
|
||||
|
||||
const recentEmptyMessage = $derived(
|
||||
searchQuery.length > 0 ? 'No results found' : 'No conversations yet'
|
||||
);
|
||||
</script>
|
||||
|
||||
{#if isSearchModeActive}
|
||||
<SidebarNavigationSearchResults
|
||||
class={className}
|
||||
{searchQuery}
|
||||
{filteredConversations}
|
||||
{currentChatId}
|
||||
{onSelect}
|
||||
{onEdit}
|
||||
{onDelete}
|
||||
{onStop}
|
||||
/>
|
||||
{:else}
|
||||
{#if pinnedConversations.length > 0}
|
||||
<div class="py-2 flex whitespace-nowrap {className}">
|
||||
<div
|
||||
class="text-muted-foreground inline-flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium gap-1"
|
||||
>
|
||||
<Pin class="h-3.5 w-3.5" />
|
||||
|
||||
<span>Pinned</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<ul class="flex w-full min-w-0 flex-col gap-4 md:gap-1 {className}">
|
||||
{#each pinnedConversations as { conversation, depth } (conversation.id)}
|
||||
<li class="group/item relative mb-1 p-0">
|
||||
<SidebarNavigationConversationItem
|
||||
conversation={{
|
||||
id: conversation.id,
|
||||
name: conversation.name,
|
||||
lastModified: conversation.lastModified,
|
||||
currNode: conversation.currNode,
|
||||
forkedFromConversationId: conversation.forkedFromConversationId,
|
||||
pinned: conversation.pinned
|
||||
}}
|
||||
{depth}
|
||||
isActive={currentChatId === conversation.id}
|
||||
{onSelect}
|
||||
{onEdit}
|
||||
{onDelete}
|
||||
{onStop}
|
||||
/>
|
||||
</li>
|
||||
{/each}
|
||||
</ul>
|
||||
{/if}
|
||||
|
||||
<div class="mt-2 flex min-h-0 flex-1 flex-col gap-4 md:gap-2 whitespace-nowrap {className}">
|
||||
{#if filteredConversations.length > 0}
|
||||
<div
|
||||
class="text-muted-foreground flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium"
|
||||
>
|
||||
Recent conversations
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="min-h-0 flex-1 md:overflow-y-auto">
|
||||
<ul class="flex w-full min-w-0 flex-col gap-4 md:gap-1">
|
||||
{#each unpinnedConversations as { conversation, depth } (conversation.id)}
|
||||
<li class="group/item relative mb-1 p-0">
|
||||
<SidebarNavigationConversationItem
|
||||
conversation={{
|
||||
id: conversation.id,
|
||||
name: conversation.name,
|
||||
lastModified: conversation.lastModified,
|
||||
currNode: conversation.currNode,
|
||||
forkedFromConversationId: conversation.forkedFromConversationId,
|
||||
pinned: conversation.pinned
|
||||
}}
|
||||
{depth}
|
||||
isActive={currentChatId === conversation.id}
|
||||
{onSelect}
|
||||
{onEdit}
|
||||
{onDelete}
|
||||
{onStop}
|
||||
/>
|
||||
</li>
|
||||
{/each}
|
||||
|
||||
{#if unpinnedConversations.length === 0}
|
||||
<li class="px-2 py-4 text-center">
|
||||
<p class="mb-4 p-4 text-sm text-muted-foreground">
|
||||
{recentEmptyMessage}
|
||||
</p>
|
||||
</li>
|
||||
{/if}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
+3
-1
@@ -16,4 +16,6 @@
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
<SearchInput bind:value {placeholder} {onInput} class="mb-4 {className}" />
|
||||
<div class="mb-4 px-2 {className}">
|
||||
<SearchInput bind:value {placeholder} {onInput} />
|
||||
</div>
|
||||
|
||||
+76
@@ -0,0 +1,76 @@
|
||||
<script lang="ts">
|
||||
import { buildConversationTree } from '$lib/stores/conversations.svelte';
|
||||
import SidebarNavigationConversationItem from './SidebarNavigationConversationItem.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
searchQuery: string;
|
||||
filteredConversations: DatabaseConversation[];
|
||||
currentChatId: string | undefined;
|
||||
onSelect: (id: string) => void;
|
||||
onEdit: (id: string) => void;
|
||||
onDelete: (id: string) => void;
|
||||
onStop: (id: string) => void;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
searchQuery,
|
||||
filteredConversations,
|
||||
currentChatId,
|
||||
onSelect,
|
||||
onEdit,
|
||||
onDelete,
|
||||
onStop
|
||||
}: Props = $props();
|
||||
|
||||
let tree = $derived(buildConversationTree(filteredConversations));
|
||||
|
||||
const hasQuery = $derived(searchQuery.trim().length > 0);
|
||||
const showHeader = $derived(hasQuery && filteredConversations.length > 0);
|
||||
|
||||
const emptyMessage = $derived(hasQuery ? 'No results found' : 'Start typing to see results');
|
||||
</script>
|
||||
|
||||
<div class="flex min-h-0 flex-1 flex-col gap-2 whitespace-nowrap {className}">
|
||||
{#if showHeader}
|
||||
<div
|
||||
class="text-muted-foreground flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium"
|
||||
>
|
||||
Search results
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="min-h-0 flex-1 overflow-y-auto">
|
||||
<ul class="flex w-full min-w-0 flex-col gap-1">
|
||||
{#each tree as { conversation, depth } (conversation.id)}
|
||||
<li class="group/item relative mb-1 p-0">
|
||||
<SidebarNavigationConversationItem
|
||||
conversation={{
|
||||
id: conversation.id,
|
||||
name: conversation.name,
|
||||
lastModified: conversation.lastModified,
|
||||
currNode: conversation.currNode,
|
||||
forkedFromConversationId: conversation.forkedFromConversationId,
|
||||
pinned: conversation.pinned
|
||||
}}
|
||||
{depth}
|
||||
isActive={currentChatId === conversation.id}
|
||||
{onSelect}
|
||||
{onEdit}
|
||||
{onDelete}
|
||||
{onStop}
|
||||
/>
|
||||
</li>
|
||||
{/each}
|
||||
|
||||
{#if tree.length === 0}
|
||||
<li class="px-2 py-4 text-center">
|
||||
<p class="mb-4 p-4 text-sm text-muted-foreground">
|
||||
{emptyMessage}
|
||||
</p>
|
||||
</li>
|
||||
{/if}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user