mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-29 17:17:40 +02:00
Compare commits
27 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7b95ea5d11 | |||
| bdc9c743a5 | |||
| 739393beeb | |||
| fc2b0053ff | |||
| 7b8443ac78 | |||
| 5d56effdee | |||
| 52e5f0a5c1 | |||
| f9f33654a6 | |||
| 98bb57916a | |||
| f42e29fdf1 | |||
| 19821178be | |||
| 698d19b93c | |||
| 50494a2800 | |||
| d530d6e7a2 | |||
| c3e08f4700 | |||
| 14e733e36f | |||
| 516e8d7a8a | |||
| 434b2a1ff6 | |||
| 983ca8992e | |||
| 665abc6097 | |||
| 4414c04b9a | |||
| ceaf47c4b1 | |||
| 42401c72b8 | |||
| e940b3d468 | |||
| 0f1bb602dd | |||
| d13540becd | |||
| f84270ea10 |
+377
-217
File diff suppressed because it is too large
Load Diff
+4
-2
@@ -25,7 +25,8 @@ struct common_arg {
|
||||
const char * value_hint_2 = nullptr; // for second arg value
|
||||
const char * env = nullptr;
|
||||
std::string help;
|
||||
bool is_sparam = false; // is current arg a sampling param?
|
||||
bool is_sampling = false; // is current arg a sampling param?
|
||||
bool is_spec = false; // is current arg a speculative decoding param?
|
||||
bool is_preset_only = false; // is current arg preset-only (not treated as CLI arg)
|
||||
void (*handler_void) (common_params & params) = nullptr;
|
||||
void (*handler_string) (common_params & params, const std::string &) = nullptr;
|
||||
@@ -74,7 +75,8 @@ struct common_arg {
|
||||
common_arg & set_examples(std::initializer_list<enum llama_example> examples);
|
||||
common_arg & set_excludes(std::initializer_list<enum llama_example> excludes);
|
||||
common_arg & set_env(const char * env);
|
||||
common_arg & set_sparam();
|
||||
common_arg & set_sampling();
|
||||
common_arg & set_spec();
|
||||
common_arg & set_preset_only();
|
||||
bool in_example(enum llama_example ex);
|
||||
bool is_exclude(enum llama_example ex);
|
||||
|
||||
+7
-7
@@ -70,7 +70,7 @@ common_time_meas::~common_time_meas() {
|
||||
// CPU utils
|
||||
//
|
||||
|
||||
int32_t cpu_get_num_physical_cores() {
|
||||
int32_t common_cpu_get_num_physical_cores() {
|
||||
#ifdef __linux__
|
||||
// enumerate the set of thread siblings, num entries is num cores
|
||||
std::unordered_set<std::string> siblings;
|
||||
@@ -185,11 +185,11 @@ static int cpu_count_math_cpus(int n_cpu) {
|
||||
/**
|
||||
* Returns number of CPUs on system that are useful for math.
|
||||
*/
|
||||
int32_t cpu_get_num_math() {
|
||||
int32_t common_cpu_get_num_math() {
|
||||
#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__)
|
||||
int n_cpu = sysconf(_SC_NPROCESSORS_ONLN);
|
||||
if (n_cpu < 1) {
|
||||
return cpu_get_num_physical_cores();
|
||||
return common_cpu_get_num_physical_cores();
|
||||
}
|
||||
if (is_hybrid_cpu()) {
|
||||
cpu_set_t affinity;
|
||||
@@ -202,7 +202,7 @@ int32_t cpu_get_num_math() {
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return cpu_get_num_physical_cores();
|
||||
return common_cpu_get_num_physical_cores();
|
||||
}
|
||||
|
||||
// Helper for setting process priority
|
||||
@@ -263,7 +263,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
|
||||
//
|
||||
|
||||
|
||||
void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model) {
|
||||
void postprocess_cpu_params(common_cpu_params & cpuparams, const common_cpu_params * role_model) {
|
||||
int32_t n_set = 0;
|
||||
|
||||
if (cpuparams.n_threads < 0) {
|
||||
@@ -271,7 +271,7 @@ void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model)
|
||||
if (role_model != nullptr) {
|
||||
cpuparams = *role_model;
|
||||
} else {
|
||||
cpuparams.n_threads = cpu_get_num_math();
|
||||
cpuparams.n_threads = common_cpu_get_num_math();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1521,7 +1521,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
||||
return cparams;
|
||||
}
|
||||
|
||||
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params) {
|
||||
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const common_cpu_params & params) {
|
||||
struct ggml_threadpool_params tpp;
|
||||
|
||||
ggml_threadpool_params_init(&tpp, params.n_threads); // setup the defaults
|
||||
|
||||
+56
-36
@@ -54,7 +54,7 @@ struct common_control_vector_load_info;
|
||||
// CPU utils
|
||||
//
|
||||
|
||||
struct cpu_params {
|
||||
struct common_cpu_params {
|
||||
int n_threads = -1;
|
||||
bool cpumask[GGML_MAX_N_THREADS] = {false}; // CPU affinity mask.
|
||||
bool mask_valid = false; // Default: any CPU
|
||||
@@ -63,8 +63,8 @@ struct cpu_params {
|
||||
uint32_t poll = 50; // Polling (busywait) level (0 - no polling, 100 - mostly polling)
|
||||
};
|
||||
|
||||
int32_t cpu_get_num_physical_cores();
|
||||
int32_t cpu_get_num_math();
|
||||
int32_t common_cpu_get_num_physical_cores();
|
||||
int32_t common_cpu_get_num_math();
|
||||
|
||||
//
|
||||
// Common params
|
||||
@@ -297,34 +297,19 @@ struct common_params_model {
|
||||
|
||||
struct common_ngram_mod;
|
||||
|
||||
struct common_params_speculative {
|
||||
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; // type of speculative decoding
|
||||
// draft-model-based speculative decoding parameters
|
||||
struct common_params_speculative_draft {
|
||||
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
||||
|
||||
// general-purpose speculative decoding parameters
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||
|
||||
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||
common_params_model mparams;
|
||||
|
||||
// ngram-based speculative decoding
|
||||
llama_model * model = nullptr; // a llama_model that can be shared by multiple speculative contexts
|
||||
|
||||
uint16_t ngram_size_n = 12; // ngram size for lookup
|
||||
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
|
||||
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
|
||||
|
||||
std::shared_ptr<common_ngram_mod> ngram_mod;
|
||||
|
||||
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT
|
||||
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
||||
|
||||
// draft-model speculative decoding
|
||||
|
||||
struct common_params_model mparams_dft;
|
||||
|
||||
llama_model * model_dft = nullptr; // a llama_model that can be shared by multiple speculative contexts
|
||||
|
||||
llama_context_params cparams_dft; // these are the parameters for the draft llama_context
|
||||
llama_context_params cparams; // these are the parameters for the draft llama_context
|
||||
|
||||
int32_t n_ctx = 0; // draft context size
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||
@@ -332,25 +317,60 @@ struct common_params_speculative {
|
||||
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
||||
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
||||
|
||||
struct cpu_params cpuparams;
|
||||
struct cpu_params cpuparams_batch;
|
||||
common_cpu_params cpuparams;
|
||||
common_cpu_params cpuparams_batch;
|
||||
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
};
|
||||
|
||||
struct common_params_speculative_ngram_mod {
|
||||
int32_t n_match = 24;
|
||||
|
||||
int32_t n_max = 64;
|
||||
int32_t n_min = 48;
|
||||
|
||||
// shared instance of the ngram container for all speculative decoding contexts
|
||||
std::shared_ptr<common_ngram_mod> obj;
|
||||
};
|
||||
|
||||
struct common_params_speculative_ngram_map {
|
||||
uint16_t size_n = 12; // ngram size for lookup
|
||||
uint16_t size_m = 48; // mgram size for speculative tokens
|
||||
uint16_t min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
|
||||
};
|
||||
|
||||
struct common_params_speculative_ngram_cache {
|
||||
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding
|
||||
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding
|
||||
};
|
||||
|
||||
struct common_params_speculative {
|
||||
// TODO: become a vector in order to support "chains of speculators"
|
||||
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE;
|
||||
|
||||
common_params_speculative_draft draft;
|
||||
|
||||
common_params_speculative_ngram_mod ngram_mod;
|
||||
common_params_speculative_ngram_map ngram_simple;
|
||||
common_params_speculative_ngram_map ngram_map_k;
|
||||
common_params_speculative_ngram_map ngram_map_k4v;
|
||||
|
||||
common_params_speculative_ngram_cache ngram_cache;
|
||||
|
||||
bool has_dft() const {
|
||||
return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty();
|
||||
return !draft.mparams.path.empty() || !draft.mparams.hf_repo.empty();
|
||||
}
|
||||
};
|
||||
|
||||
struct common_params_vocoder {
|
||||
struct common_params_model model;
|
||||
|
||||
std::string speaker_file = ""; // speaker file path // NOLINT
|
||||
std::string speaker_file; // speaker file path
|
||||
|
||||
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
|
||||
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy
|
||||
};
|
||||
|
||||
struct common_params_diffusion {
|
||||
@@ -433,8 +453,8 @@ struct common_params {
|
||||
|
||||
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
||||
|
||||
struct cpu_params cpuparams;
|
||||
struct cpu_params cpuparams_batch;
|
||||
common_cpu_params cpuparams;
|
||||
common_cpu_params cpuparams_batch;
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval = nullptr;
|
||||
void * cb_eval_user_data = nullptr;
|
||||
@@ -678,7 +698,7 @@ std::string common_params_get_system_info(const common_params & params);
|
||||
|
||||
bool parse_cpu_range(const std::string & range, bool(&boolmask)[GGML_MAX_N_THREADS]);
|
||||
bool parse_cpu_mask(const std::string & mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
|
||||
void postprocess_cpu_params(cpu_params & cpuparams, const cpu_params * role_model = nullptr);
|
||||
void postprocess_cpu_params(common_cpu_params & cpuparams, const common_cpu_params * role_model = nullptr);
|
||||
bool set_process_priority(enum ggml_sched_priority prio);
|
||||
|
||||
//
|
||||
@@ -846,7 +866,7 @@ common_init_result_ptr common_init_from_params(common_params & params);
|
||||
|
||||
struct llama_model_params common_model_params_to_llama ( common_params & params);
|
||||
struct llama_context_params common_context_params_to_llama(const common_params & params);
|
||||
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
|
||||
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const common_cpu_params & params);
|
||||
|
||||
// clear LoRA adapters from context, then apply new list of adapters
|
||||
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
|
||||
|
||||
+1
-1
@@ -627,7 +627,7 @@ static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
|
||||
if (!tag.empty()) {
|
||||
tags.push_back(tag);
|
||||
} else {
|
||||
tags = {"Q4_K_M", "Q4_0"};
|
||||
tags = {"Q4_K_M", "Q8_0"};
|
||||
}
|
||||
|
||||
for (const auto & t : tags) {
|
||||
|
||||
+2
-2
@@ -856,7 +856,7 @@ void common_memory_breakdown_print(const struct llama_context * ctx) {
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
|
||||
const size_t self = mb.model + mb.context + mb.compute;
|
||||
const size_t unaccounted = total - self - free;
|
||||
const int64_t unaccounted = static_cast<int64_t>(total) - static_cast<int64_t>(free) - static_cast<int64_t>(self);
|
||||
|
||||
table_data.push_back({
|
||||
template_gpu,
|
||||
@@ -867,7 +867,7 @@ void common_memory_breakdown_print(const struct llama_context * ctx) {
|
||||
std::to_string(mb.model / MiB),
|
||||
std::to_string(mb.context / MiB),
|
||||
std::to_string(mb.compute / MiB),
|
||||
std::to_string(unaccounted / MiB)});
|
||||
std::to_string(unaccounted / static_cast<int64_t>(MiB))});
|
||||
}
|
||||
|
||||
// print memory breakdown for host:
|
||||
|
||||
+11
-8
@@ -49,7 +49,7 @@ enum common_log_col : int {
|
||||
};
|
||||
|
||||
// disable colors by default
|
||||
static std::vector<const char *> g_col = {
|
||||
static const char* g_col[] = {
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
@@ -247,7 +247,6 @@ public:
|
||||
|
||||
entries = std::move(new_entries);
|
||||
}
|
||||
|
||||
cv.notify_one();
|
||||
}
|
||||
|
||||
@@ -265,7 +264,6 @@ public:
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
cv.wait(lock, [this]() { return head != tail; });
|
||||
|
||||
cur = entries[head];
|
||||
|
||||
head = (head + 1) % entries.size();
|
||||
@@ -301,7 +299,6 @@ public:
|
||||
|
||||
tail = (tail + 1) % entries.size();
|
||||
}
|
||||
|
||||
cv.notify_one();
|
||||
}
|
||||
|
||||
@@ -338,7 +335,7 @@ public:
|
||||
g_col[COMMON_LOG_COL_CYAN] = LOG_COL_CYAN;
|
||||
g_col[COMMON_LOG_COL_WHITE] = LOG_COL_WHITE;
|
||||
} else {
|
||||
for (size_t i = 0; i < g_col.size(); i++) {
|
||||
for (size_t i = 0; i < std::size(g_col); i++) {
|
||||
g_col[i] = "";
|
||||
}
|
||||
}
|
||||
@@ -368,14 +365,20 @@ struct common_log * common_log_init() {
|
||||
}
|
||||
|
||||
struct common_log * common_log_main() {
|
||||
static struct common_log log;
|
||||
// We intentionally leak (i.e. do not delete) the logger singleton because
|
||||
// common_log destructor called at DLL teardown phase will cause hanging on Windows.
|
||||
// OS will release resources anyway so it should not be a significant issue,
|
||||
// though this design may cause logs to be lost if not flushed before the program exits.
|
||||
// Refer to https://github.com/ggml-org/llama.cpp/issues/22142 for details.
|
||||
static struct common_log * log;
|
||||
static std::once_flag init_flag;
|
||||
std::call_once(init_flag, [&]() {
|
||||
log = new common_log;
|
||||
// Set default to auto-detect colors
|
||||
log.set_colors(tty_can_use_colors());
|
||||
log->set_colors(tty_can_use_colors());
|
||||
});
|
||||
|
||||
return &log;
|
||||
return log;
|
||||
}
|
||||
|
||||
void common_log_pause(struct common_log * log) {
|
||||
|
||||
+5
-1
@@ -49,7 +49,11 @@ void common_log_default_callback(enum ggml_log_level level, const char * text, v
|
||||
struct common_log;
|
||||
|
||||
struct common_log * common_log_init();
|
||||
struct common_log * common_log_main(); // singleton, automatically destroys itself on exit
|
||||
|
||||
// Singleton, intentionally leaked to avoid Windows teardown hangs.
|
||||
// Call common_log_flush() before exit if you want to ensure all logs are flushed.
|
||||
struct common_log * common_log_main();
|
||||
|
||||
void common_log_pause (struct common_log * log); // pause the worker thread, not thread-safe
|
||||
void common_log_resume(struct common_log * log); // resume the worker thread, not thread-safe
|
||||
void common_log_free (struct common_log * log);
|
||||
|
||||
+1
-1
@@ -43,7 +43,7 @@ static std::set<std::string> get_remote_preset_whitelist(const std::map<std::str
|
||||
for (const auto & it : key_to_opt) {
|
||||
const std::string & key = it.first;
|
||||
const common_arg & opt = it.second;
|
||||
if (allowed_options.find(key) != allowed_options.end() || opt.is_sparam) {
|
||||
if (allowed_options.find(key) != allowed_options.end() || opt.is_sampling) {
|
||||
allowed_keys.insert(key);
|
||||
// also add variant keys (args without leading dashes and env vars)
|
||||
for (const auto & arg : opt.get_args()) {
|
||||
|
||||
@@ -122,6 +122,20 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
}
|
||||
break;
|
||||
case REASONING_BUDGET_DONE:
|
||||
// Re-arm on a new start tag: some models emit multiple <think> blocks
|
||||
// per response, and each should get a fresh budget window.
|
||||
if (ctx->start_matcher.advance(token)) {
|
||||
ctx->state = REASONING_BUDGET_COUNTING;
|
||||
ctx->remaining = ctx->budget;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: re-activated on new start tag, budget=%d tokens\n", ctx->budget);
|
||||
|
||||
if (ctx->remaining <= 0) {
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
+140
-46
@@ -151,6 +151,9 @@ struct common_speculative_state {
|
||||
llama_tokens & result) = 0;
|
||||
|
||||
virtual void accept(uint16_t n_accepted) = 0;
|
||||
|
||||
virtual int32_t n_max(const common_params_speculative & params) const = 0;
|
||||
virtual int32_t n_min(const common_params_speculative & params) const = 0;
|
||||
};
|
||||
|
||||
struct common_speculative_checkpoint {
|
||||
@@ -296,6 +299,8 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_token id_last,
|
||||
llama_tokens & result) override {
|
||||
const auto & sparams = params.draft;
|
||||
|
||||
auto * spec = this;
|
||||
|
||||
auto & batch = spec->batch;
|
||||
@@ -309,7 +314,7 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
int reuse_i = 0; // index of part to be reused in prompt_dft
|
||||
int reuse_n = 0; // length of part to be reused in prompt_dft
|
||||
|
||||
const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max;
|
||||
const int n_ctx = llama_n_ctx(ctx_dft) - sparams.n_max;
|
||||
|
||||
llama_tokens prompt_cnv;
|
||||
if (!spec->vocab_cmpt) {
|
||||
@@ -367,7 +372,7 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
}
|
||||
|
||||
result.clear();
|
||||
result.reserve(params.n_max);
|
||||
result.reserve(sparams.n_max);
|
||||
|
||||
bool needs_ckpt = use_ckpt && prompt_dft.size() > 0;
|
||||
if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) {
|
||||
@@ -380,7 +385,7 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
|
||||
result.push_back(prompt_dft[i]);
|
||||
|
||||
if (params.n_max <= (int) result.size()) {
|
||||
if (sparams.n_max <= (int) result.size()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -473,7 +478,7 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
common_sampler_reset(smpl);
|
||||
|
||||
// sample n_draft tokens from the draft model
|
||||
for (int i = 0; i < params.n_max; ++i) {
|
||||
for (int i = 0; i < sparams.n_max; ++i) {
|
||||
common_batch_clear(batch);
|
||||
|
||||
common_sampler_sample(smpl, ctx_dft, 0, true);
|
||||
@@ -492,12 +497,12 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
|
||||
result.push_back(id);
|
||||
|
||||
if (params.n_max <= (int) result.size()) {
|
||||
if (sparams.n_max <= (int) result.size()) {
|
||||
break;
|
||||
}
|
||||
|
||||
// only collect very high-confidence draft tokens
|
||||
if (cur_p->data[0].p < params.p_min) {
|
||||
if (cur_p->data[0].p < sparams.p_min) {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -518,10 +523,14 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
detokenized = replace_to_tgt(detokenized);
|
||||
LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
|
||||
result = common_tokenize(ctx_tgt, detokenized, false, true);
|
||||
if (result.size() > (size_t)params.n_max) {
|
||||
result.resize(params.n_max);
|
||||
if (result.size() > (size_t) sparams.n_max) {
|
||||
result.resize(sparams.n_max);
|
||||
}
|
||||
}
|
||||
|
||||
if (result.size() < (size_t) sparams.n_min) {
|
||||
result.clear();
|
||||
}
|
||||
}
|
||||
|
||||
void accept(uint16_t n_accepted) override {
|
||||
@@ -529,6 +538,14 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
GGML_UNUSED(n_accepted);
|
||||
}
|
||||
|
||||
int32_t n_max(const common_params_speculative & params) const override {
|
||||
return params.draft.n_max;
|
||||
}
|
||||
|
||||
int32_t n_min(const common_params_speculative & params) const override {
|
||||
return params.draft.n_min;
|
||||
}
|
||||
|
||||
std::string replace_to_dft(const std::string & input) const {
|
||||
std::string result = input;
|
||||
|
||||
@@ -581,6 +598,14 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
|
||||
// noop
|
||||
GGML_UNUSED(n_accepted);
|
||||
}
|
||||
|
||||
int32_t n_max(const common_params_speculative & params) const override {
|
||||
return params.draft.n_max;
|
||||
}
|
||||
|
||||
int32_t n_min(const common_params_speculative & params) const override {
|
||||
return params.draft.n_min;
|
||||
}
|
||||
};
|
||||
|
||||
// state of self-speculation (simple implementation, not ngram-map)
|
||||
@@ -610,19 +635,27 @@ struct common_speculative_state_ngram_simple : public common_speculative_state {
|
||||
// noop
|
||||
GGML_UNUSED(n_accepted);
|
||||
}
|
||||
|
||||
int32_t n_max(const common_params_speculative & /*params*/) const override {
|
||||
return config.size_mgram;
|
||||
}
|
||||
|
||||
int32_t n_min(const common_params_speculative & /*params*/) const override {
|
||||
return config.size_mgram;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_speculative_state_ngram_map_k : public common_speculative_state {
|
||||
// draft ngram map for speculative decoding without draft model
|
||||
common_ngram_map map;
|
||||
common_ngram_map config;
|
||||
|
||||
common_speculative_state_ngram_map_k(
|
||||
enum common_speculative_type type,
|
||||
common_ngram_map map)
|
||||
: common_speculative_state(type), map(std::move(map)) {}
|
||||
common_ngram_map config)
|
||||
: common_speculative_state(type), config(std::move(config)) {}
|
||||
|
||||
void begin(const llama_tokens & prompt) override {
|
||||
common_ngram_map_begin(map, prompt);
|
||||
common_ngram_map_begin(config, prompt);
|
||||
}
|
||||
|
||||
void draft(
|
||||
@@ -630,12 +663,20 @@ struct common_speculative_state_ngram_map_k : public common_speculative_state {
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_token id_last,
|
||||
llama_tokens & result) override {
|
||||
common_ngram_map_draft(map, prompt_tgt, id_last, result);
|
||||
common_ngram_map_draft(config, prompt_tgt, id_last, result);
|
||||
GGML_UNUSED(params);
|
||||
}
|
||||
|
||||
void accept(uint16_t n_accepted) override {
|
||||
common_ngram_map_accept(map, n_accepted);
|
||||
common_ngram_map_accept(config, n_accepted);
|
||||
}
|
||||
|
||||
int32_t n_max(const common_params_speculative & /*params*/) const override {
|
||||
return config.size_value;
|
||||
}
|
||||
|
||||
int32_t n_min(const common_params_speculative & /*params*/) const override {
|
||||
return config.size_value;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -692,7 +733,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_token id_last,
|
||||
llama_tokens & result) override {
|
||||
GGML_UNUSED(params);
|
||||
const auto & sparams = params.ngram_mod;
|
||||
|
||||
n_draft_last = 0;
|
||||
|
||||
@@ -712,16 +753,16 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
|
||||
i_last = cur_len - n;
|
||||
}
|
||||
|
||||
result.resize(n + params.n_max);
|
||||
result.resize(n + sparams.n_max);
|
||||
for (size_t i = 0; i < n - 1; ++i) {
|
||||
result[i] = prompt_tgt[cur_len - n + 1 + i];
|
||||
}
|
||||
result[n - 1] = id_last;
|
||||
|
||||
for (int i = 0; i < params.n_max; ++i) {
|
||||
for (int i = 0; i < sparams.n_max; ++i) {
|
||||
const llama_token token = mod.get(result.data() + i);
|
||||
if (token == common_ngram_mod::EMPTY) {
|
||||
if (i < params.n_min) {
|
||||
if (i < sparams.n_min) {
|
||||
result.clear();
|
||||
return;
|
||||
}
|
||||
@@ -764,6 +805,14 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int32_t n_max(const common_params_speculative & params) const override {
|
||||
return params.ngram_mod.n_max;
|
||||
}
|
||||
|
||||
int32_t n_min(const common_params_speculative & params) const override {
|
||||
return params.ngram_mod.n_min;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_speculative_state_ngram_cache : public common_speculative_state {
|
||||
@@ -857,6 +906,14 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {
|
||||
// TODO: noop
|
||||
GGML_UNUSED(n_accepted);
|
||||
}
|
||||
|
||||
int32_t n_max(const common_params_speculative & /*params*/) const override {
|
||||
return n_draft;
|
||||
}
|
||||
|
||||
int32_t n_min(const common_params_speculative & /*params*/) const override {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_speculative {
|
||||
@@ -865,11 +922,13 @@ struct common_speculative {
|
||||
common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats)
|
||||
};
|
||||
|
||||
static common_ngram_map get_common_ngram_map(const common_speculative_config & config) {
|
||||
uint16_t size_key = config.params.ngram_size_n;
|
||||
uint16_t size_value = config.params.ngram_size_m;
|
||||
bool key_only = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K);
|
||||
uint16_t min_hits = config.params.ngram_min_hits;
|
||||
static common_ngram_map get_common_ngram_map(
|
||||
common_speculative_type type,
|
||||
const common_params_speculative_ngram_map & config) {
|
||||
uint16_t size_key = config.size_n;
|
||||
uint16_t size_value = config.size_m;
|
||||
bool key_only = type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K;
|
||||
uint16_t min_hits = config.min_hits;
|
||||
|
||||
return common_ngram_map(size_key, size_value, key_only, min_hits);
|
||||
}
|
||||
@@ -927,8 +986,8 @@ common_speculative * common_speculative_init(
|
||||
common_params_speculative & params,
|
||||
llama_context * ctx_tgt) {
|
||||
llama_context * ctx_dft = nullptr;
|
||||
if (params.model_dft) {
|
||||
ctx_dft = llama_init_from_model(params.model_dft, params.cparams_dft);
|
||||
if (params.draft.model) {
|
||||
ctx_dft = llama_init_from_model(params.draft.model, params.draft.cparams);
|
||||
if (ctx_dft == nullptr) {
|
||||
LOG_ERR("%s", "failed to create draft context\n");
|
||||
return nullptr;
|
||||
@@ -938,7 +997,7 @@ common_speculative * common_speculative_init(
|
||||
// Compute the implementations to use based on the config and their order of preference
|
||||
std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
|
||||
{
|
||||
bool has_draft = !params.mparams_dft.path.empty();
|
||||
bool has_draft = !params.draft.mparams.path.empty();
|
||||
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
|
||||
|
||||
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
|
||||
@@ -961,16 +1020,17 @@ common_speculative * common_speculative_init(
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params));
|
||||
}
|
||||
if (has_ngram_mod) {
|
||||
// shared instance for all speculative decoding contexts
|
||||
if (!params.ngram_mod) {
|
||||
params.ngram_mod = std::make_shared<common_ngram_mod>(params.ngram_size_n, 4*1024*1024);
|
||||
auto & sparams = params.ngram_mod;
|
||||
|
||||
LOG_INF("%s: initialized ngram_mod with n=%d, size=%zu (%.3f MB)\n", __func__,
|
||||
params.ngram_size_n, params.ngram_mod->size(),
|
||||
(float)(params.ngram_mod->size_bytes())/1024/1024);
|
||||
if (!sparams.obj) {
|
||||
sparams.obj = std::make_shared<common_ngram_mod>(sparams.n_match, 4*1024*1024);
|
||||
|
||||
if (params.ngram_size_n < 16) {
|
||||
LOG_WRN("%s: ngram_mod n=%d is too small - poor quality is possible, see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, params.ngram_size_n);
|
||||
LOG_INF("%s: initialized ngram_mod with n_match=%d, size=%zu (%.3f MB)\n", __func__,
|
||||
sparams.n_match, sparams.obj->size(), (float)(sparams.obj->size_bytes())/1024/1024);
|
||||
|
||||
if (sparams.n_match < 16) {
|
||||
LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, "
|
||||
"see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, sparams.n_match);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1000,7 +1060,7 @@ common_speculative * common_speculative_init(
|
||||
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
|
||||
/* .ctx_tgt = */ ctx_tgt,
|
||||
/* .ctx_dft = */ ctx_dft,
|
||||
/* .replacements = */ params.replacements,
|
||||
/* .replacements = */ params.draft.replacements,
|
||||
/* .use_ckpt = */ use_ckpt
|
||||
));
|
||||
break;
|
||||
@@ -1010,18 +1070,18 @@ common_speculative * common_speculative_init(
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
|
||||
common_ngram_map ngram_map = get_common_ngram_map(config);
|
||||
common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple);
|
||||
|
||||
uint16_t ngram_size_key = ngram_map.size_key;
|
||||
uint16_t mgram_size_value = ngram_map.size_value;
|
||||
|
||||
auto config_simple = common_ngram_simple_config {
|
||||
/* .size_ngram = */ ngram_size_key,
|
||||
/* .size_mgram = */ mgram_size_value
|
||||
/* .size_ngram = */ ngram_size_key,
|
||||
/* .size_mgram = */ mgram_size_value
|
||||
};
|
||||
auto state = std::make_unique<common_speculative_state_ngram_simple>(
|
||||
/* .type = */ config.type,
|
||||
/* .state = */ config_simple
|
||||
/* .type = */ config.type,
|
||||
/* .state = */ config_simple
|
||||
);
|
||||
impls.push_back(std::move(state));
|
||||
break;
|
||||
@@ -1030,18 +1090,17 @@ common_speculative * common_speculative_init(
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: {
|
||||
impls.push_back(std::make_unique<common_speculative_state_ngram_map_k>(
|
||||
(config.type),
|
||||
get_common_ngram_map(config)
|
||||
get_common_ngram_map(config.type, config.params.ngram_map_k)
|
||||
));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: {
|
||||
GGML_ASSERT(config.params.ngram_mod);
|
||||
impls.push_back(std::make_unique<common_speculative_state_ngram_mod>(config.type, *config.params.ngram_mod));
|
||||
GGML_ASSERT(config.params.ngram_mod.obj);
|
||||
impls.push_back(std::make_unique<common_speculative_state_ngram_mod>(config.type, *config.params.ngram_mod.obj));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: {
|
||||
auto state = create_state_ngram_cache(
|
||||
params.lookup_cache_static, params.lookup_cache_dynamic, config);
|
||||
auto state = create_state_ngram_cache(params.ngram_cache.lookup_cache_static, params.ngram_cache.lookup_cache_dynamic, config);
|
||||
impls.push_back(std::make_unique<common_speculative_state_ngram_cache>(state));
|
||||
break;
|
||||
}
|
||||
@@ -1099,6 +1158,15 @@ llama_tokens common_speculative_draft(
|
||||
impl->n_call_draft++;
|
||||
}
|
||||
|
||||
{
|
||||
const int n_min = impl->n_min(params);
|
||||
|
||||
if (!result.empty() && (int) result.size() < n_min) {
|
||||
LOG_DBG("%s: ignoring small draft: %d < %d\n", __func__, (int) result.size(), n_min);
|
||||
result.clear();
|
||||
}
|
||||
}
|
||||
|
||||
if (!result.empty()) {
|
||||
LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__,
|
||||
common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(),
|
||||
@@ -1108,7 +1176,7 @@ llama_tokens common_speculative_draft(
|
||||
impl->n_gen_drafts++;
|
||||
impl->n_gen_tokens += result.size();
|
||||
|
||||
break; // We have a draft, so break out of the loop and return it.
|
||||
break; // we have a draft, so break out of the loop and return it.
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1136,6 +1204,32 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) {
|
||||
}
|
||||
}
|
||||
|
||||
int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params) {
|
||||
if (spec == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t n_max = 0;
|
||||
for (const auto & impl : spec->impls) {
|
||||
n_max = std::max(n_max, impl->n_max(params));
|
||||
}
|
||||
|
||||
return n_max;
|
||||
}
|
||||
|
||||
int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params) {
|
||||
if (spec == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t n_min = 0;
|
||||
for (const auto & impl : spec->impls) {
|
||||
n_min = std::max(n_min, impl->n_min(params));
|
||||
}
|
||||
|
||||
return n_min;
|
||||
}
|
||||
|
||||
void common_speculative_print_stats(const common_speculative * spec) {
|
||||
if (spec == nullptr) {
|
||||
return;
|
||||
|
||||
@@ -33,6 +33,9 @@ llama_tokens common_speculative_draft(
|
||||
// informs the speculative decoder that n_accepted tokens were accepted by the target model
|
||||
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
|
||||
|
||||
int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params);
|
||||
int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params);
|
||||
|
||||
// print statistics about the speculative decoding
|
||||
void common_speculative_print_stats(const common_speculative * spec);
|
||||
|
||||
|
||||
+49
-32
@@ -272,6 +272,22 @@ class ModelBase:
|
||||
|
||||
return tensors
|
||||
|
||||
@staticmethod
|
||||
def _scale_is_trivial(scale: Tensor) -> bool:
|
||||
return scale.numel() <= 1 and abs(float(scale.float().sum()) - 1.0) < 1e-6
|
||||
|
||||
def _write_scale_tensor(self, scale_name: str, scale: Tensor):
|
||||
if not self._scale_is_trivial(scale):
|
||||
scale_f32 = scale.float().numpy().flatten()
|
||||
logger.info(f" + {scale_name} (per-tensor scale, shape [{scale_f32.size}])")
|
||||
self.gguf_writer.add_tensor(scale_name, scale_f32)
|
||||
|
||||
def _write_scales_tensor(self, scale_name: str, scales: list[float]):
|
||||
if not np.allclose(scales, 1.0, atol=1e-6):
|
||||
scale_vals = np.array(scales, dtype=np.float32)
|
||||
logger.info(f" + {scale_name} (per-expert scale, shape [{len(scales)}])")
|
||||
self.gguf_writer.add_tensor(scale_name, scale_vals)
|
||||
|
||||
def dequant_model(self):
|
||||
# If all quantized tensors were already handled (e.g. pure NVFP4), skip
|
||||
if self._is_nvfp4 and not any(k.endswith((".weight_scale", ".weight_scale_inv")) for k in self.model_tensors):
|
||||
@@ -494,7 +510,7 @@ class ModelBase:
|
||||
s = self.model_tensors[name]
|
||||
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s(), None)
|
||||
tensors_to_remove.append(name)
|
||||
if name.endswith((".k_scale", ".v_scale")):
|
||||
if name.endswith((".input_scale", ".k_scale", ".v_scale")):
|
||||
tensors_to_remove.append(name)
|
||||
elif quant_method is not None:
|
||||
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
|
||||
@@ -602,10 +618,6 @@ class ModelBase:
|
||||
raw = np.concatenate([d_grouped, qs_grouped], axis=-1).reshape(out_features, n_super * 36)
|
||||
return raw, [out_features, n_super * 64]
|
||||
|
||||
@staticmethod
|
||||
def _nvfp4_scale2_is_trivial(scale2: Tensor) -> bool:
|
||||
return scale2.numel() <= 1 and abs(float(scale2.float().sum()) - 1.0) < 1e-6
|
||||
|
||||
def _repack_nvfp4(self, name: str, weight: Tensor, scale: Tensor, scale2: Tensor, input_scale: Tensor):
|
||||
if "language_model." in name:
|
||||
name = name.replace("language_model.", "")
|
||||
@@ -616,19 +628,8 @@ class ModelBase:
|
||||
logger.info(f"Repacked {new_name} with shape {shape} and quantization NVFP4")
|
||||
self.gguf_writer.add_tensor(new_name, raw, raw_dtype=gguf.GGMLQuantizationType.NVFP4)
|
||||
|
||||
# Emit per-tensor scale2 as a separate F32 tensor when non-trivial
|
||||
if not self._nvfp4_scale2_is_trivial(scale2):
|
||||
scale2_f32 = scale2.float().numpy().flatten()
|
||||
scale_name = new_name.replace(".weight", ".scale")
|
||||
logger.info(f" + {scale_name} (per-tensor NVFP4 scale2, shape [{scale2_f32.size}])")
|
||||
self.gguf_writer.add_tensor(scale_name, scale2_f32)
|
||||
|
||||
# Emit per-tensor input_scale as a separate F32 tensor when non-trivial
|
||||
if not self._nvfp4_scale2_is_trivial(input_scale):
|
||||
input_scale_f32 = input_scale.float().numpy().flatten()
|
||||
input_scale_name = new_name.replace(".weight", ".input_scale")
|
||||
logger.info(f" + {input_scale_name} (per-tensor NVFP4 input_scale, shape [{input_scale_f32.size}])")
|
||||
self.gguf_writer.add_tensor(input_scale_name, input_scale_f32)
|
||||
self._write_scale_tensor(new_name.replace(".weight", ".scale"), scale2)
|
||||
self._write_scale_tensor(new_name.replace(".weight", ".input_scale"), input_scale)
|
||||
|
||||
def _generate_nvfp4_tensors(self):
|
||||
# Per-layer expert merging to avoid holding all experts in memory
|
||||
@@ -719,24 +720,17 @@ class ModelBase:
|
||||
logger.info(f"Repacked {new_name} with shape [{len(experts)}, {shape[0]}, {shape[1]}] and quantization NVFP4")
|
||||
self.gguf_writer.add_tensor(new_name, merged, raw_dtype=gguf.GGMLQuantizationType.NVFP4)
|
||||
|
||||
# Emit per-expert scale2 tensor if any expert has non-trivial scale2
|
||||
scales.sort(key=lambda x: x[0])
|
||||
scale_vals = np.array([s[1] for s in scales], dtype=np.float32)
|
||||
if not np.allclose(scale_vals, 1.0, atol=1e-6):
|
||||
scale_name = new_name.replace(".weight", ".scale")
|
||||
logger.info(f" + {scale_name} (per-expert NVFP4 scale2, shape [{len(scales)}])")
|
||||
self.gguf_writer.add_tensor(scale_name, scale_vals)
|
||||
self._write_scales_tensor(new_name.replace(".weight", ".scale"), [s[1] for s in scales])
|
||||
|
||||
# Emit per-expert input_scale tensor if any expert has non-trivial input_scale
|
||||
input_scales.sort(key=lambda x: x[0])
|
||||
input_scale_vals = np.array([s[1] for s in input_scales], dtype=np.float32)
|
||||
if not np.allclose(input_scale_vals, 1.0, atol=1e-6):
|
||||
input_scale_name = new_name.replace(".weight", ".input_scale")
|
||||
logger.info(f" + {input_scale_name} (per-expert NVFP4 input_scale, shape [{len(input_scales)}])")
|
||||
self.gguf_writer.add_tensor(input_scale_name, input_scale_vals)
|
||||
self._write_scales_tensor(new_name.replace(".weight", ".input_scale"), [s[1] for s in input_scales])
|
||||
|
||||
del experts, merged
|
||||
|
||||
def _needs_nvfp4_processing(self) -> bool:
|
||||
return True
|
||||
|
||||
def prepare_tensors(self):
|
||||
# detect NVFP4 quantization (ModelOpt format)
|
||||
quant_algo = (self.hparams.get("quantization_config") or {}).get("quant_algo")
|
||||
@@ -767,7 +761,7 @@ class ModelBase:
|
||||
# NVFP4 weights are repacked and written directly to gguf_writer.
|
||||
# This must run before dequant_model so NVFP4 tensors are removed
|
||||
# from model_tensors, leaving only non-NVFP4 (e.g. FP8) for dequant.
|
||||
if self._is_nvfp4:
|
||||
if self._is_nvfp4 and self._needs_nvfp4_processing():
|
||||
self._generate_nvfp4_tensors()
|
||||
|
||||
self.dequant_model()
|
||||
@@ -2199,6 +2193,10 @@ class MmprojModel(ModelBase):
|
||||
# merge configs
|
||||
self.preprocessor_config = {**self.preprocessor_config, **cfg}
|
||||
|
||||
def _needs_nvfp4_processing(self) -> bool:
|
||||
# nvfp4 quantization applies to the text model only.
|
||||
return False
|
||||
|
||||
def get_vision_config(self) -> dict[str, Any] | None:
|
||||
config_name = "vision_config" if not self.is_mistral_format else "vision_encoder"
|
||||
return self.global_config.get(config_name)
|
||||
@@ -4459,6 +4457,12 @@ class NemotronNanoV2VLModel(MmprojModel):
|
||||
}
|
||||
return vision_config
|
||||
|
||||
def dequant_model(self):
|
||||
if self._is_nvfp4:
|
||||
# Skip nvfp4 quantization for vision/audio model.
|
||||
return
|
||||
super().dequant_model()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
if "image_mean" not in self.preprocessor_config:
|
||||
self.preprocessor_config["image_mean"] = [0.485, 0.456, 0.406]
|
||||
@@ -4482,6 +4486,10 @@ class NemotronNanoV2VLModel(MmprojModel):
|
||||
if "input_conditioner" in name:
|
||||
return
|
||||
|
||||
# mtmd does not support video yet so skip tensors related to video.
|
||||
if "radio_model.model.patch_generator.video_embedder" in name:
|
||||
return
|
||||
|
||||
# RADIO's pos_embed doesn't have .weight suffix, but clip.cpp expects it
|
||||
if "patch_generator.pos_embed" in name:
|
||||
if not name.endswith(".weight"):
|
||||
@@ -10829,7 +10837,11 @@ class NemotronHModel(GraniteHybridModel):
|
||||
# uses self.model_arch to build the tensor name map, and all MoE-specific
|
||||
# mappings would be missed if it were called with the default non-MoE arch.
|
||||
hparams = ModelBase.load_hparams(args[0], self.is_mistral_format)
|
||||
if "num_experts_per_tok" in hparams:
|
||||
has_moe_params = (
|
||||
"num_experts_per_tok" in hparams
|
||||
or (isinstance(hparams.get("llm_config"), dict) and "num_experts_per_tok" in hparams["llm_config"])
|
||||
)
|
||||
if has_moe_params:
|
||||
self.model_arch = gguf.MODEL_ARCH.NEMOTRON_H_MOE
|
||||
self.is_moe = True
|
||||
|
||||
@@ -10976,6 +10988,11 @@ class NemotronHModel(GraniteHybridModel):
|
||||
if name.startswith(("vision_model.", "mlp1.")):
|
||||
return
|
||||
|
||||
if name.startswith(("sound_encoder.")):
|
||||
return
|
||||
if name.startswith(("sound_projection.")):
|
||||
return
|
||||
|
||||
# Strip language_model. prefix for VLM models (e.g., Nemotron Nano 12B v2 VL)
|
||||
if name.startswith("language_model."):
|
||||
name = name[len("language_model."):]
|
||||
|
||||
@@ -73,12 +73,12 @@ static void write_help(std::ostringstream & ss, const md_file & md) {
|
||||
auto ctx_arg = common_params_parser_init(params, md.ex);
|
||||
|
||||
std::vector<common_arg *> common_options;
|
||||
std::vector<common_arg *> sparam_options;
|
||||
std::vector<common_arg *> sampling_options;
|
||||
std::vector<common_arg *> specific_options;
|
||||
for (auto & opt : ctx_arg.options) {
|
||||
// in case multiple LLAMA_EXAMPLE_* are set, we prioritize the LLAMA_EXAMPLE_* matching current example
|
||||
if (opt.is_sparam) {
|
||||
sparam_options.push_back(&opt);
|
||||
if (opt.is_sampling) {
|
||||
sampling_options.push_back(&opt);
|
||||
} else if (opt.in_example(ctx_arg.ex)) {
|
||||
specific_options.push_back(&opt);
|
||||
} else {
|
||||
@@ -93,7 +93,7 @@ static void write_help(std::ostringstream & ss, const md_file & md) {
|
||||
ss << "### Common params\n\n";
|
||||
write_table(ss, common_options);
|
||||
ss << "\n\n### Sampling params\n\n";
|
||||
write_table(ss, sparam_options);
|
||||
write_table(ss, sampling_options);
|
||||
ss << "\n\n### " << md.specific_section_header << "\n\n";
|
||||
write_table(ss, specific_options);
|
||||
|
||||
|
||||
@@ -37,9 +37,9 @@ int main(int argc, char ** argv){
|
||||
|
||||
common_ngram_cache ngram_cache;
|
||||
common_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true);
|
||||
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.speculative.lookup_cache_static.c_str());
|
||||
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.speculative.ngram_cache.lookup_cache_static.c_str());
|
||||
|
||||
common_ngram_cache_save(ngram_cache, params.speculative.lookup_cache_static);
|
||||
common_ngram_cache_save(ngram_cache, params.speculative.ngram_cache.lookup_cache_static);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ int main(int argc, char ** argv){
|
||||
return 1;
|
||||
}
|
||||
|
||||
const int n_draft = params.speculative.n_max;
|
||||
const int n_draft = params.speculative.draft.n_max;
|
||||
|
||||
// init llama.cpp
|
||||
llama_backend_init();
|
||||
@@ -49,18 +49,18 @@ int main(int argc, char ** argv){
|
||||
{
|
||||
const int64_t t_start_draft_us = ggml_time_us();
|
||||
|
||||
if (!params.speculative.lookup_cache_static.empty()) {
|
||||
if (!params.speculative.ngram_cache.lookup_cache_static.empty()) {
|
||||
try {
|
||||
ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static);
|
||||
ngram_cache_static = common_ngram_cache_load(params.speculative.ngram_cache.lookup_cache_static);
|
||||
} catch (std::ifstream::failure const &) {
|
||||
LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str());
|
||||
LOG_ERR("failed to open static lookup cache: %s", params.speculative.ngram_cache.lookup_cache_static.c_str());
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
if (!params.speculative.lookup_cache_dynamic.empty()) {
|
||||
if (!params.speculative.ngram_cache.lookup_cache_dynamic.empty()) {
|
||||
try {
|
||||
ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic);
|
||||
ngram_cache_dynamic = common_ngram_cache_load(params.speculative.ngram_cache.lookup_cache_dynamic);
|
||||
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
|
||||
}
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ int main(int argc, char ** argv){
|
||||
}
|
||||
|
||||
// max. number of additional tokens to draft if match is found
|
||||
const int n_draft = params.speculative.n_max;
|
||||
const int n_draft = params.speculative.draft.n_max;
|
||||
|
||||
// init llama.cpp
|
||||
llama_backend_init();
|
||||
@@ -54,18 +54,18 @@ int main(int argc, char ** argv){
|
||||
const int64_t t_start_draft_us = ggml_time_us();
|
||||
common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false);
|
||||
|
||||
if (!params.speculative.lookup_cache_static.empty()) {
|
||||
if (!params.speculative.ngram_cache.lookup_cache_static.empty()) {
|
||||
try {
|
||||
ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static);
|
||||
ngram_cache_static = common_ngram_cache_load(params.speculative.ngram_cache.lookup_cache_static);
|
||||
} catch (std::ifstream::failure const &) {
|
||||
LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str());
|
||||
LOG_ERR("failed to open static lookup cache: %s", params.speculative.ngram_cache.lookup_cache_static.c_str());
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
if (!params.speculative.lookup_cache_dynamic.empty()) {
|
||||
if (!params.speculative.ngram_cache.lookup_cache_dynamic.empty()) {
|
||||
try {
|
||||
ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic);
|
||||
ngram_cache_dynamic = common_ngram_cache_load(params.speculative.ngram_cache.lookup_cache_dynamic);
|
||||
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
|
||||
}
|
||||
|
||||
@@ -213,7 +213,7 @@ int main(int argc, char ** argv){
|
||||
|
||||
// Update dynamic ngram cache with context ngram cache and save it to disk:
|
||||
common_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context);
|
||||
common_ngram_cache_save(ngram_cache_dynamic, params.speculative.lookup_cache_dynamic);
|
||||
common_ngram_cache_save(ngram_cache_dynamic, params.speculative.ngram_cache.lookup_cache_dynamic);
|
||||
|
||||
LOG("\n\n");
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.speculative.mparams_dft.path.empty()) {
|
||||
if (params.speculative.draft.mparams.path.empty()) {
|
||||
LOG_ERR("%s: --model-draft is required\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@@ -77,7 +77,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// TODO: simplify this logic
|
||||
{
|
||||
const auto & params_spec = params.speculative;
|
||||
const auto & params_spec = params.speculative.draft;
|
||||
|
||||
auto params_dft = params;
|
||||
|
||||
@@ -85,15 +85,15 @@ int main(int argc, char ** argv) {
|
||||
params_dft.n_ctx = params_spec.n_ctx;
|
||||
params_dft.n_batch = llama_n_ctx_seq(ctx_tgt);
|
||||
params_dft.devices = params_spec.devices;
|
||||
params_dft.model = params_spec.mparams_dft;
|
||||
params_dft.model = params_spec.mparams;
|
||||
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
|
||||
|
||||
if (params_spec.cpuparams.n_threads > 0) {
|
||||
params_dft.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
|
||||
params_dft.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
|
||||
params_dft.cpuparams.n_threads = params.speculative.draft.cpuparams.n_threads;
|
||||
params_dft.cpuparams_batch.n_threads = params.speculative.draft.cpuparams_batch.n_threads;
|
||||
}
|
||||
|
||||
params_dft.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
|
||||
params_dft.tensor_buft_overrides = params.speculative.draft.tensor_buft_overrides;
|
||||
|
||||
auto mparams_dft = common_model_params_to_llama(params_dft);
|
||||
|
||||
@@ -103,8 +103,8 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
params.speculative.model_dft = model_dft.get();
|
||||
params.speculative.cparams_dft = common_context_params_to_llama(params_dft);
|
||||
params.speculative.draft.model = model_dft.get();
|
||||
params.speculative.draft.cparams = common_context_params_to_llama(params_dft);
|
||||
}
|
||||
|
||||
// Tokenize the prompt
|
||||
@@ -187,16 +187,6 @@ int main(int argc, char ** argv) {
|
||||
// generate a new draft
|
||||
draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last);
|
||||
|
||||
if ((int) draft.size() > params_spec.n_max) {
|
||||
LOG_WRN("draft size %zu exceeds max %d, truncating\n", draft.size(), params_spec.n_max);
|
||||
draft.resize(params_spec.n_max);
|
||||
}
|
||||
|
||||
if ((int) draft.size() < params_spec.n_min) {
|
||||
LOG_DBG("ignoring small draft: %zu < %d\n", draft.size(), params_spec.n_min);
|
||||
draft.clear();
|
||||
}
|
||||
|
||||
// save the original draft size
|
||||
n_draft = draft.size();
|
||||
|
||||
@@ -220,19 +210,12 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_draft > 0);
|
||||
|
||||
// always have a token to evaluate from before - id_last
|
||||
common_batch_clear(batch_tgt);
|
||||
common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true);
|
||||
|
||||
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
|
||||
{
|
||||
// do not waste time on small drafts
|
||||
if (draft.size() < (size_t) params_spec.n_min) {
|
||||
draft.clear();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < draft.size(); ++i) {
|
||||
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
|
||||
}
|
||||
@@ -340,7 +323,7 @@ int main(int argc, char ** argv) {
|
||||
LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
|
||||
|
||||
LOG_INF("\n");
|
||||
LOG_INF("n_draft = %d\n", params_spec.n_max);
|
||||
LOG_INF("n_draft = %d\n", params_spec.draft.n_max);
|
||||
LOG_INF("n_predict = %d\n", n_predict);
|
||||
LOG_INF("n_drafted = %d\n", n_drafted);
|
||||
LOG_INF("n_accept = %d\n", n_accept);
|
||||
|
||||
@@ -49,7 +49,7 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.speculative.mparams_dft.path.empty()) {
|
||||
if (params.speculative.draft.mparams.path.empty()) {
|
||||
LOG_ERR("%s: --model-draft is required\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@@ -58,7 +58,7 @@ int main(int argc, char ** argv) {
|
||||
const int n_seq_dft = params.n_parallel;
|
||||
|
||||
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
|
||||
const float p_draft_split = params.speculative.p_split;
|
||||
const float p_draft_split = params.speculative.draft.p_split;
|
||||
|
||||
std::default_random_engine rng(params.sampling.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sampling.seed);
|
||||
std::uniform_real_distribution<> u_dist;
|
||||
@@ -80,15 +80,15 @@ int main(int argc, char ** argv) {
|
||||
ctx_tgt = llama_init_tgt->context();
|
||||
|
||||
// load the draft model
|
||||
params.devices = params.speculative.devices;
|
||||
params.model = params.speculative.mparams_dft;
|
||||
params.n_gpu_layers = params.speculative.n_gpu_layers;
|
||||
if (params.speculative.cpuparams.n_threads > 0) {
|
||||
params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
|
||||
params.devices = params.speculative.draft.devices;
|
||||
params.model = params.speculative.draft.mparams;
|
||||
params.n_gpu_layers = params.speculative.draft.n_gpu_layers;
|
||||
if (params.speculative.draft.cpuparams.n_threads > 0) {
|
||||
params.cpuparams.n_threads = params.speculative.draft.cpuparams.n_threads;
|
||||
}
|
||||
|
||||
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
|
||||
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
|
||||
params.cpuparams_batch.n_threads = params.speculative.draft.cpuparams_batch.n_threads;
|
||||
params.tensor_buft_overrides = params.speculative.draft.tensor_buft_overrides;
|
||||
|
||||
auto llama_init_dft = common_init_from_params(params);
|
||||
|
||||
@@ -183,7 +183,7 @@ int main(int argc, char ** argv) {
|
||||
//GGML_ASSERT(n_vocab == llama_vocab_n_tokens(model_dft));
|
||||
|
||||
// how many tokens to draft each time
|
||||
int n_draft = params.speculative.n_max;
|
||||
int n_draft = params.speculative.draft.n_max;
|
||||
|
||||
int n_predict = 0;
|
||||
int n_drafted = 0;
|
||||
|
||||
@@ -470,11 +470,10 @@ endforeach()
|
||||
|
||||
target_link_libraries(ggml-base PRIVATE Threads::Threads)
|
||||
|
||||
find_library(MATH_LIBRARY m)
|
||||
if (MATH_LIBRARY)
|
||||
if (NOT WIN32 OR NOT DEFINED ENV{ONEAPI_ROOT})
|
||||
target_link_libraries(ggml-base PRIVATE ${MATH_LIBRARY})
|
||||
endif()
|
||||
if (DEFINED MATH_LIBRARY)
|
||||
target_link_libraries(ggml-base PRIVATE ${MATH_LIBRARY})
|
||||
elseif (NOT WIN32 AND NOT DEFINED ENV{ONEAPI_ROOT})
|
||||
target_link_libraries(ggml-base PRIVATE m)
|
||||
endif()
|
||||
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Android")
|
||||
|
||||
@@ -1826,7 +1826,24 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
|
||||
continue;
|
||||
}
|
||||
|
||||
i = get_i_delayed(i);
|
||||
const int i_delayed = get_i_delayed(i);
|
||||
|
||||
// If we can delay the AllReduce we need to consider the interaction with zero-sized tensor slices.
|
||||
// A backend with such a slice would normally have valid data after participating in the AllReduce with a node that has
|
||||
// its compute flag disabled and thus gets its data zeroed out.
|
||||
// If the AllReduce is delayed then the nodes until that point also need to have their compute flag disabled.
|
||||
if (i_delayed > i) {
|
||||
for (size_t j = 0; j < n_backends; j++) {
|
||||
auto & bcj = backend_ctx->backend_configs[j];
|
||||
if ((bcj.nodes[i]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
||||
for (int ii = i + 1; ii <= i_delayed; ii++) {
|
||||
bcj.nodes[ii]->flags &= ~GGML_TENSOR_FLAG_COMPUTE;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
i = i_delayed;
|
||||
|
||||
for (size_t j = 0; j < n_backends; j++) {
|
||||
auto & bcj = backend_ctx->backend_configs[j];
|
||||
|
||||
@@ -181,6 +181,12 @@ struct ggml_backend_registry {
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto & entry : backends) {
|
||||
if (entry.reg == reg) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n",
|
||||
__func__, ggml_backend_reg_name(reg), ggml_backend_reg_dev_count(reg));
|
||||
@@ -192,6 +198,12 @@ struct ggml_backend_registry {
|
||||
}
|
||||
|
||||
void register_device(ggml_backend_dev_t device) {
|
||||
for (auto & dev : devices) {
|
||||
if (dev == device) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device));
|
||||
#endif
|
||||
|
||||
+512
-248
@@ -25,6 +25,7 @@
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml.h"
|
||||
|
||||
|
||||
#include <aclnnop/aclnn_add.h>
|
||||
#include <aclnnop/aclnn_add_rms_norm.h>
|
||||
#include <aclnnop/aclnn_addcdiv.h>
|
||||
@@ -45,7 +46,9 @@
|
||||
#include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
|
||||
#include <aclnnop/aclnn_ger.h>
|
||||
#include <aclnnop/aclnn_group_norm.h>
|
||||
#include <aclnnop/aclnn_gather_v2.h>
|
||||
#include <aclnnop/aclnn_grouped_matmul_v3.h>
|
||||
#include <aclnnop/aclnn_scatter.h>
|
||||
#include <aclnnop/aclnn_gt_scalar.h>
|
||||
#include <aclnnop/aclnn_im2col.h>
|
||||
#include <aclnnop/aclnn_index_copy.h>
|
||||
@@ -62,6 +65,7 @@
|
||||
#include <aclnnop/aclnn_permute.h>
|
||||
#include <aclnnop/aclnn_pow.h>
|
||||
#include <aclnnop/aclnn_pow_tensor_tensor.h>
|
||||
#include <aclnnop/aclnn_recurrent_gated_delta_rule.h>
|
||||
#include <aclnnop/aclnn_reduce_sum.h>
|
||||
#include <aclnnop/aclnn_reflection_pad1d.h>
|
||||
#include <aclnnop/aclnn_repeat.h>
|
||||
@@ -69,11 +73,15 @@
|
||||
#include <aclnnop/aclnn_rms_norm.h>
|
||||
#include <aclnnop/aclnn_roll.h>
|
||||
#include <aclnnop/aclnn_softmax.h>
|
||||
#include <aclnnop/aclnn_softmax_cross_entropy_with_logits.h>
|
||||
#include <aclnnop/aclnn_sub.h>
|
||||
#include <aclnnop/aclnn_sum.h>
|
||||
#include <aclnnop/aclnn_threshold.h>
|
||||
#include <aclnnop/aclnn_tril.h>
|
||||
#include <aclnnop/aclnn_triangular_solve.h>
|
||||
#include <aclnnop/aclnn_triu.h>
|
||||
#include <aclnnop/aclnn_logical_not.h>
|
||||
#include <aclnnop/aclnn_masked_fill_scalar.h>
|
||||
#include <aclnnop/aclnn_upsample_nearest_2d.h>
|
||||
#include <aclnnop/aclnn_weight_quant_batch_matmul_v2.h>
|
||||
#include <aclnnop/aclnn_zero.h>
|
||||
@@ -151,6 +159,107 @@ void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, ac
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst.get(), acl_src1.get());
|
||||
}
|
||||
|
||||
// Fused SwiGLU using aclnnSwiGlu: splits input along innermost dim, applies
|
||||
// SiLU to left half, multiplies by right half.
|
||||
//
|
||||
// Falls back to the generic two-kernel path when src[1] != nullptr (two
|
||||
// independent halves) or swapped != 0 (reversed activation order), as
|
||||
// aclnnSwiGlu only handles the single interleaved tensor in standard order.
|
||||
//
|
||||
// CANN tiling for SwiGlu requires (storageShapeDim + viewDims) to be even.
|
||||
// aclCreateTensor always uses storageShapeDim=1, so viewDims must be odd.
|
||||
// We use a 3D view (1+3=4, even) to satisfy this constraint while preserving
|
||||
// correct split semantics along the innermost (ne[0]) dimension.
|
||||
void ggml_cann_swiglu(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
auto silu_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Silu, acl_src, acl_dst);
|
||||
};
|
||||
|
||||
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||
if (dst->src[1] != nullptr || swapped != 0) {
|
||||
ggml_cann_op_unary_gated(silu_fn, ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
// aclnnSwiGlu requires the split dim (src->ne[0]) to be even; fall back otherwise.
|
||||
if (dst->src[0]->ne[0] % 2 != 0) {
|
||||
ggml_cann_op_unary_gated(silu_fn, ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
size_t elem_size = ggml_element_size(src0);
|
||||
|
||||
// src0 GGML: [2*ne0, ne1, ne2, ne3] → 3D view [2*ne0, ne1, ne2*ne3]
|
||||
// CANN reversed: [ne2*ne3, ne1, 2*ne0], split along CANN dim 2 (last).
|
||||
int64_t ne0_x2 = src0->ne[0];
|
||||
int64_t ne1 = src0->ne[1];
|
||||
int64_t ne23 = src0->ne[2] * src0->ne[3];
|
||||
int64_t src3d_ne[] = { ne0_x2, ne1, ne23 };
|
||||
size_t src3d_nb[] = { (size_t)src0->nb[0], (size_t)src0->nb[1], (size_t)src0->nb[2] };
|
||||
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0->data, ggml_cann_type_mapping(src0->type),
|
||||
elem_size, src3d_ne, src3d_nb, 3);
|
||||
|
||||
// dst GGML: [ne0, ne1, ne2, ne3] → 3D view [ne0, ne1, ne2*ne3]
|
||||
int64_t ne0 = dst->ne[0];
|
||||
int64_t dst3d_ne[] = { ne0, ne1, ne23 };
|
||||
size_t dst3d_nb[] = { (size_t)dst->nb[0], (size_t)dst->nb[1], (size_t)dst->nb[2] };
|
||||
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type),
|
||||
elem_size, dst3d_ne, dst3d_nb, 3);
|
||||
|
||||
// CANN tensor [ne23, ne1, 2*ne0]: split along CANN dim 2 (last) = 2*ne0.
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, SwiGlu, acl_src.get(), (int64_t)2, acl_dst.get());
|
||||
}
|
||||
|
||||
// Fused GeGLU using aclnnGeGluV3: splits input along ne[0] (CANN last dim),
|
||||
// activates the LEFT half with GELU, multiplies by right half.
|
||||
// approximate: 0=tanh, 1=none(erf). activateLeft=true matches GGML convention.
|
||||
// outGelu is a required-but-discard output buffer.
|
||||
//
|
||||
// Falls back to the generic two-kernel path when src[1] != nullptr (two
|
||||
// independent halves) or swapped != 0 (reversed activation order), as
|
||||
// aclnnGeGluV3 only handles the single interleaved tensor in standard order.
|
||||
void ggml_cann_geglu(ggml_backend_cann_context & ctx, ggml_tensor * dst, int64_t approximate) {
|
||||
auto gelu_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Gelu, acl_src, acl_dst);
|
||||
};
|
||||
|
||||
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||
if (dst->src[1] != nullptr || swapped != 0) {
|
||||
ggml_cann_op_unary_gated(gelu_fn, ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
// aclnnGeGluV3 requires the split dim (src->ne[0]) to be even; fall back otherwise.
|
||||
if (dst->src[0]->ne[0] % 2 != 0) {
|
||||
ggml_cann_op_unary_gated(gelu_fn, ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0);
|
||||
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
|
||||
|
||||
// Allocate a temporary buffer for the required outGelu output (same shape as dst).
|
||||
// Build contiguous strides since the pool allocation is a fresh buffer.
|
||||
size_t elem_size = ggml_element_size(dst);
|
||||
int64_t ne[GGML_MAX_DIMS] = { dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3] };
|
||||
size_t nb[GGML_MAX_DIMS];
|
||||
nb[0] = elem_size;
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
nb[i] = nb[i - 1] * ne[i - 1];
|
||||
}
|
||||
size_t gelu_out_size = nb[GGML_MAX_DIMS - 1] * ne[GGML_MAX_DIMS - 1];
|
||||
ggml_cann_pool_alloc gelu_out_alloc(ctx.pool(), gelu_out_size);
|
||||
|
||||
acl_tensor_ptr acl_gelu_out = ggml_cann_create_tensor(
|
||||
gelu_out_alloc.get(), ggml_cann_type_mapping(dst->type), elem_size, ne, nb, GGML_MAX_DIMS);
|
||||
// V3 adds activateLeft param; true → Gelu(left)*right, matching GGML convention.
|
||||
// GGML dim 0 → CANN last dim (index GGML_MAX_DIMS-1 = 3 for 4D tensor).
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, GeGluV3, acl_src.get(), (int64_t)(GGML_MAX_DIMS - 1), approximate, true,
|
||||
acl_dst.get(), acl_gelu_out.get());
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Repeats elements of a tensor along each dimension according to the
|
||||
* specified repeat array.
|
||||
@@ -445,28 +554,33 @@ void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cann_pool_alloc temp_buffer_allocator(ctx.pool(), n_bytes);
|
||||
void * buffer = temp_buffer_allocator.get();
|
||||
|
||||
int64_t div_ne[] = { 1, src->ne[1], src->ne[2], src->ne[3] };
|
||||
size_t div_nb[GGML_MAX_DIMS];
|
||||
div_nb[0] = sizeof(float);
|
||||
int64_t norm_ne[] = { 1, src->ne[1], src->ne[2], src->ne[3] };
|
||||
size_t norm_nb[GGML_MAX_DIMS];
|
||||
norm_nb[0] = sizeof(float);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
|
||||
div_nb[i] = div_nb[i - 1] * div_ne[i - 1];
|
||||
norm_nb[i] = norm_nb[i - 1] * norm_ne[i - 1];
|
||||
}
|
||||
acl_tensor_ptr acl_div = ggml_cann_create_tensor(buffer, ACL_FLOAT, type_size, div_ne, div_nb, GGML_MAX_DIMS);
|
||||
acl_tensor_ptr acl_norm = ggml_cann_create_tensor(buffer, ACL_FLOAT, sizeof(float), norm_ne, norm_nb, GGML_MAX_DIMS);
|
||||
|
||||
std::vector<int64_t> norm_dims = { 3 };
|
||||
acl_int_array_ptr dims_array = ggml_cann_create_int_array(norm_dims.data(), norm_dims.size());
|
||||
|
||||
float p_value = 2.0f;
|
||||
acl_scalar_ptr p_scalar = ggml_cann_create_scalar(&p_value, aclDataType::ACL_FLOAT);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src.get(), p_scalar.get(), dims_array.get(), true, acl_div.get());
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src.get(), p_scalar.get(), dims_array.get(), true, acl_norm.get());
|
||||
|
||||
// Clamp norm to at least eps: scale = 1/fmaxf(norm, eps)
|
||||
acl_scalar_ptr acl_min = ggml_cann_create_scalar(&eps, aclDataType::ACL_FLOAT);
|
||||
float flt_max = FLT_MAX;
|
||||
acl_scalar_ptr acl_max = ggml_cann_create_scalar(&flt_max, aclDataType::ACL_FLOAT);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Clamp, acl_div.get(), acl_min.get(), acl_max.get(), acl_div.get());
|
||||
ggml_cann_pool_alloc clamp_buffer_allocator(ctx.pool());
|
||||
acl_tensor_ptr acl_clamped;
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src.get(), acl_div.get(), acl_dst.get());
|
||||
if (eps > 0.0f) {
|
||||
void * clamp_buf = clamp_buffer_allocator.alloc(n_bytes);
|
||||
acl_clamped = ggml_cann_create_tensor(clamp_buf, ACL_FLOAT, sizeof(float), norm_ne, norm_nb, GGML_MAX_DIMS);
|
||||
acl_scalar_ptr eps_scalar = ggml_cann_create_scalar(&eps, aclDataType::ACL_FLOAT);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, ClampMin, acl_norm.get(), eps_scalar.get(), acl_clamped.get());
|
||||
}
|
||||
|
||||
aclTensor * acl_div_input = acl_clamped ? acl_clamped.get() : acl_norm.get();
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src.get(), acl_div_input, acl_dst.get());
|
||||
}
|
||||
|
||||
void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
@@ -482,56 +596,30 @@ void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor *
|
||||
logits_nb[1] = logits_nb[0] * logits_ne[0];
|
||||
acl_tensor_ptr acl_logits = ggml_cann_create_tensor(src0->data, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2);
|
||||
|
||||
size_t log_softmax_type_size = sizeof(float);
|
||||
int64_t log_softmax_n_bytes = nr * nc * log_softmax_type_size;
|
||||
ggml_cann_pool_alloc log_softmax_allocator(ctx.pool(), log_softmax_n_bytes);
|
||||
void * log_softmax_buffer = log_softmax_allocator.get();
|
||||
|
||||
int64_t log_softmax_ne[] = { nc, nr };
|
||||
size_t log_softmax_nb[2];
|
||||
log_softmax_nb[0] = log_softmax_type_size;
|
||||
log_softmax_nb[1] = log_softmax_nb[0] * log_softmax_ne[0];
|
||||
acl_tensor_ptr acl_log_softmax = ggml_cann_create_tensor(log_softmax_buffer, ACL_FLOAT, log_softmax_type_size,
|
||||
log_softmax_ne, log_softmax_nb, 2);
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, LogSoftmax, acl_logits.get(), 1, acl_log_softmax.get());
|
||||
|
||||
int64_t labels_ne[] = { nc, nr };
|
||||
size_t labels_nb[2];
|
||||
labels_nb[0] = ggml_type_size(src1->type);
|
||||
labels_nb[1] = labels_nb[0] * labels_ne[0];
|
||||
acl_tensor_ptr acl_labels = ggml_cann_create_tensor(src1->data, ACL_FLOAT, sizeof(float), labels_ne, labels_nb, 2);
|
||||
|
||||
size_t mul_type_size = sizeof(float);
|
||||
int64_t mul_n_bytes = nr * nc * mul_type_size;
|
||||
ggml_cann_pool_alloc mul_allocator(ctx.pool(), mul_n_bytes);
|
||||
void * mul_buffer = mul_allocator.get();
|
||||
size_t loss_per_sample_type_size = sizeof(float);
|
||||
int64_t loss_per_sample_n_bytes = nr * loss_per_sample_type_size;
|
||||
ggml_cann_pool_alloc loss_per_sample_allocator(ctx.pool(), loss_per_sample_n_bytes);
|
||||
void * loss_per_sample_buffer = loss_per_sample_allocator.get();
|
||||
|
||||
int64_t mul_ne[] = { nc, nr };
|
||||
size_t mul_nb[2];
|
||||
mul_nb[0] = mul_type_size;
|
||||
mul_nb[1] = mul_nb[0] * mul_ne[0];
|
||||
acl_tensor_ptr acl_mul_result = ggml_cann_create_tensor(mul_buffer, ACL_FLOAT, mul_type_size, mul_ne, mul_nb, 2);
|
||||
int64_t loss_per_sample_ne[] = { nr };
|
||||
size_t loss_per_sample_nb[1];
|
||||
loss_per_sample_nb[0] = loss_per_sample_type_size;
|
||||
acl_tensor_ptr acl_loss_per_sample = ggml_cann_create_tensor(
|
||||
loss_per_sample_buffer, ACL_FLOAT, loss_per_sample_type_size, loss_per_sample_ne, loss_per_sample_nb, 1);
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_log_softmax.get(), acl_labels.get(), acl_mul_result.get());
|
||||
size_t backprop_n_bytes = nr * nc * sizeof(float);
|
||||
ggml_cann_pool_alloc backprop_allocator(ctx.pool(), backprop_n_bytes);
|
||||
void * backprop_buffer = backprop_allocator.get();
|
||||
acl_tensor_ptr acl_backprop = ggml_cann_create_tensor(backprop_buffer, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2);
|
||||
|
||||
size_t sum_per_sample_type_size = sizeof(float);
|
||||
int64_t sum_per_sample_n_bytes = nr * sum_per_sample_type_size;
|
||||
ggml_cann_pool_alloc sum_per_sample_allocator(ctx.pool(), sum_per_sample_n_bytes);
|
||||
void * sum_per_sample_buffer = sum_per_sample_allocator.get();
|
||||
|
||||
int64_t sum_per_sample_ne[] = { nr };
|
||||
size_t sum_per_sample_nb[1];
|
||||
sum_per_sample_nb[0] = sum_per_sample_type_size;
|
||||
acl_tensor_ptr acl_sum_per_sample = ggml_cann_create_tensor(
|
||||
sum_per_sample_buffer, ACL_FLOAT, sum_per_sample_type_size, sum_per_sample_ne, sum_per_sample_nb, 1);
|
||||
|
||||
std::vector<int64_t> sum_dims = { 1 };
|
||||
acl_int_array_ptr dims_array = ggml_cann_create_int_array(sum_dims.data(), sum_dims.size());
|
||||
bool keep_dims = false;
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_mul_result.get(), dims_array.get(), keep_dims, ACL_FLOAT,
|
||||
acl_sum_per_sample.get());
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, SoftmaxCrossEntropyWithLogits, acl_logits.get(), acl_labels.get(),
|
||||
acl_loss_per_sample.get(), acl_backprop.get());
|
||||
|
||||
size_t total_sum_type_size = sizeof(float);
|
||||
int64_t total_sum_n_bytes = 1 * total_sum_type_size;
|
||||
@@ -547,11 +635,12 @@ void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor *
|
||||
|
||||
std::vector<int64_t> total_sum_dims = { 0 };
|
||||
acl_int_array_ptr total_sum_dims_array = ggml_cann_create_int_array(total_sum_dims.data(), total_sum_dims.size());
|
||||
bool keep_dims = false;
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_sum_per_sample.get(), total_sum_dims_array.get(), keep_dims, ACL_FLOAT,
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_loss_per_sample.get(), total_sum_dims_array.get(), keep_dims, ACL_FLOAT,
|
||||
acl_total_sum.get());
|
||||
|
||||
float value = -1.0f / static_cast<float>(nr);
|
||||
float value = 1.0f / static_cast<float>(nr);
|
||||
acl_scalar_ptr scale_factor = ggml_cann_create_scalar(&value, aclDataType::ACL_FLOAT);
|
||||
acl_tensor_ptr acl_dst =
|
||||
ggml_cann_create_tensor(dst->data, ACL_FLOAT, sizeof(float), total_sum_ne, total_sum_nb, 1);
|
||||
@@ -589,6 +678,33 @@ void ggml_cann_group_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
acl_mean_out.get(), acl_rstd_out.get());
|
||||
}
|
||||
|
||||
void ggml_cann_set(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
size_t nb1 = ((int32_t *) dst->op_params)[0];
|
||||
size_t nb2 = ((int32_t *) dst->op_params)[1];
|
||||
size_t nb3 = ((int32_t *) dst->op_params)[2];
|
||||
size_t offset = ((int32_t *) dst->op_params)[3];
|
||||
bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
||||
|
||||
size_t param_nb[] = { ggml_element_size(src0), nb1, nb2, nb3 };
|
||||
|
||||
// Create a view of dst at the target offset with src1's dimensions
|
||||
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, src1->ne, param_nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset);
|
||||
acl_tensor_ptr acl_src1 = ggml_cann_create_tensor(src1);
|
||||
|
||||
if (!inplace) {
|
||||
// First copy src0 to dst entirely
|
||||
size_t cpy_size = ggml_nbytes(dst);
|
||||
ACL_CHECK(
|
||||
aclrtMemcpyAsync(dst->data, cpy_size, src0->data, cpy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
|
||||
}
|
||||
|
||||
// Copy src1 into the target region of dst
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst.get(), acl_src1.get());
|
||||
}
|
||||
|
||||
void ggml_cann_acc(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
ggml_tensor * src1 = dst->src[1];
|
||||
@@ -652,6 +768,113 @@ void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
aclnn_reduce_sum(ctx, dst, reduce_dims, 4);
|
||||
}
|
||||
|
||||
void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src = dst->src[0];
|
||||
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);
|
||||
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
|
||||
// GGML cumsum operates along dim 0 (innermost / ne[0]).
|
||||
// ggml_cann_create_tensor reverses dimensions to [ne3,ne2,ne1,ne0],
|
||||
// so GGML dim 0 maps to CANN dim 3 (the last dim of the 4-D tensor).
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Cumsum, acl_src.get(), (int64_t)3,
|
||||
ggml_cann_type_mapping(dst->type), acl_dst.get());
|
||||
}
|
||||
|
||||
void ggml_cann_solve_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src0 = dst->src[0]; // A: [N, N, B2, B3] lower triangular
|
||||
ggml_tensor * src1 = dst->src[1]; // B: [K, N, B2, B3]
|
||||
|
||||
acl_tensor_ptr acl_a = ggml_cann_create_tensor(src0);
|
||||
acl_tensor_ptr acl_b = ggml_cann_create_tensor(src1);
|
||||
acl_tensor_ptr acl_x = ggml_cann_create_tensor(dst);
|
||||
|
||||
// mOut: triangular copy of A (required output), same shape as A.
|
||||
const size_t a_bytes = ggml_nbytes(src0);
|
||||
ggml_cann_pool_alloc m_alloc(ctx.pool(), a_bytes);
|
||||
acl_tensor_ptr acl_m = ggml_cann_create_tensor(
|
||||
m_alloc.get(), ggml_cann_type_mapping(src0->type),
|
||||
ggml_type_size(src0->type), src0->ne, src0->nb, GGML_MAX_DIMS);
|
||||
|
||||
// Solve AX = B: upper=false (lower tri), transpose=false, unitriangular=false.
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, TriangularSolve,
|
||||
acl_b.get(), acl_a.get(), false, false, false,
|
||||
acl_x.get(), acl_m.get());
|
||||
}
|
||||
|
||||
void ggml_cann_diag(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src = dst->src[0];
|
||||
|
||||
GGML_ASSERT(src->ne[1] == 1);
|
||||
|
||||
const int64_t N = src->ne[0];
|
||||
const int64_t n_batch = src->ne[2] * src->ne[3];
|
||||
const size_t nb_f32 = sizeof(float);
|
||||
|
||||
// Fill dst with zeros.
|
||||
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
|
||||
{
|
||||
float zero = 0.0f;
|
||||
acl_scalar_ptr acl_zero = ggml_cann_create_scalar(&zero, ACL_FLOAT);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst.get(), acl_zero.get());
|
||||
}
|
||||
|
||||
// Copy src vector onto the diagonal of dst via strided views.
|
||||
// src viewed as [N, n_batch], contiguous strides.
|
||||
int64_t ne_vec[2] = { N, n_batch };
|
||||
size_t nb_src_vec[2] = { nb_f32, N * nb_f32 };
|
||||
// dst diagonal view: stride (N+1)*4 steps along the diagonal.
|
||||
size_t nb_dst_diag[2] = { (N + 1) * nb_f32, N * N * nb_f32 };
|
||||
|
||||
acl_tensor_ptr acl_src_vec = ggml_cann_create_tensor(src->data, ACL_FLOAT, nb_f32, ne_vec, nb_src_vec, 2);
|
||||
acl_tensor_ptr acl_dst_diag = ggml_cann_create_tensor(dst->data, ACL_FLOAT, nb_f32, ne_vec, nb_dst_diag, 2);
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst_diag.get(), acl_src_vec.get());
|
||||
}
|
||||
|
||||
void ggml_cann_fill(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
float c = ggml_get_op_params_f32(dst, 0);
|
||||
|
||||
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
|
||||
acl_scalar_ptr acl_c = ggml_cann_create_scalar(&c, ACL_FLOAT);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst.get(), acl_c.get());
|
||||
}
|
||||
|
||||
void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src = dst->src[0];
|
||||
|
||||
const int64_t S = src->ne[0];
|
||||
const int64_t n_batch = src->ne[2] * src->ne[3];
|
||||
const size_t nb_f32 = sizeof(float);
|
||||
|
||||
int64_t ne3d[3] = { S, S, n_batch };
|
||||
size_t nb3d[3] = { nb_f32, S * nb_f32, S * S * nb_f32 };
|
||||
|
||||
const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src->data, ACL_FLOAT, nb_f32, ne3d, nb3d, 3);
|
||||
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ACL_FLOAT, nb_f32, ne3d, nb3d, 3);
|
||||
|
||||
switch (ttype) {
|
||||
case GGML_TRI_TYPE_LOWER:
|
||||
// Tril(-1): preserve row > col (strict lower), zero upper + diagonal.
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src.get(), (int64_t)-1, acl_dst.get());
|
||||
break;
|
||||
case GGML_TRI_TYPE_UPPER_DIAG:
|
||||
// Triu(0): preserve row <= col (upper + diagonal), zero strict lower.
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Triu, acl_src.get(), (int64_t)0, acl_dst.get());
|
||||
break;
|
||||
case GGML_TRI_TYPE_UPPER:
|
||||
// Triu(1): preserve row < col (strict upper), zero lower + diagonal.
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Triu, acl_src.get(), (int64_t)1, acl_dst.get());
|
||||
break;
|
||||
case GGML_TRI_TYPE_LOWER_DIAG:
|
||||
// Tril(0): preserve row >= col (lower + diagonal), zero strict upper.
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src.get(), (int64_t)0, acl_dst.get());
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("unsupported tri type");
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cann_upsample_nearest2d(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src = dst->src[0];
|
||||
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
|
||||
@@ -1695,152 +1918,90 @@ void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
aclnn_softmax(ctx, softmax_tensor.get(), 3, acl_dst.get());
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs index select operation on a 4D tensor using the CANN backend.
|
||||
*
|
||||
* This function applies the `IndexSelect` operation along a specific dimension
|
||||
* of the source tensor (`src_buffer`) using the indices from the index tensor (`index`).
|
||||
* It iterates over the last two dimensions of the source tensor, creates the corresponding
|
||||
* CANN tensors for the source, index, and output slices, and executes the `IndexSelect`
|
||||
* operation for each slice.
|
||||
*
|
||||
* @param ctx The context for CANN backend operations.
|
||||
* @param src_buffer The source buffer containing the 4D input tensor data.
|
||||
* @param src_ne The dimensions of the source tensor.
|
||||
* @param src_nb The strides (byte offsets) of the source tensor.
|
||||
* @param dst_buffer The destination buffer where the output tensor data will be written.
|
||||
* @param dst_ne The dimensions of the destination tensor.
|
||||
* @param dst_nb The strides (byte offsets) of the destination tensor.
|
||||
* @param index The index tensor specifying the indices to select from the source tensor.
|
||||
* @param type The data type of the source and destination tensors.
|
||||
*/
|
||||
static void aclnn_index_select_4d(ggml_backend_cann_context & ctx,
|
||||
void * src_buffer,
|
||||
int64_t * src_ne,
|
||||
size_t * src_nb,
|
||||
void * dst_buffer,
|
||||
int64_t * dst_ne,
|
||||
size_t * dst_nb,
|
||||
ggml_tensor * index,
|
||||
ggml_type type) {
|
||||
for (int64_t i = 0; i < src_ne[3]; i++) {
|
||||
for (int64_t j = 0; j < src_ne[2]; j++) {
|
||||
// src
|
||||
acl_tensor_ptr acl_src_tensor =
|
||||
ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2],
|
||||
ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2);
|
||||
|
||||
// index
|
||||
acl_tensor_ptr acl_index = ggml_cann_create_tensor(
|
||||
(char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1],
|
||||
ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1);
|
||||
|
||||
// out
|
||||
acl_tensor_ptr acl_out =
|
||||
ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2],
|
||||
ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, acl_src_tensor.get(), 0, acl_index.get(), acl_out.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs inplace index copy operation on a 4D tensor using the CANN backend.
|
||||
*
|
||||
* This function applies the `IndexCopy` operation along a specific dimension of the
|
||||
* destination tensor (`dst_buffer`) by copying elements from the source tensor (`src_buffer`)
|
||||
* to positions specified by the index tensor (`index`).
|
||||
* It iterates over the last two dimensions of the tensors, creates the corresponding
|
||||
* CANN tensors for source, index, and destination slices, and performs the index copy
|
||||
* operation for each slice.
|
||||
*
|
||||
* @param ctx The context for CANN backend operations.
|
||||
* @param src_buffer The source buffer containing the 4D input tensor data to be copied.
|
||||
* @param src_ne The dimensions of the source tensor.
|
||||
* @param src_nb The strides (byte offsets) of the source tensor.
|
||||
* @param dst_buffer The destination buffer where values will be copied to.
|
||||
* @param dst_ne The dimensions of the destination tensor.
|
||||
* @param dst_nb The strides (byte offsets) of the destination tensor.
|
||||
* @param index The index tensor specifying target positions in the destination tensor.
|
||||
* @param type The data type of the source and destination tensors.
|
||||
*/
|
||||
static void aclnn_index_copy_4d(ggml_backend_cann_context & ctx,
|
||||
void * src_buffer,
|
||||
int64_t * src_ne,
|
||||
size_t * src_nb,
|
||||
void * dst_buffer,
|
||||
int64_t * dst_ne,
|
||||
size_t * dst_nb,
|
||||
ggml_tensor * index,
|
||||
ggml_type type) {
|
||||
for (int64_t i = 0; i < src_ne[3]; i++) {
|
||||
for (int64_t j = 0; j < src_ne[2]; j++) {
|
||||
// src
|
||||
acl_tensor_ptr acl_src_tensor =
|
||||
ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2],
|
||||
ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2);
|
||||
|
||||
// index
|
||||
acl_tensor_ptr acl_index = ggml_cann_create_tensor(
|
||||
(char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1],
|
||||
ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1);
|
||||
|
||||
// out
|
||||
acl_tensor_ptr acl_out =
|
||||
ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2],
|
||||
ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_out.get(), 0, acl_index.get(), acl_src_tensor.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src0 = dst->src[0]; // src
|
||||
ggml_tensor * src0 = dst->src[0]; // weight
|
||||
ggml_tensor * src1 = dst->src[1]; // index
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16
|
||||
|| dst->type == GGML_TYPE_BF16);
|
||||
|
||||
// n_idx: number of row indices per (i2, i3) batch slice.
|
||||
// ggml guarantees: src0->ne[2] == src1->ne[1], src0->ne[3] == src1->ne[2], src1->ne[3] == 1.
|
||||
const int64_t n_idx = src1->ne[0];
|
||||
|
||||
// Gather all (i2, i3) batch slices from src into dst.
|
||||
// ggml_cann_create_tensor reverses dims, so ACL sees [ne1, ne0].
|
||||
// GatherV2 with dim=0 gathers along ACL dim-0 == ggml ne[1] (the vocabulary / row axis).
|
||||
// nb: the 4 strides of the source buffer (nb[0..1] for the 2D slice shape,
|
||||
// nb[2..3] for computing per-batch-slice base pointer offsets).
|
||||
auto gather_batched = [&](void * src_base, aclDataType acl_type, size_t type_size,
|
||||
const size_t * nb) {
|
||||
int64_t src_ne[2] = { src0->ne[0], src0->ne[1] };
|
||||
size_t src_nb_2d[2] = { nb[0], nb[1] };
|
||||
int64_t dst_ne[2] = { src0->ne[0], n_idx };
|
||||
size_t dst_nb_2d[2] = { dst->nb[0], dst->nb[1] };
|
||||
int64_t idx_ne[1] = { n_idx };
|
||||
size_t idx_nb[1] = { (size_t)ggml_element_size(src1) };
|
||||
|
||||
for (int64_t i3 = 0; i3 < src0->ne[3]; i3++) {
|
||||
for (int64_t i2 = 0; i2 < src0->ne[2]; i2++) {
|
||||
acl_tensor_ptr acl_src = ggml_cann_create_tensor(
|
||||
(char *)src_base + i3 * nb[3] + i2 * nb[2],
|
||||
acl_type, type_size, src_ne, src_nb_2d, 2);
|
||||
acl_tensor_ptr acl_idx = ggml_cann_create_tensor(
|
||||
(char *)src1->data + i3 * src1->nb[2] + i2 * src1->nb[1],
|
||||
ggml_cann_type_mapping(src1->type), (size_t)ggml_element_size(src1),
|
||||
idx_ne, idx_nb, 1);
|
||||
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(
|
||||
(char *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2],
|
||||
acl_type, type_size, dst_ne, dst_nb_2d, 2);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, GatherV2, acl_src.get(), 0, acl_idx.get(), acl_dst.get());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
if (src0->type == dst->type) {
|
||||
aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1,
|
||||
dst->type);
|
||||
gather_batched(src0->data,
|
||||
ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type),
|
||||
src0->nb);
|
||||
} else {
|
||||
acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
|
||||
ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst));
|
||||
void * src_trans_buffer = src_buffer_allocator.get();
|
||||
size_t src_trans_nb[GGML_MAX_DIMS];
|
||||
src_trans_nb[0] = dst->nb[0];
|
||||
// Cast src0 to dst type, then gather.
|
||||
ggml_cann_pool_alloc src_cast_allocator(ctx.pool(),
|
||||
ggml_nelements(src0) * ggml_element_size(dst));
|
||||
size_t src_cast_nb[GGML_MAX_DIMS];
|
||||
src_cast_nb[0] = ggml_type_size(dst->type);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
|
||||
src_cast_nb[i] = src_cast_nb[i - 1] * src0->ne[i - 1];
|
||||
}
|
||||
acl_tensor_ptr src_trans_tensor =
|
||||
ggml_cann_create_tensor(src_trans_buffer, ggml_cann_type_mapping(dst->type),
|
||||
ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);
|
||||
aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type));
|
||||
aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1,
|
||||
dst->type);
|
||||
acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
|
||||
acl_tensor_ptr acl_src_cast = ggml_cann_create_tensor(
|
||||
src_cast_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
|
||||
src0->ne, src_cast_nb, GGML_MAX_DIMS);
|
||||
aclnn_cast(ctx, acl_src0.get(), acl_src_cast.get(), ggml_cann_type_mapping(dst->type));
|
||||
|
||||
gather_batched(src_cast_allocator.get(),
|
||||
ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
|
||||
src_cast_nb);
|
||||
}
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
{
|
||||
// add 1 dim for bcast mul.
|
||||
// Dequantize Q8_0 to dst type, then gather.
|
||||
size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1], dequant_nb[GGML_MAX_DIMS + 1];
|
||||
int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1], *dequant_ne;
|
||||
int64_t scale_offset = 0;
|
||||
// [3,4,5,64] -> [3,4,5,2,32]
|
||||
weight_ne[0] = QK8_0;
|
||||
weight_ne[1] = src0->ne[0] / QK8_0;
|
||||
weight_nb[0] = sizeof(int8_t);
|
||||
weight_nb[1] = weight_nb[0] * weight_ne[0];
|
||||
weight_ne[0] = QK8_0;
|
||||
weight_ne[1] = src0->ne[0] / QK8_0;
|
||||
weight_nb[0] = sizeof(int8_t);
|
||||
weight_nb[1] = weight_nb[0] * weight_ne[0];
|
||||
for (int i = 2; i < GGML_MAX_DIMS + 1; i++) {
|
||||
weight_ne[i] = src0->ne[i - 1];
|
||||
weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1];
|
||||
}
|
||||
// [3,4,5,64] -> [3,4,5,2,1]
|
||||
scale_ne[0] = 1;
|
||||
scale_ne[1] = src0->ne[0] / QK8_0;
|
||||
scale_nb[0] = sizeof(uint16_t);
|
||||
@@ -1849,31 +2010,33 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
scale_ne[i] = src0->ne[i - 1];
|
||||
scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1];
|
||||
}
|
||||
// [3,4,5,64] -> [3,4,5,2,32]
|
||||
dequant_ne = weight_ne;
|
||||
dequant_nb[0] = ggml_type_size(dst->type);
|
||||
for (int i = 1; i < GGML_MAX_DIMS + 1; i++) {
|
||||
dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1];
|
||||
}
|
||||
scale_offset = ggml_nelements(src0) * sizeof(int8_t);
|
||||
ggml_cann_pool_alloc dequant_buffer_allocator(ctx.pool(),
|
||||
ggml_nelements(src0) * ggml_type_size(dst->type));
|
||||
acl_tensor_ptr acl_weight_tensor = ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t),
|
||||
weight_ne, weight_nb, GGML_MAX_DIMS + 1);
|
||||
acl_tensor_ptr acl_scale_tensor =
|
||||
ggml_cann_create_tensor(src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,
|
||||
GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);
|
||||
acl_tensor_ptr dequant_tensor =
|
||||
ggml_cann_create_tensor(dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type),
|
||||
ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);
|
||||
aclnn_mul(ctx, acl_weight_tensor.get(), acl_scale_tensor.get(), dequant_tensor.get());
|
||||
dequant_nb[0] = ggml_type_size(dst->type);
|
||||
const int64_t scale_offset = ggml_nelements(src0) * sizeof(int8_t);
|
||||
ggml_cann_pool_alloc dequant_allocator(ctx.pool(),
|
||||
ggml_nelements(src0) * ggml_type_size(dst->type));
|
||||
acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t),
|
||||
weight_ne, weight_nb, GGML_MAX_DIMS + 1);
|
||||
acl_tensor_ptr acl_scale = ggml_cann_create_tensor(
|
||||
src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,
|
||||
GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);
|
||||
acl_tensor_ptr acl_dequant = ggml_cann_create_tensor(
|
||||
dequant_allocator.get(), ggml_cann_type_mapping(dst->type),
|
||||
ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);
|
||||
aclnn_mul(ctx, acl_weight.get(), acl_scale.get(), acl_dequant.get());
|
||||
|
||||
// Reinterpret dequant buffer as 4D [src0->ne] with contiguous strides.
|
||||
dequant_ne = src0->ne;
|
||||
dequant_nb[0] = ggml_type_size(dst->type);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1];
|
||||
}
|
||||
aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(), dequant_ne, dequant_nb, dst->data, dst->ne,
|
||||
dst->nb, src1, dst->type);
|
||||
gather_batched(dequant_allocator.get(),
|
||||
ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
|
||||
dequant_nb);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@@ -1883,31 +2046,70 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
}
|
||||
|
||||
void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src0 = dst->src[0]; // src
|
||||
ggml_tensor * src1 = dst->src[1]; // index
|
||||
ggml_tensor * src0 = dst->src[0]; // source values
|
||||
ggml_tensor * src1 = dst->src[1]; // row indices
|
||||
|
||||
// n_idx: number of source rows to scatter per batch slice.
|
||||
// ggml guarantees: src0->ne[1] == src1->ne[0].
|
||||
const int64_t n_idx = src1->ne[0];
|
||||
|
||||
// Copy n_idx rows of src [ne0, n_idx] into dst [ne0, ne1] at positions given by a 1D index.
|
||||
// ggml_cann_create_tensor reverses dims, so ACL sees [ne1, ne0] for dst.
|
||||
// InplaceIndexCopy with dim=0 copies along ACL dim-0 == ggml ne[1] (the row axis).
|
||||
// src_nb: the 4 strides of the source buffer (nb[0..1] for the 2D slice shape,
|
||||
// nb[2..3] for computing per-batch-slice base pointer offsets).
|
||||
auto scatter_batched = [&](void * src_base, aclDataType acl_type, size_t type_size,
|
||||
const size_t * src_nb) {
|
||||
int64_t d_ne[2] = { dst->ne[0], dst->ne[1] };
|
||||
size_t d_nb[2] = { dst->nb[0], dst->nb[1] };
|
||||
int64_t s_ne[2] = { dst->ne[0], n_idx };
|
||||
size_t s_nb_2d[2] = { src_nb[0], src_nb[1] };
|
||||
int64_t i_ne[1] = { n_idx };
|
||||
size_t i_nb[1] = { (size_t)ggml_element_size(src1) };
|
||||
|
||||
for (int64_t i3 = 0; i3 < dst->ne[3]; i3++) {
|
||||
for (int64_t i2 = 0; i2 < dst->ne[2]; i2++) {
|
||||
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(
|
||||
(char *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2],
|
||||
acl_type, type_size, d_ne, d_nb, 2);
|
||||
acl_tensor_ptr acl_idx = ggml_cann_create_tensor(
|
||||
(char *)src1->data + (i3 % src1->ne[2]) * src1->nb[2] + (i2 % src1->ne[1]) * src1->nb[1],
|
||||
ggml_cann_type_mapping(src1->type), (size_t)ggml_element_size(src1),
|
||||
i_ne, i_nb, 1);
|
||||
acl_tensor_ptr acl_src = ggml_cann_create_tensor(
|
||||
(char *)src_base + i3 * src_nb[3] + i2 * src_nb[2],
|
||||
acl_type, type_size, s_ne, s_nb_2d, 2);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_dst.get(), 0, acl_idx.get(), acl_src.get());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
switch (dst->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
aclnn_index_copy_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1, dst->type);
|
||||
break;
|
||||
}
|
||||
scatter_batched(src0->data,
|
||||
ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
|
||||
src0->nb);
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
|
||||
ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t));
|
||||
void * src_trans_buffer = src_buffer_allocator.get();
|
||||
size_t src_trans_nb[GGML_MAX_DIMS];
|
||||
src_trans_nb[0] = sizeof(uint16_t);
|
||||
// Cast src0 (F32) to dst type first.
|
||||
ggml_cann_pool_alloc src_cast_allocator(ctx.pool(),
|
||||
ggml_nelements(src0) * ggml_type_size(dst->type));
|
||||
size_t src_cast_nb[GGML_MAX_DIMS];
|
||||
src_cast_nb[0] = ggml_type_size(dst->type);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
|
||||
src_cast_nb[i] = src_cast_nb[i - 1] * src0->ne[i - 1];
|
||||
}
|
||||
acl_tensor_ptr src_trans_tensor = ggml_cann_create_tensor(
|
||||
src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);
|
||||
aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type));
|
||||
aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1,
|
||||
dst->type);
|
||||
acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
|
||||
acl_tensor_ptr acl_src_cast = ggml_cann_create_tensor(
|
||||
src_cast_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
|
||||
src0->ne, src_cast_nb, GGML_MAX_DIMS);
|
||||
aclnn_cast(ctx, acl_src0.get(), acl_src_cast.get(), ggml_cann_type_mapping(dst->type));
|
||||
|
||||
scatter_batched(src_cast_allocator.get(),
|
||||
ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
|
||||
src_cast_nb);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@@ -3268,29 +3470,50 @@ void ggml_cann_pad_reflect_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst
|
||||
int64_t paddingsArray[2] = { opts[0], opts[1] };
|
||||
acl_int_array_ptr paddings = ggml_cann_create_int_array(paddingsArray, 2);
|
||||
|
||||
for (int64_t i = 0; i < src0->ne[3]; i++) {
|
||||
acl_tensor_ptr acl_src =
|
||||
ggml_cann_create_tensor((char *) src0->data + i * src0->ne[3], ggml_cann_type_mapping(src0->type),
|
||||
ggml_element_size(src0), src0->ne, src0->nb, 3);
|
||||
// Collapsing ne[2]*ne[3] into a single batch dimension requires that dim3
|
||||
// is contiguous with respect to dim2 in both src and dst.
|
||||
GGML_ASSERT(src0->nb[3] == src0->nb[2] * src0->ne[2]);
|
||||
GGML_ASSERT(dst->nb[3] == dst->nb[2] * dst->ne[2]);
|
||||
|
||||
acl_tensor_ptr acl_dst =
|
||||
ggml_cann_create_tensor((char *) dst->data + i * src0->ne[3], ggml_cann_type_mapping(dst->type),
|
||||
ggml_element_size(dst), dst->ne, dst->nb, 3);
|
||||
int64_t src_ne_3d[3] = { src0->ne[0], src0->ne[1], src0->ne[2] * src0->ne[3] };
|
||||
int64_t dst_ne_3d[3] = { dst->ne[0], dst->ne[1], dst->ne[2] * dst->ne[3] };
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src.get(), paddings.get(), acl_dst.get());
|
||||
}
|
||||
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0->data, ggml_cann_type_mapping(src0->type),
|
||||
ggml_element_size(src0), src_ne_3d, src0->nb, 3);
|
||||
|
||||
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type),
|
||||
ggml_element_size(dst), dst_ne_3d, dst->nb, 3);
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src.get(), paddings.get(), acl_dst.get());
|
||||
}
|
||||
|
||||
void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
// Write element-wise equality (0 or 1) into a temporary buffer to avoid
|
||||
// modifying src0 in-place. Use the same type as src0 so ReduceSum can
|
||||
// consume it directly without a type cast.
|
||||
ggml_cann_pool_alloc eq_alloc(ctx.pool(), ggml_nelements(src0) * ggml_element_size(src0));
|
||||
size_t eq_nb[GGML_MAX_DIMS];
|
||||
eq_nb[0] = ggml_element_size(src0);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
eq_nb[i] = eq_nb[i - 1] * src0->ne[i - 1];
|
||||
}
|
||||
acl_tensor_ptr acl_eq = ggml_cann_create_tensor(
|
||||
eq_alloc.get(), ggml_cann_type_mapping(src0->type), ggml_element_size(src0),
|
||||
src0->ne, eq_nb, GGML_MAX_DIMS);
|
||||
|
||||
acl_tensor_ptr acl_self = ggml_cann_create_tensor(src0);
|
||||
acl_tensor_ptr acl_other = ggml_cann_create_tensor(src1);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, EqTensor, acl_self.get(), acl_other.get(), acl_eq.get());
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceEqTensor, acl_self.get(), acl_other.get());
|
||||
|
||||
ggml_cann_sum(ctx, dst);
|
||||
// Sum the 0/1 values into dst.
|
||||
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
|
||||
int64_t dims[4] = { 0, 1, 2, 3 };
|
||||
acl_int_array_ptr dims_arr = ggml_cann_create_int_array(dims, 4);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_eq.get(), dims_arr.get(), true,
|
||||
ggml_cann_type_mapping(dst->type), acl_dst.get());
|
||||
}
|
||||
|
||||
void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
@@ -3306,6 +3529,27 @@ void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, GtScalar, acl_src.get(), alpha.get(), acl_dst.get());
|
||||
}
|
||||
|
||||
void ggml_cann_softplus(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0);
|
||||
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
|
||||
|
||||
float beta_val = 1.0f;
|
||||
float threshold_val = 20.0f;
|
||||
acl_scalar_ptr beta = ggml_cann_create_scalar(&beta_val, ACL_FLOAT);
|
||||
acl_scalar_ptr threshold = ggml_cann_create_scalar(&threshold_val, ACL_FLOAT);
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Softplus, acl_src.get(), beta.get(), threshold.get(), acl_dst.get());
|
||||
}
|
||||
|
||||
void ggml_cann_geglu_quick(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
auto gelu_quick_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
|
||||
};
|
||||
ggml_cann_op_unary_gated(gelu_quick_fn, ctx, dst);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs expert-specific matrix multiplication (MoE) with
|
||||
* floating-point precision using the CANN backend.
|
||||
@@ -3892,46 +4136,65 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
|
||||
static void ggml_cann_out_prod_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src0 = dst->src[0]; // weight
|
||||
ggml_tensor * src1 = dst->src[1]; // input
|
||||
ggml_tensor * src0 = dst->src[0]; // weight [ne00=m, ne01=K, ne02, ne03]
|
||||
ggml_tensor * src1 = dst->src[1]; // input [ne10=n, ne11=K, ne12, ne13]
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());
|
||||
// dst[i,j] = sum_k src0[i,k] * src1[j,k] i.e. dst = src0 @ src1^T.
|
||||
//
|
||||
// ggml_cann_create_tensor reverses dimension order, so ACL sees:
|
||||
// acl_src0 slice: ggml[m,K] -> ACL[K,m]
|
||||
// acl_src1 slice: ggml[n,K] -> ACL[K,n]
|
||||
// acl_dst slice: ggml[m,n] -> ACL[n,m]
|
||||
//
|
||||
// Build a transposed view of src1 by swapping ne[0]/ne[1]:
|
||||
// src1_t: ggml[K,n] (swapped strides) -> ACL[n,K]
|
||||
//
|
||||
// Matmul(src1_t [n,K], src0 [K,m]) = [n,m] = acl_dst ✓
|
||||
//
|
||||
// The outer batch loop is kept because src0 may have fewer batch slices than
|
||||
// dst (ne02 <= ne2, ne03 <= ne3): this is a strided-broadcast not supported
|
||||
// by standard CANN Matmul broadcasting.
|
||||
|
||||
const aclDataType src0_acl_type = ggml_cann_type_mapping(src0->type);
|
||||
const aclDataType src1_acl_type = ggml_cann_type_mapping(src1->type);
|
||||
const aclDataType dst_acl_type = ggml_cann_type_mapping(dst->type);
|
||||
const size_t src0_type_sz = ggml_type_size(src0->type);
|
||||
const size_t src1_type_sz = ggml_type_size(src1->type);
|
||||
const size_t dst_type_sz = ggml_type_size(dst->type);
|
||||
|
||||
const int64_t dps2 = ne2 / ne02;
|
||||
const int64_t dps3 = ne3 / ne03;
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
||||
const int64_t i02 = i2 / dps2;
|
||||
const int64_t i03 = i3 / dps3;
|
||||
|
||||
const int64_t i12 = i2;
|
||||
const int64_t i13 = i3;
|
||||
acl_tensor_ptr accumulator =
|
||||
ggml_cann_create_tensor((char *) dst->data + i2 * nb2 + i3 * nb3, ggml_cann_type_mapping(dst->type),
|
||||
ggml_type_size(dst->type), dst->ne, dst->nb, 2);
|
||||
// src0 2D slice at [i02, i03]: ggml [m, K] -> ACL [K, m]
|
||||
int64_t src0_ne[2] = { ne00, ne01 };
|
||||
size_t src0_nb[2] = { nb00, nb01 };
|
||||
acl_tensor_ptr acl_src0_s = ggml_cann_create_tensor(
|
||||
(char *) src0->data + i02 * nb02 + i03 * nb03,
|
||||
src0_acl_type, src0_type_sz, src0_ne, src0_nb, 2);
|
||||
|
||||
// The outer product needs to be accumulated in this dimension.
|
||||
for (int64_t i1 = 0; i1 < ne11; i1++) {
|
||||
acl_tensor_ptr acl_input = ggml_cann_create_tensor(
|
||||
(char *) src1->data + i1 * nb11 + i12 * nb12 + i13 * nb13, ggml_cann_type_mapping(src0->type),
|
||||
ggml_type_size(src0->type), src1->ne, src1->nb, 1);
|
||||
// src1 transposed 2D slice at [i2, i3]: swap ne/nb -> ggml[K,n] -> ACL[n,K]
|
||||
int64_t src1_t_ne[2] = { ne11, ne10 };
|
||||
size_t src1_t_nb[2] = { nb11, nb10 };
|
||||
acl_tensor_ptr acl_src1_t = ggml_cann_create_tensor(
|
||||
(char *) src1->data + i2 * nb12 + i3 * nb13,
|
||||
src1_acl_type, src1_type_sz, src1_t_ne, src1_t_nb, 2);
|
||||
|
||||
acl_tensor_ptr acl_weight = ggml_cann_create_tensor(
|
||||
(char *) src0->data + i1 * nb01 + i02 * nb02 + i03 * nb03, ggml_cann_type_mapping(src0->type),
|
||||
ggml_type_size(src0->type), src0->ne, src0->nb, 1);
|
||||
// dst 2D slice at [i2, i3]: ggml [m, n] -> ACL [n, m]
|
||||
int64_t dst_ne[2] = { ne0, ne1 };
|
||||
size_t dst_nb[2] = { nb0, nb1 };
|
||||
acl_tensor_ptr acl_dst_s = ggml_cann_create_tensor(
|
||||
(char *) dst->data + i2 * nb2 + i3 * nb3,
|
||||
dst_acl_type, dst_type_sz, dst_ne, dst_nb, 2);
|
||||
|
||||
ggml_cann_pool_alloc output_allocator(ctx.pool());
|
||||
void * output_buffer = output_allocator.alloc(ggml_nbytes(dst));
|
||||
acl_tensor_ptr acl_out = ggml_cann_create_tensor(output_buffer, ggml_cann_type_mapping(dst->type),
|
||||
ggml_type_size(dst->type), dst->ne, dst->nb, 2);
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_input.get(), acl_weight.get(), acl_out.get());
|
||||
float alpha_value = 1.0f;
|
||||
aclScalar * alpha = aclCreateScalar(&alpha_value, ACL_FLOAT);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, accumulator.get(), acl_out.get(), alpha);
|
||||
}
|
||||
// Matmul(src1_t [n,K], src0 [K,m]) = [n,m] = acl_dst_s ✓
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Matmul,
|
||||
acl_src1_t.get(), acl_src0_s.get(), acl_dst_s.get(), (int8_t) 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4170,3 +4433,4 @@ void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor *
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -32,6 +32,9 @@
|
||||
#include <aclnnop/aclnn_cat.h>
|
||||
#include <aclnnop/aclnn_clamp.h>
|
||||
#include <aclnnop/aclnn_cos.h>
|
||||
#include <aclnnop/aclnn_cumsum.h>
|
||||
#include <aclnnop/aclnn_tril.h>
|
||||
#include <aclnnop/aclnn_triu.h>
|
||||
#include <aclnnop/aclnn_exp.h>
|
||||
#include <aclnnop/aclnn_gelu.h>
|
||||
#include <aclnnop/aclnn_gelu_v2.h>
|
||||
@@ -47,6 +50,9 @@
|
||||
#include <aclnnop/aclnn_sign.h>
|
||||
#include <aclnnop/aclnn_silu.h>
|
||||
#include <aclnnop/aclnn_sin.h>
|
||||
#include <aclnnop/aclnn_softplus.h>
|
||||
#include <aclnnop/aclnn_swi_glu.h>
|
||||
#include <aclnnop/aclnn_geglu.h>
|
||||
#include <aclnnop/aclnn_slice.h>
|
||||
#include <aclnnop/aclnn_sqrt.h>
|
||||
#include <aclnnop/aclnn_tanh.h>
|
||||
@@ -69,6 +75,9 @@
|
||||
*/
|
||||
void ggml_cann_repeat(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cann_swiglu(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cann_geglu(ggml_backend_cann_context & ctx, ggml_tensor * dst, int64_t approximate);
|
||||
|
||||
/**
|
||||
* @brief Applies the Leaky ReLU activation function to a tensor using the CANN
|
||||
* backend.
|
||||
@@ -325,6 +334,48 @@ void ggml_cann_sum_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the cumulative sum of a ggml tensor along dim 0 using the
|
||||
* CANN backend.
|
||||
*
|
||||
* @param ctx The CANN context used for operations.
|
||||
* @param dst The destination tensor. dst->op is `GGML_OP_CUMSUM`.
|
||||
*/
|
||||
void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes a triangular mask (tril/triu) of a square ggml tensor
|
||||
* using the CANN backend.
|
||||
*
|
||||
* @param ctx The CANN context used for operations.
|
||||
* @param dst The destination tensor. dst->op is `GGML_OP_TRI`.
|
||||
*/
|
||||
void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Solves a triangular linear system AX=B using the CANN backend.
|
||||
*
|
||||
* @param ctx The CANN context used for operations.
|
||||
* @param dst The destination tensor. dst->op is `GGML_OP_SOLVE_TRI`.
|
||||
*/
|
||||
void ggml_cann_solve_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Creates a diagonal matrix from a vector using the CANN backend.
|
||||
*
|
||||
* @param ctx The CANN context used for operations.
|
||||
* @param dst The destination tensor. dst->op is `GGML_OP_DIAG`.
|
||||
*/
|
||||
void ggml_cann_diag(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Fills a tensor with a constant scalar value using the CANN backend.
|
||||
*
|
||||
* @param ctx The CANN context used for operations.
|
||||
* @param dst The destination tensor. dst->op is `GGML_OP_FILL`.
|
||||
*/
|
||||
void ggml_cann_fill(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Upsamples a ggml tensor using nearest neighbor interpolation using
|
||||
* the CANN backend.
|
||||
@@ -461,6 +512,9 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context & ctx, ggml_tensor *
|
||||
// @see ggml_cann_dup.
|
||||
void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
// @see ggml_cann_acc, but copies src1 into dst instead of adding.
|
||||
void ggml_cann_set(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the softmax activation with optional masking.
|
||||
*
|
||||
@@ -813,6 +867,8 @@ void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
* dst->op is expected to be `GGML_OP_STEP`.
|
||||
*/
|
||||
void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cann_softplus(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cann_geglu_quick(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Performs the Flash Attention extended operator using the CANN backend.
|
||||
|
||||
@@ -1428,6 +1428,22 @@ static bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set a region of a tensor's device memory to a specified value.
|
||||
*
|
||||
* @param buffer The CANN buffer containing the tensor.
|
||||
* @param tensor Pointer to the tensor whose memory will be set.
|
||||
* @param value The value to which each byte in the region will be set.
|
||||
* @param offset Byte offset within the tensor's data to start setting.
|
||||
* @param size Number of bytes to set.
|
||||
*/
|
||||
static void ggml_backend_cann_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
||||
ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
|
||||
|
||||
ggml_cann_set_device(ctx->device);
|
||||
ACL_CHECK(aclrtMemset((char *) tensor->data + offset, size, value, size));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Clear a CANN buffer by setting all its memory to a specified value.
|
||||
*
|
||||
@@ -1454,7 +1470,7 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
|
||||
/* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
|
||||
/* .get_base = */ ggml_backend_cann_buffer_get_base,
|
||||
/* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
|
||||
/* .memset_tensor = */ NULL,
|
||||
/* .memset_tensor = */ ggml_backend_cann_buffer_memset_tensor,
|
||||
/* .set_tensor = */ ggml_backend_cann_buffer_set_tensor,
|
||||
/* .get_tensor = */ ggml_backend_cann_buffer_get_tensor,
|
||||
/* .set_tensor_2d = */ NULL,
|
||||
@@ -1835,6 +1851,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
|
||||
case GGML_UNARY_OP_STEP:
|
||||
ggml_cann_step(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
ggml_cann_softplus(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -1845,20 +1864,16 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
|
||||
GGML_CANN_CALL_OP_UNARY_GATED(Relu);
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
ggml_cann_geglu(ctx, dst, 0); // approximate=0 → tanh
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
// aclnnGelu internally uses the erf-based approximation.
|
||||
GGML_CANN_CALL_OP_UNARY_GATED(Gelu);
|
||||
ggml_cann_geglu(ctx, dst, 1); // approximate=1 → erf
|
||||
break;
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
GGML_CANN_CALL_OP_UNARY_GATED(Silu);
|
||||
ggml_cann_swiglu(ctx, dst);
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
{
|
||||
auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
|
||||
};
|
||||
ggml_cann_op_unary_gated(lambda, ctx, dst);
|
||||
}
|
||||
ggml_cann_geglu_quick(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
@@ -1920,6 +1935,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
|
||||
case GGML_OP_CPY:
|
||||
ggml_cann_cpy(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SET:
|
||||
ggml_cann_set(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONT:
|
||||
ggml_cann_dup(ctx, dst);
|
||||
break;
|
||||
@@ -1989,6 +2007,21 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
|
||||
case GGML_OP_SSM_CONV:
|
||||
ggml_cann_ssm_conv(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CUMSUM:
|
||||
ggml_cann_cumsum(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_TRI:
|
||||
ggml_cann_tri(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_FILL:
|
||||
ggml_cann_fill(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_DIAG:
|
||||
ggml_cann_diag(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
ggml_cann_solve_tri(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -2324,6 +2357,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
|
||||
if (use_cann_graph) {
|
||||
// If no matching graph is found, the graph needs to be recaptured.
|
||||
graph_capture_required = !cann_ctx->graph_lru_cache.find_and_move_to_front(cgraph);
|
||||
|
||||
if (graph_capture_required) {
|
||||
// If no matching graph is found, add a new ACL graph.
|
||||
ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph);
|
||||
@@ -2382,6 +2416,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
case GGML_UNARY_OP_SGN:
|
||||
case GGML_UNARY_OP_STEP:
|
||||
case GGML_UNARY_OP_GELU_ERF:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@@ -2572,6 +2607,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_SET:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return true;
|
||||
case GGML_OP_PAD:
|
||||
@@ -2649,6 +2685,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
}
|
||||
case GGML_OP_SSM_CONV:
|
||||
return true;
|
||||
case GGML_OP_CUMSUM:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_TRI:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_FILL:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_DIAG:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -2005,12 +2005,12 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v
|
||||
const int lda = KB * sizeof(TA);
|
||||
//const int ldb = KB * sizeof(TB);
|
||||
|
||||
static thread_local packed_B_t Tile0[TILE_N * TILE_K];
|
||||
static thread_local packed_B_t Tile1[TILE_N * TILE_K];
|
||||
static thread_local int8_t Tile23[TILE_M * TILE_K];
|
||||
alignas(64) static thread_local packed_B_t Tile0[TILE_N * TILE_K];
|
||||
alignas(64) static thread_local packed_B_t Tile1[TILE_N * TILE_K];
|
||||
alignas(64) static thread_local int8_t Tile23[TILE_M * TILE_K];
|
||||
|
||||
static thread_local int32_t TileC0[TILE_M * TILE_N * 4];
|
||||
static thread_local int32_t TileC1[TILE_M * TILE_N * 4];
|
||||
alignas(64) static thread_local int32_t TileC0[TILE_M * TILE_N * 4];
|
||||
alignas(64) static thread_local int32_t TileC1[TILE_M * TILE_N * 4];
|
||||
|
||||
// double buffering C to interleave avx512 and amx
|
||||
int32_t * C_cur = TileC0;
|
||||
@@ -2187,21 +2187,21 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v
|
||||
const int m1 = std::max(M - TILE_M, 0);
|
||||
//const int lda = KB * sizeof(TA);
|
||||
|
||||
static thread_local int8_t Tile0[TILE_N * TILE_K];
|
||||
static thread_local int8_t Tile1[TILE_N * TILE_K];
|
||||
static thread_local int8_t Tile23[TILE_M * TILE_K];
|
||||
alignas(64) static thread_local int8_t Tile0[TILE_N * TILE_K];
|
||||
alignas(64) static thread_local int8_t Tile1[TILE_N * TILE_K];
|
||||
alignas(64) static thread_local int8_t Tile23[TILE_M * TILE_K];
|
||||
|
||||
// mat mul result for each group
|
||||
static thread_local int32_t Tile4[TILE_M * TILE_N];
|
||||
static thread_local int32_t Tile5[TILE_M * TILE_N];
|
||||
static thread_local int32_t Tile6[TILE_M * TILE_N];
|
||||
static thread_local int32_t Tile7[TILE_M * TILE_N];
|
||||
alignas(64) static thread_local int32_t Tile4[TILE_M * TILE_N];
|
||||
alignas(64) static thread_local int32_t Tile5[TILE_M * TILE_N];
|
||||
alignas(64) static thread_local int32_t Tile6[TILE_M * TILE_N];
|
||||
alignas(64) static thread_local int32_t Tile7[TILE_M * TILE_N];
|
||||
|
||||
// sum of each QK_K block, contains 8 groups, int32
|
||||
static thread_local int32_t Sumi4[TILE_M * TILE_N];
|
||||
static thread_local int32_t Sumi5[TILE_M * TILE_N];
|
||||
static thread_local int32_t Sumi6[TILE_M * TILE_N];
|
||||
static thread_local int32_t Sumi7[TILE_M * TILE_N];
|
||||
alignas(64) static thread_local int32_t Sumi4[TILE_M * TILE_N];
|
||||
alignas(64) static thread_local int32_t Sumi5[TILE_M * TILE_N];
|
||||
alignas(64) static thread_local int32_t Sumi6[TILE_M * TILE_N];
|
||||
alignas(64) static thread_local int32_t Sumi7[TILE_M * TILE_N];
|
||||
|
||||
const int k_group_size = std::is_same<TB, block_q6_K>::value ? 16 : 32;
|
||||
for (int i = 0; i < KB; ++i) {
|
||||
|
||||
@@ -5023,6 +5023,71 @@ void ggml_gemm_q8_0_4x8_q8_0(int n,
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
if (svcntb() * 8 == 256) {
|
||||
const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
|
||||
|
||||
static const uint32_t idx_arr[8] = {0, 1, 4, 5, 2, 3, 6, 7};
|
||||
svuint32_t idx = svld1(svptrue_b32(), idx_arr);
|
||||
static const uint32_t idx_arr1[8] = {0, 1, 2, 3, 1, 2, 3, 0};
|
||||
svuint32_t idx_sc1 = svld1(svptrue_b32(), idx_arr1);
|
||||
static const uint32_t idx_arr2[8] = {0, 1, 2, 3, 0, 1, 2, 3};
|
||||
svuint32_t idx_sc2 = svld1(svptrue_b32(), idx_arr2);
|
||||
|
||||
for (int y = 0; y < nr; y += 4) {
|
||||
const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb;
|
||||
|
||||
for (int x = 0; x < nc; x += ncols_interleaved) {
|
||||
const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb;
|
||||
const block_q8_0x4 * a_ptr = a_ptr_base;
|
||||
|
||||
svfloat32_t acc_f32_01 = svdup_f32(0);
|
||||
svfloat32_t acc_f32_23 = svdup_f32(0);
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
|
||||
svint32_t acc_01 = svdup_s32(0);
|
||||
svint32_t acc_23 = svdup_s32(0);
|
||||
|
||||
// Process 4 chunks of 8 positions each
|
||||
for (int chunk = 0; chunk < 4; chunk++) {
|
||||
svint8_t s_a01 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32);
|
||||
svint8_t s_a23 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32 + 16);
|
||||
svint8_t s_b0123 = svld1_s8(svptrue_b8(), b_ptr->qs + chunk * 32);
|
||||
|
||||
acc_01 = svmmla_s32(acc_01, s_a01, s_b0123);
|
||||
acc_23 = svmmla_s32(acc_23, s_a23, s_b0123);
|
||||
}
|
||||
|
||||
// Reorder outputs from 2×2 tiles to row-major
|
||||
// acc[01] = [r0c0, r0c1, r1c0, r1c1, r0c2, r0c3, r1c2, r1c3]
|
||||
// acc[23] = [r2c0, r2c1, r3c0, r3c1, r2c2, r2c3, r3c2, r3c3]
|
||||
|
||||
svint32_t row01 = svtbl_s32(acc_01, idx);
|
||||
svint32_t row23 = svtbl_s32(acc_23, idx);
|
||||
|
||||
svfloat16_t temp1 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) a_ptr->d);
|
||||
svfloat16_t temp2 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) b_ptr->d);
|
||||
svfloat32_t sv_a_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp1, temp1)), idx_sc1);
|
||||
svfloat32_t sv_b_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp2, temp2)), idx_sc2);
|
||||
|
||||
acc_f32_01 = svmla_f32_x(svptrue_b32(), acc_f32_01, svcvt_f32_s32_x(svptrue_b32(), row01), svmul_lane_f32(sv_b_d, sv_a_d, 0));
|
||||
acc_f32_23 = svmla_f32_x(svptrue_b32(), acc_f32_23, svcvt_f32_s32_x(svptrue_b32(), row23), svmul_lane_f32(sv_b_d, sv_a_d, 2));
|
||||
a_ptr++;
|
||||
b_ptr++;
|
||||
}
|
||||
|
||||
svbool_t pg4 = svptrue_pat_b32(SV_VL4);
|
||||
svst1_f32(pg4, s + (y+0) * bs + x, acc_f32_01);
|
||||
svst1_f32(pg4, s + (y+1) * bs + x, svext_f32(acc_f32_01, acc_f32_01, 4));
|
||||
svst1_f32(pg4, s + (y+2) * bs + x, acc_f32_23);
|
||||
svst1_f32(pg4, s + (y+3) * bs + x, svext_f32(acc_f32_23, acc_f32_23, 4));
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif // SVE compile-time end
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
|
||||
|
||||
|
||||
@@ -830,6 +830,18 @@ static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) {
|
||||
#endif // defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ uint8_t ggml_cuda_fp32_to_ue4m3(float x) {
|
||||
#if defined(BLACKWELL_MMA_AVAILABLE) // This is used for NVFP4 subblock scale quantizations only
|
||||
if (!(x > 0.0f)) {
|
||||
return 0;
|
||||
}
|
||||
const __nv_fp8_e4m3 xf(x);
|
||||
return xf.__x;
|
||||
#else
|
||||
NO_DEVICE_CODE; // Used only for NVFP4 Scales for Activations, only for Blackwell
|
||||
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
|
||||
const uint8_t sign_bit = (x < 0.0f) << 3;
|
||||
float ax = fabsf(x) * e;
|
||||
|
||||
@@ -66,6 +66,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
@@ -85,6 +88,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
@@ -118,6 +124,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 64, 160, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 64, 160, 128, 64, 2, false);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
|
||||
@@ -1217,7 +1226,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
float KQ_max_scale[cols_per_thread];
|
||||
#pragma unroll
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col);
|
||||
const int jc = (threadIdx.y/np)*cols_per_warp + (cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col));
|
||||
const float sink = sinks_f[jc % ncols2];
|
||||
|
||||
const float KQ_max_new = fmaxf(KQ_max[col], sink);
|
||||
@@ -1825,6 +1834,10 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
||||
|
||||
// Mistral Small 4 (DKQ=320, DV=256), GQA=32-only build:
|
||||
extern DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32);
|
||||
extern DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32);
|
||||
|
||||
// For GLM 4.7 Flash
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
|
||||
@@ -38,6 +38,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
|
||||
} break;
|
||||
case 320: {
|
||||
GGML_ASSERT(V->ne[0] == 256);
|
||||
ggml_cuda_flash_attn_ext_tile_case<320, 256>(ctx, dst);
|
||||
} break;
|
||||
case 512: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst);
|
||||
|
||||
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
|
||||
@@ -128,6 +130,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64)
|
||||
@@ -195,6 +199,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 512, 1, 128, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
|
||||
@@ -264,6 +270,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 128, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64)
|
||||
@@ -1144,14 +1152,16 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
|
||||
}
|
||||
}
|
||||
|
||||
if (Q->ne[1] > 8/ncols2) {
|
||||
constexpr int cols_per_block = 16;
|
||||
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
|
||||
return;
|
||||
if constexpr (ncols2 <= 16) {
|
||||
if (Q->ne[1] > 8/ncols2) {
|
||||
constexpr int cols_per_block = 16;
|
||||
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (ncols2 <= 8) {
|
||||
@@ -1210,6 +1220,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
if constexpr (DKQ == 320) { // Mistral Small 4
|
||||
if (use_gqa_opt && gqa_ratio % 32 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 32, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
GGML_ABORT("flash-attn tile (320/256): expected GQA ratio multiple of 32");
|
||||
}
|
||||
|
||||
if constexpr (DKQ == 576) {
|
||||
if (use_gqa_opt && gqa_ratio % 16 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
||||
@@ -1221,7 +1239,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (DKQ <= 512) {
|
||||
if constexpr (DKQ <= 512 && DKQ != 320) {
|
||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
@@ -1275,5 +1293,6 @@ extern DECL_FATTN_TILE_CASE( 96, 96);
|
||||
extern DECL_FATTN_TILE_CASE(112, 112);
|
||||
extern DECL_FATTN_TILE_CASE(128, 128);
|
||||
extern DECL_FATTN_TILE_CASE(256, 256);
|
||||
extern DECL_FATTN_TILE_CASE(320, 256);
|
||||
extern DECL_FATTN_TILE_CASE(512, 512);
|
||||
extern DECL_FATTN_TILE_CASE(576, 512);
|
||||
|
||||
@@ -143,6 +143,22 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
GGML_ASSERT(V->ne[0] == 256);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
|
||||
break;
|
||||
case 320:
|
||||
// For Mistral Small 4, go straight to the ncols1 switch (ncols2=32-only build).
|
||||
GGML_ASSERT(V->ne[0] == 256);
|
||||
{
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f;
|
||||
GGML_ASSERT(use_gqa_opt);
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
GGML_ASSERT(gqa_ratio % 32 == 0);
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<320, 256, 32>(ctx, dst);
|
||||
}
|
||||
break;
|
||||
case 512:
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst);
|
||||
@@ -352,6 +368,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
case 320:
|
||||
if (V->ne[0] != 256 || !gqa_opt_applies) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (gqa_ratio % 32 != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
case 512:
|
||||
if (V->ne[0] != K->ne[0]) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
|
||||
+22
-12
@@ -1015,25 +1015,35 @@ namespace ggml_cuda_mma {
|
||||
#endif // AMD_MFMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
|
||||
const tile<16, 8, int> & A,
|
||||
const tile<8, 8, int> & B,
|
||||
uint32_t a_scale,
|
||||
uint32_t b_scale) {
|
||||
template <ggml_type type>
|
||||
static __device__ __forceinline__ void mma_block_scaled_fp4(tile<16, 8, float> & D,
|
||||
const tile<16, 8, int> & A,
|
||||
const tile<8, 8, int> & B,
|
||||
uint32_t a_scale,
|
||||
uint32_t b_scale) {
|
||||
#ifdef BLACKWELL_MMA_AVAILABLE
|
||||
const int * Axi = (const int *) A.x;
|
||||
const int * Bxi = (const int *) B.x;
|
||||
float * Dxi = (float *) D.x;
|
||||
|
||||
asm volatile(
|
||||
"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
|
||||
"%10, {0, 0}, %11, {0, 0};"
|
||||
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
|
||||
if constexpr (type == GGML_TYPE_MXFP4) {
|
||||
asm volatile(
|
||||
"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
|
||||
"%10, {0, 0}, %11, {0, 0};"
|
||||
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
|
||||
} else {
|
||||
asm volatile(
|
||||
"mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 "
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
|
||||
"%10, {0, 0}, %11, {0, 0};"
|
||||
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);
|
||||
#endif // BLACKWELL_MMA_AVAILABLE
|
||||
#endif // BLACKWELL_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
|
||||
+10
-11
@@ -122,7 +122,7 @@ void ggml_cuda_mul_mat_q(
|
||||
|| GGML_CUDA_CC_IS_CDNA(cc);
|
||||
|
||||
// TODO: tighter pool buffer size vs q8 path
|
||||
const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
|
||||
const bool use_native_fp4 = blackwell_mma_available(cc) && (src0->type == GGML_TYPE_MXFP4 || src0->type == GGML_TYPE_NVFP4);
|
||||
|
||||
if (!ids) {
|
||||
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
|
||||
@@ -133,9 +133,9 @@ void ggml_cuda_mul_mat_q(
|
||||
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s13 = src1->nb[3] / ts_src1;
|
||||
if (use_native_mxfp4) {
|
||||
if (use_native_fp4) {
|
||||
static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
|
||||
quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
|
||||
quantize_mmq_fp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
|
||||
ne11, ne12, ne13, stream);
|
||||
|
||||
} else {
|
||||
@@ -146,10 +146,8 @@ void ggml_cuda_mul_mat_q(
|
||||
}
|
||||
|
||||
// Stride depends on quantization format
|
||||
const int64_t s12 = use_native_mxfp4 ?
|
||||
ne11 * ne10_padded * sizeof(block_fp4_mmq) /
|
||||
(8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32)
|
||||
:
|
||||
const int64_t s12 = use_native_fp4 ?
|
||||
ne11 * ne10_padded * sizeof(block_fp4_mmq) / (QK_K * sizeof(int)) : // block_fp4_mmq holds 256 values
|
||||
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
|
||||
const int64_t s13 = ne12*s12;
|
||||
|
||||
@@ -198,8 +196,8 @@ void ggml_cuda_mul_mat_q(
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s13 = src1->nb[3] / ts_src1;
|
||||
|
||||
if (use_native_mxfp4) {
|
||||
quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
|
||||
if (use_native_fp4) {
|
||||
quantize_mmq_fp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
|
||||
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
||||
} else {
|
||||
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
|
||||
@@ -208,8 +206,9 @@ void ggml_cuda_mul_mat_q(
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :
|
||||
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
|
||||
static_assert(QK_K == 8 * QK_MXFP4, "QK_K needs to be 8 * QK_MXFP4");
|
||||
const int64_t s12 = use_native_fp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (QK_K * sizeof(int)) :
|
||||
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
|
||||
const int64_t s13 = ne12*s12;
|
||||
|
||||
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
|
||||
|
||||
+147
-83
@@ -10,9 +10,9 @@
|
||||
using namespace ggml_cuda_mma;
|
||||
|
||||
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
|
||||
#define MMQ_ITER_K 256
|
||||
#define MMQ_ITER_K_MXFP4_FP4 512
|
||||
#define MMQ_NWARPS 8
|
||||
#define MMQ_ITER_K 256
|
||||
#define MMQ_ITER_K_FP4 512
|
||||
#define MMQ_NWARPS 8
|
||||
|
||||
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
|
||||
typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00);
|
||||
@@ -46,9 +46,12 @@ struct block_q8_1_mmq {
|
||||
int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
|
||||
};
|
||||
|
||||
// this struct is used for fp4 data types (currently only used for Blackwell)
|
||||
// mxfp4 has block size 32, each int32 of d4 contains 2 e8m0 scales in the lower 16 bits
|
||||
// nvfp4 has block size 16, each int32 of d4 contains 4 ue4m3 scales
|
||||
struct block_fp4_mmq {
|
||||
uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
|
||||
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
|
||||
uint32_t d4[4];
|
||||
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte)
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
|
||||
@@ -143,10 +146,11 @@ static int get_mmq_y_host(const int cc) {
|
||||
|
||||
static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
|
||||
#if defined(BLACKWELL_MMA_AVAILABLE)
|
||||
return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
|
||||
#else
|
||||
return MMQ_ITER_K;
|
||||
if (type == GGML_TYPE_NVFP4 || type == GGML_TYPE_MXFP4) {
|
||||
return MMQ_ITER_K_FP4;
|
||||
}
|
||||
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
||||
return MMQ_ITER_K;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_mmq_y_device() {
|
||||
@@ -213,8 +217,8 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
||||
}
|
||||
|
||||
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4
|
||||
#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4
|
||||
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4 and NVFP4 Blackwell
|
||||
#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4 Generic
|
||||
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
|
||||
@@ -240,7 +244,11 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
||||
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
// tile sizes are the same for Q8_1 and FP4 for blackwell
|
||||
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||
#if defined(BLACKWELL_MMA_AVAILABLE)
|
||||
case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_FP4;
|
||||
#else
|
||||
case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4;
|
||||
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
||||
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
|
||||
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
|
||||
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||
@@ -934,6 +942,128 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef BLACKWELL_MMA_AVAILABLE
|
||||
template <int mmq_y, bool need_check>
|
||||
static __device__ __forceinline__ void load_tiles_nvfp4_nvfp4(const char * __restrict__ x,
|
||||
int * __restrict__ x_tile,
|
||||
const int kbx0,
|
||||
const int i_max,
|
||||
const int stride) {
|
||||
constexpr int nwarps = mmq_get_nwarps_device();
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int iter_k = get_iter_k(GGML_TYPE_NVFP4);
|
||||
constexpr int threads_per_row = iter_k / QK_NVFP4; // each thread processes 1 block
|
||||
constexpr int rows_per_warp = warp_size / threads_per_row;
|
||||
|
||||
uint32_t * x_u32 = (uint32_t *) x_tile;
|
||||
|
||||
const int txi = threadIdx.x;
|
||||
const int kbx = txi % threads_per_row;
|
||||
const int row_in_warp = txi / threads_per_row;
|
||||
|
||||
const block_nvfp4 * bxi_base = (const block_nvfp4 *) x + kbx0 + kbx;
|
||||
uint32_t * x_u32_scale = x_u32 + 64 + kbx;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
|
||||
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
|
||||
|
||||
if constexpr (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
|
||||
const block_nvfp4 * bxi = bxi_base + i * stride;
|
||||
const int row_base = i * MMQ_MMA_TILE_X_K_FP4;
|
||||
const int q_base = row_base + 8 * kbx;
|
||||
|
||||
const uint32_t * src_qs = reinterpret_cast<const uint32_t *>(bxi->qs);
|
||||
|
||||
#pragma unroll
|
||||
for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) {
|
||||
x_u32[q_base + 2 * sub + 0] = src_qs[2 * sub + 0];
|
||||
x_u32[q_base + 2 * sub + 1] = src_qs[2 * sub + 1];
|
||||
}
|
||||
|
||||
x_u32_scale[row_base] = get_int_b4(bxi->d, 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Shared MMA kernel for MXFP4 and NVFP4 on Blackwell.
|
||||
// Both quantizations encode values as e2m1 (FP4) and produce one uint32 scale per
|
||||
// m16n8k64 MMA call; only the PTX kind (scale_vec::2X ue8m0 vs scale_vec::4X ue4m3)
|
||||
// and the per-type stride constant differ.
|
||||
template <int mmq_x, int mmq_y, ggml_type type>
|
||||
static __device__ __forceinline__ void vec_dot_fp4_fp4_mma(const int * __restrict__ x,
|
||||
const int * __restrict__ y,
|
||||
float * __restrict__ sum,
|
||||
const int k00) {
|
||||
static_assert(type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4,
|
||||
"vec_dot_fp4_fp4_mma: type must be MXFP4 or NVFP4");
|
||||
|
||||
typedef tile<16, 8, int> tile_A;
|
||||
typedef tile<8, 8, int> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
|
||||
constexpr int stride = MMQ_MMA_TILE_X_K_FP4;
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = 2 * granularity;
|
||||
constexpr int ntx = rows_per_warp / tile_C::I;
|
||||
constexpr int nfrags = MMQ_TILE_NE_K / tile_A::J;
|
||||
|
||||
y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_K);
|
||||
|
||||
const int * x_qs = (const int *) x;
|
||||
const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const uint32_t * y_sc = (const uint32_t *) y;
|
||||
|
||||
// 2 threads per quad supply the packed scale register to the block_scale MMA,
|
||||
// see https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
|
||||
const int tidx_A = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
|
||||
const int tidx_B = threadIdx.x / 4;
|
||||
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
||||
|
||||
tile_A A[ntx][nfrags];
|
||||
uint32_t scaleA[ntx][nfrags];
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
#pragma unroll
|
||||
for (int frag = 0; frag < nfrags; ++frag) {
|
||||
const int k0 = k00 + frag * tile_A::J;
|
||||
load_ldmatrix(A[n][frag], x_qs + (i0 + n * tile_A::I) * stride + k0, stride);
|
||||
scaleA[n][frag] = x_sc[(i0 + n * tile_A::I + tidx_A) * stride + k0 / tile_A::J];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
|
||||
tile_B B[nfrags];
|
||||
uint32_t scaleB[nfrags];
|
||||
|
||||
#pragma unroll
|
||||
for (int frag = 0; frag < nfrags; ++frag) {
|
||||
const int k0 = frag * tile_B::J;
|
||||
load_generic(B[frag], y_qs + j0 * MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K);
|
||||
scaleB[frag] = y_sc[(j0 + tidx_B) * MMQ_TILE_Y_K + frag];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
#pragma unroll
|
||||
for (int frag = 0; frag < nfrags; ++frag) {
|
||||
tile_C C = {};
|
||||
mma_block_scaled_fp4<type>(C, A[n][frag], B[frag], scaleA[n][frag], scaleB[frag]);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // BLACKWELL_MMA_AVAILABLE
|
||||
|
||||
|
||||
template <int mmq_y, bool need_check>
|
||||
static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x,
|
||||
@@ -1163,77 +1293,6 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y>
|
||||
static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
|
||||
const int * __restrict__ y,
|
||||
float * __restrict__ sum,
|
||||
const int k00) {
|
||||
typedef tile<16, 8, int> tile_A;
|
||||
typedef tile<8, 8, int> tile_B;
|
||||
typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = 2 * granularity;
|
||||
constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
|
||||
|
||||
y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
|
||||
|
||||
// Match layout from load_tiles_mxfp4_fp4
|
||||
const int * x_qs = (const int *) x;
|
||||
const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const uint32_t * y_sc = (const uint32_t *) y;
|
||||
|
||||
// tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
|
||||
tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
|
||||
uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
|
||||
|
||||
// Block scale
|
||||
// Each thread has to point to a 4 byte scale value
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
|
||||
|
||||
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
|
||||
const int k0 = k00 + k01;
|
||||
|
||||
load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
|
||||
MMQ_MMA_TILE_X_K_FP4);
|
||||
|
||||
// based on block-scaling document, 2 threads in each quad need to supply to the scale value
|
||||
const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
|
||||
scaleA[n][k01 / (2 * QI_MXFP4)] =
|
||||
*(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
|
||||
tile_B B;
|
||||
uint32_t scaleB; // 2xN scales
|
||||
|
||||
load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
|
||||
|
||||
scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
tile_C C;
|
||||
|
||||
mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y>
|
||||
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
||||
@@ -3259,7 +3318,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
||||
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
|
||||
#ifdef BLACKWELL_MMA_AVAILABLE
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma<mmq_x, mmq_y, GGML_TYPE_MXFP4>;
|
||||
#else
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
||||
@@ -3270,8 +3329,13 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
||||
template <int mmq_x, int mmq_y, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_NVFP4> {
|
||||
static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ;
|
||||
#ifdef BLACKWELL_MMA_AVAILABLE
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4_nvfp4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma<mmq_x, mmq_y, GGML_TYPE_NVFP4>;
|
||||
#else
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
|
||||
#endif // BLACKWELL_MMA_AVAILABLE
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
|
||||
};
|
||||
|
||||
@@ -3406,7 +3470,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
||||
|
||||
#if defined(BLACKWELL_MMA_AVAILABLE)
|
||||
// FP4 tile stores 8 blocks
|
||||
constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
|
||||
constexpr int ne_block = (type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4) ? QK_K : 4 * QK8_1;
|
||||
#else
|
||||
constexpr int ne_block = 4 * QK8_1;
|
||||
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
||||
|
||||
@@ -115,6 +115,7 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_pascal_older(gg
|
||||
case GGML_TYPE_IQ4_NL: return 6;
|
||||
case GGML_TYPE_IQ4_XS: return 5;
|
||||
case GGML_TYPE_MXFP4: return 4;
|
||||
case GGML_TYPE_NVFP4: return 4;
|
||||
case GGML_TYPE_Q2_K: return 4;
|
||||
case GGML_TYPE_Q3_K: return 4;
|
||||
case GGML_TYPE_Q4_0: return 6;
|
||||
@@ -135,6 +136,7 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_turing_plus(ggm
|
||||
case GGML_TYPE_IQ3_S: return 6;
|
||||
case GGML_TYPE_IQ3_XXS: return 7;
|
||||
case GGML_TYPE_MXFP4: return 7;
|
||||
case GGML_TYPE_NVFP4: return 8;
|
||||
case GGML_TYPE_Q2_K: return 7;
|
||||
case GGML_TYPE_Q3_K: return 5;
|
||||
default: return MMVQ_MAX_BATCH_SIZE;
|
||||
@@ -221,6 +223,7 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type
|
||||
case GGML_TYPE_IQ4_NL: return 7;
|
||||
case GGML_TYPE_IQ4_XS: return 5;
|
||||
case GGML_TYPE_MXFP4: return 5;
|
||||
case GGML_TYPE_NVFP4: return 5;
|
||||
case GGML_TYPE_Q3_K: return 4;
|
||||
case GGML_TYPE_Q4_0: return 7;
|
||||
case GGML_TYPE_Q4_1: return 7;
|
||||
|
||||
+121
-21
@@ -70,6 +70,102 @@ __device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) {
|
||||
return static_cast<uint8_t>(biased);
|
||||
}
|
||||
|
||||
|
||||
static __global__ void quantize_mmq_nvfp4(
|
||||
const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
|
||||
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2) {
|
||||
#if defined(BLACKWELL_MMA_AVAILABLE)
|
||||
|
||||
const int64_t i0_base = ((int64_t) blockDim.x * blockIdx.y + threadIdx.x) * QK_NVFP4_SUB;
|
||||
if (i0_base >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i1 = blockIdx.x;
|
||||
const int64_t i2 = blockIdx.z % ne2;
|
||||
const int64_t i3 = blockIdx.z / ne2;
|
||||
const int64_t i01 = ids ? ids[i1] : i1;
|
||||
const int64_t k_block = i0_base / QK_K;
|
||||
const int64_t blocks_per_col = (ne0 + QK_K - 1) / QK_K;
|
||||
if (k_block >= blocks_per_col) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t ib = blockIdx.z * ((int64_t) blocks_per_col * ne1) + k_block * ne1 + blockIdx.x;
|
||||
block_fp4_mmq * y = (block_fp4_mmq *) vy;
|
||||
block_fp4_mmq * yb = y + ib;
|
||||
|
||||
const int sub = (i0_base % QK_K) / QK_NVFP4_SUB;
|
||||
|
||||
float vals_raw[QK_NVFP4_SUB];
|
||||
float amax_raw = 0.0f;
|
||||
const int64_t base_idx = i3 * s03 + i2 * s02 + i01 * s01;
|
||||
#pragma unroll
|
||||
for (int k = 0; k < QK_NVFP4_SUB; k++) {
|
||||
const int64_t i00 = i0_base + k;
|
||||
if (i00 < ne00) {
|
||||
const float v = x[base_idx + i00];
|
||||
vals_raw[k] = v;
|
||||
amax_raw = fmaxf(amax_raw, fabsf(v));
|
||||
} else {
|
||||
vals_raw[k] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr int test_offsets[5] = { 0, -1, 1, -2, 2};
|
||||
const int first_fp8_code = (int) ggml_cuda_fp32_to_ue4m3(amax_raw / 6.0f);
|
||||
|
||||
float best_err = FLT_MAX;
|
||||
uint8_t fp8_code = 0;
|
||||
float subblock_scale = 0.0f;
|
||||
|
||||
#pragma unroll // Check +/- 2 to find best code to reduce NVFP4 activation loss. Negligible overhead on Blackwell.
|
||||
for (int i = 0; i < 5; i++) {
|
||||
const int test_code = first_fp8_code + test_offsets[i];
|
||||
if (test_code < 0 || test_code > 0x7e) {
|
||||
continue;
|
||||
}
|
||||
const uint8_t code = (uint8_t) test_code;
|
||||
const float test_scale = ggml_cuda_ue4m3_to_fp32(code);
|
||||
const float test_inv_scale = test_scale > 0.0f ? 0.5f / test_scale : 0.0f;
|
||||
float cur_err = 0.0f;
|
||||
#pragma unroll
|
||||
for (int k = 0; k < QK_NVFP4_SUB; ++k) {
|
||||
const float v = vals_raw[k];
|
||||
const uint8_t q = ggml_cuda_float_to_fp4_e2m1(v, test_inv_scale);
|
||||
const float err_diff = fabsf(v) - fabsf(kvalues_mxfp4[q & 0x7]) * test_scale;
|
||||
cur_err = fmaf(err_diff, err_diff, cur_err);
|
||||
}
|
||||
|
||||
if (cur_err < best_err) {
|
||||
best_err = cur_err;
|
||||
fp8_code = test_code;
|
||||
subblock_scale = test_scale;
|
||||
}
|
||||
}
|
||||
|
||||
const float inv_scale = subblock_scale > 0.0f ? 0.5f / subblock_scale : 0.0f;
|
||||
uint32_t q0 = 0;
|
||||
uint32_t q1 = 0;
|
||||
#pragma unroll // this is faster than the previous __nv_fp4x4_e2m1
|
||||
for (int k = 0; k < QK_NVFP4_SUB / 4; ++k) {
|
||||
q0 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 0], inv_scale) << (8 * k);
|
||||
q0 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 8], inv_scale) << (8 * k + 4);
|
||||
q1 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 4], inv_scale) << (8 * k);
|
||||
q1 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 12], inv_scale) << (8 * k + 4);
|
||||
}
|
||||
|
||||
uint32_t * yqs = reinterpret_cast<uint32_t *>(yb->qs);
|
||||
yqs[2 * sub + 0] = q0;
|
||||
yqs[2 * sub + 1] = q1;
|
||||
reinterpret_cast<uint8_t *>(yb->d4)[sub] = fp8_code;
|
||||
#else
|
||||
NO_DEVICE_CODE; // This is for Blackwell NVFP4 activations only.
|
||||
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
||||
|
||||
}
|
||||
|
||||
// quantize values in the format mxfp4 is stored which is interleaved nibbles
|
||||
// i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31
|
||||
static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
|
||||
@@ -316,28 +412,32 @@ void quantize_mmq_q8_1_cuda(
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_mmq_mxfp4_cuda(const float * x,
|
||||
const int32_t * ids,
|
||||
void * vy,
|
||||
[[maybe_unused]] const ggml_type type_src0,
|
||||
const int64_t ne00,
|
||||
const int64_t s01,
|
||||
const int64_t s02,
|
||||
const int64_t s03,
|
||||
const int64_t ne0,
|
||||
const int64_t ne1,
|
||||
const int64_t ne2,
|
||||
const int64_t ne3,
|
||||
cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);
|
||||
void quantize_mmq_fp4_cuda(
|
||||
const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
|
||||
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
|
||||
GGML_ASSERT(type_src0 == GGML_TYPE_MXFP4 || type_src0 == GGML_TYPE_NVFP4);
|
||||
GGML_ASSERT(ne0 > 0);
|
||||
|
||||
constexpr int nwarps = 8;
|
||||
constexpr int vals_per_warp = 2 * QK_MXFP4;
|
||||
constexpr int vals_per_block = nwarps * vals_per_warp;
|
||||
if (type_src0 == GGML_TYPE_NVFP4) {
|
||||
GGML_ASSERT(ne00 % QK_NVFP4 == 0);
|
||||
constexpr int nvfp4_block_size = 128;
|
||||
const int64_t block_num_y = (ne0 + QK_NVFP4_SUB * nvfp4_block_size - 1) / (QK_NVFP4_SUB * nvfp4_block_size);
|
||||
const dim3 block_size(nvfp4_block_size, 1, 1);
|
||||
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
|
||||
quantize_mmq_nvfp4<<<num_blocks, block_size, 0, stream>>>(
|
||||
x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
|
||||
} else {
|
||||
GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);
|
||||
|
||||
const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
|
||||
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
|
||||
const dim3 block_size(WARP_SIZE, nwarps, 1);
|
||||
constexpr int nwarps = 8;
|
||||
constexpr int vals_per_warp = 2 * QK_MXFP4;
|
||||
constexpr int vals_per_block = nwarps * vals_per_warp;
|
||||
|
||||
quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
|
||||
const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
|
||||
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
|
||||
const dim3 block_size(WARP_SIZE, nwarps, 1);
|
||||
|
||||
quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ void quantize_mmq_q8_1_cuda(
|
||||
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
|
||||
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
|
||||
|
||||
void quantize_mmq_mxfp4_cuda(const float * x,
|
||||
void quantize_mmq_fp4_cuda(const float * x,
|
||||
const int32_t * ids,
|
||||
void * vy,
|
||||
ggml_type type_src0,
|
||||
|
||||
@@ -2,4 +2,5 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
|
||||
|
||||
@@ -2,4 +2,5 @@
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(320, 256);
|
||||
@@ -3,7 +3,7 @@
|
||||
from glob import glob
|
||||
import os
|
||||
|
||||
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 512, 576]
|
||||
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 320, 512, 576]
|
||||
|
||||
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"]
|
||||
|
||||
@@ -62,7 +62,7 @@ for filename in glob("*.cu"):
|
||||
os.remove(filename)
|
||||
|
||||
for head_size_kq in HEAD_SIZES_KQ:
|
||||
head_size_v = head_size_kq if head_size_kq != 576 else 512
|
||||
head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512)
|
||||
with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v))
|
||||
|
||||
@@ -84,13 +84,16 @@ for ncols in [8, 16, 32, 64]:
|
||||
continue
|
||||
if head_size_kq == 72:
|
||||
continue
|
||||
if head_size_kq == 512 and ncols2 not in (4, 8):
|
||||
# Skip compilation of unused ncols2 values for niche head sizes:
|
||||
if head_size_kq == 320 and ncols2 != 32: # Mistral Small 4
|
||||
continue
|
||||
if head_size_kq != 576 and ncols2 in (16, 32):
|
||||
if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4
|
||||
continue
|
||||
if head_size_kq == 576 and ncols2 not in (4, 16, 32):
|
||||
if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash
|
||||
continue
|
||||
head_size_v = head_size_kq if head_size_kq != 576 else 512
|
||||
if head_size_kq not in (320, 576) and ncols2 in (16, 32):
|
||||
continue
|
||||
head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512)
|
||||
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
|
||||
|
||||
for type in TYPES_MMQ:
|
||||
|
||||
@@ -1101,7 +1101,7 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
|
||||
fs::path cache_file = fs::path(cache_dir) / hash_str;
|
||||
std::ofstream ofs(cache_file, std::ios::binary);
|
||||
ofs.write((const char *)data, size);
|
||||
GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.c_str());
|
||||
GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.string().c_str());
|
||||
}
|
||||
ggml_backend_tensor_set(tensor, data, offset, size);
|
||||
return true;
|
||||
|
||||
@@ -20,12 +20,19 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
|
||||
#define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher()
|
||||
|
||||
#include <vulkan/vulkan.hpp>
|
||||
// SPIRV-Headers: LunarG Windows SDK uses Include/spirv-headers/spirv.hpp (not spirv/unified1/). MinGW/MSYS2 and
|
||||
// Linux packages use Khronos layout spirv/unified1/spirv.hpp. See docs/build.md#vulkan.
|
||||
#if defined(_WIN32) && !defined(__MINGW32__)
|
||||
#include <spirv-headers/spirv.hpp>
|
||||
|
||||
// SPIR-V Headers: different SDK installations expose different include paths.
|
||||
// LunarG Vulkan SDK on Windows typically provides <spirv-headers/spirv.hpp>.
|
||||
// Linux packages, MSYS2 and MinGW often use the Khronos layout <spirv/unified1/spirv.hpp>.
|
||||
#if __has_include(<spirv/unified1/spirv.hpp>)
|
||||
# include <spirv/unified1/spirv.hpp>
|
||||
#elif __has_include(<spirv-headers/spirv.hpp>)
|
||||
# include <spirv-headers/spirv.hpp>
|
||||
#elif __has_include(<spirv.hpp>)
|
||||
# include <spirv.hpp>
|
||||
#else
|
||||
#include <spirv/unified1/spirv.hpp>
|
||||
// Fallback to let the compiler throw a standard "file not found" error
|
||||
# include <spirv/unified1/spirv.hpp>
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
@@ -13007,6 +13014,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
if (vk_perf_logger_enabled && vk_perf_logger_concurrent) {
|
||||
ctx->query_node_idx[ctx->query_idx] = node_idx;
|
||||
compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
|
||||
ggml_vk_sync_buffers(ctx, compute_ctx);
|
||||
}
|
||||
}
|
||||
// Add all fused nodes to the unsynchronized lists.
|
||||
@@ -14496,6 +14504,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
||||
ctx->query_idx = 0;
|
||||
compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
|
||||
ggml_vk_sync_buffers(ctx, compute_ctx);
|
||||
}
|
||||
|
||||
ctx->prealloc_y_last_pipeline_used = nullptr;
|
||||
@@ -14732,6 +14741,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i];
|
||||
ctx->query_fusion_names[ctx->query_idx] = fusion_string;
|
||||
compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
|
||||
ggml_vk_sync_buffers(ctx, compute_ctx);
|
||||
} else {
|
||||
// track a fusion string and number of fused ops for the current node_idx
|
||||
ctx->query_fusion_names[i] = fusion_string;
|
||||
|
||||
@@ -296,13 +296,22 @@ vec2 get_dm_scale(uint ib, uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||
const uint is = iqs_k / 8;
|
||||
u8vec2 scale_dm;
|
||||
if (is < 4) {
|
||||
scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
|
||||
} else {
|
||||
scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
|
||||
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
|
||||
}
|
||||
|
||||
const uvec3 scales = uvec3(data_a_packed32[ib_k].scales[0],
|
||||
data_a_packed32[ib_k].scales[1],
|
||||
data_a_packed32[ib_k].scales[2]);
|
||||
const uint scalesoffs = (is & 3) * 8;
|
||||
|
||||
const uint scidx0 = (is < 4) ? 0 : 2;
|
||||
const uint scidxshift0 = scalesoffs;
|
||||
const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;
|
||||
const uint mbidx0 = (is < 4) ? 1 : 2;
|
||||
const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4;
|
||||
const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;
|
||||
|
||||
const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30));
|
||||
const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30));
|
||||
u8vec2 scale_dm = u8vec2(sc, mbyte);
|
||||
|
||||
return FLOAT_TYPEV2(data_a_packed32[ib_k].dm) * FLOAT_TYPEV2(scale_dm);
|
||||
}
|
||||
|
||||
@@ -201,19 +201,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
|
||||
const vec2 loadd = vec2(data_a[ib].dm);
|
||||
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
const uint scidxshift1 = (is < 4) ? 0 : 2;
|
||||
const uint mbidx0 = is + 4;
|
||||
const uint mbidx1 = (is < 4) ? is + 4 : is;
|
||||
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
||||
const uint mbidxshift0 = (is < 4) ? 0 : 4;
|
||||
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
const uint mbidxshift1 = (is < 4) ? 0 : 2;
|
||||
const uvec3 scales = uvec3(data_a_packed32[ib].scales[0],
|
||||
data_a_packed32[ib].scales[1],
|
||||
data_a_packed32[ib].scales[2]);
|
||||
const uint scalesoffs = (is & 3) * 8;
|
||||
|
||||
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
|
||||
const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
||||
const uint scidx0 = (is < 4) ? 0 : 2;
|
||||
const uint scidxshift0 = scalesoffs;
|
||||
const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;
|
||||
const uint mbidx0 = (is < 4) ? 1 : 2;
|
||||
const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4;
|
||||
const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;
|
||||
|
||||
const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30));
|
||||
const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30));
|
||||
|
||||
const float d = loadd.x * sc;
|
||||
const float m = -loadd.y * mbyte;
|
||||
@@ -237,19 +238,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
|
||||
const vec2 loadd = vec2(data_a[ib].dm);
|
||||
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
const uint scidxshift1 = (is < 4) ? 0 : 2;
|
||||
const uint mbidx0 = is + 4;
|
||||
const uint mbidx1 = (is < 4) ? is + 4 : is;
|
||||
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
||||
const uint mbidxshift0 = (is < 4) ? 0 : 4;
|
||||
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
const uint mbidxshift1 = (is < 4) ? 0 : 2;
|
||||
const uvec3 scales = uvec3(data_a_packed32[ib].scales[0],
|
||||
data_a_packed32[ib].scales[1],
|
||||
data_a_packed32[ib].scales[2]);
|
||||
const uint scalesoffs = (is & 3) * 8;
|
||||
|
||||
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
|
||||
const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
||||
const uint scidx0 = (is < 4) ? 0 : 2;
|
||||
const uint scidxshift0 = scalesoffs;
|
||||
const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;
|
||||
const uint mbidx0 = (is < 4) ? 1 : 2;
|
||||
const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4;
|
||||
const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;
|
||||
|
||||
const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30));
|
||||
const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30));
|
||||
|
||||
const float d = loadd.x * sc;
|
||||
const float m = -loadd.y * mbyte;
|
||||
|
||||
@@ -26,21 +26,21 @@
|
||||
// Matrix multiplication parameters
|
||||
|
||||
// Register tiling parameters
|
||||
#define WEBGPU_MUL_MAT_TILE_M 4
|
||||
#define WEBGPU_MUL_MAT_TILE_N 4
|
||||
#define WEBGPU_MUL_MAT_WG_SIZE_M 8
|
||||
#define WEBGPU_MUL_MAT_WG_SIZE_N 8
|
||||
#define WEBGPU_MUL_MAT_TILE_M 4
|
||||
#define WEBGPU_MUL_MAT_TILE_N 4
|
||||
#define WEBGPU_MUL_MAT_WG_SIZE_M 8
|
||||
#define WEBGPU_MUL_MAT_WG_SIZE_N 8
|
||||
#define WEBGPU_MUL_MAT_REG_TILE_K_FLOAT 8
|
||||
#define WEBGPU_MUL_MAT_REG_TILE_K_QUANT 32
|
||||
|
||||
// Subgroup matrix parameters
|
||||
// The number of subgroups in the M dimension
|
||||
#define WEBGPU_MUL_MAT_SUBGROUP_M 2
|
||||
#define WEBGPU_MUL_MAT_SUBGROUP_M 2
|
||||
// The number of subgroups in the N dimension
|
||||
#define WEBGPU_MUL_MAT_SUBGROUP_N 4
|
||||
#define WEBGPU_MUL_MAT_SUBGROUP_N 4
|
||||
// The number of subgroup matrices each subgroup accumulates over
|
||||
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
|
||||
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
|
||||
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
|
||||
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
|
||||
#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT 32
|
||||
#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT 32
|
||||
|
||||
@@ -59,19 +59,32 @@ template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const
|
||||
seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||
}
|
||||
|
||||
// Calculates base address of a tensor ignoring the fake base pointer
|
||||
inline uintptr_t ggml_webgpu_tensor_addr(const ggml_tensor * tensor) {
|
||||
const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor;
|
||||
return (uintptr_t) base_tensor->data + tensor->view_offs;
|
||||
}
|
||||
|
||||
inline bool ggml_webgpu_tensor_equal(const ggml_tensor * a, const ggml_tensor * b) {
|
||||
return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) == ggml_webgpu_tensor_addr(b);
|
||||
}
|
||||
|
||||
inline bool ggml_webgpu_tensor_overlap(const ggml_tensor * a, const ggml_tensor * b) {
|
||||
return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) < ggml_webgpu_tensor_addr(b) + ggml_nbytes(b) &&
|
||||
ggml_webgpu_tensor_addr(b) < ggml_webgpu_tensor_addr(a) + ggml_nbytes(a);
|
||||
}
|
||||
|
||||
struct ggml_webgpu_shader_lib_context {
|
||||
ggml_tensor * src0;
|
||||
ggml_tensor * src1;
|
||||
ggml_tensor * src2;
|
||||
ggml_tensor * src3;
|
||||
ggml_tensor * src4;
|
||||
ggml_tensor * src5;
|
||||
ggml_tensor * dst;
|
||||
|
||||
uint32_t max_wg_size;
|
||||
size_t wg_mem_limit_bytes = 0;
|
||||
bool inplace = false;
|
||||
bool overlap = false;
|
||||
bool src_overlap = false;
|
||||
bool supports_subgroups = false;
|
||||
bool supports_subgroup_matrix = false;
|
||||
uint32_t sg_mat_m = 0;
|
||||
@@ -88,6 +101,14 @@ struct webgpu_pipeline {
|
||||
|
||||
struct ggml_webgpu_generic_shader_decisions {
|
||||
uint32_t wg_size = 0;
|
||||
bool inplace = false;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_binary_shader_decisions {
|
||||
uint32_t wg_size = 0;
|
||||
bool inplace = false;
|
||||
bool overlap = false;
|
||||
bool src_overlap = false;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_processed_shader {
|
||||
@@ -102,11 +123,12 @@ struct ggml_webgpu_ssm_conv_shader_decisions {
|
||||
};
|
||||
|
||||
struct ggml_webgpu_ssm_scan_pipeline_key {
|
||||
int type;
|
||||
int d_state;
|
||||
int type;
|
||||
int d_state;
|
||||
bool xbc_overlap;
|
||||
|
||||
bool operator==(const ggml_webgpu_ssm_scan_pipeline_key & other) const {
|
||||
return type == other.type && d_state == other.d_state;
|
||||
return type == other.type && d_state == other.d_state && xbc_overlap == other.xbc_overlap;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -115,6 +137,7 @@ struct ggml_webgpu_ssm_scan_pipeline_key_hash {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.d_state);
|
||||
ggml_webgpu_hash_combine(seed, key.xbc_overlap);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
@@ -122,6 +145,7 @@ struct ggml_webgpu_ssm_scan_pipeline_key_hash {
|
||||
struct ggml_webgpu_ssm_scan_shader_decisions {
|
||||
uint32_t wg_size;
|
||||
uint32_t tokens_per_tile;
|
||||
bool xbc_overlap = false;
|
||||
};
|
||||
|
||||
/** Argsort **/
|
||||
@@ -242,6 +266,13 @@ struct ggml_webgpu_rms_norm_mul_pipeline_key_hash {
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_rms_norm_mul_shader_decisions {
|
||||
uint32_t wg_size = 0;
|
||||
bool inplace = false;
|
||||
bool overlap = false;
|
||||
bool src_overlap = false;
|
||||
};
|
||||
|
||||
/** Pad **/
|
||||
struct ggml_webgpu_pad_pipeline_key {
|
||||
bool circular;
|
||||
@@ -503,11 +534,12 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash {
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_decisions {
|
||||
uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX;
|
||||
uint32_t q_tile = 0;
|
||||
uint32_t kv_tile = 0;
|
||||
uint32_t wg_size = 0;
|
||||
bool kv_direct = false;
|
||||
uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX;
|
||||
uint32_t q_tile = 0;
|
||||
uint32_t kv_tile = 0;
|
||||
uint32_t wg_size = 0;
|
||||
bool kv_direct = false;
|
||||
bool kv_overlap = false;
|
||||
};
|
||||
|
||||
inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u;
|
||||
@@ -552,7 +584,7 @@ inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_
|
||||
key.head_dim_qk = (uint32_t) context.src0->ne[0];
|
||||
key.head_dim_v = (uint32_t) context.src2->ne[0];
|
||||
key.kv_direct = kv_direct;
|
||||
key.kv_overlap = context.src_overlap;
|
||||
key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2);
|
||||
key.has_mask = has_mask;
|
||||
key.has_sinks = has_sinks;
|
||||
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
|
||||
@@ -1021,7 +1053,7 @@ class ggml_webgpu_shader_lib {
|
||||
webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_row_norm_pipeline_key key = {};
|
||||
key.op = context.dst->op;
|
||||
key.inplace = context.inplace;
|
||||
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
||||
|
||||
auto it = row_norm_pipelines.find(key);
|
||||
if (it != row_norm_pipelines.end()) {
|
||||
@@ -1051,8 +1083,12 @@ class ggml_webgpu_shader_lib {
|
||||
const uint32_t row_norm_wg_size = 128u;
|
||||
uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size);
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
|
||||
row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = wg_size;
|
||||
decisions->inplace = key.inplace;
|
||||
row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
row_norm_pipelines[key].context = decisions;
|
||||
return row_norm_pipelines[key];
|
||||
}
|
||||
|
||||
@@ -1127,7 +1163,7 @@ class ggml_webgpu_shader_lib {
|
||||
webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_set_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
key.inplace = context.inplace;
|
||||
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
||||
|
||||
auto it = set_pipelines.find(key);
|
||||
if (it != set_pipelines.end()) {
|
||||
@@ -1160,6 +1196,7 @@ class ggml_webgpu_shader_lib {
|
||||
auto processed = preprocessor.preprocess(wgsl_set, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
decisions->inplace = key.inplace;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
set_pipelines[key] = pipeline;
|
||||
@@ -1287,6 +1324,7 @@ class ggml_webgpu_shader_lib {
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
switch (key.src_type) {
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
@@ -1323,7 +1361,9 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
defines.push_back("DST_TYPE=f32");
|
||||
|
||||
if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
|
||||
if (key.src_type == GGML_TYPE_Q1_0) {
|
||||
defines.push_back("BLOCK_SIZE=128u");
|
||||
} else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
|
||||
key.src_type == GGML_TYPE_IQ4_NL) {
|
||||
defines.push_back("BLOCK_SIZE=32u");
|
||||
} else if (key.src_type >= GGML_TYPE_Q2_K) {
|
||||
@@ -1352,7 +1392,7 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_scale_pipeline_key key = {};
|
||||
key.inplace = context.inplace;
|
||||
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
||||
|
||||
auto it = scale_pipelines.find(key);
|
||||
if (it != scale_pipelines.end()) {
|
||||
@@ -1372,6 +1412,7 @@ class ggml_webgpu_shader_lib {
|
||||
auto processed = preprocessor.preprocess(wgsl_scale, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
decisions->inplace = key.inplace;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
scale_pipelines[key] = pipeline;
|
||||
@@ -1465,6 +1506,8 @@ class ggml_webgpu_shader_lib {
|
||||
ggml_webgpu_ssm_scan_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
key.d_state = (int) context.src0->ne[0];
|
||||
key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
|
||||
ggml_webgpu_tensor_overlap(context.src1, context.src5);
|
||||
|
||||
auto it = ssm_scan_pipelines.find(key);
|
||||
if (it != ssm_scan_pipelines.end()) {
|
||||
@@ -1496,12 +1539,17 @@ class ggml_webgpu_shader_lib {
|
||||
variant += "_wg_reduce";
|
||||
}
|
||||
|
||||
if (key.xbc_overlap) {
|
||||
defines.push_back("XBC_OVERLAP");
|
||||
}
|
||||
|
||||
variant += "_d" + std::to_string(key.d_state);
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_ssm_scan, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_ssm_scan_shader_decisions>();
|
||||
decisions->wg_size = wg_size;
|
||||
decisions->tokens_per_tile = tokens_per_tile;
|
||||
decisions->xbc_overlap = key.xbc_overlap;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
ssm_scan_pipelines[key] = pipeline;
|
||||
@@ -1615,6 +1663,24 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back("MUL_ACC_" + type_upper);
|
||||
defines.push_back("U32_DEQUANT_HELPERS");
|
||||
defines.push_back("SRC0_INNER_TYPE=u32");
|
||||
switch (context.src0->type) {
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
defines.push_back(type_upper + "_GRID");
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
defines.push_back(type_upper + "_GRID");
|
||||
defines.push_back(type_upper + "_TABLES");
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -1639,7 +1705,9 @@ class ggml_webgpu_shader_lib {
|
||||
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
||||
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
|
||||
|
||||
if (key.src0_type >= GGML_TYPE_Q2_K) {
|
||||
if (key.src0_type == GGML_TYPE_Q1_0) {
|
||||
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
|
||||
} else if (key.src0_type >= GGML_TYPE_Q2_K) {
|
||||
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
|
||||
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
|
||||
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
|
||||
@@ -1741,11 +1809,9 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
uint32_t tile_k;
|
||||
if (key.use_subgroup_matrix) {
|
||||
tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT
|
||||
: WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT;
|
||||
tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT : WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT;
|
||||
} else {
|
||||
tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT
|
||||
: WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
|
||||
tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
|
||||
}
|
||||
|
||||
// Tiles
|
||||
@@ -1978,9 +2044,8 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back("SCALAR");
|
||||
|
||||
// mul_mat_id is register-tile only.
|
||||
const uint32_t tile_k = ggml_is_quantized(context.src0->type)
|
||||
? WEBGPU_MUL_MAT_REG_TILE_K_QUANT
|
||||
: WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
|
||||
const uint32_t tile_k =
|
||||
ggml_is_quantized(context.src0->type) ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
|
||||
|
||||
// Tiles
|
||||
defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
|
||||
@@ -2016,8 +2081,8 @@ class ggml_webgpu_shader_lib {
|
||||
key.type = context.dst->type;
|
||||
key.op = op;
|
||||
key.is_unary = is_unary;
|
||||
key.inplace = context.inplace;
|
||||
key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0);
|
||||
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst) || context.dst->op == GGML_OP_FILL;
|
||||
key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0);
|
||||
|
||||
auto it = unary_pipelines.find(key);
|
||||
if (it != unary_pipelines.end()) {
|
||||
@@ -2075,6 +2140,7 @@ class ggml_webgpu_shader_lib {
|
||||
auto processed = preprocessor.preprocess(wgsl_unary, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
decisions->inplace = key.inplace;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
unary_pipelines[key] = pipeline;
|
||||
@@ -2083,9 +2149,9 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_rms_norm_mul_pipeline_key key = {};
|
||||
key.inplace = context.inplace;
|
||||
key.overlap = context.overlap;
|
||||
key.src_overlap = context.src_overlap;
|
||||
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
||||
key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst);
|
||||
key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
|
||||
|
||||
auto it = rms_norm_mul_pipelines.find(key);
|
||||
if (it != rms_norm_mul_pipelines.end()) {
|
||||
@@ -2109,12 +2175,15 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
rms_norm_mul_pipelines[key] = pipeline;
|
||||
auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines);
|
||||
auto pipeline_decisions = std::make_shared<ggml_webgpu_rms_norm_mul_shader_decisions>();
|
||||
pipeline_decisions->wg_size = context.max_wg_size;
|
||||
pipeline_decisions->inplace = key.inplace;
|
||||
pipeline_decisions->overlap = key.overlap;
|
||||
pipeline_decisions->src_overlap = key.src_overlap;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = pipeline_decisions;
|
||||
rms_norm_mul_pipelines[key] = pipeline;
|
||||
return rms_norm_mul_pipelines[key];
|
||||
}
|
||||
|
||||
@@ -2122,9 +2191,9 @@ class ggml_webgpu_shader_lib {
|
||||
ggml_webgpu_binary_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
key.op = context.dst->op;
|
||||
key.inplace = context.inplace;
|
||||
key.overlap = context.overlap;
|
||||
key.src_overlap = context.src_overlap;
|
||||
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
||||
key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst);
|
||||
key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
|
||||
|
||||
auto it = binary_pipelines.find(key);
|
||||
if (it != binary_pipelines.end()) {
|
||||
@@ -2163,11 +2232,15 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_binary, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
auto processed = preprocessor.preprocess(wgsl_binary, defines);
|
||||
auto pipeline_decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>();
|
||||
pipeline_decisions->wg_size = context.max_wg_size;
|
||||
pipeline_decisions->inplace = key.inplace;
|
||||
pipeline_decisions->overlap = key.overlap;
|
||||
pipeline_decisions->src_overlap = key.src_overlap;
|
||||
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
pipeline.context = pipeline_decisions;
|
||||
binary_pipelines[key] = pipeline;
|
||||
return binary_pipelines[key];
|
||||
}
|
||||
@@ -2328,7 +2401,8 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
|
||||
}
|
||||
|
||||
auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions);
|
||||
auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions);
|
||||
pipeline_decisions->kv_overlap = key.kv_overlap;
|
||||
defines.push_back(std::string("Q_TILE=") + std::to_string(decisions.q_tile));
|
||||
defines.push_back(std::string("KV_TILE=") + std::to_string(decisions.kv_tile));
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions.wg_size));
|
||||
@@ -2520,7 +2594,7 @@ class ggml_webgpu_shader_lib {
|
||||
webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_rope_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
key.inplace = context.inplace;
|
||||
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
||||
key.has_ff = (context.src2 != nullptr);
|
||||
|
||||
auto it = rope_pipelines.find(key);
|
||||
@@ -2559,6 +2633,7 @@ class ggml_webgpu_shader_lib {
|
||||
auto processed = preprocessor.preprocess(wgsl_rope, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
decisions->inplace = key.inplace;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
rope_pipelines[key] = pipeline;
|
||||
@@ -2570,7 +2645,7 @@ class ggml_webgpu_shader_lib {
|
||||
key.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32;
|
||||
key.has_mask = (context.src1 != nullptr);
|
||||
key.has_sink = (context.src2 != nullptr);
|
||||
key.inplace = context.inplace;
|
||||
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
||||
|
||||
auto it = soft_max_pipelines.find(key);
|
||||
if (it != soft_max_pipelines.end()) {
|
||||
@@ -2611,6 +2686,7 @@ class ggml_webgpu_shader_lib {
|
||||
auto processed = preprocessor.preprocess(wgsl_soft_max, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
decisions->inplace = key.inplace;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
soft_max_pipelines[key] = pipeline;
|
||||
|
||||
@@ -108,12 +108,9 @@ static inline uint32_t ggml_webgpu_u32_from_f32(float value) {
|
||||
// their locations.
|
||||
static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
|
||||
|
||||
// Always returns the base offset of a tensor, regardless of views.
|
||||
static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
|
||||
if (tensor->view_src) {
|
||||
return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base;
|
||||
}
|
||||
return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base;
|
||||
static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
|
||||
const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor;
|
||||
return (size_t) ((uintptr_t) base_tensor->data - (uintptr_t) webgpu_ptr_base) + tensor->view_offs;
|
||||
}
|
||||
|
||||
/* Struct definitions */
|
||||
@@ -375,10 +372,6 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
||||
buffer = device.CreateBuffer(&buffer_desc);
|
||||
}
|
||||
|
||||
static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
|
||||
return webgpu_tensor_offset(tensor) + tensor->view_offs;
|
||||
}
|
||||
|
||||
static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
|
||||
ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
|
||||
return ctx->buffer;
|
||||
@@ -398,34 +391,31 @@ static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor
|
||||
return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
}
|
||||
|
||||
// Used to determine if two tensors are the same for in-place operations
|
||||
static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
|
||||
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
|
||||
(ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
|
||||
}
|
||||
|
||||
// Used to determine if two tensors share the same buffer and their byte ranges overlap,
|
||||
static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
|
||||
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
|
||||
ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) &&
|
||||
ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a));
|
||||
}
|
||||
|
||||
struct binary_overlap_flags {
|
||||
bool inplace; // src0 == dst
|
||||
bool overlap; // src1 == dst
|
||||
bool src_overlap;
|
||||
struct ggml_webgpu_merged_binding_range {
|
||||
size_t offset;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
binary_overlap_flags flags = {};
|
||||
flags.inplace = ggml_webgpu_tensor_equal(src0, dst);
|
||||
flags.overlap = ggml_webgpu_tensor_overlap(src1, dst);
|
||||
flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1);
|
||||
static ggml_webgpu_merged_binding_range ggml_webgpu_tensor_merged_binding_range(
|
||||
webgpu_context & ctx,
|
||||
std::initializer_list<ggml_tensor *> tensors) {
|
||||
size_t merged_offset = SIZE_MAX;
|
||||
size_t merged_end = 0;
|
||||
|
||||
return flags;
|
||||
for (ggml_tensor * tensor : tensors) {
|
||||
const size_t bind_offset = ggml_webgpu_tensor_align_offset(ctx, tensor);
|
||||
const size_t bind_end = bind_offset + ggml_webgpu_tensor_binding_size(ctx, tensor);
|
||||
|
||||
merged_offset = std::min(merged_offset, bind_offset);
|
||||
merged_end = std::max(merged_end, bind_end);
|
||||
}
|
||||
|
||||
return { merged_offset, merged_end - merged_offset };
|
||||
}
|
||||
|
||||
static uint32_t ggml_webgpu_tensor_merged_element_offset(const ggml_tensor * tensor,
|
||||
const ggml_webgpu_merged_binding_range & merged_range) {
|
||||
return (uint32_t) ((ggml_webgpu_tensor_offset(tensor) - merged_range.offset) / ggml_type_size(tensor->type));
|
||||
}
|
||||
|
||||
static wgpu::BindGroupEntry ggml_webgpu_make_bind_group_entry(uint32_t binding,
|
||||
@@ -753,18 +743,16 @@ static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
const bool inplace = ggml_webgpu_tensor_equal(src0, dst);
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src0;
|
||||
shader_lib_ctx.src1 = src1;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.inplace = inplace;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
const bool inplace = decisions->inplace;
|
||||
|
||||
const uint32_t ne = inplace ? (uint32_t) ggml_nelements(src1) : (uint32_t) ggml_nelements(dst);
|
||||
const uint32_t dst_type_size = (uint32_t) ggml_type_size(dst->type);
|
||||
@@ -1126,19 +1114,39 @@ static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx,
|
||||
ggml_tensor * dst) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src0;
|
||||
shader_lib_ctx.src1 = src1;
|
||||
shader_lib_ctx.src4 = src4;
|
||||
shader_lib_ctx.src5 = src5;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_scan_pipeline(shader_lib_ctx);
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_scan_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_ssm_scan_shader_decisions *>(pipeline.context.get());
|
||||
const bool xbc_overlap = decisions->xbc_overlap;
|
||||
|
||||
uint32_t offset_x = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type));
|
||||
uint32_t offset_B = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type));
|
||||
uint32_t offset_C = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type));
|
||||
size_t xbc_bind_offset = 0;
|
||||
size_t xbc_bind_size = 0;
|
||||
if (xbc_overlap) {
|
||||
const ggml_webgpu_merged_binding_range merged_range =
|
||||
ggml_webgpu_tensor_merged_binding_range(ctx, { src1, src4, src5 });
|
||||
xbc_bind_offset = merged_range.offset;
|
||||
xbc_bind_size = merged_range.size;
|
||||
offset_x = ggml_webgpu_tensor_merged_element_offset(src1, merged_range);
|
||||
offset_B = ggml_webgpu_tensor_merged_element_offset(src4, merged_range);
|
||||
offset_C = ggml_webgpu_tensor_merged_element_offset(src5, merged_range);
|
||||
}
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
offset_x,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src3) / ggml_type_size(src3->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type)),
|
||||
offset_B,
|
||||
offset_C,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src6) / ggml_type_size(src6->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
|
||||
@@ -1174,11 +1182,24 @@ static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx,
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2), ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4), ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6), ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0),
|
||||
};
|
||||
if (xbc_overlap) {
|
||||
entries.push_back(
|
||||
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), xbc_bind_offset, xbc_bind_size));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src6));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, dst));
|
||||
} else {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst));
|
||||
}
|
||||
|
||||
const uint32_t total_wg = (uint32_t) (src0->ne[1] * src0->ne[2] * src1->ne[3]);
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
@@ -1389,8 +1410,20 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q1_0:
|
||||
use_fast = true;
|
||||
break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
use_fast = is_vec;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -1641,23 +1674,38 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
|
||||
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = Q;
|
||||
shader_lib_ctx.src1 = K;
|
||||
shader_lib_ctx.src2 = V;
|
||||
shader_lib_ctx.src3 = mask;
|
||||
shader_lib_ctx.src4 = sinks;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
|
||||
shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
|
||||
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
|
||||
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
|
||||
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(
|
||||
shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get());
|
||||
const int has_mask = (mask != nullptr);
|
||||
const int has_sinks = (sinks != nullptr);
|
||||
const bool kv_overlap = ggml_webgpu_tensor_overlap(K, V) && K->type == V->type;
|
||||
const bool kv_overlap = decisions->kv_overlap;
|
||||
|
||||
uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type));
|
||||
uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type));
|
||||
size_t kv_bind_offset = 0;
|
||||
size_t kv_bind_size = 0;
|
||||
if (kv_overlap) {
|
||||
const size_t k_bind_offset = ggml_webgpu_tensor_align_offset(ctx, K);
|
||||
const size_t v_bind_offset = ggml_webgpu_tensor_align_offset(ctx, V);
|
||||
const size_t k_bind_end = k_bind_offset + ggml_webgpu_tensor_binding_size(ctx, K);
|
||||
const size_t v_bind_end = v_bind_offset + ggml_webgpu_tensor_binding_size(ctx, V);
|
||||
kv_bind_offset = std::min(k_bind_offset, v_bind_offset);
|
||||
kv_bind_size = std::max(k_bind_end, v_bind_end) - kv_bind_offset;
|
||||
offset_k = (uint32_t) ((ggml_webgpu_tensor_offset(K) - kv_bind_offset) / ggml_type_size(K->type));
|
||||
offset_v = (uint32_t) ((ggml_webgpu_tensor_offset(V) - kv_bind_offset) / ggml_type_size(V->type));
|
||||
const ggml_webgpu_merged_binding_range merged_range = ggml_webgpu_tensor_merged_binding_range(ctx, { K, V });
|
||||
kv_bind_offset = merged_range.offset;
|
||||
kv_bind_size = merged_range.size;
|
||||
offset_k = ggml_webgpu_tensor_merged_element_offset(K, merged_range);
|
||||
offset_v = ggml_webgpu_tensor_merged_element_offset(V, merged_range);
|
||||
}
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
@@ -1708,26 +1756,6 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
}
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst));
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = Q;
|
||||
shader_lib_ctx.src1 = K;
|
||||
shader_lib_ctx.src2 = V;
|
||||
shader_lib_ctx.src3 = mask;
|
||||
shader_lib_ctx.src4 = sinks;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.src_overlap = kv_overlap;
|
||||
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
|
||||
shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
|
||||
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
|
||||
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
|
||||
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(
|
||||
shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get());
|
||||
|
||||
if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
|
||||
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
|
||||
@@ -1909,18 +1937,17 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
bool is_unary = dst->op == GGML_OP_UNARY;
|
||||
bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL);
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src;
|
||||
shader_lib_ctx.src1 = nullptr;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.inplace = inplace;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
const bool inplace = decisions->inplace;
|
||||
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
|
||||
@@ -1982,41 +2009,38 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src0;
|
||||
shader_lib_ctx.src1 = src1;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.inplace = flags.inplace;
|
||||
shader_lib_ctx.overlap = flags.overlap;
|
||||
shader_lib_ctx.src_overlap = flags.src_overlap;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_binary_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
|
||||
size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0);
|
||||
size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1);
|
||||
|
||||
uint32_t offset_merged_src0 = 0;
|
||||
uint32_t offset_merged_src1 = 0;
|
||||
if (flags.src_overlap) {
|
||||
size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
|
||||
offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
|
||||
offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
|
||||
uint32_t offset_src0 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type));
|
||||
uint32_t offset_src1 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type));
|
||||
size_t merged_offset = 0;
|
||||
size_t merged_size = 0;
|
||||
if (decisions->src_overlap) {
|
||||
const ggml_webgpu_merged_binding_range merged_range =
|
||||
ggml_webgpu_tensor_merged_binding_range(ctx, { src0, src1 });
|
||||
merged_offset = merged_range.offset;
|
||||
merged_size = merged_range.size;
|
||||
offset_src0 = ggml_webgpu_tensor_merged_element_offset(src0, merged_range);
|
||||
offset_src1 = ggml_webgpu_tensor_merged_element_offset(src1, merged_range);
|
||||
}
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
ne,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
offset_src0,
|
||||
offset_src1,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
offset_merged_src0,
|
||||
offset_merged_src1,
|
||||
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
@@ -2036,12 +2060,9 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries;
|
||||
|
||||
if (flags.src_overlap) {
|
||||
size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
|
||||
size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0),
|
||||
src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1));
|
||||
entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset,
|
||||
merged_end - merged_offset));
|
||||
if (decisions->src_overlap) {
|
||||
entries.push_back(
|
||||
ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, merged_size));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
|
||||
} else {
|
||||
entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0),
|
||||
@@ -2050,7 +2071,7 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1),
|
||||
src1_webgpu_tensor_align_offset,
|
||||
ggml_webgpu_tensor_binding_size(ctx, src1)));
|
||||
if (!flags.inplace && !flags.overlap) {
|
||||
if (!decisions->inplace && !decisions->overlap) {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
|
||||
}
|
||||
}
|
||||
@@ -2156,29 +2177,15 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
|
||||
GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1");
|
||||
}
|
||||
|
||||
bool overlap = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
|
||||
(ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst));
|
||||
bool inplace = ggml_webgpu_tensor_equal(rn_src, dst);
|
||||
bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src);
|
||||
|
||||
uint32_t offset_merged_rn_src = 0;
|
||||
uint32_t offset_merged_mul_src = 0;
|
||||
size_t rn_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, rn_src);
|
||||
size_t mul_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, mul_src);
|
||||
|
||||
if (src_overlap) {
|
||||
size_t min_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset);
|
||||
offset_merged_rn_src =
|
||||
(uint32_t) ((rn_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(rn_src->type));
|
||||
offset_merged_mul_src =
|
||||
(uint32_t) ((mul_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(mul_src->type));
|
||||
}
|
||||
uint32_t offset_rn_src = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type));
|
||||
uint32_t offset_mul_src =
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type));
|
||||
size_t merged_offset = 0;
|
||||
size_t merged_size = 0;
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type)),
|
||||
offset_merged_rn_src,
|
||||
offset_merged_mul_src,
|
||||
offset_rn_src,
|
||||
offset_mul_src,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
(uint32_t) (rn_src->nb[1] / ggml_type_size(rn_src->type)),
|
||||
(uint32_t) (rn_src->nb[2] / ggml_type_size(rn_src->type)),
|
||||
@@ -2202,16 +2209,32 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries;
|
||||
|
||||
if (inplace || overlap) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = rn_src;
|
||||
shader_lib_ctx.src1 = mul_src;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_rms_norm_mul_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
if (decisions->src_overlap) {
|
||||
const ggml_webgpu_merged_binding_range merged_range =
|
||||
ggml_webgpu_tensor_merged_binding_range(ctx, { rn_src, mul_src });
|
||||
merged_offset = merged_range.offset;
|
||||
merged_size = merged_range.size;
|
||||
offset_rn_src = ggml_webgpu_tensor_merged_element_offset(rn_src, merged_range);
|
||||
offset_mul_src = ggml_webgpu_tensor_merged_element_offset(mul_src, merged_range);
|
||||
params[0] = offset_rn_src;
|
||||
params[1] = offset_mul_src;
|
||||
}
|
||||
|
||||
if (decisions->inplace || decisions->overlap) {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src));
|
||||
} else if (src_overlap) {
|
||||
size_t merged_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset);
|
||||
size_t merged_end =
|
||||
std::max(rn_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, rn_src),
|
||||
mul_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, mul_src));
|
||||
entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset,
|
||||
merged_end - merged_offset));
|
||||
} else if (decisions->src_overlap) {
|
||||
entries.push_back(
|
||||
ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset, merged_size));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
|
||||
} else {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src));
|
||||
@@ -2219,20 +2242,10 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
|
||||
}
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.inplace = inplace;
|
||||
shader_lib_ctx.overlap = overlap;
|
||||
shader_lib_ctx.src_overlap = src_overlap;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx);
|
||||
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst));
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
bool inplace = ggml_webgpu_tensor_equal(src, dst);
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
@@ -2249,18 +2262,18 @@ static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor
|
||||
ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)) // epsilon, treated as f32 in the shader
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) };
|
||||
if (!inplace) {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
|
||||
}
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.inplace = inplace;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx);
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) };
|
||||
if (!decisions->inplace) {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
|
||||
}
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(src));
|
||||
}
|
||||
|
||||
@@ -2275,14 +2288,13 @@ static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx,
|
||||
shader_lib_ctx.src2 = src2;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst);
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
|
||||
const int has_freq_factor = (src2 != nullptr);
|
||||
const bool inplace = decisions->inplace;
|
||||
const int has_freq_factor = (src2 != nullptr);
|
||||
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
@@ -2409,14 +2421,11 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx,
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
bool inplace = ggml_webgpu_tensor_equal(src, dst);
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src;
|
||||
shader_lib_ctx.src1 = nullptr;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.inplace = inplace;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
@@ -2442,7 +2451,7 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * s
|
||||
// bindgroups unchanged
|
||||
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) };
|
||||
|
||||
if (!inplace) {
|
||||
if (!decisions->inplace) {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
|
||||
}
|
||||
|
||||
@@ -2461,17 +2470,17 @@ static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx,
|
||||
shader_lib_ctx.src2 = src2;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst);
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx);
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
|
||||
const int has_mask = (src1 != nullptr);
|
||||
const int has_sink = (src2 != nullptr);
|
||||
float max_bias = ggml_get_op_params_f32(dst, 1);
|
||||
float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
|
||||
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
|
||||
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
const bool inplace = decisions->inplace;
|
||||
const int has_mask = (src1 != nullptr);
|
||||
const int has_sink = (src2 != nullptr);
|
||||
float max_bias = ggml_get_op_params_f32(dst, 1);
|
||||
float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
|
||||
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
|
||||
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
@@ -3067,7 +3076,7 @@ static void ggml_backend_webgpu_set_tensor_async(ggml_backend_t backend,
|
||||
size_t size) {
|
||||
GGML_UNUSED(backend);
|
||||
auto * buf_ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
|
||||
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
||||
size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset;
|
||||
|
||||
// Write aligned portion
|
||||
buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
|
||||
@@ -3149,7 +3158,7 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value
|
||||
<< ", " << offset << ", " << size << ")");
|
||||
|
||||
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
||||
size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset;
|
||||
|
||||
// This is a trick to set all bytes of a u32 to the same 1 byte value.
|
||||
uint32_t val32 = (uint32_t) value * 0x01010101;
|
||||
@@ -3168,7 +3177,7 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
|
||||
<< ", " << offset << ", " << size << ")");
|
||||
|
||||
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
||||
size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset;
|
||||
|
||||
buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
|
||||
|
||||
@@ -3200,7 +3209,7 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
||||
<< ", " << offset << ", " << size << ")");
|
||||
wgpu::Device device = buf_ctx->global_ctx->device;
|
||||
|
||||
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
||||
size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset;
|
||||
|
||||
size_t final_size = size;
|
||||
if (size % 4 != 0) {
|
||||
@@ -3725,6 +3734,7 @@ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggm
|
||||
|
||||
static bool ggml_webgpu_supported_qtype(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
@@ -3819,6 +3829,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
@@ -3857,6 +3868,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
|
||||
@@ -7,8 +7,6 @@ struct Params {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
offset_merged_src0: u32,
|
||||
offset_merged_src1: u32,
|
||||
|
||||
stride_src0_0: u32,
|
||||
stride_src0_1: u32,
|
||||
@@ -134,8 +132,8 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x < params.ne) {
|
||||
let src0_i = params.offset_src0 + params.offset_merged_src0 + src0_index(gid.x);
|
||||
let src1_i = params.offset_src1 + params.offset_merged_src1 + src1_index(gid.x);
|
||||
let src0_i = params.offset_src0 + src0_index(gid.x);
|
||||
let src1_i = params.offset_src1 + src1_index(gid.x);
|
||||
update(params.offset_dst + gid.x, src0_i, src1_i);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,24 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q1_0
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_byte_base = (src_base + offset) * 18;
|
||||
let d = load_f16_as_f32_at_src(block_byte_base);
|
||||
for (var j: u32 = 0u; j < 4u; j++) {
|
||||
let q_packed = load_u32_at_src(block_byte_base + 2u + j * 4u);
|
||||
let dst_base128 = dst_base + offset * 128u + j * 32u;
|
||||
for (var k: u32 = 0; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
for (var bit: u32 = 0; bit < 8u; bit++) {
|
||||
let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u);
|
||||
dst[dst_base128 + k * 8u + bit] = w;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q4_0
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
|
||||
|
||||
@@ -61,6 +61,39 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
|
||||
#endif // INIT_SRC1_SHMEM_FLOAT
|
||||
#endif
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q1_0
|
||||
const BLOCK_SIZE = 128u;
|
||||
const BLOCK_SIZE_BYTES = 18u;
|
||||
const NQ = 8u; // 8 weights (1 byte of qs) per thread per iteration
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let tile_m = i / TILE_K;
|
||||
let tile_k_start = i % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k_start = k_outer + tile_k_start;
|
||||
|
||||
if (global_m >= params.m) {
|
||||
break;
|
||||
}
|
||||
|
||||
let block_k = global_k_start / BLOCK_SIZE;
|
||||
let byte_in_block = (global_k_start % BLOCK_SIZE) / 8u;
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
let q_byte = load_u32_at_src0(block_byte_base + 2u + byte_in_block) & 0xFFu;
|
||||
|
||||
for (var bit = 0u; bit < NQ; bit++) {
|
||||
let global_k = global_k_start + bit;
|
||||
if (global_k < params.k) {
|
||||
shmem[i + bit] = select(-d, d, ((q_byte >> bit) & 1u) != 0u);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q1_0
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_0
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 18u;
|
||||
|
||||
@@ -128,6 +128,38 @@ fn main(
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q1_0
|
||||
#define BLOCK_SIZE 128
|
||||
#define BLOCK_SIZE_BYTES 18
|
||||
#define THREADS_PER_BLOCK 16
|
||||
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
let thread_within_block = thread_id % THREADS_PER_BLOCK;
|
||||
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD;
|
||||
var x_block: array<f32, ELEMS_PER_THREAD>;
|
||||
for (var i = 0u; i < ELEMS_PER_THREAD; i++) {
|
||||
x_block[i] = f32(src1[x_base + i]);
|
||||
}
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
let d = f32(load_f16_at_src0(block_byte_base));
|
||||
let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu;
|
||||
var row_sum = 0.0;
|
||||
for (var bit = 0u; bit < 8u; bit++) {
|
||||
let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u);
|
||||
row_sum += w * x_block[bit];
|
||||
}
|
||||
acc[row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q4_0
|
||||
#define BLOCK_SIZE 32
|
||||
#define BLOCK_SIZE_BYTES 18
|
||||
@@ -812,6 +844,520 @@ fn main(
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_IQ1_S
|
||||
#define BLOCK_SIZE 256
|
||||
#define BLOCK_SIZE_BYTES 50
|
||||
#define THREADS_PER_BLOCK 16
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
let block_group = thread_id / THREADS_PER_BLOCK;
|
||||
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
|
||||
|
||||
let sub_blk = tid / 2u;
|
||||
let half = tid % 2u;
|
||||
let slot0 = half * 2u;
|
||||
let y_offset = sub_blk * 32u + slot0 * 8u;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = block_group; block < num_blocks; block += num_block_groups) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
|
||||
var x_block: array<f32, 16>;
|
||||
for (var i = 0u; i < 16u; i++) {
|
||||
x_block[i] = f32(src1[x_base + i]);
|
||||
}
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = f32(load_f16_at_src0(block_byte_base));
|
||||
let qh = load_u32_at_src0(block_byte_base + 34u + sub_blk * 2u) & 0xFFFFu;
|
||||
let dl = d * f32(2u * ((qh >> 12u) & 7u) + 1u);
|
||||
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u);
|
||||
let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u);
|
||||
|
||||
var row_sum = 0.0;
|
||||
for (var ll = 0u; ll < 2u; ll++) {
|
||||
let l = slot0 + ll;
|
||||
let qs_byte = get_byte(qs_w, l);
|
||||
let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u;
|
||||
let gw = iq1_grid[ig / 16u];
|
||||
let bit_base = (ig % 16u) * 2u;
|
||||
for (var j = 0u; j < 8u; j++) {
|
||||
let g = (gw >> (bit_base + j * 2u)) & 3u;
|
||||
let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u);
|
||||
row_sum += dl * (gs + delta) * x_block[ll * 8u + j];
|
||||
}
|
||||
}
|
||||
acc[row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_IQ1_M
|
||||
#define BLOCK_SIZE 256
|
||||
#define BLOCK_SIZE_BYTES 56
|
||||
#define THREADS_PER_BLOCK 16
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
let block_group = thread_id / THREADS_PER_BLOCK;
|
||||
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
|
||||
|
||||
let sub_blk = tid / 2u;
|
||||
let half = tid % 2u;
|
||||
let slot0 = half * 2u;
|
||||
let y_offset = sub_blk * 32u + slot0 * 8u;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = block_group; block < num_blocks; block += num_block_groups) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
|
||||
var x_block: array<f32, 16>;
|
||||
for (var i = 0u; i < 16u; i++) {
|
||||
x_block[i] = f32(src1[x_base + i]);
|
||||
}
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
|
||||
let sc_lo = load_u32_at_src0(block_byte_base + 48u);
|
||||
let sc_hi = load_u32_at_src0(block_byte_base + 52u);
|
||||
let sc0 = sc_lo & 0xFFFFu;
|
||||
let sc1 = (sc_lo >> 16u) & 0xFFFFu;
|
||||
let sc2 = sc_hi & 0xFFFFu;
|
||||
let sc3 = (sc_hi >> 16u) & 0xFFFFu;
|
||||
let d_bits = (sc0 >> 12u) | ((sc1 >> 8u) & 0xF0u) | ((sc2 >> 4u) & 0xF00u) | (sc3 & 0xF000u);
|
||||
let d = f32(bitcast<vec2<f16>>(d_bits)[0]);
|
||||
|
||||
let sc_u16 = select(select(sc2, sc3, sub_blk >= 6u),
|
||||
select(sc0, sc1, sub_blk >= 2u),
|
||||
sub_blk < 4u);
|
||||
|
||||
let qs_w = load_u32_at_src0(block_byte_base + sub_blk * 4u);
|
||||
let qh = load_u32_at_src0(block_byte_base + 32u + sub_blk * 2u) & 0xFFFFu;
|
||||
let qh_lo = qh & 0xFFu;
|
||||
let qh_hi = (qh >> 8u) & 0xFFu;
|
||||
|
||||
var row_sum = 0.0;
|
||||
for (var ll = 0u; ll < 2u; ll++) {
|
||||
let l = slot0 + ll;
|
||||
let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u);
|
||||
let sub_scale = (sc_u16 >> bit_off) & 0x7u;
|
||||
let dl = d * f32(2u * sub_scale + 1u);
|
||||
let qh_byte = select(qh_lo, qh_hi, l >= 2u);
|
||||
let ll2 = l % 2u;
|
||||
let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u);
|
||||
let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u);
|
||||
let ig = grid_idx * 8u;
|
||||
let gw = iq1_grid[ig / 16u];
|
||||
let bit_base = (ig % 16u) * 2u;
|
||||
for (var j = 0u; j < 8u; j++) {
|
||||
let g = (gw >> (bit_base + j * 2u)) & 3u;
|
||||
let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u);
|
||||
row_sum += dl * (gs + delta) * x_block[ll * 8u + j];
|
||||
}
|
||||
}
|
||||
acc[row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_IQ2_XXS
|
||||
#define BLOCK_SIZE 256
|
||||
#define BLOCK_SIZE_BYTES 66
|
||||
#define THREADS_PER_BLOCK 16
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
let block_group = thread_id / THREADS_PER_BLOCK;
|
||||
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
|
||||
|
||||
let sub_blk = tid / 2u;
|
||||
let half = tid % 2u;
|
||||
let slot0 = half * 2u;
|
||||
let y_offset = sub_blk * 32u + slot0 * 8u;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = block_group; block < num_blocks; block += num_block_groups) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
|
||||
var x_block: array<f32, 16>;
|
||||
for (var i = 0u; i < 16u; i++) {
|
||||
x_block[i] = f32(src1[x_base + i]);
|
||||
}
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
let d = f32(load_f16_at_src0(block_byte_base));
|
||||
let aux_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
|
||||
let aux_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
|
||||
let ls = aux_hi >> 28u;
|
||||
let db = d * (0.5 + f32(ls)) * 0.25;
|
||||
|
||||
var row_sum = 0.0;
|
||||
for (var ll = 0u; ll < 2u; ll++) {
|
||||
let l = slot0 + ll;
|
||||
let grid_idx = (aux_lo >> (8u * l)) & 0xFFu;
|
||||
let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu;
|
||||
let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu;
|
||||
let gw_lo = iq2xxs_grid[grid_idx * 2u];
|
||||
let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u];
|
||||
for (var j = 0u; j < 8u; j++) {
|
||||
let gw = select(gw_hi, gw_lo, j < 4u);
|
||||
let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu);
|
||||
let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u);
|
||||
row_sum += db * b * s * x_block[ll * 8u + j];
|
||||
}
|
||||
}
|
||||
acc[row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_IQ2_XS
|
||||
#define BLOCK_SIZE 256
|
||||
#define BLOCK_SIZE_BYTES 74
|
||||
#define THREADS_PER_BLOCK 16
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
let block_group = thread_id / THREADS_PER_BLOCK;
|
||||
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
|
||||
|
||||
let sub_blk = tid / 2u;
|
||||
let half = tid % 2u;
|
||||
let slot0 = half * 2u;
|
||||
let y_offset = sub_blk * 32u + slot0 * 8u;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = block_group; block < num_blocks; block += num_block_groups) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
|
||||
var x_block: array<f32, 16>;
|
||||
for (var i = 0u; i < 16u; i++) {
|
||||
x_block[i] = f32(src1[x_base + i]);
|
||||
}
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
let d = f32(load_f16_at_src0(block_byte_base));
|
||||
let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
|
||||
let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
|
||||
let scales_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u);
|
||||
let scales_byte = get_byte(scales_word, sub_blk % 4u);
|
||||
|
||||
var row_sum = 0.0;
|
||||
for (var ll = 0u; ll < 2u; ll++) {
|
||||
let l = slot0 + ll;
|
||||
let qs_word = select(qs_hi, qs_lo, l < 2u);
|
||||
let half2 = (l % 2u) * 16u;
|
||||
let qs_val = (qs_word >> half2) & 0xFFFFu;
|
||||
let grid_idx = qs_val & 0x1FFu;
|
||||
let signs_idx = (qs_val >> 9u) & 0x7Fu;
|
||||
let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu;
|
||||
let db = d * (0.5 + f32(sub_scale)) * 0.25;
|
||||
let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu;
|
||||
let gw_lo = iq2xs_grid[grid_idx * 2u];
|
||||
let gw_hi = iq2xs_grid[grid_idx * 2u + 1u];
|
||||
for (var j = 0u; j < 8u; j++) {
|
||||
let gw = select(gw_hi, gw_lo, j < 4u);
|
||||
let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu);
|
||||
let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u);
|
||||
row_sum += db * b * s * x_block[ll * 8u + j];
|
||||
}
|
||||
}
|
||||
acc[row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_IQ2_S
|
||||
#define BLOCK_SIZE 256
|
||||
#define BLOCK_SIZE_BYTES 82
|
||||
#define THREADS_PER_BLOCK 16
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
let block_group = thread_id / THREADS_PER_BLOCK;
|
||||
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
|
||||
|
||||
let sub_blk = tid / 2u;
|
||||
let half = tid % 2u;
|
||||
let slot0 = half * 2u;
|
||||
let y_offset = sub_blk * 32u + slot0 * 8u;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = block_group; block < num_blocks; block += num_block_groups) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
|
||||
var x_block: array<f32, 16>;
|
||||
for (var i = 0u; i < 16u; i++) {
|
||||
x_block[i] = f32(src1[x_base + i]);
|
||||
}
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
let d = f32(load_f16_at_src0(block_byte_base));
|
||||
let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u);
|
||||
let sg_w = load_u32_at_src0(block_byte_base + 34u + sub_blk * 4u);
|
||||
let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u);
|
||||
let qh_byte = get_byte(qh_word, sub_blk % 4u);
|
||||
let sc_word = load_u32_at_src0(block_byte_base + 74u + (sub_blk / 4u) * 4u);
|
||||
let scales_byte = get_byte(sc_word, sub_blk % 4u);
|
||||
|
||||
var row_sum = 0.0;
|
||||
for (var ll = 0u; ll < 2u; ll++) {
|
||||
let l = slot0 + ll;
|
||||
let qs_byte = get_byte(qs_w, l);
|
||||
let sign_byte = get_byte(sg_w, l);
|
||||
let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u);
|
||||
let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu;
|
||||
let db = d * (0.5 + f32(sub_scale)) * 0.25;
|
||||
let gw_lo = iq2s_grid[grid_idx * 2u];
|
||||
let gw_hi = iq2s_grid[grid_idx * 2u + 1u];
|
||||
for (var j = 0u; j < 8u; j++) {
|
||||
let gw = select(gw_hi, gw_lo, j < 4u);
|
||||
let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu);
|
||||
let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u);
|
||||
row_sum += db * b * s * x_block[ll * 8u + j];
|
||||
}
|
||||
}
|
||||
acc[row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_IQ3_XXS
|
||||
#define BLOCK_SIZE 256
|
||||
#define BLOCK_SIZE_BYTES 98
|
||||
#define THREADS_PER_BLOCK 16
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
let block_group = thread_id / THREADS_PER_BLOCK;
|
||||
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
|
||||
|
||||
let sub_blk = tid / 2u;
|
||||
let half = tid % 2u;
|
||||
let slot0 = half * 2u;
|
||||
let y_offset = sub_blk * 32u + slot0 * 8u;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = block_group; block < num_blocks; block += num_block_groups) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
|
||||
var x_block: array<f32, 16>;
|
||||
for (var i = 0u; i < 16u; i++) {
|
||||
x_block[i] = f32(src1[x_base + i]);
|
||||
}
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
let d = f32(load_f16_at_src0(block_byte_base));
|
||||
let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
|
||||
let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
|
||||
let aux = load_u32_at_src0(block_byte_base + 66u + sub_blk * 4u);
|
||||
let ls = aux >> 28u;
|
||||
let db = d * (0.5 + f32(ls)) * 0.5;
|
||||
|
||||
var row_sum = 0.0;
|
||||
for (var ll = 0u; ll < 2u; ll++) {
|
||||
let l = slot0 + ll;
|
||||
let qs_word = select(qs_hi, qs_lo, l < 2u);
|
||||
let byte_pos = (l % 2u) * 2u;
|
||||
let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu;
|
||||
let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu;
|
||||
let signs_idx = (aux >> (7u * l)) & 0x7Fu;
|
||||
let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu;
|
||||
let grid1 = iq3xxs_grid[grid_idx_0];
|
||||
let grid2 = iq3xxs_grid[grid_idx_1];
|
||||
for (var j = 0u; j < 4u; j++) {
|
||||
let b1 = f32((grid1 >> (j * 8u)) & 0xFFu);
|
||||
let b2 = f32((grid2 >> (j * 8u)) & 0xFFu);
|
||||
let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u);
|
||||
let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u);
|
||||
row_sum += db * b1 * s1 * x_block[ll * 8u + j];
|
||||
row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u];
|
||||
}
|
||||
}
|
||||
acc[row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_IQ3_S
|
||||
#define BLOCK_SIZE 256
|
||||
#define BLOCK_SIZE_BYTES 110
|
||||
#define THREADS_PER_BLOCK 16
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
let block_group = thread_id / THREADS_PER_BLOCK;
|
||||
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
|
||||
|
||||
let sub_blk = tid / 2u;
|
||||
let half = tid % 2u;
|
||||
let slot0 = half * 2u;
|
||||
let y_offset = sub_blk * 32u + slot0 * 8u;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = block_group; block < num_blocks; block += num_block_groups) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
|
||||
var x_block: array<f32, 16>;
|
||||
for (var i = 0u; i < 16u; i++) {
|
||||
x_block[i] = f32(src1[x_base + i]);
|
||||
}
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
let d = f32(load_f16_at_src0(block_byte_base));
|
||||
let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
|
||||
let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
|
||||
let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u);
|
||||
let qh_byte = get_byte(qh_word, sub_blk % 4u);
|
||||
let sg_w = load_u32_at_src0(block_byte_base + 74u + sub_blk * 4u);
|
||||
let sc_word = load_u32_at_src0(block_byte_base + 106u);
|
||||
let scales_byte = get_byte(sc_word, sub_blk / 2u);
|
||||
let sub_scale = (scales_byte >> (4u * (sub_blk % 2u))) & 0xFu;
|
||||
let db = d * (1.0 + 2.0 * f32(sub_scale));
|
||||
|
||||
var row_sum = 0.0;
|
||||
for (var ll = 0u; ll < 2u; ll++) {
|
||||
let l = slot0 + ll;
|
||||
let qs_word = select(qs_hi, qs_lo, l < 2u);
|
||||
let byte_pos = (l % 2u) * 2u;
|
||||
let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu;
|
||||
let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu;
|
||||
let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u);
|
||||
let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u);
|
||||
let sign_byte = get_byte(sg_w, l);
|
||||
let grid1 = iq3s_grid[grid_idx_1];
|
||||
let grid2 = iq3s_grid[grid_idx_2];
|
||||
for (var j = 0u; j < 4u; j++) {
|
||||
let b1 = f32((grid1 >> (j * 8u)) & 0xFFu);
|
||||
let b2 = f32((grid2 >> (j * 8u)) & 0xFFu);
|
||||
let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u);
|
||||
let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u);
|
||||
row_sum += db * b1 * s1 * x_block[ll * 8u + j];
|
||||
row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u];
|
||||
}
|
||||
}
|
||||
acc[row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_IQ4_NL
|
||||
#define BLOCK_SIZE 32
|
||||
#define BLOCK_SIZE_BYTES 18
|
||||
#define THREADS_PER_BLOCK 4
|
||||
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
let thread_within_block = thread_id % THREADS_PER_BLOCK;
|
||||
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4u;
|
||||
var x_block: array<f32, ELEMS_PER_THREAD>;
|
||||
for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) {
|
||||
x_block[i] = f32(src1[x_base + i]);
|
||||
x_block[i + 4u] = f32(src1[x_base + i + 16u]);
|
||||
}
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
let d = f32(load_f16_at_src0(block_byte_base));
|
||||
var row_sum = 0.0;
|
||||
|
||||
let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block);
|
||||
for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
|
||||
let q_byte = get_byte(q_packed, byte_idx);
|
||||
let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d;
|
||||
let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d;
|
||||
row_sum += q_lo * x_block[byte_idx];
|
||||
row_sum += q_hi * x_block[byte_idx + 4u];
|
||||
}
|
||||
acc[row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_IQ4_XS
|
||||
#define BLOCK_SIZE 256
|
||||
#define BLOCK_SIZE_BYTES 136
|
||||
#define THREADS_PER_BLOCK 16
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
let block_group = thread_id / THREADS_PER_BLOCK;
|
||||
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
|
||||
|
||||
let sub_blk = tid / 2u;
|
||||
let half = tid % 2u;
|
||||
let y_offset = sub_blk * 32u + half * 16u;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = block_group; block < num_blocks; block += num_block_groups) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
|
||||
var x_block: array<f32, 16>;
|
||||
for (var i = 0u; i < 16u; i++) {
|
||||
x_block[i] = f32(src1[x_base + i]);
|
||||
}
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
let d = f32(load_f16_at_src0(block_byte_base));
|
||||
let scales_h = load_u16_at_src0(block_byte_base + 2u);
|
||||
let scales_l_word = load_u32_at_src0(block_byte_base + 4u);
|
||||
let sl_byte = get_byte(scales_l_word, sub_blk / 2u);
|
||||
let sl = (sl_byte >> (4u * (sub_blk % 2u))) & 0xFu;
|
||||
let sh_bits = (scales_h >> (2u * sub_blk)) & 3u;
|
||||
let ls = i32(sl | (sh_bits << 4u));
|
||||
let dl = d * f32(ls - 32);
|
||||
|
||||
let qs_byte_off = 8u + sub_blk * 16u;
|
||||
let q_w0 = load_u32_at_src0(block_byte_base + qs_byte_off);
|
||||
let q_w1 = load_u32_at_src0(block_byte_base + qs_byte_off + 4u);
|
||||
let q_w2 = load_u32_at_src0(block_byte_base + qs_byte_off + 8u);
|
||||
let q_w3 = load_u32_at_src0(block_byte_base + qs_byte_off + 12u);
|
||||
|
||||
var row_sum = 0.0;
|
||||
for (var i = 0u; i < 16u; i++) {
|
||||
let q_word = select(
|
||||
select(q_w0, q_w1, i >= 4u),
|
||||
select(q_w2, q_w3, i >= 12u),
|
||||
i >= 8u);
|
||||
let q_byte = get_byte(q_word, i % 4u);
|
||||
let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u);
|
||||
row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[i];
|
||||
}
|
||||
acc[row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let subgroup_total = subgroupAdd(acc[row]);
|
||||
|
||||
@@ -66,8 +66,6 @@ fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32)
|
||||
struct Params {
|
||||
offset_rn_src: u32,
|
||||
offset_mul_src: u32,
|
||||
offset_merged_rn_src: u32,
|
||||
offset_merged_mul_src: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
stride_rn_src1: u32,
|
||||
@@ -107,8 +105,8 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||
i = i % (params.ne2 * params.ne1);
|
||||
let i2 = i / params.ne1;
|
||||
let i1 = i % params.ne1;
|
||||
let i_rn_src_row = params.offset_rn_src + params.offset_merged_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1;
|
||||
let i_mul_src_row = params.offset_mul_src + params.offset_merged_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1;
|
||||
let i_rn_src_row = params.offset_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1;
|
||||
let i_mul_src_row = params.offset_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1;
|
||||
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
|
||||
|
||||
let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
|
||||
|
||||
@@ -45,6 +45,14 @@ struct Params {
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s_in: array<f32>;
|
||||
#ifdef XBC_OVERLAP
|
||||
@group(0) @binding(1) var<storage, read_write> x_B_C_merged: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> dt: array<f32>;
|
||||
@group(0) @binding(3) var<storage, read_write> A: array<f32>;
|
||||
@group(0) @binding(4) var<storage, read_write> ids: array<i32>;
|
||||
@group(0) @binding(5) var<storage, read_write> dst: array<f32>;
|
||||
@group(0) @binding(6) var<uniform> params: Params;
|
||||
#else
|
||||
@group(0) @binding(1) var<storage, read_write> x: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> dt: array<f32>;
|
||||
@group(0) @binding(3) var<storage, read_write> A: array<f32>;
|
||||
@@ -53,6 +61,7 @@ struct Params {
|
||||
@group(0) @binding(6) var<storage, read_write> ids: array<i32>;
|
||||
@group(0) @binding(7) var<storage, read_write> dst: array<f32>;
|
||||
@group(0) @binding(8) var<uniform> params: Params;
|
||||
#endif
|
||||
|
||||
var<workgroup> shared_x_dt: array<f32, TOKENS_PER_TILE>;
|
||||
var<workgroup> shared_dtsp: array<f32, TOKENS_PER_TILE>;
|
||||
@@ -98,7 +107,11 @@ fn main(
|
||||
let dt0 = dt[dt_idx];
|
||||
let dtsp = select(log(1.0 + exp(dt0)), dt0, dt0 > 20.0);
|
||||
shared_dtsp[tid] = dtsp;
|
||||
#ifdef XBC_OVERLAP
|
||||
shared_x_dt[tid] = x_B_C_merged[x_idx] * dtsp;
|
||||
#else
|
||||
shared_x_dt[tid] = x[x_idx] * dtsp;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,16 +129,28 @@ fn main(
|
||||
|
||||
let b_idx = params.offset_B + tid + g * params.stride_B1 + token * params.stride_B2 + i3 * params.stride_B3;
|
||||
let c_idx = params.offset_C + tid + g * params.stride_C1 + token * params.stride_C2 + i3 * params.stride_C3;
|
||||
#ifdef XBC_OVERLAP
|
||||
let s = s_prev * dA + x_B_C_merged[b_idx] * x_dt;
|
||||
#else
|
||||
let s = s_prev * dA + B[b_idx] * x_dt;
|
||||
#endif
|
||||
s_prev = s;
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
#ifdef XBC_OVERLAP
|
||||
let subgroup_partial = subgroupAdd(s * x_B_C_merged[c_idx]);
|
||||
#else
|
||||
let subgroup_partial = subgroupAdd(s * C[c_idx]);
|
||||
#endif
|
||||
if (subgroup_invocation_id == 0u) {
|
||||
shared_reduce[reduce_idx - tid + subgroup_id] = subgroup_partial;
|
||||
}
|
||||
#else
|
||||
#ifdef XBC_OVERLAP
|
||||
shared_reduce[reduce_idx] = s * x_B_C_merged[c_idx];
|
||||
#else
|
||||
shared_reduce[reduce_idx] = s * C[c_idx];
|
||||
#endif
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
@@ -72,9 +72,6 @@ llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_gra
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
if (model.layers[il].wo_s) {
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].wo_s);
|
||||
}
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
|
||||
@@ -58,9 +58,6 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
if (model.layers[il].wo_s) {
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].wo_s);
|
||||
}
|
||||
}
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
|
||||
@@ -58,9 +58,6 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
if (model.layers[il].wo_s) {
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].wo_s);
|
||||
}
|
||||
}
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
|
||||
@@ -40,8 +40,12 @@ int main(void) {
|
||||
}
|
||||
}
|
||||
|
||||
// exclude spec args from this check
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/22397
|
||||
const bool skip = opt.is_spec;
|
||||
|
||||
// ensure shorter argument precedes longer argument
|
||||
if (opt.args.size() > 1) {
|
||||
if (!skip && opt.args.size() > 1) {
|
||||
const std::string first(opt.args.front());
|
||||
const std::string last(opt.args.back());
|
||||
|
||||
@@ -124,9 +128,9 @@ int main(void) {
|
||||
assert(params.n_batch == 9090);
|
||||
|
||||
// --draft cannot be used outside llama-speculative
|
||||
argv = {"binary_name", "--draft", "123"};
|
||||
argv = {"binary_name", "--spec-draft-n-max", "123"};
|
||||
assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_SPECULATIVE));
|
||||
assert(params.speculative.n_max == 123);
|
||||
assert(params.speculative.draft.n_max == 123);
|
||||
|
||||
// multi-value args (CSV)
|
||||
argv = {"binary_name", "--lora", "file1.gguf,\"file2,2.gguf\",\"file3\"\"3\"\".gguf\",file4\".gguf"};
|
||||
|
||||
@@ -2984,7 +2984,7 @@ struct test_bin_bcast : public test_case {
|
||||
bool run_whole_graph() override { return nf > 1; }
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR5(type, ne, nr, nf, perm1);
|
||||
return VARS_TO_STR6(type, ne, nr, nf, perm1, src_overlap);
|
||||
}
|
||||
|
||||
size_t op_size(ggml_tensor * t) override {
|
||||
@@ -3589,9 +3589,10 @@ struct test_ssm_scan : public test_case {
|
||||
const int64_t n_group;
|
||||
const int64_t n_seq_tokens;
|
||||
const int64_t n_seqs;
|
||||
const bool xbc_overlap;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR7(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs);
|
||||
return VARS_TO_STR8(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs, xbc_overlap);
|
||||
}
|
||||
|
||||
test_ssm_scan(ggml_type type = GGML_TYPE_F32,
|
||||
@@ -3600,16 +3601,31 @@ struct test_ssm_scan : public test_case {
|
||||
int64_t n_head = 32,
|
||||
int64_t n_group = 1,
|
||||
int64_t n_seq_tokens = 32,
|
||||
int64_t n_seqs = 32)
|
||||
: type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
|
||||
int64_t n_seqs = 32,
|
||||
bool xbc_overlap = false)
|
||||
: type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs), xbc_overlap(xbc_overlap) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, head_dim, n_head, n_seqs);
|
||||
ggml_tensor * x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (head_dim > 1) ? 1 : d_state, n_head);
|
||||
ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * x;
|
||||
ggml_tensor * B;
|
||||
ggml_tensor * C;
|
||||
|
||||
if (xbc_overlap) {
|
||||
ggml_tensor * xbc = ggml_new_tensor_4d(ctx, type, d_state, n_head, n_seq_tokens, 2 * n_seqs);
|
||||
x = ggml_view_4d(ctx, xbc, head_dim, n_head, n_seq_tokens, n_seqs,
|
||||
xbc->nb[1], xbc->nb[2], xbc->nb[3], xbc->nb[3]);
|
||||
B = ggml_view_4d(ctx, xbc, d_state, n_group, n_seq_tokens, n_seqs,
|
||||
xbc->nb[1], xbc->nb[2], xbc->nb[3], 0);
|
||||
C = ggml_view_4d(ctx, xbc, d_state, n_group, n_seq_tokens, n_seqs,
|
||||
xbc->nb[1], xbc->nb[2], xbc->nb[3], 2 * xbc->nb[3]);
|
||||
} else {
|
||||
x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs);
|
||||
B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
|
||||
C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
|
||||
}
|
||||
ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs);
|
||||
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, ids);
|
||||
return out;
|
||||
@@ -3799,7 +3815,7 @@ struct test_mul_mat : public test_case {
|
||||
|
||||
double max_nmse_err(ggml_backend_t backend) override {
|
||||
// for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance
|
||||
if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
|
||||
if ((type_a == GGML_TYPE_MXFP4 || type_a == GGML_TYPE_NVFP4) && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
|
||||
return 2e-2;
|
||||
}
|
||||
return max_nmse_err();
|
||||
@@ -3935,7 +3951,7 @@ struct test_mul_mat_id : public test_case {
|
||||
|
||||
double max_nmse_err(ggml_backend_t backend) override {
|
||||
// for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance
|
||||
if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
|
||||
if ((type_a == GGML_TYPE_MXFP4 || type_a == GGML_TYPE_NVFP4) && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
|
||||
return 2e-2;
|
||||
}
|
||||
return max_nmse_err();
|
||||
@@ -7964,6 +7980,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 256, 64, 8, 2, 32, 4)); // Falcon-H1
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 128, 4, 4, 16, 2, true)); // x/B/C overlap
|
||||
|
||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
|
||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));
|
||||
|
||||
@@ -2249,6 +2249,46 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect(message_assist)
|
||||
.run();
|
||||
|
||||
{
|
||||
// additional tests for https://github.com/ggml-org/llama.cpp/pull/21760
|
||||
auto tmpls = read_templates("models/templates/google-gemma-4-31B-it.jinja");
|
||||
|
||||
common_chat_msg tool_call_msg = simple_assist_msg(
|
||||
"Let me check.", "", "special_function", "{\"arg1\": 1}","c0");
|
||||
|
||||
common_chat_msg tool_msg;
|
||||
tool_msg.role = "tool";
|
||||
tool_msg.tool_name = "special_function";
|
||||
tool_msg.tool_call_id = "c0";
|
||||
tool_msg.content = "{\"r\":\"ok\"}";
|
||||
|
||||
{
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.messages = { message_user, tool_call_msg, tool_msg };
|
||||
inputs.tools = { special_function_tool };
|
||||
inputs.add_generation_prompt = true;
|
||||
|
||||
auto params = common_chat_templates_apply(tmpls.get(), inputs);
|
||||
|
||||
if (!string_ends_with(params.prompt, "<turn|>\n<|turn>model\n")) {
|
||||
throw std::runtime_error("Missing generation prompt for Gemma 4");
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.messages = { message_user, tool_call_msg, tool_msg };
|
||||
inputs.tools = { special_function_tool };
|
||||
inputs.add_generation_prompt = false;
|
||||
|
||||
auto params = common_chat_templates_apply(tmpls.get(), inputs);
|
||||
|
||||
if (string_ends_with(params.prompt, "<|turn>model\n")) {
|
||||
throw std::runtime_error("Gemma 4: generation prompt was modified despite add_generation_prompt=false");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
|
||||
@@ -35,5 +35,9 @@ int main() {
|
||||
threads[i].join();
|
||||
}
|
||||
|
||||
common_log_flush(common_log_main());
|
||||
// We explicitly free the logger singleton to avoid hanging on Windows
|
||||
// related to timing issues of thread startup and DLL teardown
|
||||
common_log_free(common_log_main());
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -227,7 +227,30 @@ int main(void) {
|
||||
3); // forcing continues through i=3
|
||||
}
|
||||
|
||||
printf("OK (5 tests passed)\n");
|
||||
// Test 6: Multi-block thinking. First block ends naturally at i=2, second
|
||||
// start tag at i=3 re-arms the budget, which then exhausts at i=5.
|
||||
// Regression: before this fix, DONE absorbed all subsequent tokens and a
|
||||
// second <think> block ran unbudgeted.
|
||||
// Flow: i=0 accept(100)->COUNTING rem=2; i=1 accept(50)->rem=1;
|
||||
// i=2 accept(101)->end_matcher matches, DONE;
|
||||
// i=3 accept(100)->re-arm, COUNTING rem=2;
|
||||
// i=4 accept(60)->rem=1; i=5 accept(61)->rem=0->FORCING;
|
||||
// i=6 apply()->forces token[0]=102, accept(62)->force_pos=1, stay FORCING;
|
||||
// i=7 apply()->forces token[1]=101, accept(63)->force_pos=2->DONE.
|
||||
{
|
||||
const std::vector<llama_token> start = {100};
|
||||
const std::vector<llama_token> end = {101};
|
||||
const std::vector<llama_token> forced = {102, 101};
|
||||
const std::vector<llama_token> sequence = {100, 50, 101, 100, 60, 61, 62, 63};
|
||||
|
||||
test_reasoning_budget("multi-block re-arms budget after DONE", sequence, start, end, forced,
|
||||
2, // budget of 2 tokens (per block)
|
||||
REASONING_BUDGET_IDLE,
|
||||
6, // forcing starts at i=6 (after second block exhausts at i=5)
|
||||
7); // forcing continues through i=7
|
||||
}
|
||||
|
||||
printf("OK (6 tests passed)\n");
|
||||
|
||||
printf("Testing UTF-8 boundary detection... ");
|
||||
test_utf8_boundary_detection();
|
||||
|
||||
@@ -372,7 +372,7 @@ static const cmd_params cmd_params_defaults = {
|
||||
/* n_ubatch */ { 512 },
|
||||
/* type_k */ { GGML_TYPE_F16 },
|
||||
/* type_v */ { GGML_TYPE_F16 },
|
||||
/* n_threads */ { cpu_get_num_math() },
|
||||
/* n_threads */ { common_cpu_get_num_math() },
|
||||
/* cpu_mask */ { "0x0" },
|
||||
/* cpu_strict */ { false },
|
||||
/* poll */ { 50 },
|
||||
|
||||
@@ -317,7 +317,7 @@ int main(int argc, char * argv[]) {
|
||||
const char * cache_dir = nullptr;
|
||||
std::string cache_dir_str;
|
||||
if (params.use_cache) {
|
||||
cache_dir_str = fs_get_cache_directory() + "rpc/";
|
||||
cache_dir_str = fs_get_cache_directory() + "rpc" + DIRECTORY_SEPARATOR;
|
||||
if (!fs_create_directory_with_parents(cache_dir_str)) {
|
||||
fprintf(stderr, "Failed to create cache directory: %s\n", cache_dir_str.c_str());
|
||||
return 1;
|
||||
|
||||
File diff suppressed because one or more lines are too long
+5618
-5478
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@@ -575,14 +575,14 @@ json server_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
|
||||
json convert_transcriptions_to_chatcmpl(
|
||||
const json & inp_body,
|
||||
const common_chat_templates * tmpls,
|
||||
const std::map<std::string, raw_buffer> & in_files,
|
||||
const std::map<std::string, uploaded_file> & in_files,
|
||||
std::vector<raw_buffer> & out_files) {
|
||||
// TODO @ngxson : this function may need to be improved in the future
|
||||
// handle input files
|
||||
out_files.clear();
|
||||
auto it = in_files.find("file");
|
||||
if (it != in_files.end()) {
|
||||
out_files.push_back(it->second);
|
||||
out_files.push_back(it->second.data);
|
||||
} else {
|
||||
throw std::invalid_argument("No input file found for transcription");
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "chat.h"
|
||||
#include "server-common.h"
|
||||
#include "server-http.h"
|
||||
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
@@ -19,7 +20,7 @@ json server_chat_convert_anthropic_to_oai(const json & body);
|
||||
json convert_transcriptions_to_chatcmpl(
|
||||
const json & body,
|
||||
const common_chat_templates * tmpls,
|
||||
const std::map<std::string, raw_buffer> & in_files,
|
||||
const std::map<std::string, uploaded_file> & in_files,
|
||||
std::vector<raw_buffer> & out_files);
|
||||
|
||||
json server_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|
||||
|
||||
@@ -309,8 +309,10 @@ struct server_slot {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int n_draft_min = common_speculative_n_min(spec.get(), task->params.speculative);
|
||||
|
||||
// determine the max draft that fits the current slot state
|
||||
int n_draft_max = task->params.speculative.n_max;
|
||||
int n_draft_max = common_speculative_n_max(spec.get(), task->params.speculative);
|
||||
|
||||
// note: slot.prompt is not yet expanded with the `id` token sampled above
|
||||
// also, need to leave space for 1 extra token to allow context shifts
|
||||
@@ -322,8 +324,8 @@ struct server_slot {
|
||||
|
||||
SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);
|
||||
|
||||
if (n_draft_max < task->params.speculative.n_min) {
|
||||
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min);
|
||||
if (n_draft_max < n_draft_min) {
|
||||
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, n_draft_min);
|
||||
n_draft_max = 0;
|
||||
}
|
||||
|
||||
@@ -358,11 +360,6 @@ struct server_slot {
|
||||
spec_draft.resize(n_draft_max);
|
||||
}
|
||||
|
||||
if (spec_draft.size() < (size_t) params_spec.n_min) {
|
||||
SLT_DBG(*this, "ignoring small draft: %d < %d\n", (int) spec_draft.size(), params_spec.n_min);
|
||||
spec_draft.clear();
|
||||
}
|
||||
|
||||
if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
||||
const auto n_tokens = prompt.tokens.size();
|
||||
|
||||
@@ -770,9 +767,9 @@ private:
|
||||
|
||||
if (params_base.speculative.has_dft()) {
|
||||
// TODO speculative: move to common/speculative.cpp?
|
||||
SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
|
||||
const auto & params_spec = params_base.speculative.draft;
|
||||
|
||||
const auto & params_spec = params_base.speculative;
|
||||
SRV_INF("loading draft model '%s'\n", params_spec.mparams.path.c_str());
|
||||
|
||||
auto params_dft = params_base;
|
||||
|
||||
@@ -780,7 +777,7 @@ private:
|
||||
params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx;
|
||||
params_dft.n_batch = llama_n_ctx_seq(ctx);
|
||||
params_dft.devices = params_spec.devices;
|
||||
params_dft.model = params_spec.mparams_dft;
|
||||
params_dft.model = params_spec.mparams;
|
||||
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
|
||||
params_dft.cache_type_k = params_spec.cache_type_k;
|
||||
params_dft.cache_type_v = params_spec.cache_type_v;
|
||||
@@ -800,8 +797,8 @@ private:
|
||||
return false;
|
||||
}
|
||||
|
||||
params_base.speculative.model_dft = model_dft.get();
|
||||
params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft);
|
||||
params_base.speculative.draft.model = model_dft.get();
|
||||
params_base.speculative.draft.cparams = common_context_params_to_llama(params_dft);
|
||||
}
|
||||
|
||||
std::string & mmproj_path = params_base.mmproj.path;
|
||||
@@ -1310,7 +1307,7 @@ private:
|
||||
backend_sampling &= task.params.sampling.backend_sampling;
|
||||
|
||||
// TODO: speculative decoding requires multiple samples per batch - not supported yet
|
||||
backend_sampling &= !(slot.can_speculate() && task.params.speculative.n_max > 0);
|
||||
backend_sampling &= !(slot.can_speculate() && common_speculative_n_max(slot.spec.get(), task.params.speculative) > 0);
|
||||
|
||||
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
|
||||
backend_sampling &= !need_logits;
|
||||
@@ -3031,7 +3028,7 @@ private:
|
||||
slot.sampled = ids.back(); // last accepted token
|
||||
SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, slot.prompt.n_tokens(), -1);
|
||||
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, slot.prompt.tokens.pos_next(), -1);
|
||||
|
||||
for (size_t i = 0; i < ids.size(); ++i) {
|
||||
completion_token_output result;
|
||||
|
||||
@@ -49,6 +49,7 @@ static server_http_res_ptr proxy_request(const server_http_req & req, std::strin
|
||||
parsed_url.path,
|
||||
headers,
|
||||
req.body,
|
||||
req.files,
|
||||
req.should_stop,
|
||||
600, // timeout_read (default to 10 minutes)
|
||||
600 // timeout_write (default to 10 minutes)
|
||||
|
||||
@@ -438,7 +438,7 @@ void server_http_context::get(const std::string & path, const server_http_contex
|
||||
void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const {
|
||||
pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
|
||||
std::string body = req.body;
|
||||
std::map<std::string, raw_buffer> files;
|
||||
std::map<std::string, uploaded_file> files;
|
||||
|
||||
if (req.is_multipart_form_data()) {
|
||||
// translate text fields to a JSON object and use it as the body
|
||||
@@ -459,7 +459,11 @@ void server_http_context::post(const std::string & path, const server_http_conte
|
||||
|
||||
// populate files from multipart form
|
||||
for (const auto & [key, file] : req.form.files) {
|
||||
files[key] = raw_buffer(file.content.begin(), file.content.end());
|
||||
files[key] = uploaded_file{
|
||||
raw_buffer(file.content.begin(), file.content.end()),
|
||||
file.filename,
|
||||
file.content_type,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -36,13 +36,19 @@ struct server_http_res {
|
||||
using server_http_res_ptr = std::unique_ptr<server_http_res>;
|
||||
using raw_buffer = std::vector<uint8_t>;
|
||||
|
||||
struct uploaded_file {
|
||||
raw_buffer data;
|
||||
std::string filename;
|
||||
std::string content_type;
|
||||
};
|
||||
|
||||
struct server_http_req {
|
||||
std::map<std::string, std::string> params; // path_params + query_params
|
||||
std::map<std::string, std::string> headers; // used by MCP proxy
|
||||
std::string path;
|
||||
std::string query_string; // query parameters string (e.g. "action=save")
|
||||
std::string body;
|
||||
std::map<std::string, raw_buffer> files; // used for file uploads (form data)
|
||||
std::map<std::string, uploaded_file> files; // used for file uploads (form data)
|
||||
const std::function<bool()> & should_stop;
|
||||
|
||||
std::string get_param(const std::string & key, const std::string & def = "") const {
|
||||
|
||||
@@ -18,6 +18,8 @@
|
||||
#include <chrono>
|
||||
#include <queue>
|
||||
#include <filesystem>
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
#include <cstring>
|
||||
|
||||
#ifdef _WIN32
|
||||
@@ -823,6 +825,7 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co
|
||||
proxy_path,
|
||||
req.headers,
|
||||
req.body,
|
||||
req.files,
|
||||
req.should_stop,
|
||||
base_params.timeout_read,
|
||||
base_params.timeout_write
|
||||
@@ -1126,6 +1129,77 @@ static bool should_strip_proxy_header(const std::string & header_name) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static std::string generate_multipart_boundary() {
|
||||
thread_local std::mt19937 gen(std::random_device{}());
|
||||
static const char chars[] = "0123456789abcdefghijklmnopqrstuvwxyz";
|
||||
std::uniform_int_distribution<> dis(0, sizeof(chars) - 2);
|
||||
std::string boundary = "----llama-cpp-proxy-";
|
||||
for (int i = 0; i < 16; i++) {
|
||||
boundary += chars[dis(gen)];
|
||||
}
|
||||
return boundary;
|
||||
}
|
||||
|
||||
static std::string build_multipart_body(
|
||||
const json & form_fields,
|
||||
const std::map<std::string, uploaded_file> & files,
|
||||
const std::string & boundary) {
|
||||
static auto sanitize_field = [](const std::string & text) {
|
||||
std::string result;
|
||||
result.reserve(text.size());
|
||||
for (char c : text) {
|
||||
if (c != '\n' && c != '\r' && c != '"') {
|
||||
result += c;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
std::ostringstream body;
|
||||
|
||||
for (const auto & [key, value] : form_fields.items()) {
|
||||
if (value.is_array()) {
|
||||
for (const auto & item : value) {
|
||||
body << "--" << boundary << "\r\n";
|
||||
body << "Content-Disposition: form-data; name=\"" << sanitize_field(key) << "\"\r\n";
|
||||
body << "\r\n";
|
||||
if (!item.is_string()) {
|
||||
throw std::invalid_argument("expected string");
|
||||
}
|
||||
body << item.get<std::string>() << "\r\n";
|
||||
}
|
||||
} else {
|
||||
body << "--" << boundary << "\r\n";
|
||||
body << "Content-Disposition: form-data; name=\"" << sanitize_field(key) << "\"\r\n";
|
||||
body << "\r\n";
|
||||
if (!value.is_string()) {
|
||||
throw std::invalid_argument("expected string");
|
||||
}
|
||||
body << value.get<std::string>() << "\r\n";
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & [key, file] : files) {
|
||||
body << "--" << boundary << "\r\n";
|
||||
body << "Content-Disposition: form-data; name=\"" << sanitize_field(key) << "\"";
|
||||
if (!file.filename.empty()) {
|
||||
body << "; filename=\"" << sanitize_field(file.filename) << "\"";
|
||||
}
|
||||
body << "\r\n";
|
||||
if (!file.content_type.empty()) {
|
||||
body << "Content-Type: " << sanitize_field(file.content_type) << "\r\n";
|
||||
} else {
|
||||
body << "Content-Type: application/octet-stream\r\n";
|
||||
}
|
||||
body << "\r\n";
|
||||
body.write(reinterpret_cast<const char*>(file.data.data()), file.data.size());
|
||||
body << "\r\n";
|
||||
}
|
||||
|
||||
body << "--" << boundary << "--\r\n";
|
||||
return body.str();
|
||||
}
|
||||
|
||||
server_http_proxy::server_http_proxy(
|
||||
const std::string & method,
|
||||
const std::string & scheme,
|
||||
@@ -1134,6 +1208,7 @@ server_http_proxy::server_http_proxy(
|
||||
const std::string & path,
|
||||
const std::map<std::string, std::string> & headers,
|
||||
const std::string & body,
|
||||
const std::map<std::string, uploaded_file> & files,
|
||||
const std::function<bool()> should_stop,
|
||||
int32_t timeout_read,
|
||||
int32_t timeout_write
|
||||
@@ -1195,28 +1270,65 @@ server_http_proxy::server_http_proxy(
|
||||
return pipe->write({{}, 0, std::string(data, data_length), ""});
|
||||
};
|
||||
|
||||
// when files are present, the body was converted from multipart form data to JSON
|
||||
// we need to reconstruct the multipart body for the downstream server
|
||||
std::string effective_body = body;
|
||||
std::string override_content_type;
|
||||
bool has_files = !files.empty();
|
||||
|
||||
if (has_files) {
|
||||
json form_fields = json::parse(body, nullptr, false);
|
||||
if (!form_fields.is_discarded()) {
|
||||
auto boundary = generate_multipart_boundary();
|
||||
effective_body = build_multipart_body(form_fields, files, boundary);
|
||||
override_content_type = "multipart/form-data; boundary=" + boundary;
|
||||
} else {
|
||||
throw std::runtime_error("failed to parse multipart form fields JSON");
|
||||
}
|
||||
}
|
||||
|
||||
// prepare the request to destination server
|
||||
httplib::Request req;
|
||||
{
|
||||
req.method = method;
|
||||
req.path = path;
|
||||
for (const auto & [key, value] : headers) {
|
||||
if (key == "Accept-Encoding") {
|
||||
const auto lowered = to_lower_copy(key);
|
||||
if (lowered == "accept-encoding") {
|
||||
// disable Accept-Encoding to avoid compressed responses
|
||||
continue;
|
||||
}
|
||||
if (key == "Transfer-Encoding") {
|
||||
if (lowered == "transfer-encoding") {
|
||||
// the body is already decoded
|
||||
continue;
|
||||
}
|
||||
if (key == "Host" || key == "host") {
|
||||
if (lowered == "content-length") {
|
||||
// let httplib calculate Content-Length from the actual body
|
||||
continue;
|
||||
}
|
||||
if (lowered == "content-type") {
|
||||
if (has_files) {
|
||||
// we set our own Content-Type with the new boundary
|
||||
continue;
|
||||
}
|
||||
// when no files but the original request was multipart,
|
||||
// the body is now JSON, so correct the Content-Type
|
||||
if (value.find("multipart/form-data") != std::string::npos) {
|
||||
override_content_type = "application/json; charset=utf-8";
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (lowered == "host") {
|
||||
bool is_default_port = (scheme == "https" && port == 443) || (scheme == "http" && port == 80);
|
||||
req.set_header(key, is_default_port ? host : host + ":" + std::to_string(port));
|
||||
} else {
|
||||
req.set_header(key, value);
|
||||
}
|
||||
}
|
||||
req.body = body;
|
||||
req.body = effective_body;
|
||||
if (!override_content_type.empty()) {
|
||||
req.set_header("Content-Type", override_content_type);
|
||||
}
|
||||
req.response_handler = response_handler;
|
||||
req.content_receiver = content_receiver;
|
||||
}
|
||||
|
||||
@@ -202,6 +202,7 @@ public:
|
||||
const std::string & path,
|
||||
const std::map<std::string, std::string> & headers,
|
||||
const std::string & body,
|
||||
const std::map<std::string, uploaded_file> & files,
|
||||
const std::function<bool()> should_stop,
|
||||
int32_t timeout_read,
|
||||
int32_t timeout_write
|
||||
|
||||
@@ -76,13 +76,7 @@ json task_params::to_json(bool only_metrics) const {
|
||||
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
|
||||
{"generation_prompt", chat_parser_params.generation_prompt},
|
||||
{"samplers", samplers},
|
||||
{"speculative.n_max", speculative.n_max},
|
||||
{"speculative.n_min", speculative.n_min},
|
||||
{"speculative.p_min", speculative.p_min},
|
||||
{"speculative.type", common_speculative_type_to_str(speculative.type)},
|
||||
{"speculative.ngram_size_n", speculative.ngram_size_n},
|
||||
{"speculative.ngram_size_m", speculative.ngram_size_m},
|
||||
{"speculative.ngram_m_hits", speculative.ngram_min_hits},
|
||||
{"timings_per_token", timings_per_token},
|
||||
{"post_sampling_probs", post_sampling_probs},
|
||||
{"backend_sampling", sampling.backend_sampling},
|
||||
@@ -139,13 +133,7 @@ json task_params::to_json(bool only_metrics) const {
|
||||
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
|
||||
{"generation_prompt", chat_parser_params.generation_prompt},
|
||||
{"samplers", samplers},
|
||||
{"speculative.n_max", speculative.n_max},
|
||||
{"speculative.n_min", speculative.n_min},
|
||||
{"speculative.p_min", speculative.p_min},
|
||||
{"speculative.type", common_speculative_type_to_str(speculative.type)},
|
||||
{"speculative.ngram_size_n", speculative.ngram_size_n},
|
||||
{"speculative.ngram_size_m", speculative.ngram_size_m},
|
||||
{"speculative.ngram_m_hits", speculative.ngram_min_hits},
|
||||
{"timings_per_token", timings_per_token},
|
||||
{"post_sampling_probs", post_sampling_probs},
|
||||
{"backend_sampling", sampling.backend_sampling},
|
||||
@@ -308,14 +296,17 @@ task_params server_task::params_from_json_cmpl(
|
||||
|
||||
params.speculative = defaults.speculative;
|
||||
|
||||
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
|
||||
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
|
||||
params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
|
||||
// TODO: for now, be able to adjust only the draft-model based speculative parameters
|
||||
params.speculative.draft.n_min = json_value(data, "speculative.n_min", defaults.speculative.draft.n_min);
|
||||
params.speculative.draft.n_max = json_value(data, "speculative.n_max", defaults.speculative.draft.n_max);
|
||||
params.speculative.draft.p_min = json_value(data, "speculative.p_min", defaults.speculative.draft.p_min);
|
||||
|
||||
params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
|
||||
params.speculative.n_min = std::max(params.speculative.n_min, 0);
|
||||
params.speculative.n_max = std::max(params.speculative.n_max, 0);
|
||||
params.speculative.draft.n_min = std::min(params.speculative.draft.n_max, params.speculative.draft.n_min);
|
||||
params.speculative.draft.n_min = std::max(params.speculative.draft.n_min, 0);
|
||||
params.speculative.draft.n_max = std::max(params.speculative.draft.n_max, 0);
|
||||
|
||||
#if 0
|
||||
// for debugging and research purposes
|
||||
params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type)));
|
||||
|
||||
params.speculative.ngram_size_n = json_value(data, "speculative.ngram_size_n", defaults.speculative.ngram_size_n);
|
||||
@@ -325,6 +316,7 @@ task_params server_task::params_from_json_cmpl(
|
||||
params.speculative.ngram_size_n = std::max(std::min(1, (int) params.speculative.ngram_size_n), 1024);
|
||||
params.speculative.ngram_size_m = std::max(std::min(1, (int) params.speculative.ngram_size_m), 1024);
|
||||
params.speculative.ngram_min_hits = std::max(std::min(1, (int) params.speculative.ngram_min_hits), 1024);
|
||||
#endif
|
||||
|
||||
// Use OpenAI API logprobs only if n_probs wasn't provided
|
||||
if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
|
||||
|
||||
@@ -83,15 +83,14 @@ class ServerProcess:
|
||||
kv_unified: bool | None = False
|
||||
server_slots: bool | None = False
|
||||
pooling: str | None = None
|
||||
draft: int | None = None
|
||||
api_key: str | None = None
|
||||
models_dir: str | None = None
|
||||
models_max: int | None = None
|
||||
no_models_autoload: bool | None = None
|
||||
lora_files: List[str] | None = None
|
||||
enable_ctx_shift: int | None = False
|
||||
draft_min: int | None = None
|
||||
draft_max: int | None = None
|
||||
spec_draft_n_min: int | None = None
|
||||
spec_draft_n_max: int | None = None
|
||||
no_webui: bool | None = None
|
||||
jinja: bool | None = None
|
||||
reasoning_format: Literal['deepseek', 'none', 'nothink'] | None = None
|
||||
@@ -165,8 +164,6 @@ class ServerProcess:
|
||||
server_args.extend(["--threads", self.n_threads])
|
||||
if self.n_gpu_layer:
|
||||
server_args.extend(["--n-gpu-layers", self.n_gpu_layer])
|
||||
if self.draft is not None:
|
||||
server_args.extend(["--draft", self.draft])
|
||||
if self.server_continuous_batching:
|
||||
server_args.append("--cont-batching")
|
||||
if self.server_embeddings:
|
||||
@@ -214,10 +211,10 @@ class ServerProcess:
|
||||
server_args.append("--context-shift")
|
||||
if self.api_key:
|
||||
server_args.extend(["--api-key", self.api_key])
|
||||
if self.draft_max:
|
||||
server_args.extend(["--draft-max", self.draft_max])
|
||||
if self.draft_min:
|
||||
server_args.extend(["--draft-min", self.draft_min])
|
||||
if self.spec_draft_n_max:
|
||||
server_args.extend(["--spec-draft-n-max", self.spec_draft_n_max])
|
||||
if self.spec_draft_n_min:
|
||||
server_args.extend(["--spec-draft-n-min", self.spec_draft_n_min])
|
||||
if self.no_webui:
|
||||
server_args.append("--no-webui")
|
||||
if self.no_models_autoload:
|
||||
|
||||
Generated
+392
-500
File diff suppressed because it is too large
Load Diff
@@ -29,7 +29,7 @@
|
||||
--chart-3: oklch(0.398 0.07 227.392);
|
||||
--chart-4: oklch(0.828 0.189 84.429);
|
||||
--chart-5: oklch(0.769 0.188 70.08);
|
||||
--sidebar: oklch(0.987 0 0);
|
||||
--sidebar: oklch(0.985 0 0);
|
||||
--sidebar-foreground: oklch(0.145 0 0);
|
||||
--sidebar-primary: oklch(0.205 0 0);
|
||||
--sidebar-primary-foreground: oklch(0.985 0 0);
|
||||
@@ -77,7 +77,7 @@
|
||||
--chart-3: oklch(0.769 0.188 70.08);
|
||||
--chart-4: oklch(0.627 0.265 303.9);
|
||||
--chart-5: oklch(0.645 0.246 16.439);
|
||||
--sidebar: oklch(0.19 0 0);
|
||||
--sidebar: oklch(0.2 0 0);
|
||||
--sidebar-foreground: oklch(0.985 0 0);
|
||||
--sidebar-primary: oklch(0.488 0.243 264.376);
|
||||
--sidebar-primary-foreground: oklch(0.985 0 0);
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
<script lang="ts">
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { Button, type ButtonVariant, type ButtonSize } from '$lib/components/ui/button';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import type { Component } from 'svelte';
|
||||
import { TooltipSide } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
icon: Component;
|
||||
tooltip: string;
|
||||
variant?: 'default' | 'destructive' | 'outline' | 'secondary' | 'ghost' | 'link';
|
||||
size?: 'default' | 'sm' | 'lg' | 'icon';
|
||||
variant?: ButtonVariant;
|
||||
size?: ButtonSize;
|
||||
iconSize?: string;
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
onclick: (e?: MouseEvent) => void;
|
||||
'aria-label'?: string;
|
||||
tooltipSide?: TooltipSide;
|
||||
}
|
||||
|
||||
let {
|
||||
@@ -23,6 +25,7 @@
|
||||
class: className = '',
|
||||
disabled = false,
|
||||
iconSize = 'h-3 w-3',
|
||||
tooltipSide = TooltipSide.TOP,
|
||||
onclick,
|
||||
'aria-label': ariaLabel
|
||||
}: Props = $props();
|
||||
@@ -35,7 +38,7 @@
|
||||
{size}
|
||||
{disabled}
|
||||
{onclick}
|
||||
class="h-6 w-6 p-0 {className} flex"
|
||||
class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!"
|
||||
aria-label={ariaLabel || tooltip}
|
||||
>
|
||||
{@const IconComponent = icon}
|
||||
@@ -44,7 +47,7 @@
|
||||
</Button>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
<Tooltip.Content side={tooltipSide}>
|
||||
<p>{tooltip}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
|
||||
@@ -300,7 +300,7 @@
|
||||
if (sendOnEnter || isModifier) {
|
||||
event.preventDefault();
|
||||
|
||||
if (!canSubmit || disabled || isLoading || hasLoadingAttachments) return;
|
||||
if (!canSubmit || disabled || hasLoadingAttachments) return;
|
||||
|
||||
onSubmit?.();
|
||||
}
|
||||
@@ -555,7 +555,7 @@
|
||||
class="relative {className}"
|
||||
onsubmit={(e) => {
|
||||
e.preventDefault();
|
||||
if (!canSubmit || disabled || isLoading || hasLoadingAttachments) return;
|
||||
if (!canSubmit || disabled || hasLoadingAttachments) return;
|
||||
onSubmit?.();
|
||||
}}
|
||||
>
|
||||
|
||||
-333
@@ -1,333 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { page } from '$app/state';
|
||||
import { Plus, MessageSquare, Settings, Zap, FolderOpen } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { Switch } from '$lib/components/ui/switch';
|
||||
import { FILE_TYPE_ICONS, TOOLTIP_DELAY_DURATION } from '$lib/constants';
|
||||
import { McpLogo, DropdownMenuSearchable } from '$lib/components/app';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { mcpStore } from '$lib/stores/mcp.svelte';
|
||||
|
||||
import { HealthCheckStatus } from '$lib/enums';
|
||||
import type { MCPServerSettingsEntry } from '$lib/types';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
hasAudioModality?: boolean;
|
||||
hasVisionModality?: boolean;
|
||||
hasMcpPromptsSupport?: boolean;
|
||||
hasMcpResourcesSupport?: boolean;
|
||||
onFileUpload?: () => void;
|
||||
onSystemPromptClick?: () => void;
|
||||
onMcpPromptClick?: () => void;
|
||||
onMcpSettingsClick?: () => void;
|
||||
onMcpResourcesClick?: () => void;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
disabled = false,
|
||||
hasAudioModality = false,
|
||||
hasVisionModality = false,
|
||||
hasMcpPromptsSupport = false,
|
||||
hasMcpResourcesSupport = false,
|
||||
onFileUpload,
|
||||
onSystemPromptClick,
|
||||
onMcpPromptClick,
|
||||
onMcpSettingsClick,
|
||||
onMcpResourcesClick
|
||||
}: Props = $props();
|
||||
|
||||
let isNewChat = $derived(!page.params.id);
|
||||
|
||||
let systemMessageTooltip = $derived(
|
||||
isNewChat
|
||||
? 'Add custom system message for a new conversation'
|
||||
: 'Inject custom system message at the beginning of the conversation'
|
||||
);
|
||||
|
||||
let dropdownOpen = $state(false);
|
||||
|
||||
let mcpServers = $derived(mcpStore.getServersSorted().filter((s) => s.enabled));
|
||||
let hasMcpServers = $derived(mcpServers.length > 0);
|
||||
let mcpSearchQuery = $state('');
|
||||
let filteredMcpServers = $derived.by(() => {
|
||||
const query = mcpSearchQuery.toLowerCase().trim();
|
||||
if (!query) return mcpServers;
|
||||
return mcpServers.filter((s) => {
|
||||
const name = getServerLabel(s).toLowerCase();
|
||||
const url = s.url.toLowerCase();
|
||||
return name.includes(query) || url.includes(query);
|
||||
});
|
||||
});
|
||||
|
||||
function getServerLabel(server: MCPServerSettingsEntry): string {
|
||||
return mcpStore.getServerLabel(server);
|
||||
}
|
||||
|
||||
function isServerEnabledForChat(serverId: string): boolean {
|
||||
return conversationsStore.isMcpServerEnabledForChat(serverId);
|
||||
}
|
||||
|
||||
async function toggleServerForChat(serverId: string) {
|
||||
await conversationsStore.toggleMcpServerForChat(serverId);
|
||||
}
|
||||
|
||||
function handleMcpSubMenuOpen(open: boolean) {
|
||||
if (open) {
|
||||
mcpSearchQuery = '';
|
||||
mcpStore.runHealthChecksForServers(mcpServers);
|
||||
}
|
||||
}
|
||||
|
||||
function handleMcpPromptClick() {
|
||||
dropdownOpen = false;
|
||||
onMcpPromptClick?.();
|
||||
}
|
||||
|
||||
function handleMcpSettingsClick() {
|
||||
dropdownOpen = false;
|
||||
onMcpSettingsClick?.();
|
||||
}
|
||||
|
||||
function handleMcpResourcesClick() {
|
||||
dropdownOpen = false;
|
||||
onMcpResourcesClick?.();
|
||||
}
|
||||
|
||||
const fileUploadTooltipText = 'Add files, system prompt or MCP Servers';
|
||||
</script>
|
||||
|
||||
<div class="flex items-center gap-1 {className}">
|
||||
<DropdownMenu.Root bind:open={dropdownOpen}>
|
||||
<DropdownMenu.Trigger name="Attach files" {disabled}>
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<Button
|
||||
class="file-upload-button h-8 w-8 rounded-full p-0"
|
||||
{disabled}
|
||||
variant="secondary"
|
||||
type="button"
|
||||
>
|
||||
<span class="sr-only">{fileUploadTooltipText}</span>
|
||||
|
||||
<Plus class="h-4 w-4" />
|
||||
</Button>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
<p>{fileUploadTooltipText}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
</DropdownMenu.Trigger>
|
||||
|
||||
<DropdownMenu.Content align="start" class="w-48">
|
||||
{#if hasVisionModality}
|
||||
<DropdownMenu.Item
|
||||
class="images-button flex cursor-pointer items-center gap-2"
|
||||
onclick={() => onFileUpload?.()}
|
||||
>
|
||||
<FILE_TYPE_ICONS.image class="h-4 w-4" />
|
||||
|
||||
<span>Images</span>
|
||||
</DropdownMenu.Item>
|
||||
{:else}
|
||||
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<DropdownMenu.Item
|
||||
class="images-button flex cursor-pointer items-center gap-2"
|
||||
disabled
|
||||
>
|
||||
<FILE_TYPE_ICONS.image class="h-4 w-4" />
|
||||
|
||||
<span>Images</span>
|
||||
</DropdownMenu.Item>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content side="right">
|
||||
<p>Image processing requires a vision model</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{/if}
|
||||
|
||||
{#if hasAudioModality}
|
||||
<DropdownMenu.Item
|
||||
class="audio-button flex cursor-pointer items-center gap-2"
|
||||
onclick={() => onFileUpload?.()}
|
||||
>
|
||||
<FILE_TYPE_ICONS.audio class="h-4 w-4" />
|
||||
|
||||
<span>Audio Files</span>
|
||||
</DropdownMenu.Item>
|
||||
{:else}
|
||||
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<DropdownMenu.Item class="audio-button flex cursor-pointer items-center gap-2" disabled>
|
||||
<FILE_TYPE_ICONS.audio class="h-4 w-4" />
|
||||
|
||||
<span>Audio Files</span>
|
||||
</DropdownMenu.Item>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content side="right">
|
||||
<p>Audio files processing requires an audio model</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{/if}
|
||||
|
||||
<DropdownMenu.Item
|
||||
class="flex cursor-pointer items-center gap-2"
|
||||
onclick={() => onFileUpload?.()}
|
||||
>
|
||||
<FILE_TYPE_ICONS.text class="h-4 w-4" />
|
||||
|
||||
<span>Text Files</span>
|
||||
</DropdownMenu.Item>
|
||||
|
||||
{#if hasVisionModality}
|
||||
<DropdownMenu.Item
|
||||
class="flex cursor-pointer items-center gap-2"
|
||||
onclick={() => onFileUpload?.()}
|
||||
>
|
||||
<FILE_TYPE_ICONS.pdf class="h-4 w-4" />
|
||||
|
||||
<span>PDF Files</span>
|
||||
</DropdownMenu.Item>
|
||||
{:else}
|
||||
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<DropdownMenu.Item
|
||||
class="flex cursor-pointer items-center gap-2"
|
||||
onclick={() => onFileUpload?.()}
|
||||
>
|
||||
<FILE_TYPE_ICONS.pdf class="h-4 w-4" />
|
||||
|
||||
<span>PDF Files</span>
|
||||
</DropdownMenu.Item>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content side="right">
|
||||
<p>PDFs will be converted to text. Image-based PDFs may not work properly.</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{/if}
|
||||
|
||||
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<DropdownMenu.Item
|
||||
class="flex cursor-pointer items-center gap-2"
|
||||
onclick={() => onSystemPromptClick?.()}
|
||||
>
|
||||
<MessageSquare class="h-4 w-4" />
|
||||
|
||||
<span>System Message</span>
|
||||
</DropdownMenu.Item>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content side="right">
|
||||
<p>{systemMessageTooltip}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
|
||||
<DropdownMenu.Separator />
|
||||
|
||||
<DropdownMenu.Sub onOpenChange={handleMcpSubMenuOpen}>
|
||||
<DropdownMenu.SubTrigger class="flex cursor-pointer items-center gap-2">
|
||||
<McpLogo class="h-4 w-4" />
|
||||
|
||||
<span>MCP Servers</span>
|
||||
</DropdownMenu.SubTrigger>
|
||||
|
||||
<DropdownMenu.SubContent class="w-72 pt-0">
|
||||
<DropdownMenuSearchable
|
||||
placeholder="Search servers..."
|
||||
bind:searchValue={mcpSearchQuery}
|
||||
emptyMessage={hasMcpServers ? 'No servers found' : 'No MCP servers configured'}
|
||||
isEmpty={filteredMcpServers.length === 0}
|
||||
>
|
||||
<div class="max-h-64 overflow-y-auto">
|
||||
{#each filteredMcpServers as server (server.id)}
|
||||
{@const healthState = mcpStore.getHealthCheckState(server.id)}
|
||||
{@const hasError = healthState.status === HealthCheckStatus.ERROR}
|
||||
{@const isEnabledForChat = isServerEnabledForChat(server.id)}
|
||||
|
||||
<button
|
||||
type="button"
|
||||
class="flex w-full items-center justify-between gap-2 rounded-sm px-2 py-2 text-left transition-colors hover:bg-accent disabled:cursor-not-allowed disabled:opacity-50"
|
||||
onclick={() => !hasError && toggleServerForChat(server.id)}
|
||||
disabled={hasError}
|
||||
>
|
||||
<div class="flex min-w-0 flex-1 items-center gap-2">
|
||||
{#if mcpStore.getServerFavicon(server.id)}
|
||||
<img
|
||||
src={mcpStore.getServerFavicon(server.id)}
|
||||
alt=""
|
||||
class="h-4 w-4 shrink-0 rounded-sm"
|
||||
onerror={(e) => {
|
||||
(e.currentTarget as HTMLImageElement).style.display = 'none';
|
||||
}}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
<span class="truncate text-sm">{getServerLabel(server)}</span>
|
||||
|
||||
{#if hasError}
|
||||
<span
|
||||
class="shrink-0 rounded bg-destructive/15 px-1.5 py-0.5 text-xs text-destructive"
|
||||
>
|
||||
Error
|
||||
</span>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<Switch
|
||||
checked={isEnabledForChat}
|
||||
disabled={hasError}
|
||||
onclick={(e: MouseEvent) => e.stopPropagation()}
|
||||
onCheckedChange={() => toggleServerForChat(server.id)}
|
||||
/>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
{#snippet footer()}
|
||||
<DropdownMenu.Item
|
||||
class="flex cursor-pointer items-center gap-2"
|
||||
onclick={handleMcpSettingsClick}
|
||||
>
|
||||
<Settings class="h-4 w-4" />
|
||||
|
||||
<span>Manage MCP Servers</span>
|
||||
</DropdownMenu.Item>
|
||||
{/snippet}
|
||||
</DropdownMenuSearchable>
|
||||
</DropdownMenu.SubContent>
|
||||
</DropdownMenu.Sub>
|
||||
|
||||
{#if hasMcpPromptsSupport}
|
||||
<DropdownMenu.Item
|
||||
class="flex cursor-pointer items-center gap-2"
|
||||
onclick={handleMcpPromptClick}
|
||||
>
|
||||
<Zap class="h-4 w-4" />
|
||||
|
||||
<span>MCP Prompt</span>
|
||||
</DropdownMenu.Item>
|
||||
{/if}
|
||||
|
||||
{#if hasMcpResourcesSupport}
|
||||
<DropdownMenu.Item
|
||||
class="flex cursor-pointer items-center gap-2"
|
||||
onclick={handleMcpResourcesClick}
|
||||
>
|
||||
<FolderOpen class="h-4 w-4" />
|
||||
|
||||
<span>MCP Resources</span>
|
||||
</DropdownMenu.Item>
|
||||
{/if}
|
||||
</DropdownMenu.Content>
|
||||
</DropdownMenu.Root>
|
||||
</div>
|
||||
-170
@@ -1,170 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { Plus, MessageSquare, Zap, FolderOpen } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as Sheet from '$lib/components/ui/sheet';
|
||||
import { FILE_TYPE_ICONS } from '$lib/constants';
|
||||
import { McpLogo } from '$lib/components/app';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
hasAudioModality?: boolean;
|
||||
hasVisionModality?: boolean;
|
||||
hasMcpPromptsSupport?: boolean;
|
||||
hasMcpResourcesSupport?: boolean;
|
||||
onFileUpload?: () => void;
|
||||
onSystemPromptClick?: () => void;
|
||||
onMcpPromptClick?: () => void;
|
||||
onMcpSettingsClick?: () => void;
|
||||
onMcpResourcesClick?: () => void;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
disabled = false,
|
||||
hasAudioModality = false,
|
||||
hasVisionModality = false,
|
||||
hasMcpPromptsSupport = false,
|
||||
hasMcpResourcesSupport = false,
|
||||
onFileUpload,
|
||||
onSystemPromptClick,
|
||||
onMcpPromptClick,
|
||||
onMcpSettingsClick,
|
||||
onMcpResourcesClick
|
||||
}: Props = $props();
|
||||
|
||||
let sheetOpen = $state(false);
|
||||
|
||||
function handleMcpPromptClick() {
|
||||
sheetOpen = false;
|
||||
onMcpPromptClick?.();
|
||||
}
|
||||
|
||||
function handleMcpSettingsClick() {
|
||||
onMcpSettingsClick?.();
|
||||
}
|
||||
|
||||
function handleMcpResourcesClick() {
|
||||
sheetOpen = false;
|
||||
onMcpResourcesClick?.();
|
||||
}
|
||||
|
||||
function handleSheetFileUpload() {
|
||||
sheetOpen = false;
|
||||
onFileUpload?.();
|
||||
}
|
||||
|
||||
function handleSheetSystemPromptClick() {
|
||||
sheetOpen = false;
|
||||
onSystemPromptClick?.();
|
||||
}
|
||||
|
||||
const fileUploadTooltipText = 'Add files, system prompt or MCP Servers';
|
||||
|
||||
const sheetItemClass =
|
||||
'flex w-full items-center gap-3 rounded-md px-3 py-2.5 text-left text-sm transition-colors hover:bg-accent active:bg-accent disabled:cursor-not-allowed disabled:opacity-50';
|
||||
</script>
|
||||
|
||||
<div class="flex items-center gap-1 {className}">
|
||||
<Sheet.Root bind:open={sheetOpen}>
|
||||
<Button
|
||||
class="file-upload-button h-8 w-8 rounded-full p-0"
|
||||
{disabled}
|
||||
variant="secondary"
|
||||
type="button"
|
||||
onclick={() => (sheetOpen = true)}
|
||||
>
|
||||
<span class="sr-only">{fileUploadTooltipText}</span>
|
||||
|
||||
<Plus class="h-4 w-4" />
|
||||
</Button>
|
||||
|
||||
<Sheet.Content side="bottom" class="max-h-[85vh] gap-0">
|
||||
<Sheet.Header>
|
||||
<Sheet.Title>Add to chat</Sheet.Title>
|
||||
|
||||
<Sheet.Description class="sr-only">
|
||||
Add files, system prompt or configure MCP servers
|
||||
</Sheet.Description>
|
||||
</Sheet.Header>
|
||||
|
||||
<div class="flex flex-col gap-1 overflow-y-auto px-1.5 pb-2">
|
||||
<!-- Images -->
|
||||
<button
|
||||
type="button"
|
||||
class={sheetItemClass}
|
||||
disabled={!hasVisionModality}
|
||||
onclick={handleSheetFileUpload}
|
||||
>
|
||||
<FILE_TYPE_ICONS.image class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>Images</span>
|
||||
|
||||
{#if !hasVisionModality}
|
||||
<span class="ml-auto text-xs text-muted-foreground">Requires vision model</span>
|
||||
{/if}
|
||||
</button>
|
||||
|
||||
<!-- Audio -->
|
||||
<button
|
||||
type="button"
|
||||
class={sheetItemClass}
|
||||
disabled={!hasAudioModality}
|
||||
onclick={handleSheetFileUpload}
|
||||
>
|
||||
<FILE_TYPE_ICONS.audio class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>Audio Files</span>
|
||||
|
||||
{#if !hasAudioModality}
|
||||
<span class="ml-auto text-xs text-muted-foreground">Requires audio model</span>
|
||||
{/if}
|
||||
</button>
|
||||
|
||||
<button type="button" class={sheetItemClass} onclick={handleSheetFileUpload}>
|
||||
<FILE_TYPE_ICONS.text class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>Text Files</span>
|
||||
</button>
|
||||
|
||||
<button type="button" class={sheetItemClass} onclick={handleSheetFileUpload}>
|
||||
<FILE_TYPE_ICONS.pdf class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>PDF Files</span>
|
||||
|
||||
{#if !hasVisionModality}
|
||||
<span class="ml-auto text-xs text-muted-foreground">Text-only</span>
|
||||
{/if}
|
||||
</button>
|
||||
|
||||
<button type="button" class={sheetItemClass} onclick={handleSheetSystemPromptClick}>
|
||||
<MessageSquare class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>System Message</span>
|
||||
</button>
|
||||
|
||||
<button type="button" class={sheetItemClass} onclick={handleMcpSettingsClick}>
|
||||
<McpLogo class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>MCP Servers</span>
|
||||
</button>
|
||||
|
||||
{#if hasMcpPromptsSupport}
|
||||
<button type="button" class={sheetItemClass} onclick={handleMcpPromptClick}>
|
||||
<Zap class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>MCP Prompt</span>
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
{#if hasMcpResourcesSupport}
|
||||
<button type="button" class={sheetItemClass} onclick={handleMcpResourcesClick}>
|
||||
<FolderOpen class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>MCP Resources</span>
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
</Sheet.Content>
|
||||
</Sheet.Root>
|
||||
</div>
|
||||
+59
-72
@@ -1,62 +1,39 @@
|
||||
<script lang="ts">
|
||||
import { Settings } from '@lucide/svelte';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import { Settings, Plus } from '@lucide/svelte';
|
||||
import { Switch } from '$lib/components/ui/switch';
|
||||
import { DropdownMenuSearchable, McpActiveServersAvatars } from '$lib/components/app';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import { McpLogo, DropdownMenuSearchable } from '$lib/components/app';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { mcpStore } from '$lib/stores/mcp.svelte';
|
||||
import { HealthCheckStatus } from '$lib/enums';
|
||||
import type { MCPServerSettingsEntry } from '$lib/types';
|
||||
import { goto } from '$app/navigation';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
onSettingsClick?: () => void;
|
||||
onMcpSettingsClick?: () => void;
|
||||
}
|
||||
|
||||
let { class: className = '', disabled = false, onSettingsClick }: Props = $props();
|
||||
let { onMcpSettingsClick }: Props = $props();
|
||||
|
||||
let searchQuery = $state('');
|
||||
let mcpServers = $derived(mcpStore.getServersSorted().filter((s) => s.enabled));
|
||||
let mcpSearchQuery = $state('');
|
||||
let allMcpServers = $derived(mcpStore.getServersSorted());
|
||||
let mcpServers = $derived(allMcpServers.filter((s) => s.enabled));
|
||||
let hasMcpServers = $derived(mcpServers.length > 0);
|
||||
let enabledMcpServersForChat = $derived(
|
||||
mcpServers.filter((s) => conversationsStore.isMcpServerEnabledForChat(s.id) && s.url.trim())
|
||||
);
|
||||
let healthyEnabledMcpServers = $derived(
|
||||
enabledMcpServersForChat.filter((s) => {
|
||||
const healthState = mcpStore.getHealthCheckState(s.id);
|
||||
return healthState.status !== HealthCheckStatus.ERROR;
|
||||
})
|
||||
);
|
||||
let hasEnabledMcpServers = $derived(enabledMcpServersForChat.length > 0);
|
||||
let mcpFavicons = $derived(
|
||||
healthyEnabledMcpServers
|
||||
.slice(0, 3)
|
||||
.map((s) => ({ id: s.id, url: mcpStore.getServerFavicon(s.id) }))
|
||||
.filter((f) => f.url !== null)
|
||||
);
|
||||
// let hasAnyMcpServers = $derived(allMcpServers.length > 0);
|
||||
let filteredMcpServers = $derived.by(() => {
|
||||
const query = searchQuery.toLowerCase().trim();
|
||||
if (query) {
|
||||
return mcpServers.filter((s) => {
|
||||
const name = getServerLabel(s).toLowerCase();
|
||||
const url = s.url.toLowerCase();
|
||||
return name.includes(query) || url.includes(query);
|
||||
});
|
||||
}
|
||||
return mcpServers;
|
||||
const query = mcpSearchQuery.toLowerCase().trim();
|
||||
if (!query) return mcpServers;
|
||||
return mcpServers.filter((s) => {
|
||||
const name = getServerLabel(s).toLowerCase();
|
||||
const url = s.url.toLowerCase();
|
||||
return name.includes(query) || url.includes(query);
|
||||
});
|
||||
});
|
||||
|
||||
function getServerLabel(server: MCPServerSettingsEntry): string {
|
||||
return mcpStore.getServerLabel(server);
|
||||
}
|
||||
|
||||
function handleDropdownOpen(open: boolean) {
|
||||
if (open) {
|
||||
mcpStore.runHealthChecksForServers(mcpServers);
|
||||
}
|
||||
}
|
||||
|
||||
function isServerEnabledForChat(serverId: string): boolean {
|
||||
return conversationsStore.isMcpServerEnabledForChat(serverId);
|
||||
}
|
||||
@@ -64,38 +41,33 @@
|
||||
async function toggleServerForChat(serverId: string) {
|
||||
await conversationsStore.toggleMcpServerForChat(serverId);
|
||||
}
|
||||
|
||||
function handleMcpSubMenuOpen(open: boolean) {
|
||||
if (open) {
|
||||
mcpSearchQuery = '';
|
||||
mcpStore.runHealthChecksForServers(allMcpServers);
|
||||
}
|
||||
}
|
||||
|
||||
function handleMcpSettingsClick() {
|
||||
onMcpSettingsClick?.();
|
||||
|
||||
goto(`${hasMcpServers ? '' : '?add'}#/settings/mcp`);
|
||||
}
|
||||
</script>
|
||||
|
||||
{#if hasMcpServers && hasEnabledMcpServers && mcpFavicons.length > 0}
|
||||
<DropdownMenu.Root
|
||||
onOpenChange={(open) => {
|
||||
if (!open) {
|
||||
searchQuery = '';
|
||||
}
|
||||
handleDropdownOpen(open);
|
||||
}}
|
||||
>
|
||||
<DropdownMenu.Trigger
|
||||
{disabled}
|
||||
onclick={(e) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
}}
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
class="inline-flex cursor-pointer items-center rounded-sm py-1 disabled:cursor-not-allowed disabled:opacity-60"
|
||||
{disabled}
|
||||
aria-label="MCP Servers"
|
||||
>
|
||||
<McpActiveServersAvatars class={className} />
|
||||
</button>
|
||||
</DropdownMenu.Trigger>
|
||||
<DropdownMenu.Sub onOpenChange={handleMcpSubMenuOpen}>
|
||||
<DropdownMenu.SubTrigger class="flex cursor-pointer items-center gap-2">
|
||||
<McpLogo class="h-4 w-4" />
|
||||
|
||||
<DropdownMenu.Content align="start" class="w-72 pt-0">
|
||||
<span>MCP Servers</span>
|
||||
</DropdownMenu.SubTrigger>
|
||||
|
||||
<DropdownMenu.SubContent class="w-72 pt-0">
|
||||
{#if hasMcpServers}
|
||||
<DropdownMenuSearchable
|
||||
bind:searchValue={searchQuery}
|
||||
placeholder="Search servers..."
|
||||
bind:searchValue={mcpSearchQuery}
|
||||
emptyMessage="No servers found"
|
||||
isEmpty={filteredMcpServers.length === 0}
|
||||
>
|
||||
@@ -107,7 +79,7 @@
|
||||
|
||||
<button
|
||||
type="button"
|
||||
class="flex w-full items-center justify-between gap-2 px-2 py-2 text-left transition-colors hover:bg-accent disabled:cursor-not-allowed disabled:opacity-50"
|
||||
class="flex w-full items-center justify-between gap-2 rounded-sm px-2 py-2 text-left transition-colors hover:bg-accent disabled:cursor-not-allowed disabled:opacity-50"
|
||||
onclick={() => !hasError && toggleServerForChat(server.id)}
|
||||
disabled={hasError}
|
||||
>
|
||||
@@ -147,7 +119,7 @@
|
||||
{#snippet footer()}
|
||||
<DropdownMenu.Item
|
||||
class="flex cursor-pointer items-center gap-2"
|
||||
onclick={onSettingsClick}
|
||||
onclick={handleMcpSettingsClick}
|
||||
>
|
||||
<Settings class="h-4 w-4" />
|
||||
|
||||
@@ -155,6 +127,21 @@
|
||||
</DropdownMenu.Item>
|
||||
{/snippet}
|
||||
</DropdownMenuSearchable>
|
||||
</DropdownMenu.Content>
|
||||
</DropdownMenu.Root>
|
||||
{/if}
|
||||
{:else}
|
||||
<div class="px-2 py-3 text-center text-sm text-muted-foreground">
|
||||
No MCP servers configured
|
||||
</div>
|
||||
|
||||
<DropdownMenu.Separator />
|
||||
|
||||
<DropdownMenu.Item
|
||||
class="flex cursor-pointer items-center gap-2"
|
||||
onclick={handleMcpSettingsClick}
|
||||
>
|
||||
<Plus class="h-4 w-4" />
|
||||
|
||||
<span>Add MCP Servers</span>
|
||||
</DropdownMenu.Item>
|
||||
{/if}
|
||||
</DropdownMenu.SubContent>
|
||||
</DropdownMenu.Sub>
|
||||
+2
-9
@@ -7,20 +7,13 @@
|
||||
interface Props {
|
||||
canSend?: boolean;
|
||||
disabled?: boolean;
|
||||
isLoading?: boolean;
|
||||
showErrorState?: boolean;
|
||||
tooltipLabel?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
canSend = false,
|
||||
disabled = false,
|
||||
isLoading = false,
|
||||
showErrorState = false,
|
||||
tooltipLabel
|
||||
}: Props = $props();
|
||||
let { canSend = false, disabled = false, showErrorState = false, tooltipLabel }: Props = $props();
|
||||
|
||||
let isDisabled = $derived(!canSend || disabled || isLoading);
|
||||
let isDisabled = $derived(!canSend || disabled);
|
||||
</script>
|
||||
|
||||
{#snippet submitButton(props = {})}
|
||||
|
||||
+146
@@ -0,0 +1,146 @@
|
||||
<script lang="ts">
|
||||
import { PencilRuler, ChevronDown, ChevronRight, Loader2, Info } from '@lucide/svelte';
|
||||
import { Checkbox } from '$lib/components/ui/checkbox';
|
||||
import * as Collapsible from '$lib/components/ui/collapsible';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { toolsStore } from '$lib/stores/tools.svelte';
|
||||
import { mcpStore } from '$lib/stores/mcp.svelte';
|
||||
import { useToolsPanel } from '$lib/hooks/use-tools-panel.svelte';
|
||||
|
||||
const toolsPanel = useToolsPanel();
|
||||
const hasMcpServersAvailable = $derived(mcpStore.getServersSorted().length > 0);
|
||||
</script>
|
||||
|
||||
<DropdownMenu.Sub onOpenChange={(open) => open && toolsPanel.handleOpen()}>
|
||||
<DropdownMenu.SubTrigger class="flex cursor-pointer items-center gap-2">
|
||||
<PencilRuler class="h-4 w-4" />
|
||||
|
||||
<span>Tools</span>
|
||||
</DropdownMenu.SubTrigger>
|
||||
|
||||
<DropdownMenu.SubContent class="w-72 p-0">
|
||||
{#if toolsPanel.totalToolCount === 0}
|
||||
{#if toolsStore.loading}
|
||||
<div class="px-3 py-4 text-center text-sm text-muted-foreground">
|
||||
<Loader2 class="mx-auto mb-1 h-4 w-4 animate-spin" />
|
||||
Loading tools...
|
||||
</div>
|
||||
{:else if toolsStore.isToolsEndpointUnreachable}
|
||||
<div class="grid gap-2.5 px-3 py-4 text-sm text-muted-foreground">
|
||||
<span class="flex gap-2">
|
||||
<Info class="mt-0.5 h-4 w-4 shrink-0" />
|
||||
|
||||
<span
|
||||
>Run llama-server with <code>--tools</code> flag to enable
|
||||
<strong>Built-in Tools</strong>.</span
|
||||
>
|
||||
</span>
|
||||
|
||||
<span class="flex gap-2">
|
||||
<Info class="mt-0.5 h-4 w-4 shrink-0" />
|
||||
|
||||
<span
|
||||
>{hasMcpServersAvailable ? 'Enable' : 'Add'} MCP Server(s) to access
|
||||
<strong>MCP Tools</strong>.</span
|
||||
>
|
||||
</span>
|
||||
</div>
|
||||
{:else if toolsStore.error}
|
||||
<div class="px-3 py-4 text-center text-sm text-muted-foreground">Failed to load tools</div>
|
||||
{:else if toolsPanel.noToolsInfoMessage}
|
||||
<div class="flex gap-2 px-3 py-4 text-sm text-muted-foreground">
|
||||
<Info class="mt-0.5 h-4 w-4 shrink-0" />
|
||||
|
||||
<span>{toolsPanel.noToolsInfoMessage}</span>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="px-3 py-4 text-center text-sm text-muted-foreground">No tools available</div>
|
||||
{/if}
|
||||
{:else}
|
||||
<div class="max-h-80 overflow-y-auto p-2 pr-1">
|
||||
{#each toolsPanel.activeGroups as group (group.label)}
|
||||
{@const isExpanded = toolsPanel.expandedGroups.has(group.label)}
|
||||
{@const { checked, indeterminate } = toolsPanel.getGroupCheckedState(group)}
|
||||
{@const favicon = toolsPanel.getFavicon(group)}
|
||||
|
||||
<Collapsible.Root
|
||||
open={isExpanded}
|
||||
onOpenChange={() => toolsPanel.toggleGroupExpanded(group.label)}
|
||||
>
|
||||
<div class="flex items-center gap-1">
|
||||
<Collapsible.Trigger
|
||||
class="flex min-w-0 flex-1 items-center gap-2 rounded px-2 py-1.5 text-sm hover:bg-muted/50"
|
||||
>
|
||||
{#if isExpanded}
|
||||
<ChevronDown class="h-3.5 w-3.5 shrink-0" />
|
||||
{:else}
|
||||
<ChevronRight class="h-3.5 w-3.5 shrink-0" />
|
||||
{/if}
|
||||
|
||||
<span class="inline-flex min-w-0 items-center gap-1.5 font-medium">
|
||||
{#if favicon}
|
||||
<img
|
||||
src={favicon}
|
||||
alt=""
|
||||
class="h-4 w-4 shrink-0 rounded-sm"
|
||||
onerror={(e) => {
|
||||
(e.currentTarget as HTMLImageElement).style.display = 'none';
|
||||
}}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
<span class="truncate">{group.label}</span>
|
||||
</span>
|
||||
|
||||
<span class="ml-auto shrink-0 text-xs text-muted-foreground">
|
||||
{toolsPanel.getEnabledToolCount(group)}/{group.tools.length}
|
||||
</span>
|
||||
</Collapsible.Trigger>
|
||||
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<Checkbox
|
||||
{checked}
|
||||
{indeterminate}
|
||||
onCheckedChange={() => toolsStore.toggleGroup(group)}
|
||||
class="mr-2 h-4 w-4 shrink-0"
|
||||
/>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content side="right">
|
||||
<p>
|
||||
{checked ? 'Disable' : 'Enable'}
|
||||
{group.tools.length} tool{group.tools.length !== 1 ? 's' : ''}
|
||||
</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
</div>
|
||||
|
||||
<Collapsible.Content>
|
||||
<div class="ml-4 flex flex-col gap-0.5 border-l border-border/50 pl-2">
|
||||
{#each group.tools as tool (tool.function.name)}
|
||||
<button
|
||||
type="button"
|
||||
class="flex w-full items-center gap-2 rounded px-2 py-1.5 text-left text-sm transition-colors hover:bg-muted/50"
|
||||
onclick={() => toolsStore.toggleTool(tool.function.name)}
|
||||
>
|
||||
<Checkbox
|
||||
checked={toolsStore.isToolEnabled(tool.function.name)}
|
||||
onCheckedChange={() => toolsStore.toggleTool(tool.function.name)}
|
||||
class="h-4 w-4 shrink-0"
|
||||
/>
|
||||
|
||||
<span class="min-w-0 flex-1 truncate font-mono text-[12px]">
|
||||
{tool.function.name}
|
||||
</span>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
</Collapsible.Content>
|
||||
</Collapsible.Root>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</DropdownMenu.SubContent>
|
||||
</DropdownMenu.Sub>
|
||||
+14
-23
@@ -6,21 +6,19 @@
|
||||
ChatFormActionAttachmentsSheet,
|
||||
ChatFormActionRecord,
|
||||
ChatFormActionSubmit,
|
||||
McpServersSelector,
|
||||
ModelsSelector,
|
||||
ModelsSelectorDropdown,
|
||||
ModelsSelectorSheet
|
||||
} from '$lib/components/app';
|
||||
import { SETTINGS_SECTION_TITLES } from '$lib/constants';
|
||||
import { mcpStore } from '$lib/stores/mcp.svelte';
|
||||
import { getChatSettingsDialogContext } from '$lib/contexts';
|
||||
import { FileTypeCategory } from '$lib/enums';
|
||||
import { getFileTypeCategory } from '$lib/utils';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { IsMobile } from '$lib/hooks/is-mobile.svelte';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { mcpStore } from '$lib/stores/mcp.svelte';
|
||||
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
|
||||
import { isRouterMode, serverError } from '$lib/stores/server.svelte';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { activeMessages, conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { IsMobile } from '$lib/hooks/is-mobile.svelte';
|
||||
import { getFileTypeCategory } from '$lib/utils';
|
||||
import { goto } from '$app/navigation';
|
||||
|
||||
interface Props {
|
||||
canSend?: boolean;
|
||||
@@ -165,7 +163,8 @@
|
||||
return '';
|
||||
});
|
||||
|
||||
let selectorModelRef: ModelsSelector | ModelsSelectorSheet | undefined = $state(undefined);
|
||||
let selectorModelRef: ModelsSelectorDropdown | ModelsSelectorSheet | undefined =
|
||||
$state(undefined);
|
||||
|
||||
let isMobile = new IsMobile();
|
||||
|
||||
@@ -173,8 +172,6 @@
|
||||
selectorModelRef?.open();
|
||||
}
|
||||
|
||||
const chatSettingsDialog = getChatSettingsDialogContext();
|
||||
|
||||
let hasMcpPromptsSupport = $derived.by(() => {
|
||||
const perChatOverrides = conversationsStore.getAllMcpServerOverrides();
|
||||
|
||||
@@ -200,8 +197,8 @@
|
||||
{onFileUpload}
|
||||
{onSystemPromptClick}
|
||||
{onMcpPromptClick}
|
||||
onMcpSettingsClick={() => goto('#/settings/mcp')}
|
||||
{onMcpResourcesClick}
|
||||
onMcpSettingsClick={() => chatSettingsDialog.open(SETTINGS_SECTION_TITLES.MCP)}
|
||||
/>
|
||||
{:else}
|
||||
<ChatFormActionAttachmentsDropdown
|
||||
@@ -214,17 +211,12 @@
|
||||
{onSystemPromptClick}
|
||||
{onMcpPromptClick}
|
||||
{onMcpResourcesClick}
|
||||
onMcpSettingsClick={() => chatSettingsDialog.open(SETTINGS_SECTION_TITLES.MCP)}
|
||||
onMcpSettingsClick={() => goto('#/settings/mcp')}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
<McpServersSelector
|
||||
{disabled}
|
||||
onSettingsClick={() => chatSettingsDialog.open(SETTINGS_SECTION_TITLES.MCP)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="ml-auto flex items-center gap-1.5">
|
||||
<div class="ml-auto flex items-center gap-2">
|
||||
{#if isMobile.current}
|
||||
<ModelsSelectorSheet
|
||||
disabled={disabled || isOffline}
|
||||
@@ -234,7 +226,7 @@
|
||||
useGlobalSelection
|
||||
/>
|
||||
{:else}
|
||||
<ModelsSelector
|
||||
<ModelsSelectorDropdown
|
||||
disabled={disabled || isOffline}
|
||||
bind:this={selectorModelRef}
|
||||
currentModel={conversationModel}
|
||||
@@ -244,7 +236,7 @@
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
{#if isLoading}
|
||||
{#if isLoading && !hasText}
|
||||
<Button
|
||||
type="button"
|
||||
variant="secondary"
|
||||
@@ -263,7 +255,6 @@
|
||||
<ChatFormActionSubmit
|
||||
canSend={canSend && hasModelSelected && isSelectedModelInCache}
|
||||
{disabled}
|
||||
{isLoading}
|
||||
tooltipLabel={submitTooltip}
|
||||
showErrorState={hasModelSelected && !isSelectedModelInCache}
|
||||
/>
|
||||
|
||||
+182
@@ -0,0 +1,182 @@
|
||||
<script lang="ts">
|
||||
import { Plus } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import {
|
||||
ATTACHMENT_FILE_ITEMS,
|
||||
ATTACHMENT_EXTRA_ITEMS,
|
||||
ATTACHMENT_MCP_ITEMS,
|
||||
ATTACHMENT_TOOLTIP_TEXT,
|
||||
TOOLTIP_DELAY_DURATION
|
||||
} from '$lib/constants';
|
||||
import { AttachmentMenuItemId } from '$lib/enums';
|
||||
import { ChatFormActionToolsSubmenu, ChatFormActionMcpServersSubmenu } from '$lib/components/app';
|
||||
|
||||
import { useAttachmentMenu } from '$lib/hooks/use-attachment-menu.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
hasAudioModality?: boolean;
|
||||
hasVisionModality?: boolean;
|
||||
hasMcpPromptsSupport?: boolean;
|
||||
hasMcpResourcesSupport?: boolean;
|
||||
onFileUpload?: () => void;
|
||||
onSystemPromptClick?: () => void;
|
||||
onMcpPromptClick?: () => void;
|
||||
onMcpSettingsClick?: () => void;
|
||||
onMcpResourcesClick?: () => void;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
disabled = false,
|
||||
hasAudioModality = false,
|
||||
hasVisionModality = false,
|
||||
hasMcpPromptsSupport = false,
|
||||
hasMcpResourcesSupport = false,
|
||||
onFileUpload,
|
||||
onSystemPromptClick,
|
||||
onMcpPromptClick,
|
||||
onMcpSettingsClick,
|
||||
onMcpResourcesClick
|
||||
}: Props = $props();
|
||||
|
||||
let dropdownOpen = $state(false);
|
||||
|
||||
function handleMcpSettingsClick() {
|
||||
dropdownOpen = false;
|
||||
onMcpSettingsClick?.();
|
||||
}
|
||||
|
||||
const attachmentMenu = useAttachmentMenu(
|
||||
() => ({ hasVisionModality, hasAudioModality, hasMcpPromptsSupport, hasMcpResourcesSupport }),
|
||||
() => ({ onFileUpload, onSystemPromptClick, onMcpPromptClick, onMcpResourcesClick }),
|
||||
() => {
|
||||
dropdownOpen = false;
|
||||
}
|
||||
);
|
||||
</script>
|
||||
|
||||
<div class="flex items-center gap-1 {className}">
|
||||
<DropdownMenu.Root bind:open={dropdownOpen}>
|
||||
<DropdownMenu.Trigger name="Attach files" {disabled}>
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<Button
|
||||
class="file-upload-button h-8 w-8 rounded-full p-0"
|
||||
{disabled}
|
||||
variant="secondary"
|
||||
type="button"
|
||||
>
|
||||
<span class="sr-only">{ATTACHMENT_TOOLTIP_TEXT}</span>
|
||||
|
||||
<Plus class="h-4 w-4" />
|
||||
</Button>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
<p>{ATTACHMENT_TOOLTIP_TEXT}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
</DropdownMenu.Trigger>
|
||||
|
||||
<DropdownMenu.Content align="start" class="w-48">
|
||||
{#each ATTACHMENT_FILE_ITEMS as item (item.id)}
|
||||
{@const enabled = attachmentMenu.isItemEnabled(item.enabledWhen)}
|
||||
{#if enabled}
|
||||
<DropdownMenu.Item
|
||||
class="{item.class ?? ''} flex cursor-pointer items-center gap-2"
|
||||
onclick={() => attachmentMenu.callbacks[item.action]()}
|
||||
>
|
||||
<item.icon class="h-4 w-4" />
|
||||
|
||||
<span>{item.label}</span>
|
||||
</DropdownMenu.Item>
|
||||
{:else if item.disabledTooltip}
|
||||
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<DropdownMenu.Item
|
||||
class="{item.class ?? ''} flex cursor-pointer items-center gap-2"
|
||||
disabled
|
||||
>
|
||||
<item.icon class="h-4 w-4" />
|
||||
|
||||
<span>{item.label}</span>
|
||||
</DropdownMenu.Item>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content side="right">
|
||||
<p>{item.disabledTooltip}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{/if}
|
||||
{/each}
|
||||
|
||||
{#if !attachmentMenu.isItemEnabled('hasVisionModality')}
|
||||
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<DropdownMenu.Item
|
||||
class="flex cursor-pointer items-center gap-2"
|
||||
onclick={attachmentMenu.callbacks.onFileUpload}
|
||||
>
|
||||
{@const pdfItem = ATTACHMENT_FILE_ITEMS.find(
|
||||
(i) => i.id === AttachmentMenuItemId.PDF
|
||||
)}
|
||||
{#if pdfItem}
|
||||
<pdfItem.icon class="h-4 w-4" />
|
||||
|
||||
<span>{pdfItem.label}</span>
|
||||
{/if}
|
||||
</DropdownMenu.Item>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content side="right">
|
||||
<p>PDFs will be converted to text. Image-based PDFs may not work properly.</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{/if}
|
||||
|
||||
<DropdownMenu.Separator />
|
||||
|
||||
{#each ATTACHMENT_EXTRA_ITEMS as item (item.id)}
|
||||
{#if item.id === AttachmentMenuItemId.SYSTEM_MESSAGE}
|
||||
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<DropdownMenu.Item
|
||||
class="flex cursor-pointer items-center gap-2"
|
||||
onclick={() => attachmentMenu.callbacks[item.action]()}
|
||||
>
|
||||
<item.icon class="h-4 w-4" />
|
||||
|
||||
<span>{item.label}</span>
|
||||
</DropdownMenu.Item>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content side="right">
|
||||
<p>{attachmentMenu.getSystemMessageTooltip()}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{/if}
|
||||
{/each}
|
||||
|
||||
<ChatFormActionToolsSubmenu />
|
||||
|
||||
<ChatFormActionMcpServersSubmenu onMcpSettingsClick={handleMcpSettingsClick} />
|
||||
|
||||
{#each ATTACHMENT_MCP_ITEMS as item (item.id)}
|
||||
{#if attachmentMenu.isItemVisible(item.visibleWhen)}
|
||||
<DropdownMenu.Item
|
||||
class="flex cursor-pointer items-center gap-2"
|
||||
onclick={() => attachmentMenu.callbacks[item.action]()}
|
||||
>
|
||||
<item.icon class="h-4 w-4" />
|
||||
|
||||
<span>{item.label}</span>
|
||||
</DropdownMenu.Item>
|
||||
{/if}
|
||||
{/each}
|
||||
</DropdownMenu.Content>
|
||||
</DropdownMenu.Root>
|
||||
</div>
|
||||
+184
@@ -0,0 +1,184 @@
|
||||
<script lang="ts">
|
||||
import { Plus } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import * as Sheet from '$lib/components/ui/sheet';
|
||||
import { TOOLTIP_DELAY_DURATION } from '$lib/constants';
|
||||
import {
|
||||
ATTACHMENT_FILE_ITEMS,
|
||||
ATTACHMENT_EXTRA_ITEMS,
|
||||
ATTACHMENT_MCP_ITEMS,
|
||||
ATTACHMENT_TOOLTIP_TEXT
|
||||
} from '$lib/constants/attachment-menu';
|
||||
import { ChatFormActionToolsSubmenu, ChatFormActionMcpServersSubmenu } from '$lib/components/app';
|
||||
import { useAttachmentMenu } from '$lib/hooks/use-attachment-menu.svelte';
|
||||
import { AttachmentMenuItemId } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
hasAudioModality?: boolean;
|
||||
hasVisionModality?: boolean;
|
||||
hasMcpPromptsSupport?: boolean;
|
||||
hasMcpResourcesSupport?: boolean;
|
||||
onFileUpload?: () => void;
|
||||
onSystemPromptClick?: () => void;
|
||||
onMcpPromptClick?: () => void;
|
||||
onMcpSettingsClick?: () => void;
|
||||
onMcpResourcesClick?: () => void;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
disabled = false,
|
||||
hasAudioModality = false,
|
||||
hasVisionModality = false,
|
||||
hasMcpPromptsSupport = false,
|
||||
hasMcpResourcesSupport = false,
|
||||
onFileUpload,
|
||||
onSystemPromptClick,
|
||||
onMcpPromptClick,
|
||||
onMcpSettingsClick,
|
||||
onMcpResourcesClick
|
||||
}: Props = $props();
|
||||
|
||||
let sheetOpen = $state(false);
|
||||
|
||||
const attachmentMenu = useAttachmentMenu(
|
||||
() => ({ hasVisionModality, hasAudioModality, hasMcpPromptsSupport, hasMcpResourcesSupport }),
|
||||
() => ({ onFileUpload, onSystemPromptClick, onMcpPromptClick, onMcpResourcesClick }),
|
||||
() => {
|
||||
sheetOpen = false;
|
||||
}
|
||||
);
|
||||
|
||||
function handleMcpSettingsClick() {
|
||||
sheetOpen = false;
|
||||
onMcpSettingsClick?.();
|
||||
}
|
||||
|
||||
const sheetItemClass =
|
||||
'flex w-full items-center gap-3 rounded-md px-3 py-2.5 text-left text-sm transition-colors hover:bg-accent active:bg-accent disabled:cursor-not-allowed disabled:opacity-50';
|
||||
</script>
|
||||
|
||||
<div class="flex items-center gap-1 {className}">
|
||||
<Sheet.Root bind:open={sheetOpen}>
|
||||
<Button
|
||||
class="file-upload-button h-8 w-8 rounded-full p-0"
|
||||
{disabled}
|
||||
variant="secondary"
|
||||
type="button"
|
||||
onclick={() => (sheetOpen = true)}
|
||||
>
|
||||
<span class="sr-only">{ATTACHMENT_TOOLTIP_TEXT}</span>
|
||||
|
||||
<Plus class="h-4 w-4" />
|
||||
</Button>
|
||||
|
||||
<Sheet.Content side="bottom" class="max-h-[85vh] gap-0 overflow-y-auto">
|
||||
<Sheet.Header>
|
||||
<Sheet.Title>Add to chat</Sheet.Title>
|
||||
|
||||
<Sheet.Description class="sr-only">
|
||||
Add files, system prompt or configure MCP servers
|
||||
</Sheet.Description>
|
||||
</Sheet.Header>
|
||||
|
||||
<div class="flex flex-col gap-1 px-1.5 pb-2">
|
||||
{#each ATTACHMENT_FILE_ITEMS as item (item.id)}
|
||||
{@const enabled = attachmentMenu.isItemEnabled(item.enabledWhen)}
|
||||
{#if enabled}
|
||||
<button
|
||||
type="button"
|
||||
class={sheetItemClass}
|
||||
onclick={() => attachmentMenu.callbacks[item.action]()}
|
||||
>
|
||||
<item.icon class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>{item.label}</span>
|
||||
</button>
|
||||
{:else if item.disabledTooltip}
|
||||
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
<Tooltip.Trigger>
|
||||
<button type="button" class={sheetItemClass} disabled>
|
||||
<item.icon class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>{item.label}</span>
|
||||
</button>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content side="right">
|
||||
<p>{item.disabledTooltip}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{/if}
|
||||
{/each}
|
||||
|
||||
{#if !attachmentMenu.isItemEnabled('hasVisionModality')}
|
||||
{@const pdfItem = ATTACHMENT_FILE_ITEMS.find((i) => i.id === AttachmentMenuItemId.PDF)}
|
||||
{#if pdfItem}
|
||||
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
<Tooltip.Trigger>
|
||||
<button
|
||||
type="button"
|
||||
class={sheetItemClass}
|
||||
onclick={() => attachmentMenu.callbacks[pdfItem.action]()}
|
||||
>
|
||||
<pdfItem.icon class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>{pdfItem.label}</span>
|
||||
</button>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content side="right">
|
||||
<p>PDFs will be converted to text. Image-based PDFs may not work properly.</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{/if}
|
||||
{/if}
|
||||
|
||||
{#each ATTACHMENT_EXTRA_ITEMS as item (item.id)}
|
||||
{#if item.id === AttachmentMenuItemId.SYSTEM_MESSAGE}
|
||||
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
<Tooltip.Trigger>
|
||||
<button
|
||||
type="button"
|
||||
class={sheetItemClass}
|
||||
onclick={() => attachmentMenu.callbacks[item.action]()}
|
||||
>
|
||||
<item.icon class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>{item.label}</span>
|
||||
</button>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content side="right">
|
||||
<p>{attachmentMenu.getSystemMessageTooltip()}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{/if}
|
||||
{/each}
|
||||
|
||||
<div class="my-2 border-t"></div>
|
||||
|
||||
<ChatFormActionToolsSubmenu />
|
||||
|
||||
<ChatFormActionMcpServersSubmenu onMcpSettingsClick={handleMcpSettingsClick} />
|
||||
|
||||
{#each ATTACHMENT_MCP_ITEMS as item (item.id)}
|
||||
{#if attachmentMenu.isItemVisible(item.visibleWhen)}
|
||||
<button
|
||||
type="button"
|
||||
class={sheetItemClass}
|
||||
onclick={() => attachmentMenu.callbacks[item.action]()}
|
||||
>
|
||||
<item.icon class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>{item.label}</span>
|
||||
</button>
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
||||
</Sheet.Content>
|
||||
</Sheet.Root>
|
||||
</div>
|
||||
@@ -1,12 +1,12 @@
|
||||
<script lang="ts">
|
||||
import { goto } from '$app/navigation';
|
||||
import { base } from '$app/paths';
|
||||
import { getChatActionsContext, setMessageEditContext } from '$lib/contexts';
|
||||
import { chatStore, pendingEditMessageId } from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { DatabaseService } from '$lib/services/database.service';
|
||||
import { SYSTEM_MESSAGE_PLACEHOLDER } from '$lib/constants';
|
||||
import { MessageRole, AttachmentType } from '$lib/enums';
|
||||
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
|
||||
import {
|
||||
ChatMessageAssistant,
|
||||
ChatMessageUser,
|
||||
@@ -118,7 +118,7 @@
|
||||
const conversationDeleted = await chatStore.removeSystemPromptPlaceholder(message.id);
|
||||
|
||||
if (conversationDeleted) {
|
||||
goto(`${base}/`);
|
||||
goto(`#/`);
|
||||
}
|
||||
|
||||
return;
|
||||
@@ -138,7 +138,7 @@
|
||||
const conversationDeleted = await chatStore.removeSystemPromptPlaceholder(message.id);
|
||||
|
||||
if (conversationDeleted) {
|
||||
goto(`${base}/`);
|
||||
goto(`#/`);
|
||||
}
|
||||
} else {
|
||||
chatActions.delete(message);
|
||||
@@ -200,7 +200,7 @@
|
||||
const conversationDeleted = await chatStore.removeSystemPromptPlaceholder(message.id);
|
||||
isEditing = false;
|
||||
if (conversationDeleted) {
|
||||
goto(`${base}/`);
|
||||
goto(`#/`);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -252,70 +252,72 @@
|
||||
}
|
||||
</script>
|
||||
|
||||
{#if message.role === MessageRole.SYSTEM}
|
||||
<ChatMessageSystem
|
||||
bind:textareaElement
|
||||
class={className}
|
||||
{deletionInfo}
|
||||
{message}
|
||||
onConfirmDelete={handleConfirmDelete}
|
||||
onCopy={handleCopy}
|
||||
onDelete={handleDelete}
|
||||
onEdit={handleEdit}
|
||||
onNavigateToSibling={handleNavigateToSibling}
|
||||
onShowDeleteDialogChange={handleShowDeleteDialogChange}
|
||||
{showDeleteDialog}
|
||||
{siblingInfo}
|
||||
/>
|
||||
{:else if mcpPromptExtra}
|
||||
<ChatMessageMcpPrompt
|
||||
class={className}
|
||||
{deletionInfo}
|
||||
{message}
|
||||
mcpPrompt={mcpPromptExtra}
|
||||
onConfirmDelete={handleConfirmDelete}
|
||||
onCopy={handleCopy}
|
||||
onDelete={handleDelete}
|
||||
onEdit={handleEdit}
|
||||
onNavigateToSibling={handleNavigateToSibling}
|
||||
onShowDeleteDialogChange={handleShowDeleteDialogChange}
|
||||
{showDeleteDialog}
|
||||
{siblingInfo}
|
||||
/>
|
||||
{:else if message.role === MessageRole.USER}
|
||||
<ChatMessageUser
|
||||
class={className}
|
||||
{deletionInfo}
|
||||
{message}
|
||||
onConfirmDelete={handleConfirmDelete}
|
||||
onCopy={handleCopy}
|
||||
onDelete={handleDelete}
|
||||
onEdit={handleEdit}
|
||||
onForkConversation={handleForkConversation}
|
||||
onNavigateToSibling={handleNavigateToSibling}
|
||||
onShowDeleteDialogChange={handleShowDeleteDialogChange}
|
||||
{showDeleteDialog}
|
||||
{siblingInfo}
|
||||
/>
|
||||
{:else}
|
||||
<ChatMessageAssistant
|
||||
bind:textareaElement
|
||||
class={className}
|
||||
{deletionInfo}
|
||||
{isLastAssistantMessage}
|
||||
{message}
|
||||
{toolMessages}
|
||||
messageContent={message.content}
|
||||
onConfirmDelete={handleConfirmDelete}
|
||||
onContinue={handleContinue}
|
||||
onCopy={handleCopy}
|
||||
onDelete={handleDelete}
|
||||
onEdit={handleEdit}
|
||||
onForkConversation={handleForkConversation}
|
||||
onNavigateToSibling={handleNavigateToSibling}
|
||||
onRegenerate={handleRegenerate}
|
||||
onShowDeleteDialogChange={handleShowDeleteDialogChange}
|
||||
{showDeleteDialog}
|
||||
{siblingInfo}
|
||||
/>
|
||||
{/if}
|
||||
<div use:fadeInView>
|
||||
{#if message.role === MessageRole.SYSTEM}
|
||||
<ChatMessageSystem
|
||||
bind:textareaElement
|
||||
class={className}
|
||||
{deletionInfo}
|
||||
{message}
|
||||
onConfirmDelete={handleConfirmDelete}
|
||||
onCopy={handleCopy}
|
||||
onDelete={handleDelete}
|
||||
onEdit={handleEdit}
|
||||
onNavigateToSibling={handleNavigateToSibling}
|
||||
onShowDeleteDialogChange={handleShowDeleteDialogChange}
|
||||
{showDeleteDialog}
|
||||
{siblingInfo}
|
||||
/>
|
||||
{:else if mcpPromptExtra}
|
||||
<ChatMessageMcpPrompt
|
||||
class={className}
|
||||
{deletionInfo}
|
||||
{message}
|
||||
mcpPrompt={mcpPromptExtra}
|
||||
onConfirmDelete={handleConfirmDelete}
|
||||
onCopy={handleCopy}
|
||||
onDelete={handleDelete}
|
||||
onEdit={handleEdit}
|
||||
onNavigateToSibling={handleNavigateToSibling}
|
||||
onShowDeleteDialogChange={handleShowDeleteDialogChange}
|
||||
{showDeleteDialog}
|
||||
{siblingInfo}
|
||||
/>
|
||||
{:else if message.role === MessageRole.USER}
|
||||
<ChatMessageUser
|
||||
class={className}
|
||||
{deletionInfo}
|
||||
{message}
|
||||
onConfirmDelete={handleConfirmDelete}
|
||||
onCopy={handleCopy}
|
||||
onDelete={handleDelete}
|
||||
onEdit={handleEdit}
|
||||
onForkConversation={handleForkConversation}
|
||||
onNavigateToSibling={handleNavigateToSibling}
|
||||
onShowDeleteDialogChange={handleShowDeleteDialogChange}
|
||||
{showDeleteDialog}
|
||||
{siblingInfo}
|
||||
/>
|
||||
{:else}
|
||||
<ChatMessageAssistant
|
||||
bind:textareaElement
|
||||
class={className}
|
||||
{deletionInfo}
|
||||
{isLastAssistantMessage}
|
||||
{message}
|
||||
{toolMessages}
|
||||
messageContent={message.content}
|
||||
onConfirmDelete={handleConfirmDelete}
|
||||
onContinue={handleContinue}
|
||||
onCopy={handleCopy}
|
||||
onDelete={handleDelete}
|
||||
onEdit={handleEdit}
|
||||
onForkConversation={handleForkConversation}
|
||||
onNavigateToSibling={handleNavigateToSibling}
|
||||
onRegenerate={handleRegenerate}
|
||||
onShowDeleteDialogChange={handleShowDeleteDialogChange}
|
||||
{showDeleteDialog}
|
||||
{siblingInfo}
|
||||
/>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
+23
@@ -0,0 +1,23 @@
|
||||
<script lang="ts">
|
||||
import type { Snippet, Component } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
icon: Component<{ class?: string }>;
|
||||
message: Snippet;
|
||||
actions: Snippet;
|
||||
}
|
||||
|
||||
let { icon: Icon, message, actions }: Props = $props();
|
||||
</script>
|
||||
|
||||
<div class="my-2 rounded-lg border border-border bg-card p-3">
|
||||
<div class="mb-3 flex items-center gap-2 text-sm">
|
||||
<Icon class="h-4 w-4 shrink-0 text-muted-foreground" />
|
||||
<span>
|
||||
{@render message()}
|
||||
</span>
|
||||
</div>
|
||||
<div class="flex flex-wrap items-center gap-2">
|
||||
{@render actions()}
|
||||
</div>
|
||||
</div>
|
||||
+94
-14
@@ -1,38 +1,104 @@
|
||||
<script lang="ts">
|
||||
import { Wrench, Loader2, Brain } from '@lucide/svelte';
|
||||
import {
|
||||
ChatMessageStatistics,
|
||||
CollapsibleContentBlock,
|
||||
MarkdownContent,
|
||||
SyntaxHighlightedCode
|
||||
SyntaxHighlightedCode,
|
||||
ChatMessagePermissionRequest,
|
||||
ChatMessageContinueRequest
|
||||
} from '$lib/components/app';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { Wrench, Loader2, Brain } from '@lucide/svelte';
|
||||
import { AgenticSectionType, FileTypeText } from '$lib/enums';
|
||||
import { formatJsonPretty } from '$lib/utils';
|
||||
|
||||
import {
|
||||
AgenticSectionType,
|
||||
ChatMessageStatsView,
|
||||
FileTypeText,
|
||||
ToolPermissionDecision
|
||||
} from '$lib/enums';
|
||||
import type {
|
||||
ChatMessageAgenticTimings,
|
||||
ChatMessageAgenticTurnStats,
|
||||
DatabaseMessage
|
||||
} from '$lib/types';
|
||||
import {
|
||||
deriveAgenticSections,
|
||||
formatJsonPretty,
|
||||
parseToolResultWithImages,
|
||||
type AgenticSection,
|
||||
type ToolResultLine
|
||||
} from '$lib/utils';
|
||||
import type { DatabaseMessage } from '$lib/types/database';
|
||||
import type { ChatMessageAgenticTimings, ChatMessageAgenticTurnStats } from '$lib/types/chat';
|
||||
import { ChatMessageStatsView } from '$lib/enums';
|
||||
import {
|
||||
agenticPendingPermissionRequest,
|
||||
agenticResolvePermission,
|
||||
agenticPendingContinueRequest,
|
||||
agenticResolveContinue
|
||||
} from '$lib/stores/agentic.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
|
||||
interface Props {
|
||||
message: DatabaseMessage;
|
||||
toolMessages?: DatabaseMessage[];
|
||||
isStreaming?: boolean;
|
||||
isLastAssistantMessage?: boolean;
|
||||
highlightTurns?: boolean;
|
||||
}
|
||||
|
||||
let { message, toolMessages = [], isStreaming = false, highlightTurns = false }: Props = $props();
|
||||
let {
|
||||
message,
|
||||
toolMessages = [],
|
||||
isStreaming = false,
|
||||
isLastAssistantMessage = false,
|
||||
highlightTurns = false
|
||||
}: Props = $props();
|
||||
|
||||
let expandedStates: Record<number, boolean> = $state({});
|
||||
|
||||
const showToolCallInProgress = $derived(config().showToolCallInProgress as boolean);
|
||||
const showThoughtInProgress = $derived(config().showThoughtInProgress as boolean);
|
||||
|
||||
let permissionDismissed = $state(false);
|
||||
|
||||
const pendingPermission = $derived(
|
||||
isStreaming && isLastAssistantMessage ? agenticPendingPermissionRequest(message.convId) : null
|
||||
);
|
||||
|
||||
// Reset dismissed when pendingPermission changes (new request or cleared)
|
||||
let prevPendingRef: typeof pendingPermission = null;
|
||||
$effect(() => {
|
||||
if (pendingPermission !== prevPendingRef) {
|
||||
prevPendingRef = pendingPermission;
|
||||
if (pendingPermission) {
|
||||
permissionDismissed = false;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
function handlePermission(decision: ToolPermissionDecision) {
|
||||
permissionDismissed = true;
|
||||
agenticResolvePermission(message.convId, decision);
|
||||
}
|
||||
|
||||
let continueDismissed = $state(false);
|
||||
|
||||
const pendingContinue = $derived(
|
||||
isStreaming && isLastAssistantMessage ? agenticPendingContinueRequest(message.convId) : false
|
||||
);
|
||||
|
||||
let prevContinueRef = false;
|
||||
$effect(() => {
|
||||
if (pendingContinue !== prevContinueRef) {
|
||||
prevContinueRef = pendingContinue;
|
||||
if (pendingContinue) {
|
||||
continueDismissed = false;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
function handleContinue(shouldContinue: boolean) {
|
||||
continueDismissed = true;
|
||||
agenticResolveContinue(message.convId, shouldContinue);
|
||||
}
|
||||
|
||||
const sections = $derived(deriveAgenticSections(message, toolMessages, [], isStreaming));
|
||||
|
||||
// Parse tool results with images
|
||||
@@ -201,7 +267,11 @@
|
||||
<Loader2 class="h-3 w-3 animate-spin" />
|
||||
{/if}
|
||||
</div>
|
||||
{#if section.toolResult}
|
||||
{#if isPending}
|
||||
<div class="rounded bg-muted/30 p-2 text-xs text-muted-foreground italic">
|
||||
Waiting for result...
|
||||
</div>
|
||||
{:else if section.toolResult}
|
||||
<div class="overflow-auto rounded-lg border border-border bg-muted p-4">
|
||||
{#each section.parsedLines as line, i (i)}
|
||||
<div class="font-mono text-xs leading-relaxed whitespace-pre-wrap">{line.text}</div>
|
||||
@@ -215,10 +285,8 @@
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
||||
{:else if isPending}
|
||||
<div class="rounded bg-muted/30 p-2 text-xs text-muted-foreground italic">
|
||||
Waiting for result...
|
||||
</div>
|
||||
{:else}
|
||||
<div class="rounded bg-muted/30 p-2 text-xs text-muted-foreground italic">No output</div>
|
||||
{/if}
|
||||
</div>
|
||||
</CollapsibleContentBlock>
|
||||
@@ -289,6 +357,18 @@
|
||||
{@render renderSection(section, index)}
|
||||
{/each}
|
||||
{/if}
|
||||
|
||||
{#if pendingPermission && !permissionDismissed}
|
||||
<ChatMessagePermissionRequest
|
||||
toolName={pendingPermission.toolName}
|
||||
serverLabel={pendingPermission.serverLabel}
|
||||
onDecision={handlePermission}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
{#if pendingContinue && !continueDismissed}
|
||||
<ChatMessageContinueRequest onDecision={handleContinue} />
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<style>
|
||||
|
||||
+4
-3
@@ -4,7 +4,7 @@
|
||||
ChatMessageActions,
|
||||
ChatMessageStatistics,
|
||||
ModelBadge,
|
||||
ModelsSelector
|
||||
ModelsSelectorDropdown
|
||||
} from '$lib/components/app';
|
||||
import { getMessageEditContext } from '$lib/contexts';
|
||||
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
|
||||
@@ -308,6 +308,7 @@
|
||||
{message}
|
||||
{toolMessages}
|
||||
isStreaming={isChatStreaming()}
|
||||
{isLastAssistantMessage}
|
||||
highlightTurns={highlightAgenticTurns}
|
||||
/>
|
||||
{/if}
|
||||
@@ -336,10 +337,10 @@
|
||||
class="inline-flex flex-wrap items-start gap-2 text-xs text-muted-foreground"
|
||||
>
|
||||
{#if isRouter}
|
||||
<ModelsSelector
|
||||
<ModelsSelectorDropdown
|
||||
currentModel={displayedModel}
|
||||
disabled={isLoading()}
|
||||
onModelChange={async (modelId, modelName) => {
|
||||
onModelChange={async (modelId: string, modelName: string) => {
|
||||
const status = modelsStore.getModelStatus(modelId);
|
||||
|
||||
if (status !== ServerModelStatus.LOADED) {
|
||||
|
||||
+30
@@ -0,0 +1,30 @@
|
||||
<script lang="ts">
|
||||
import { RotateCw } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import ChatMessageActionCard from './ChatMessageActionCard.svelte';
|
||||
|
||||
interface Props {
|
||||
onDecision: (shouldContinue: boolean) => void;
|
||||
}
|
||||
|
||||
let { onDecision }: Props = $props();
|
||||
</script>
|
||||
|
||||
<ChatMessageActionCard icon={RotateCw}>
|
||||
{#snippet message()}
|
||||
Agentic turn limit reached. Continue?
|
||||
{/snippet}
|
||||
|
||||
{#snippet actions()}
|
||||
<Button size="sm" onclick={() => onDecision(true)}>Continue</Button>
|
||||
|
||||
<Button
|
||||
variant="destructive"
|
||||
size="sm"
|
||||
class="text-destructive hover:text-destructive"
|
||||
onclick={() => onDecision(false)}
|
||||
>
|
||||
Stop
|
||||
</Button>
|
||||
{/snippet}
|
||||
</ChatMessageActionCard>
|
||||
+88
@@ -0,0 +1,88 @@
|
||||
<script lang="ts">
|
||||
import { ChevronDown, ShieldQuestion } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as ButtonGroup from '$lib/components/ui/button-group';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import { ToolSource, ToolPermissionDecision } from '$lib/enums';
|
||||
import { TOOL_SERVER_LABELS } from '$lib/constants';
|
||||
import { toolsStore } from '$lib/stores/tools.svelte';
|
||||
import ChatMessageActionCard from './ChatMessageActionCard.svelte';
|
||||
|
||||
interface Props {
|
||||
toolName: string;
|
||||
serverLabel: string;
|
||||
onDecision: (decision: ToolPermissionDecision) => void;
|
||||
}
|
||||
|
||||
let { toolName, serverLabel, onDecision }: Props = $props();
|
||||
</script>
|
||||
|
||||
<ChatMessageActionCard icon={ShieldQuestion}>
|
||||
{#snippet message()}
|
||||
Allow use of
|
||||
|
||||
<span class="font-semibold">{toolName}</span>
|
||||
|
||||
{#if serverLabel}
|
||||
from <span class="font-semibold">{serverLabel}</span>
|
||||
{/if}
|
||||
|
||||
?
|
||||
{/snippet}
|
||||
|
||||
{#snippet actions()}
|
||||
<DropdownMenu.Root>
|
||||
<ButtonGroup.Root
|
||||
class="overflow-hidden rounded-md bg-foreground text-white shadow-sm dark:bg-secondary dark:text-foreground"
|
||||
>
|
||||
<Button
|
||||
class="rounded-none! shadow-none!"
|
||||
size="sm"
|
||||
onclick={() => onDecision(ToolPermissionDecision.ONCE)}
|
||||
>
|
||||
Allow once
|
||||
</Button>
|
||||
|
||||
<ButtonGroup.Separator />
|
||||
|
||||
<DropdownMenu.Trigger>
|
||||
<Button size="sm" class="rounded-none! !ps-2 shadow-none!">
|
||||
<ChevronDown class="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
</DropdownMenu.Trigger>
|
||||
</ButtonGroup.Root>
|
||||
|
||||
<DropdownMenu.Content align="start" class="min-w-[8rem]">
|
||||
<DropdownMenu.Item onclick={() => onDecision(ToolPermissionDecision.ALWAYS)}>
|
||||
Always allow <pre>{toolName}</pre>
|
||||
tool
|
||||
</DropdownMenu.Item>
|
||||
{#if serverLabel}
|
||||
<DropdownMenu.Item onclick={() => onDecision(ToolPermissionDecision.ALWAYS_SERVER)}>
|
||||
Always allow all tools from {serverLabel}
|
||||
</DropdownMenu.Item>
|
||||
{:else}
|
||||
{@const source = toolsStore.getToolSource(toolName)}
|
||||
{@const providerName =
|
||||
source === ToolSource.BUILTIN
|
||||
? TOOL_SERVER_LABELS[ToolSource.BUILTIN]
|
||||
: source === ToolSource.CUSTOM
|
||||
? TOOL_SERVER_LABELS[ToolSource.CUSTOM]
|
||||
: 'MCP Tools'}
|
||||
<DropdownMenu.Item onclick={() => onDecision(ToolPermissionDecision.ALWAYS_SERVER)}>
|
||||
Approve all tools from {providerName}
|
||||
</DropdownMenu.Item>
|
||||
{/if}
|
||||
</DropdownMenu.Content>
|
||||
</DropdownMenu.Root>
|
||||
|
||||
<Button
|
||||
variant="destructive"
|
||||
size="sm"
|
||||
class="text-destructive hover:text-destructive"
|
||||
onclick={() => onDecision(ToolPermissionDecision.DENY)}
|
||||
>
|
||||
Deny
|
||||
</Button>
|
||||
{/snippet}
|
||||
</ChatMessageActionCard>
|
||||
@@ -1,11 +1,9 @@
|
||||
<script lang="ts">
|
||||
import { Card } from '$lib/components/ui/card';
|
||||
import { ChatAttachmentsList, MarkdownContent } from '$lib/components/app';
|
||||
import { getMessageEditContext } from '$lib/contexts';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import ChatMessageActions from './ChatMessageActions.svelte';
|
||||
import ChatMessageEditForm from './ChatMessageEditForm.svelte';
|
||||
import { MessageRole } from '$lib/enums';
|
||||
import ChatMessageUserBubble from './ChatMessageUserBubble.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -44,34 +42,6 @@
|
||||
|
||||
// Get contexts
|
||||
const editCtx = getMessageEditContext();
|
||||
|
||||
let isMultiline = $state(false);
|
||||
let messageElement: HTMLElement | undefined = $state();
|
||||
const currentConfig = config();
|
||||
|
||||
$effect(() => {
|
||||
if (!messageElement || !message.content.trim()) return;
|
||||
|
||||
if (message.content.includes('\n')) {
|
||||
isMultiline = true;
|
||||
return;
|
||||
}
|
||||
|
||||
const resizeObserver = new ResizeObserver((entries) => {
|
||||
for (const entry of entries) {
|
||||
const element = entry.target as HTMLElement;
|
||||
const estimatedSingleLineHeight = 24; // Typical line height for text-md
|
||||
|
||||
isMultiline = element.offsetHeight > estimatedSingleLineHeight * 1.5;
|
||||
}
|
||||
});
|
||||
|
||||
resizeObserver.observe(messageElement);
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
};
|
||||
});
|
||||
</script>
|
||||
|
||||
<div
|
||||
@@ -82,29 +52,11 @@
|
||||
{#if editCtx.isEditing}
|
||||
<ChatMessageEditForm />
|
||||
{:else}
|
||||
{#if message.extra && message.extra.length > 0}
|
||||
<div class="mb-2 max-w-[80%]">
|
||||
<ChatAttachmentsList attachments={message.extra} readonly imageHeight="h-80" />
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if message.content.trim()}
|
||||
<Card
|
||||
class="max-w-[80%] overflow-y-auto rounded-[1.125rem] border-none bg-primary/5 px-3.75 py-1.5 text-foreground backdrop-blur-md data-[multiline]:py-2.5 dark:bg-primary/15"
|
||||
data-multiline={isMultiline ? '' : undefined}
|
||||
style="max-height: var(--max-message-height); overflow-wrap: anywhere; word-break: break-word;"
|
||||
>
|
||||
{#if currentConfig.renderUserContentAsMarkdown}
|
||||
<div bind:this={messageElement}>
|
||||
<MarkdownContent class="markdown-user-content -my-4" content={message.content} />
|
||||
</div>
|
||||
{:else}
|
||||
<span bind:this={messageElement} class="text-md whitespace-pre-wrap">
|
||||
{message.content}
|
||||
</span>
|
||||
{/if}
|
||||
</Card>
|
||||
{/if}
|
||||
<ChatMessageUserBubble
|
||||
content={message.content}
|
||||
attachments={message.extra}
|
||||
renderMarkdown={true}
|
||||
/>
|
||||
|
||||
{#if message.timestamp}
|
||||
<div class="max-w-[80%]">
|
||||
|
||||
+76
@@ -0,0 +1,76 @@
|
||||
<script lang="ts">
|
||||
import { Card } from '$lib/components/ui/card';
|
||||
import { ChatAttachmentsList, MarkdownContent } from '$lib/components/app';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import type { DatabaseMessageExtra } from '$lib/types/database';
|
||||
|
||||
interface Props {
|
||||
content: string;
|
||||
attachments?: DatabaseMessageExtra[];
|
||||
renderMarkdown?: boolean;
|
||||
textColorClass?: string;
|
||||
cardBgClass?: string;
|
||||
maxHeightStyle?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
content,
|
||||
attachments = [],
|
||||
renderMarkdown = false,
|
||||
textColorClass = 'text-foreground',
|
||||
cardBgClass = 'dark:bg-primary/15',
|
||||
maxHeightStyle = 'max-height: var(--max-message-height);'
|
||||
}: Props = $props();
|
||||
|
||||
let isMultiline = $state(false);
|
||||
let messageElement: HTMLElement | undefined = $state();
|
||||
const currentConfig = config();
|
||||
|
||||
$effect(() => {
|
||||
if (!messageElement || !content.trim()) return;
|
||||
|
||||
if (content.includes('\n')) {
|
||||
isMultiline = true;
|
||||
return;
|
||||
}
|
||||
|
||||
const resizeObserver = new ResizeObserver((entries) => {
|
||||
for (const entry of entries) {
|
||||
const element = entry.target as HTMLElement;
|
||||
const estimatedSingleLineHeight = 24; // Typical line height for text-md
|
||||
|
||||
isMultiline = element.offsetHeight > estimatedSingleLineHeight * 1.5;
|
||||
}
|
||||
});
|
||||
|
||||
resizeObserver.observe(messageElement);
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
};
|
||||
});
|
||||
</script>
|
||||
|
||||
{#if attachments && attachments.length > 0}
|
||||
<div class="mb-2 max-w-[80%]">
|
||||
<ChatAttachmentsList {attachments} readonly imageHeight="h-40" />
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if content.trim()}
|
||||
<Card
|
||||
class="max-w-[80%] overflow-y-auto rounded-[1.125rem] border-none bg-primary/5 px-3.75 py-1.5 {textColorClass} backdrop-blur-md data-[multiline]:py-2.5 {cardBgClass}"
|
||||
data-multiline={isMultiline ? '' : undefined}
|
||||
style="{maxHeightStyle} overflow-wrap: anywhere; word-break: break-word;"
|
||||
>
|
||||
{#if renderMarkdown && currentConfig.renderUserContentAsMarkdown}
|
||||
<div bind:this={messageElement}>
|
||||
<MarkdownContent class="markdown-user-content -my-4" {content} />
|
||||
</div>
|
||||
{:else}
|
||||
<span bind:this={messageElement} class="text-md whitespace-pre-wrap">
|
||||
{content}
|
||||
</span>
|
||||
{/if}
|
||||
</Card>
|
||||
{/if}
|
||||
+71
@@ -0,0 +1,71 @@
|
||||
<script lang="ts">
|
||||
import { ActionIcon } from '$lib/components/app';
|
||||
import ChatMessageEditForm from './ChatMessageEditForm.svelte';
|
||||
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
|
||||
import { ArrowUp, Edit, Trash2 } from '@lucide/svelte';
|
||||
import { getProcessingInfoContext } from '$lib/contexts';
|
||||
import { useMessageEditContext } from '$lib/hooks/use-message-edit-context.svelte';
|
||||
import ChatMessageUserBubble from './ChatMessageUserBubble.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
content: string;
|
||||
extras?: DatabaseMessageExtra[];
|
||||
onSendImmediately: () => void;
|
||||
onEdit: (newContent: string, extras?: DatabaseMessageExtra[]) => void;
|
||||
onDelete: () => void;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
content,
|
||||
extras = [],
|
||||
onSendImmediately,
|
||||
onEdit,
|
||||
onDelete
|
||||
}: Props = $props();
|
||||
|
||||
const processingInfoCtx = getProcessingInfoContext();
|
||||
let showProcessingInfo = $derived(processingInfoCtx.showProcessingInfo);
|
||||
|
||||
const editCtx = useMessageEditContext({
|
||||
getContent: () => content,
|
||||
getExtras: () => extras,
|
||||
onSave: (content, extras) => onEdit(content, extras)
|
||||
});
|
||||
</script>
|
||||
|
||||
<div
|
||||
use:fadeInView
|
||||
aria-label="Pending user message"
|
||||
class="group flex flex-col items-end gap-3 transition-opacity hover:opacity-80 md:gap-2 {className} sticky {showProcessingInfo
|
||||
? 'bottom-44'
|
||||
: 'bottom-32'}"
|
||||
role="group"
|
||||
>
|
||||
{#if editCtx.isEditing}
|
||||
<ChatMessageEditForm />
|
||||
{:else}
|
||||
<ChatMessageUserBubble
|
||||
{content}
|
||||
attachments={extras}
|
||||
textColorClass="text-muted-foreground"
|
||||
cardBgClass="dark:bg-primary/8"
|
||||
maxHeightStyle="overflow-wrap: anywhere; word-break: break-word;"
|
||||
/>
|
||||
|
||||
<div class="max-w-[80%]">
|
||||
<div class="relative flex h-6 items-center justify-between">
|
||||
<div class="right-0 flex items-center gap-2 opacity-100 transition-opacity">
|
||||
<div
|
||||
class="pointer-events-auto inset-0 flex items-center gap-1 opacity-0 transition-all duration-150 group-hover:opacity-100"
|
||||
>
|
||||
<ActionIcon icon={Edit} tooltip="Edit" onclick={editCtx.handleEdit} />
|
||||
<ActionIcon icon={Trash2} tooltip="Delete" onclick={onDelete} />
|
||||
<ActionIcon icon={ArrowUp} tooltip="Send immediately" onclick={onSendImmediately} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
@@ -1,11 +1,23 @@
|
||||
<script lang="ts">
|
||||
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
|
||||
import { ChatMessage } from '$lib/components/app';
|
||||
import ChatMessageUserPending from './ChatMessageUserPending.svelte';
|
||||
import { setChatActionsContext } from '$lib/contexts';
|
||||
import { MessageRole } from '$lib/enums';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import {
|
||||
chatPendingMessageContent,
|
||||
chatPendingMessageExtras,
|
||||
chatClearPendingMessage,
|
||||
chatInjectPendingMessage
|
||||
} from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore, activeConversation } from '$lib/stores/conversations.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import {
|
||||
agenticPendingSteeringMessageContent,
|
||||
agenticPendingSteeringMessageExtras,
|
||||
agenticClearSteeringMessage,
|
||||
agenticInjectSteeringMessage
|
||||
} from '$lib/stores/agentic.svelte';
|
||||
import {
|
||||
copyToClipboard,
|
||||
formatMessageForClipboard,
|
||||
@@ -14,12 +26,11 @@
|
||||
} from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
messages?: DatabaseMessage[];
|
||||
onUserAction?: () => void;
|
||||
}
|
||||
|
||||
let { class: className, messages = [], onUserAction }: Props = $props();
|
||||
let { messages = [], onUserAction }: Props = $props();
|
||||
|
||||
let allConversationMessages = $state<DatabaseMessage[]>([]);
|
||||
const currentConfig = config();
|
||||
@@ -196,19 +207,42 @@
|
||||
});
|
||||
</script>
|
||||
|
||||
<div
|
||||
class="flex h-full flex-col space-y-10 pt-24 {className}"
|
||||
style="height: auto; min-height: calc(100dvh - 14rem);"
|
||||
>
|
||||
{#each displayMessages as { message, toolMessages, isLastAssistantMessage, siblingInfo } (message.id)}
|
||||
<div use:fadeInView>
|
||||
<ChatMessage
|
||||
class="mx-auto w-full max-w-[48rem]"
|
||||
{message}
|
||||
{toolMessages}
|
||||
{isLastAssistantMessage}
|
||||
{siblingInfo}
|
||||
/>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{#each displayMessages as { message, toolMessages, isLastAssistantMessage, siblingInfo } (message.id)}
|
||||
<ChatMessage
|
||||
class="mx-auto mt-12 w-full max-w-[48rem]"
|
||||
{message}
|
||||
{toolMessages}
|
||||
{isLastAssistantMessage}
|
||||
{siblingInfo}
|
||||
/>
|
||||
{/each}
|
||||
|
||||
{#if activeConversation() && agenticPendingSteeringMessageContent(activeConversation()!.id)}
|
||||
{@const convId = activeConversation()!.id}
|
||||
{@const pendingContent = agenticPendingSteeringMessageContent(convId)}
|
||||
|
||||
{#if pendingContent}
|
||||
<ChatMessageUserPending
|
||||
class="mx-auto mt-12 w-full max-w-[48rem]"
|
||||
content={pendingContent}
|
||||
extras={agenticPendingSteeringMessageExtras(convId)}
|
||||
onSendImmediately={() => chatStore.abortCurrentFlow(convId)}
|
||||
onEdit={(newContent, extras) => agenticInjectSteeringMessage(convId, newContent, extras)}
|
||||
onDelete={() => agenticClearSteeringMessage(convId)}
|
||||
/>
|
||||
{/if}
|
||||
{:else if activeConversation() && chatPendingMessageContent(activeConversation()!.id)}
|
||||
{@const convId = activeConversation()!.id}
|
||||
{@const pendingContent = chatPendingMessageContent(convId)}
|
||||
|
||||
{#if pendingContent}
|
||||
<ChatMessageUserPending
|
||||
class="mx-auto mt-12 w-full max-w-[48rem]"
|
||||
content={pendingContent}
|
||||
extras={chatPendingMessageExtras(convId)}
|
||||
onSendImmediately={() => chatStore.abortCurrentFlow(convId)}
|
||||
onEdit={(newContent, extras) => chatInjectPendingMessage(convId, newContent, extras)}
|
||||
onDelete={() => chatClearPendingMessage(convId)}
|
||||
/>
|
||||
{/if}
|
||||
{/if}
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
import { afterNavigate } from '$app/navigation';
|
||||
import {
|
||||
ChatScreenForm,
|
||||
ChatScreenHeader,
|
||||
ChatMessages,
|
||||
ChatScreenProcessingInfo,
|
||||
DialogEmptyFileAlert,
|
||||
@@ -12,15 +11,16 @@
|
||||
} from '$lib/components/app';
|
||||
import * as Alert from '$lib/components/ui/alert';
|
||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
import { KeyboardKey } from '$lib/enums';
|
||||
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
|
||||
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
|
||||
import {
|
||||
chatStore,
|
||||
errorDialog,
|
||||
isLoading,
|
||||
isChatStreaming,
|
||||
isEditing,
|
||||
getAddFilesHandler
|
||||
getAddFilesHandler,
|
||||
activeProcessingState
|
||||
} from '$lib/stores/chat.svelte';
|
||||
import {
|
||||
conversationsStore,
|
||||
@@ -34,9 +34,11 @@
|
||||
import { parseFilesToMessageExtras, processFilesToChatUploaded } from '$lib/utils/browser-only';
|
||||
import { ErrorDialogType } from '$lib/enums';
|
||||
import { onMount } from 'svelte';
|
||||
import { fade, fly, slide } from 'svelte/transition';
|
||||
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
|
||||
import { Trash2, AlertTriangle, RefreshCw } from '@lucide/svelte';
|
||||
import ChatScreenDragOverlay from './ChatScreenDragOverlay.svelte';
|
||||
import { page } from '$app/state';
|
||||
import { setProcessingInfoContext } from '$lib/contexts';
|
||||
|
||||
let { showCenteredEmpty = false } = $props();
|
||||
|
||||
@@ -79,6 +81,18 @@
|
||||
|
||||
let isCurrentConversationLoading = $derived(isLoading() || isChatStreaming());
|
||||
|
||||
let showProcessingInfo = $derived(
|
||||
isCurrentConversationLoading ||
|
||||
(config().keepStatsVisible && !!page.params.id) ||
|
||||
activeProcessingState() !== null
|
||||
);
|
||||
|
||||
setProcessingInfoContext({
|
||||
get showProcessingInfo() {
|
||||
return showProcessingInfo;
|
||||
}
|
||||
});
|
||||
|
||||
let isRouter = $derived(isRouterMode());
|
||||
|
||||
let conversationModel = $derived(
|
||||
@@ -208,20 +222,13 @@
|
||||
processFiles(files);
|
||||
}
|
||||
|
||||
function handleKeydown(event: KeyboardEvent) {
|
||||
const isCtrlOrCmd = event.ctrlKey || event.metaKey;
|
||||
|
||||
if (
|
||||
isCtrlOrCmd &&
|
||||
event.shiftKey &&
|
||||
(event.key === KeyboardKey.D_LOWER || event.key === KeyboardKey.D_UPPER)
|
||||
) {
|
||||
event.preventDefault();
|
||||
const { handleKeydown } = useKeyboardShortcuts({
|
||||
deleteActiveConversation: () => {
|
||||
if (activeConversation()) {
|
||||
showDeleteDialog = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
async function handleSystemPromptAdd(draft: { message: string; files: ChatUploadedFile[] }) {
|
||||
if (draft.message || draft.files.length > 0) {
|
||||
@@ -342,9 +349,9 @@
|
||||
|
||||
<svelte:window onkeydown={handleKeydown} />
|
||||
|
||||
<ChatScreenHeader />
|
||||
|
||||
{#if !isEmpty}
|
||||
{#if isServerLoading}
|
||||
<ServerLoadingSplash />
|
||||
{:else}
|
||||
<div
|
||||
bind:this={chatScrollContainer}
|
||||
aria-label="Chat interface with file drop zone"
|
||||
@@ -356,26 +363,42 @@
|
||||
onscroll={handleScroll}
|
||||
role="main"
|
||||
>
|
||||
<div class="flex flex-col">
|
||||
<ChatMessages
|
||||
class="mb-16 md:mb-24"
|
||||
messages={activeMessages()}
|
||||
onUserAction={() => {
|
||||
autoScroll.enable();
|
||||
autoScroll.scrollToBottom();
|
||||
}}
|
||||
/>
|
||||
<div class="flex grow flex-col pt-14">
|
||||
{#if !isEmpty}
|
||||
<ChatMessages
|
||||
messages={activeMessages()}
|
||||
onUserAction={() => {
|
||||
autoScroll.enable();
|
||||
autoScroll.scrollToBottom();
|
||||
}}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
<div
|
||||
class="pointer-events-none sticky right-0 bottom-4 left-0 mt-auto"
|
||||
in:slide={{ duration: 150, axis: 'y' }}
|
||||
class="pointer-events-none {isEmpty
|
||||
? 'absolute bottom-[calc(50dvh-7rem)]'
|
||||
: 'sticky bottom-4'} right-4 left-4 mt-auto pt-16 transition-all duration-200"
|
||||
>
|
||||
<ChatScreenProcessingInfo />
|
||||
{#if isEmpty}
|
||||
<div class="mb-8 px-4 text-center" use:fadeInView={{ duration: 300 }}>
|
||||
<h1 class="mb-2 text-2xl font-semibold tracking-tight md:text-3xl">Hello there</h1>
|
||||
|
||||
<p class="text-muted-foreground md:text-lg">
|
||||
{serverStore.props?.modalities?.audio
|
||||
? 'Record audio, type a message '
|
||||
: 'Type a message'} or upload files to get started
|
||||
</p>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if page.params.id}
|
||||
<ChatScreenProcessingInfo />
|
||||
{/if}
|
||||
|
||||
{#if hasPropsError}
|
||||
<div
|
||||
class="pointer-events-auto mx-auto mb-4 max-w-[48rem] px-1"
|
||||
in:fly={{ y: 10, duration: 250 }}
|
||||
use:fadeInView={{ y: 10, duration: 250 }}
|
||||
>
|
||||
<Alert.Root variant="destructive">
|
||||
<AlertTriangle class="h-4 w-4" />
|
||||
@@ -412,69 +435,6 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{:else if isServerLoading}
|
||||
<!-- Server Loading State -->
|
||||
<ServerLoadingSplash />
|
||||
{:else}
|
||||
<div
|
||||
aria-label="Welcome screen with file drop zone"
|
||||
class="flex h-full items-center justify-center"
|
||||
ondragenter={handleDragEnter}
|
||||
ondragleave={handleDragLeave}
|
||||
ondragover={handleDragOver}
|
||||
ondrop={handleDrop}
|
||||
role="main"
|
||||
>
|
||||
<div class="w-full max-w-[48rem] px-4">
|
||||
<div class="mb-10 text-center" in:fade={{ duration: 300 }}>
|
||||
<h1 class="mb-2 text-2xl font-semibold tracking-tight md:text-3xl">llama.cpp</h1>
|
||||
|
||||
<p class="text-muted-foreground md:text-lg">
|
||||
{serverStore.props?.modalities?.audio
|
||||
? 'Record audio, type a message '
|
||||
: 'Type a message'} or upload files to get started
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{#if hasPropsError}
|
||||
<div class="mb-4" in:fly={{ y: 10, duration: 250 }}>
|
||||
<Alert.Root variant="destructive">
|
||||
<AlertTriangle class="h-4 w-4" />
|
||||
|
||||
<Alert.Title class="flex items-center justify-between">
|
||||
<span>Server unavailable</span>
|
||||
|
||||
<button
|
||||
onclick={() => serverStore.fetch()}
|
||||
disabled={isServerLoading}
|
||||
class="flex items-center gap-1.5 rounded-lg bg-destructive/20 px-2 py-1 text-xs font-medium hover:bg-destructive/30 disabled:opacity-50"
|
||||
>
|
||||
<RefreshCw class="h-3 w-3 {isServerLoading ? 'animate-spin' : ''}" />
|
||||
{isServerLoading ? 'Retrying...' : 'Retry'}
|
||||
</button>
|
||||
</Alert.Title>
|
||||
|
||||
<Alert.Description>{serverError()}</Alert.Description>
|
||||
</Alert.Root>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div in:fly={{ y: 10, duration: 250, delay: hasPropsError ? 0 : 300 }}>
|
||||
<ChatScreenForm
|
||||
disabled={hasPropsError}
|
||||
{initialMessage}
|
||||
isLoading={isCurrentConversationLoading}
|
||||
onFileRemove={handleFileRemove}
|
||||
onFileUpload={handleFileUpload}
|
||||
onSend={handleSendMessage}
|
||||
onStop={() => chatStore.stopGeneration()}
|
||||
onSystemPromptAdd={handleSystemPromptAdd}
|
||||
showHelperText
|
||||
bind:uploadedFiles
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- File Upload Error Alert Dialog -->
|
||||
@@ -575,21 +535,3 @@
|
||||
open={Boolean(activeErrorDialog)}
|
||||
type={activeErrorDialog?.type ?? ErrorDialogType.SERVER}
|
||||
/>
|
||||
|
||||
<style>
|
||||
.conversation-chat-form {
|
||||
position: relative;
|
||||
|
||||
&::after {
|
||||
content: '';
|
||||
position: absolute;
|
||||
bottom: 0;
|
||||
z-index: -1;
|
||||
left: 0;
|
||||
right: 0;
|
||||
width: 100%;
|
||||
height: 2.375rem;
|
||||
background-color: var(--background);
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user