mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-07-01 01:57:43 +02:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ecbf01d441 | |||
| 1025fd2c09 | |||
| c7358ddf64 | |||
| d284baf1b5 | |||
| bd90fc74c3 | |||
| ce38a4db47 | |||
| 4fdbc1e4db | |||
| 7b7ae857f6 | |||
| 84b0a98319 | |||
| b45ef2702c | |||
| f3dd7b8e68 | |||
| eed25bc6b0 | |||
| b33df266d0 | |||
| 3bcc990997 | |||
| d4964a7c66 | |||
| 50e8962f79 | |||
| f6b533d898 | |||
| 72d3b1898a | |||
| ebf5725870 | |||
| 0cd7032ca4 |
@@ -21,7 +21,8 @@ on:
|
||||
'**/*.m',
|
||||
'**/*.metal',
|
||||
'**/*.comp',
|
||||
'**/*.glsl'
|
||||
'**/*.glsl',
|
||||
'**/*.wgsl'
|
||||
]
|
||||
|
||||
pull_request:
|
||||
@@ -42,7 +43,8 @@ on:
|
||||
'**/*.m',
|
||||
'**/*.metal',
|
||||
'**/*.comp',
|
||||
'**/*.glsl'
|
||||
'**/*.glsl',
|
||||
'**/*.wgsl'
|
||||
]
|
||||
|
||||
concurrency:
|
||||
@@ -1371,7 +1373,7 @@ jobs:
|
||||
id: update_presets
|
||||
if: ${{ matrix.build == 'arm64-snapdragon' }}
|
||||
run: |
|
||||
cp docs/backend/hexagon/CMakeUserPresets.json .
|
||||
cp docs/backend/snapdragon/CMakeUserPresets.json .
|
||||
|
||||
- name: Build
|
||||
id: ndk_build
|
||||
|
||||
@@ -28,16 +28,17 @@ jobs:
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
});
|
||||
console.log("Latest release:", releases[0].tag_name);
|
||||
return releases[0].tag_name;
|
||||
const { tag_name: version, assets: assets } = releases.find(({assets}) => assets.find(asset => asset.name.includes('win-vulkan')));
|
||||
const { browser_download_url: asset_url } = assets.find(asset => asset.name.includes('win-vulkan'));
|
||||
console.log("Latest release:", version);
|
||||
core.setOutput('VERSION', version);
|
||||
core.setOutput('ASSETURL', asset_url);
|
||||
|
||||
- name: Update manifest
|
||||
env:
|
||||
VERSION: ${{ steps.find_latest_release.outputs.result }}
|
||||
run: |
|
||||
echo "Updating manifest..."
|
||||
komac update --version ${{ env.VERSION }} \
|
||||
--urls "https://github.com/ggml-org/llama.cpp/releases/download/${{ env.VERSION }}/llama-${{ env.VERSION }}-bin-win-vulkan-x64.zip" \
|
||||
komac update --version ${{ steps.find_latest_release.outputs.VERSION }} \
|
||||
--urls "${{ steps.find_latest_release.outputs.ASSETURL }}" \
|
||||
--token ${{ secrets.WINGET_GITHUB_TOKEN }} \
|
||||
--submit \
|
||||
ggml.llamacpp
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
/common/jinja/ @ngxson @CISC @aldehir
|
||||
/common/llguidance.* @ggerganov
|
||||
/common/log.* @ggerganov
|
||||
/common/ngram-map.* @srogmann
|
||||
/common/peg-parser.* @aldehir
|
||||
/common/sampling.* @ggerganov
|
||||
/common/speculative.* @ggerganov
|
||||
|
||||
@@ -73,6 +73,8 @@ add_library(${TARGET} STATIC
|
||||
log.h
|
||||
ngram-cache.cpp
|
||||
ngram-cache.h
|
||||
ngram-map.cpp
|
||||
ngram-map.h
|
||||
peg-parser.cpp
|
||||
peg-parser.h
|
||||
preset.cpp
|
||||
|
||||
+75
-14
@@ -6,6 +6,7 @@
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
#include "speculative.h"
|
||||
#include "preset.h"
|
||||
|
||||
// fix problem with std::min and std::max
|
||||
@@ -579,14 +580,14 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
params.mmproj = res.mmproj;
|
||||
}
|
||||
// only download mmproj if the current example is using it
|
||||
for (auto & ex : mmproj_examples) {
|
||||
for (const auto & ex : mmproj_examples) {
|
||||
if (ctx_arg.ex == ex) {
|
||||
common_params_handle_model(params.mmproj, params.hf_token, params.offline);
|
||||
break;
|
||||
}
|
||||
}
|
||||
common_params_handle_model(params.speculative.model, params.hf_token, params.offline);
|
||||
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
|
||||
common_params_handle_model(params.speculative.mparams_dft, params.hf_token, params.offline);
|
||||
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
|
||||
}
|
||||
|
||||
// model is required (except for server)
|
||||
@@ -1216,16 +1217,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"-lcs", "--lookup-cache-static"}, "FNAME",
|
||||
"path to static lookup cache to use for lookup decoding (not updated by generation)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.lookup_cache_static = value;
|
||||
params.speculative.lookup_cache_static = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_LOOKUP}));
|
||||
).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"-lcd", "--lookup-cache-dynamic"}, "FNAME",
|
||||
"path to dynamic lookup cache to use for lookup decoding (updated by generation)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.lookup_cache_dynamic = value;
|
||||
params.speculative.lookup_cache_dynamic = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_LOOKUP}));
|
||||
).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"-c", "--ctx-size"}, "N",
|
||||
string_format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx),
|
||||
@@ -1300,7 +1301,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, bool value) {
|
||||
params.kv_unified = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED}));
|
||||
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH}));
|
||||
add_opt(common_arg(
|
||||
{"--context-shift"},
|
||||
{"--no-context-shift"},
|
||||
@@ -2563,7 +2564,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"-hfd", "-hfrd", "--hf-repo-draft"}, "<user>/<model>[:quant]",
|
||||
"Same as --hf-repo, but for the draft model (default: unused)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.model.hf_repo = value;
|
||||
params.speculative.mparams_dft.hf_repo = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_HFD_REPO"));
|
||||
add_opt(common_arg(
|
||||
@@ -3384,7 +3385,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"-md", "--model-draft"}, "FNAME",
|
||||
"draft model for speculative decoding (default: unused)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.model.path = value;
|
||||
params.speculative.mparams_dft.path = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_MODEL_DRAFT"));
|
||||
add_opt(common_arg(
|
||||
@@ -3394,6 +3395,66 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.speculative.replacements.push_back({ tgt, dft });
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v]",
|
||||
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
|
||||
common_speculative_type_to_str(params.speculative.type).c_str()),
|
||||
[](common_params & params, const std::string & value) {
|
||||
if (value == "none") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
|
||||
} else if (value == "ngram-cache") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
|
||||
} else if (value == "ngram-simple") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE;
|
||||
} else if (value == "ngram-map-k") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K;
|
||||
} else if (value == "ngram-map-k4v") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
|
||||
} else {
|
||||
throw std::invalid_argument("unknown speculative decoding type without draft model");
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-size-n"}, "N",
|
||||
string_format("ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)", params.speculative.ngram_size_n),
|
||||
[](common_params & params, int value) {
|
||||
if (value < 1 || value > 1024) {
|
||||
throw std::invalid_argument("ngram size N must be between 1 and 1024 inclusive");
|
||||
}
|
||||
params.speculative.ngram_size_n = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-size-m"}, "N",
|
||||
string_format("ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)", params.speculative.ngram_size_m),
|
||||
[](common_params & params, int value) {
|
||||
if (value < 1 || value > 1024) {
|
||||
throw std::invalid_argument("ngram size M must be between 1 and 1024 inclusive");
|
||||
}
|
||||
params.speculative.ngram_size_m = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-check-rate"}, "N",
|
||||
string_format("ngram check rate for ngram-simple/ngram-map speculative decoding (default: %d)", params.speculative.ngram_check_rate),
|
||||
[](common_params & params, int value) {
|
||||
if (value < 1) {
|
||||
throw std::invalid_argument("ngram check rate must be at least 1");
|
||||
}
|
||||
params.speculative.ngram_check_rate = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-min-hits"}, "N",
|
||||
string_format("minimum hits for ngram-map speculative decoding (default: %d)", params.speculative.ngram_min_hits),
|
||||
[](common_params & params, int value) {
|
||||
if (value < 1) {
|
||||
throw std::invalid_argument("ngram min hits must be at least 1");
|
||||
}
|
||||
params.speculative.ngram_min_hits = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
|
||||
string_format(
|
||||
@@ -3620,8 +3681,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
|
||||
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
|
||||
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
|
||||
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
|
||||
params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
|
||||
params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
|
||||
params.port = 8012;
|
||||
params.n_ubatch = 1024;
|
||||
params.n_batch = 1024;
|
||||
@@ -3636,8 +3697,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF";
|
||||
params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf";
|
||||
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
|
||||
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
|
||||
params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
|
||||
params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
|
||||
params.port = 8012;
|
||||
params.n_ubatch = 1024;
|
||||
params.n_batch = 1024;
|
||||
|
||||
+161
-11
@@ -771,10 +771,12 @@ static std::string apply(
|
||||
|
||||
nlohmann::ordered_json inp = nlohmann::ordered_json{
|
||||
{"messages", messages_override.has_value() ? *messages_override : inputs.messages},
|
||||
{"tools", tools_override.has_value() ? *tools_override : inputs.tools},
|
||||
{"bos_token", tmpl.bos_token()},
|
||||
{"eos_token", tmpl.eos_token()},
|
||||
};
|
||||
if (tools_override.has_value() || !inputs.tools.empty()) {
|
||||
inp["tools"] = tools_override.has_value() ? *tools_override : inputs.tools;
|
||||
}
|
||||
if (inputs.extra_context.is_object()) {
|
||||
// TODO: do we need to merge, or replacing is fine?
|
||||
for (const auto & [k, v] : inputs.extra_context.items()) {
|
||||
@@ -790,9 +792,6 @@ static std::string apply(
|
||||
if (inputs.add_generation_prompt) {
|
||||
inp["add_generation_prompt"] = true;
|
||||
}
|
||||
if (inp["tools"].is_null()) {
|
||||
inp["tools"] = json::array();
|
||||
}
|
||||
|
||||
jinja::global_from_json(ctx, inp, inputs.mark_input);
|
||||
|
||||
@@ -2219,12 +2218,11 @@ static common_chat_params common_chat_params_init_glm_4_5(const common_chat_temp
|
||||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
LOG_DBG("%s\n", __func__);
|
||||
common_chat_params data;
|
||||
const std::optional<json> tools_override = json();
|
||||
const std::optional<json> additional_context = json {
|
||||
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
|
||||
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||
};
|
||||
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, tools_override, additional_context);
|
||||
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override =*/ std::nullopt, additional_context);
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
@@ -2573,20 +2571,165 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp
|
||||
static common_chat_params common_chat_params_init_solar_open(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
// TODO: Reasoning effort
|
||||
json additional_context = {};
|
||||
// Copy `reasoning_content` to `reasoning`
|
||||
auto adjusted_messages = json::array();
|
||||
for (const auto & msg : inputs.messages) {
|
||||
if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) {
|
||||
auto adjusted_message = msg;
|
||||
adjusted_message["reasoning"] = msg.at("reasoning_content");
|
||||
adjusted_message.erase("reasoning_content");
|
||||
adjusted_messages.push_back(adjusted_message);
|
||||
} else {
|
||||
adjusted_messages.push_back(msg);
|
||||
}
|
||||
}
|
||||
|
||||
data.prompt = apply(tmpl, inputs, std::nullopt, std::nullopt, additional_context);
|
||||
data.format = COMMON_CHAT_FORMAT_SOLAR_OPEN;
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto include_grammar = true;
|
||||
|
||||
auto prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
|
||||
// Check if we need to replace the flush token with end token during inference and without generation prompt.
|
||||
if (inputs.is_inference && !inputs.add_generation_prompt) {
|
||||
static constexpr std::string_view return_token = "<|flush|>";
|
||||
static constexpr std::string_view end_token = "<|end|>";
|
||||
if (size_t pos = prompt.rfind(return_token); pos != std::string::npos) {
|
||||
prompt.replace(pos, return_token.length(), end_token);
|
||||
}
|
||||
}
|
||||
|
||||
data.prompt = prompt;
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = {
|
||||
"<|think|>",
|
||||
"<|content|>",
|
||||
"<|begin|>",
|
||||
"<|end|>",
|
||||
"<|tool_calls|>",
|
||||
"<|tool_call:begin|>",
|
||||
"<|tool_call:end|>",
|
||||
"<|tool_call:name|>",
|
||||
"<|tool_call:args|>",
|
||||
};
|
||||
|
||||
// TODO: Tool calling
|
||||
auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) {
|
||||
auto lit_think = p.atomic(p.literal("<|think|>"));
|
||||
auto lit_assistant_begin = p.atomic(p.literal("<|begin|>assistant"));
|
||||
auto lit_content = p.atomic(p.literal("<|content|>"));
|
||||
auto lit_end = p.atomic(p.literal("<|end|>"));
|
||||
auto parser_until_end = p.until("<|end|>");
|
||||
|
||||
// reasoning <- "<|think|>" (!"<|end|>" .)*
|
||||
auto parser_reasoning = p.rule("reasoning", lit_think + p.reasoning(parser_until_end));
|
||||
|
||||
// content <- "<|content|>" (!"<|end|>" .)*
|
||||
auto parser_content = p.rule("content", lit_content + p.content(parser_until_end));
|
||||
|
||||
// wrap_choice(items) <- item-choice wrapped*
|
||||
// item-choice <- items[0] / ... / items[n]
|
||||
// wrapped <- "<|end|><|begin|>assistant" item-choice
|
||||
auto wrap_choice = [&](const std::vector<common_peg_parser> & items) {
|
||||
auto choice = p.choice(items);
|
||||
return choice + p.zero_or_more(lit_end + lit_assistant_begin + choice);
|
||||
};
|
||||
|
||||
// wrap_seq(items) <- item[0] "<|end|><|begin|>assistant" item[1] ...
|
||||
auto wrap_seq = [&](const std::vector<common_peg_parser> & items) {
|
||||
auto seq = p.sequence();
|
||||
for (auto i = 0u; i < items.size(); i++) {
|
||||
if (i == 0) {
|
||||
seq += items[i];
|
||||
continue;
|
||||
}
|
||||
seq += lit_end + lit_assistant_begin + items[i];
|
||||
}
|
||||
return seq;
|
||||
};
|
||||
|
||||
// Response format parser
|
||||
if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
|
||||
auto parser_response_format = lit_content + p.content(p.schema(p.json(), "response-format", inputs.json_schema));
|
||||
return p.choice({
|
||||
wrap_seq({parser_reasoning, parser_response_format}),
|
||||
wrap_seq({parser_response_format})
|
||||
});
|
||||
}
|
||||
|
||||
auto lit_tool_call_begin = p.literal("<|tool_call:begin|>");
|
||||
auto lit_tool_call_name = p.literal("<|tool_call:name|>");
|
||||
auto lit_tool_call_args = p.literal("<|tool_call:args|>");
|
||||
auto lit_tool_call_end = p.literal("<|tool_call:end|>");
|
||||
|
||||
// Tool call parser
|
||||
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
auto parser_tool_call = p.choice();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
const auto & schema = function.at("parameters");
|
||||
|
||||
// tool(name, schema) <- name "<|tool_call:args|>" schema
|
||||
parser_tool_call |= p.rule("tool-" + name,
|
||||
p.atomic(p.tool_name(p.literal(name)) + lit_tool_call_args)
|
||||
+ p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)));
|
||||
});
|
||||
|
||||
auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0;
|
||||
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
|
||||
|
||||
// tool-calls <- "<|tool_calls|>" tool-call+
|
||||
// tool-call <- "<|tool_call:begin|> call-id "<|tool_call:name|>" &([^<]+ "<|tool_call:args|>") tool-choice "<|tool_call:end|>"
|
||||
// call-id <- [a-zA-Z0-9_-]+
|
||||
// tool-choice <- tool(t[0].name, t[0].schema) / ... / tool(t[n].name, t[n].schema)
|
||||
auto parser_tool_calls = p.trigger_rule("tool-calls",
|
||||
p.atomic(p.literal("<|tool_calls|>"))
|
||||
+ p.repeat(
|
||||
p.tool_open(
|
||||
lit_tool_call_begin
|
||||
+ p.tool_id(p.chars("[a-zA-Z0-9_-]", 1, -1))
|
||||
+ lit_tool_call_name
|
||||
+ p.peek(p.chars("[^<]", 1, -1) + lit_tool_call_args))
|
||||
+ parser_tool_call
|
||||
+ p.tool_close(lit_tool_call_end),
|
||||
/* min = */ 1,
|
||||
/* max = */ max_calls));
|
||||
|
||||
if (min_calls == 1) {
|
||||
// If required, then try any combination of the reasoning, content, and tool call
|
||||
return p.choice({
|
||||
wrap_seq({parser_reasoning, parser_content, parser_tool_calls}),
|
||||
wrap_seq({parser_reasoning, parser_tool_calls}),
|
||||
wrap_seq({parser_content, parser_tool_calls}),
|
||||
wrap_seq({parser_tool_calls})
|
||||
});
|
||||
}
|
||||
|
||||
return wrap_choice({parser_reasoning, parser_content, parser_tool_calls});
|
||||
}
|
||||
|
||||
// Content only parser
|
||||
include_grammar = false;
|
||||
return wrap_choice({parser_reasoning, parser_content});
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
auto schema = function.at("parameters");
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
data.grammar_triggers = {
|
||||
{COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls|>"}
|
||||
};
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
@@ -3043,6 +3186,13 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
return common_chat_params_init_apriel_1_5(tmpl, params);
|
||||
}
|
||||
|
||||
// Solar Open
|
||||
if (src.find("<|tool_response:begin|>") != std::string::npos &&
|
||||
src.find("<|tool_response:name|>") != std::string::npos &&
|
||||
src.find("<|tool_response:result|>") != std::string::npos) {
|
||||
return common_chat_params_init_solar_open(tmpl, params);
|
||||
}
|
||||
|
||||
// Use generic handler when mixing tools + JSON schema.
|
||||
// TODO: support that mix in handlers below.
|
||||
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
||||
|
||||
+4
-5
@@ -1097,7 +1097,10 @@ common_init_result::common_init_result(common_params & params) :
|
||||
if (params.fit_params) {
|
||||
LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
|
||||
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
|
||||
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx,
|
||||
params.tensor_split,
|
||||
params.tensor_buft_overrides.data(),
|
||||
params.fit_params_target.data(),
|
||||
params.fit_params_min_ctx,
|
||||
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
||||
}
|
||||
|
||||
@@ -1208,10 +1211,6 @@ std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
|
||||
return pimpl->lora;
|
||||
}
|
||||
|
||||
void common_init_result::free_context() {
|
||||
pimpl->context.reset();
|
||||
}
|
||||
|
||||
common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
common_init_result_ptr res(new common_init_result(params));
|
||||
|
||||
|
||||
+46
-18
@@ -164,6 +164,16 @@ enum common_params_sampling_config : uint64_t {
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
|
||||
};
|
||||
|
||||
enum common_speculative_type {
|
||||
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
|
||||
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, // self-speculative decoding with 3-level n-gram cache
|
||||
COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type
|
||||
};
|
||||
|
||||
// sampling parameters
|
||||
struct common_params_sampling {
|
||||
@@ -243,16 +253,35 @@ struct common_params_model {
|
||||
};
|
||||
|
||||
struct common_params_speculative {
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; // type of speculative decoding
|
||||
|
||||
int32_t n_ctx = 0; // draft context size
|
||||
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
// general-purpose speculative decoding parameters
|
||||
|
||||
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||
|
||||
// ngram-based speculative decoding
|
||||
|
||||
uint16_t ngram_size_n = 12; // ngram size for lookup
|
||||
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
|
||||
uint16_t ngram_check_rate = 1; // check rate for ngram lookup
|
||||
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
|
||||
|
||||
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT
|
||||
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
||||
|
||||
// draft-model speculative decoding
|
||||
|
||||
struct common_params_model mparams_dft;
|
||||
|
||||
llama_model * model_dft = nullptr; // a llama_model that can be shared by multiple speculative contexts
|
||||
|
||||
llama_context_params cparams_dft; // these are the parameters for the draft llama_context
|
||||
|
||||
int32_t n_ctx = 0; // draft context size
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||
|
||||
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
||||
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
||||
@@ -260,7 +289,14 @@ struct common_params_speculative {
|
||||
struct cpu_params cpuparams;
|
||||
struct cpu_params cpuparams_batch;
|
||||
|
||||
struct common_params_model model;
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
|
||||
bool has_dft() const {
|
||||
return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty();
|
||||
}
|
||||
};
|
||||
|
||||
struct common_params_vocoder {
|
||||
@@ -378,8 +414,6 @@ struct common_params {
|
||||
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT
|
||||
std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
|
||||
std::string input_suffix = ""; // string to suffix user inputs with // NOLINT
|
||||
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
|
||||
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
||||
std::string logits_file = ""; // file for saving *all* logits // NOLINT
|
||||
|
||||
// llama-debug specific options
|
||||
@@ -575,10 +609,6 @@ struct common_params {
|
||||
// return false from callback to abort model loading or true to continue
|
||||
llama_progress_callback load_progress_callback = NULL;
|
||||
void * load_progress_callback_user_data = NULL;
|
||||
|
||||
bool has_speculative() const {
|
||||
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
|
||||
}
|
||||
};
|
||||
|
||||
// call once at the start of a program if it uses libcommon
|
||||
@@ -714,8 +744,6 @@ struct common_init_result {
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> & lora();
|
||||
|
||||
void free_context();
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
|
||||
@@ -1028,6 +1028,16 @@ const func_builtins & value_none_t::get_builtins() const {
|
||||
{"safe", [](const func_args &) -> value {
|
||||
return mk_val<value_string>("None");
|
||||
}},
|
||||
{"strip", [](const func_args &) -> value {
|
||||
return mk_val<value_string>("None");
|
||||
}},
|
||||
{"items", empty_value_fn<value_array>},
|
||||
{"map", empty_value_fn<value_array>},
|
||||
{"reject", empty_value_fn<value_array>},
|
||||
{"rejectattr", empty_value_fn<value_array>},
|
||||
{"select", empty_value_fn<value_array>},
|
||||
{"selectattr", empty_value_fn<value_array>},
|
||||
{"unique", empty_value_fn<value_array>},
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
|
||||
@@ -192,12 +192,12 @@ void common_ngram_cache_draft(
|
||||
break;
|
||||
}
|
||||
|
||||
LOG(" - draft candidate: token=%d\n", drafted_token);
|
||||
LOG_DBG(" - draft candidate: token=%d\n", drafted_token);
|
||||
draft.push_back(drafted_token);
|
||||
}
|
||||
}
|
||||
|
||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename) {
|
||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename) {
|
||||
std::ofstream file_out(filename, std::ios::binary);
|
||||
for (std::pair<common_ngram, common_ngram_cache_part> item : ngram_cache) {
|
||||
const common_ngram ngram = item.first;
|
||||
@@ -217,10 +217,9 @@ void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & fil
|
||||
file_out.write(reinterpret_cast<const char *>(&count), sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
common_ngram_cache common_ngram_cache_load(std::string & filename) {
|
||||
common_ngram_cache common_ngram_cache_load(const std::string & filename) {
|
||||
std::ifstream hashmap_file(filename, std::ios::binary);
|
||||
if (!hashmap_file) {
|
||||
throw std::ifstream::failure("Unable to open file " + filename);
|
||||
|
||||
@@ -88,12 +88,12 @@ void common_ngram_cache_draft(
|
||||
// Save an ngram cache to a file.
|
||||
// ngram_cache: the ngram cache to save.
|
||||
// filename: the path under which to save the ngram cache.
|
||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename);
|
||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename);
|
||||
|
||||
// Load an ngram cache saved with common_ngram_cache_save.
|
||||
// filename: the path from which to load the ngram cache.
|
||||
// returns: an ngram cache containing the information saved to filename.
|
||||
common_ngram_cache common_ngram_cache_load(std::string & filename);
|
||||
common_ngram_cache common_ngram_cache_load(const std::string & filename);
|
||||
|
||||
// Merge two ngram caches.
|
||||
// ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add.
|
||||
|
||||
@@ -0,0 +1,367 @@
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "ngram-map.h"
|
||||
|
||||
#include <cinttypes>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <sstream>
|
||||
|
||||
// n-gram simple
|
||||
//
|
||||
|
||||
/**
|
||||
* Perform speculative generation using the model's own token history.
|
||||
* Searches for a matching pattern in the token history and returns draft tokens.
|
||||
*
|
||||
* @param state Current state of this implementation
|
||||
* @param tokens Token history to search in
|
||||
* @param sampled Last sampled token
|
||||
* @return Vector of draft tokens, empty if no matching pattern is found
|
||||
*/
|
||||
llama_tokens common_ngram_simple_draft(
|
||||
common_ngram_simple_state & state,
|
||||
const llama_tokens & tokens, llama_token sampled) {
|
||||
|
||||
// Simple implementation of self-speculative decoding without a draft model.
|
||||
//
|
||||
const size_t cur_len = tokens.size();
|
||||
// Only check every check_rate tokens to save compute
|
||||
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
|
||||
if (state.idx_last_check + state.config.check_rate > cur_len) {
|
||||
llama_tokens draft_tokens;
|
||||
return draft_tokens;
|
||||
}
|
||||
|
||||
size_t n_draft_min = state.config.size_ngram; // size of n-gram to lookup in token history
|
||||
size_t n_draft_max = state.config.size_mgram; // the m-gram following the found n-gram is used for draft
|
||||
|
||||
// vector for tokens we want to verify.
|
||||
// return empty vector if there is no match.
|
||||
llama_tokens draft_tokens;
|
||||
|
||||
// We need at least n_draft_min + n_draft_max + 1 tokens.
|
||||
if (cur_len <= static_cast<size_t>(n_draft_min + n_draft_max + 1)) {
|
||||
return draft_tokens;
|
||||
}
|
||||
|
||||
// pattern search
|
||||
llama_tokens pattern;
|
||||
pattern.reserve(n_draft_min);
|
||||
for (size_t j = cur_len - n_draft_min + 1; j < cur_len; ++j) {
|
||||
pattern.push_back(tokens[j]);
|
||||
}
|
||||
pattern.push_back(sampled); // add the last token to the pattern
|
||||
|
||||
// We do a search in the token history.
|
||||
state.idx_last_check = cur_len;
|
||||
|
||||
size_t match_pos = 0; // we ignore position 0, position 0 == no match
|
||||
// search backwards, but skip the current match (we are currently there)
|
||||
for (size_t j = cur_len - n_draft_min - 1; j > 0; --j) {
|
||||
bool match = true;
|
||||
for (size_t k = 0; k < pattern.size(); ++k) {
|
||||
if (tokens[j + k] != pattern[k]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
match_pos = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match_pos == 0) {
|
||||
return draft_tokens;
|
||||
}
|
||||
|
||||
const size_t copy_max = std::min(
|
||||
n_draft_max,
|
||||
cur_len - (match_pos + n_draft_min)
|
||||
);
|
||||
if (copy_max < n_draft_min) {
|
||||
return draft_tokens;
|
||||
}
|
||||
LOG_DBG("%s: #tokens = %zu: found matching pattern at pos %zu, length %zu, draft length %zu\n",
|
||||
__func__, cur_len,
|
||||
match_pos, pattern.size(), copy_max);
|
||||
|
||||
draft_tokens.reserve(copy_max);
|
||||
for (size_t j = 0; j < copy_max; ++j) {
|
||||
draft_tokens.push_back(tokens[match_pos + n_draft_min + j]);
|
||||
}
|
||||
return draft_tokens;
|
||||
}
|
||||
|
||||
|
||||
// n-gram map
|
||||
//
|
||||
|
||||
// maximum number of counted values of a ngram map value.
|
||||
#define COMMON_NGRAM_MAX_VALUE_COUNT 16380
|
||||
|
||||
static std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length);
|
||||
|
||||
void common_ngram_map_draft(common_ngram_map & map,
|
||||
const llama_tokens & inp, llama_token sampled,
|
||||
llama_tokens & draft) {
|
||||
// reset last key and value.
|
||||
map.last_draft_created = false;
|
||||
map.last_draft_key_idx = 0;
|
||||
map.last_draft_value_idx = 0;
|
||||
|
||||
const size_t cur_len = inp.size();
|
||||
const uint16_t n = map.size_key;
|
||||
const uint16_t m = map.size_value;
|
||||
if (cur_len < static_cast<size_t>(2 * n + m)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Only check every check_rate tokens to save compute
|
||||
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
|
||||
if (map.idx_last_check + map.check_rate > cur_len) {
|
||||
return;
|
||||
}
|
||||
map.idx_last_check = cur_len;
|
||||
|
||||
// search pattern, the key n-gram
|
||||
std::vector<llama_token> key_tokens;
|
||||
key_tokens.reserve(n);
|
||||
for (size_t j = cur_len - n + 1; j < cur_len; ++j) {
|
||||
key_tokens.push_back(inp[j]);
|
||||
}
|
||||
key_tokens.push_back(sampled);
|
||||
|
||||
// search for the key in the map
|
||||
size_t match_pos = 0;
|
||||
for (size_t j = cur_len - n - m - 1; j > 0; --j) {
|
||||
bool match = true;
|
||||
for (size_t k = 0; k < n; ++k) {
|
||||
if (inp[j + k] != key_tokens[k]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
match_pos = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match_pos > 0) {
|
||||
LOG_INF("%s: cur_len = %zu, n = %d, m = %d, sz_tkns = %zu, sampled = %d, match_pos = %zu\n", __func__,
|
||||
cur_len, n, m, key_tokens.size(), sampled, match_pos);
|
||||
}
|
||||
|
||||
if (match_pos == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// We have a match, now we look for the statistics of the key.
|
||||
size_t key_offset = map.keys.size(); // offset in the map
|
||||
// We iterate through the std::vector<common_ngram_map_key> map->keys.
|
||||
for (size_t i = 0; i < map.keys.size(); ++i) {
|
||||
bool match = true;
|
||||
for (size_t j = 0; j < n; ++j) {
|
||||
if (inp[map.keys[i].key_idx + j] != key_tokens[j]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
key_offset = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (key_offset == map.keys.size()) {
|
||||
// We create a new key-entry, it will get offset key_offset.
|
||||
common_ngram_map_key new_key;
|
||||
new_key.key_idx = match_pos;
|
||||
new_key.stat_idx = 0;
|
||||
new_key.key_num = 0;
|
||||
for (int i = 0; i < COMMON_NGRAM_MAX_VALUES; ++i) {
|
||||
new_key.values[i].value_num = 0;
|
||||
new_key.values[i].n_accepted = m;
|
||||
}
|
||||
map.keys.push_back(new_key);
|
||||
}
|
||||
|
||||
// our key n-gram:
|
||||
common_ngram_map_key & curr_key = map.keys[key_offset];
|
||||
|
||||
// update number of key hits
|
||||
curr_key.key_num = (uint16_t) std::min((int) map.keys[key_offset].key_num + 1,
|
||||
(int) COMMON_NGRAM_MAX_VALUE_COUNT);
|
||||
|
||||
if (map.key_only) {
|
||||
// simple mode:
|
||||
// Fill in the draft with the m tokens following the key.
|
||||
// We work with value values[0] only.
|
||||
int n_draft_tokens = std::min((int) m, (int) curr_key.values[0].n_accepted);
|
||||
|
||||
for (int i = 0; i < n_draft_tokens; ++i) {
|
||||
draft.push_back(inp[match_pos + n + i]);
|
||||
}
|
||||
|
||||
LOG_INF("%s: key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
|
||||
key_offset, curr_key.key_num, draft.size());
|
||||
|
||||
map.last_draft_created = false;
|
||||
map.last_draft_key_idx = key_offset;
|
||||
map.last_draft_value_idx = 0; // value 0 is used for simple mode
|
||||
return;
|
||||
}
|
||||
|
||||
if (curr_key.key_num < map.min_hits) {
|
||||
// not enough hits to consider this a good draft
|
||||
LOG_DBG("%s: key_offset = %zu, key_num = %d, min_hits = %d, no draft\n", __func__,
|
||||
key_offset, curr_key.key_num, map.min_hits);
|
||||
return;
|
||||
}
|
||||
|
||||
// complex mode: examine the different m-grams after this key n-gram.
|
||||
//
|
||||
|
||||
// determine all (max COMMON_NGRAM_MAX_VALUES) m-grams after the key n-gram.
|
||||
for (size_t i = curr_key.stat_idx; i <= match_pos; ++i) {
|
||||
// begins the key n-gram at index i?
|
||||
bool match_key = true;
|
||||
for (size_t k = 0; k < n; ++k) {
|
||||
if (inp[i + k] != key_tokens[k]) {
|
||||
match_key = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!match_key) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Do we haven a existing value m-gram or a new one after the key at index i?
|
||||
size_t idx_begin_value_key = i + n;
|
||||
int idx_value = -1;
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
size_t idx_begin_value_v = curr_key.values[v].value_idx;
|
||||
if (idx_begin_value_v == 0) {
|
||||
// We found an empty value slot => we found a new value m-gram after the key n-gram.
|
||||
curr_key.values[v].value_idx = idx_begin_value_key;
|
||||
curr_key.values[v].value_num = 0;
|
||||
curr_key.values[v].n_accepted = m;
|
||||
idx_value = v;
|
||||
break;
|
||||
}
|
||||
bool match = true;
|
||||
for (size_t j = 0; j < m; ++j) {
|
||||
if (inp[idx_begin_value_key + j] != inp[idx_begin_value_v + j]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
// We found an existing value m-gram after the key n-gram.
|
||||
idx_value = v;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (idx_value >= 0) {
|
||||
// We found a value m-gram of the key n-gram.
|
||||
curr_key.values[idx_value].value_num = (uint16_t) std::min((int) curr_key.values[idx_value].value_num + 1,
|
||||
(int) COMMON_NGRAM_MAX_VALUE_COUNT);
|
||||
}
|
||||
}
|
||||
// the statistics are updated up to match_pos.
|
||||
curr_key.stat_idx = match_pos;
|
||||
|
||||
// Do we have a value we could use for the draft?
|
||||
uint16_t max_occur = 0;
|
||||
int slot_max = 0;
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
uint16_t curr_occur = curr_key.values[v].value_num;
|
||||
if (curr_occur > max_occur) {
|
||||
max_occur = curr_occur;
|
||||
slot_max = v;
|
||||
}
|
||||
}
|
||||
// What is sum of the other occurences?
|
||||
uint32_t sum_occur = 0;
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
if (v == slot_max) {
|
||||
continue;
|
||||
}
|
||||
uint16_t curr_occur = curr_key.values[v].value_num;
|
||||
sum_occur += curr_occur;
|
||||
}
|
||||
|
||||
LOG_INF("%s: key_offset = %zu, max_occur = %d, sum_occur = %d, slot_max = %d [%zu/%d, %zu/%d, %zu/%d, %zu/%d]\n", __func__,
|
||||
key_offset,
|
||||
max_occur, sum_occur, slot_max,
|
||||
curr_key.values[0].value_idx, curr_key.values[0].value_num,
|
||||
curr_key.values[1].value_idx, curr_key.values[1].value_num,
|
||||
curr_key.values[2].value_idx, curr_key.values[2].value_num,
|
||||
curr_key.values[3].value_idx, curr_key.values[3].value_num
|
||||
);
|
||||
// Print the tokens of the four values (if idx != 0), use LOG_INF
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
if (curr_key.values[v].value_idx != 0) {
|
||||
LOG_INF("%s: value[%d] = %s\n", __func__, v, common_tokens_to_str(inp, curr_key.values[v].value_idx, m).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (sum_occur > 0 && max_occur < 3 * sum_occur) {
|
||||
// The most frequent value is not much more frequent than the other values.
|
||||
// We do not use the draft.
|
||||
return;
|
||||
}
|
||||
|
||||
// We use the most frequent value values[slot_max] for the draft.
|
||||
// Fill in the draft with the m tokens following the key.
|
||||
int n_draft_tokens = std::min((int) m, (int) curr_key.values[slot_max].n_accepted);
|
||||
|
||||
for (int i = 0; i < n_draft_tokens; ++i) {
|
||||
draft.push_back(inp[match_pos + n + i]);
|
||||
}
|
||||
|
||||
LOG_INF("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__,
|
||||
key_offset, slot_max,
|
||||
curr_key.key_num, draft.size());
|
||||
|
||||
map.last_draft_created = true;
|
||||
map.last_draft_key_idx = key_offset;
|
||||
map.last_draft_value_idx = slot_max; // value used for draft generation.
|
||||
}
|
||||
|
||||
void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) {
|
||||
if (!map.last_draft_created) {
|
||||
return;
|
||||
}
|
||||
|
||||
// find the key and its chosen value.
|
||||
const size_t key_idx = map.last_draft_key_idx;
|
||||
const size_t val_idx = map.last_draft_value_idx;
|
||||
|
||||
// find key corresponding to key_idx.
|
||||
common_ngram_map_key & curr_key = map.keys[key_idx];
|
||||
// find value corresponding to val_idx.
|
||||
struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation.
|
||||
|
||||
// update the value statistics
|
||||
LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
|
||||
n_accepted, curr_value.n_accepted);
|
||||
curr_value.n_accepted = n_accepted;
|
||||
}
|
||||
|
||||
// Helper functions.
|
||||
//
|
||||
|
||||
// Print the values of a sublist of `llama_tokens & inp` to a string in the form [v0, v1, v2, ...].
|
||||
std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length) {
|
||||
std::ostringstream oss;
|
||||
oss << '[';
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
if (i > 0) {
|
||||
oss << ", ";
|
||||
}
|
||||
oss << inp[start + i];
|
||||
}
|
||||
oss << ']';
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
#pragma once
|
||||
//
|
||||
// common/ngram-map.h: structures used to manage a map from n-grams to a list of m-grams
|
||||
//
|
||||
// These structures are used to do a lookup of n-grams followed by m-grams in token history.
|
||||
//
|
||||
// There are two algorithms implemented:
|
||||
// 1. ngram_simple: lookup of n-grams followed by m-grams in token history.
|
||||
// 2. ngram_map: lookup of n-grams followed by m-grams in token history using a map.
|
||||
// The map is a vector of key n-grams, and for each key n-gram there is a list of value m-grams.
|
||||
//
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
// n-gram simple
|
||||
//
|
||||
|
||||
// config of n-gram simple.
|
||||
struct common_ngram_simple_config {
|
||||
uint16_t size_ngram; // size of n-grams to lookup in self-mode
|
||||
uint16_t size_mgram; // size of m-grams to draft in self-mode
|
||||
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
|
||||
};
|
||||
|
||||
// current state (and config) of n-gram simple.
|
||||
struct common_ngram_simple_state {
|
||||
common_ngram_simple_config config;
|
||||
|
||||
size_t idx_last_check = 0; // index of last check in context history (mutable)
|
||||
|
||||
common_ngram_simple_state(const common_ngram_simple_config & config)
|
||||
: config(config) {}
|
||||
};
|
||||
|
||||
// Searches for a n-gram in the history and checks whether a draft sequence should be generated.
|
||||
// state: the ngram simple state to search in.
|
||||
// inp: the tokens generated so far.
|
||||
// sampled: the token that was just sampled.
|
||||
// draft: vector to store the draft tokens, initially empty.
|
||||
llama_tokens common_ngram_simple_draft(
|
||||
common_ngram_simple_state & state,
|
||||
const llama_tokens & tokens, llama_token sampled);
|
||||
|
||||
|
||||
// n-gram map
|
||||
//
|
||||
|
||||
// maximum number of m-gram values stored for each key n-gram.
|
||||
#define COMMON_NGRAM_MAX_VALUES 4
|
||||
|
||||
// statistics of a m-gram after a known n-gram
|
||||
struct common_ngram_map_value {
|
||||
size_t value_idx = 0; // index of value m-gram in token-history (0 if unused)
|
||||
uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot)
|
||||
int16_t n_accepted = -1; // number of accepted tokens at last draft (-1 if unused)
|
||||
};
|
||||
|
||||
// statistics of a n-gram
|
||||
struct common_ngram_map_key {
|
||||
size_t key_idx; // index of key n-gram in token-history
|
||||
size_t stat_idx; // index of last token of stastistics computation (key_num, values)
|
||||
|
||||
uint16_t key_num; // number of occurences of this key n-gram in token-history
|
||||
common_ngram_map_value values[COMMON_NGRAM_MAX_VALUES]; // some known values after the key
|
||||
};
|
||||
|
||||
// map from n-grams to following m-grams in token-history
|
||||
struct common_ngram_map {
|
||||
uint16_t size_key; // size of key n-grams
|
||||
uint16_t size_value; // size of value m-grams
|
||||
|
||||
bool key_only; // true if only key n-grams are used, no values.
|
||||
|
||||
// first draft: vector only, no map.
|
||||
std::vector<common_ngram_map_key> keys; // key n-grams which occur several times in token-history
|
||||
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
|
||||
uint16_t min_hits; // minimum number of key hits to consider a draft
|
||||
|
||||
common_ngram_map(uint16_t sz_key, uint16_t sz_value, bool only_keys,
|
||||
uint16_t check_rate, uint16_t min_hits)
|
||||
: size_key(sz_key), size_value(sz_value), key_only(only_keys),
|
||||
check_rate(check_rate), min_hits(min_hits) {}
|
||||
|
||||
bool last_draft_created = false; // true if a draft was created at last call.
|
||||
size_t last_draft_key_idx = 0; // index of last key used for draft generation.
|
||||
uint16_t last_draft_value_idx = 0; // index of last value used for draft generation.
|
||||
|
||||
size_t idx_last_check = 0; // index of last check in context history
|
||||
};
|
||||
|
||||
|
||||
// Searches for the n-gram in the history and checks whether a draft sequence should be generated.
|
||||
// map: the ngram map to search in.
|
||||
// inp: the tokens generated so far.
|
||||
// sampled: the token that was just sampled.
|
||||
// draft: vector to store the draft tokens, initially empty.
|
||||
void common_ngram_map_draft(
|
||||
common_ngram_map & map,
|
||||
const llama_tokens & inp, llama_token sampled,
|
||||
llama_tokens & draft);
|
||||
|
||||
// Update the statistics of a value after a draft was processed.
|
||||
void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted);
|
||||
+766
-246
File diff suppressed because it is too large
Load Diff
+23
-21
@@ -5,31 +5,33 @@
|
||||
|
||||
struct common_speculative;
|
||||
|
||||
struct common_speculative_params {
|
||||
int n_draft = 16; // max drafted tokens
|
||||
int n_reuse = 256;
|
||||
// comma separated list of all types
|
||||
std::string common_speculative_type_name_str();
|
||||
|
||||
float p_min = 0.75f; // min probability required to accept a token in the draft
|
||||
};
|
||||
// convert string to type
|
||||
enum common_speculative_type common_speculative_type_from_name(const std::string & name);
|
||||
|
||||
struct common_speculative * common_speculative_init(
|
||||
struct llama_context * ctx_tgt,
|
||||
struct llama_context * ctx_dft
|
||||
);
|
||||
// convert type to string
|
||||
std::string common_speculative_type_to_str(enum common_speculative_type type);
|
||||
|
||||
void common_speculative_free(struct common_speculative * spec);
|
||||
common_speculative * common_speculative_init(
|
||||
const common_params_speculative & params,
|
||||
llama_context * ctx_tgt);
|
||||
|
||||
bool common_speculative_are_compatible(
|
||||
const struct llama_context * ctx_tgt,
|
||||
const struct llama_context * ctx_dft);
|
||||
void common_speculative_free(common_speculative * spec);
|
||||
|
||||
void common_speculative_add_replacement_tgt_dft(
|
||||
struct common_speculative * spec,
|
||||
const char *source, const char *dest);
|
||||
// optionally call once at the beginning of a new generation
|
||||
void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt);
|
||||
|
||||
// sample up to n_draft tokens and add them to the batch using the draft model
|
||||
llama_tokens common_speculative_gen_draft(
|
||||
struct common_speculative * spec,
|
||||
struct common_speculative_params params,
|
||||
const llama_tokens & prompt,
|
||||
llama_token id_last);
|
||||
llama_tokens common_speculative_draft(
|
||||
common_speculative * spec,
|
||||
const common_params_speculative & params,
|
||||
const llama_tokens & prompt,
|
||||
llama_token id_last);
|
||||
|
||||
// informs the speculative decoder that n_accepted tokens were accepted by the target model
|
||||
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
|
||||
|
||||
// print statistics about the speculative decoding
|
||||
void common_speculative_print_stats(const common_speculative * spec);
|
||||
|
||||
@@ -8912,13 +8912,16 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
|
||||
name.endswith("block_sparse_moe.input_linear.weight")
|
||||
or "shared_mlp" in name
|
||||
):
|
||||
return GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
|
||||
yield from GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
|
||||
return
|
||||
|
||||
# Determine whether this is a mamba layer or an attention layer
|
||||
if bid in self._ssm_layers:
|
||||
return Mamba2Model.modify_tensors(self, data_torch, name, bid)
|
||||
yield from Mamba2Model.modify_tensors(self, data_torch, name, bid)
|
||||
return
|
||||
elif bid in self._attn_layers:
|
||||
return GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
|
||||
yield from GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
|
||||
return
|
||||
yield from ModelBase.modify_tensors(self, data_torch, name, bid)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
|
||||
@@ -35,9 +35,9 @@ The following releases are verified and recommended:
|
||||
|
||||
|Commit ID|Tag|Release|Verified Platform| Update date|
|
||||
|-|-|-|-|-|
|
||||
|24e86cae7219b0f3ede1d5abdf5bf3ad515cccb8|b5377 |[llama-b5377-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b5377/llama-b5377-bin-win-sycl-x64.zip) |ArcB580/Linux/oneAPI 2025.1<br>LNL Arc GPU/Windows 11/oneAPI 2025.1.1|2025-05-15|
|
||||
|3bcd40b3c593d14261fb2abfabad3c0fb5b9e318|b4040 |[llama-b4040-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b4040/llama-b4040-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1| 2024-11-19|
|
||||
|fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1||
|
||||
|24e86cae7219b0f3ede1d5abdf5bf3ad515cccb8|b5377 |[llama-b5377-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b5377/llama-b5377-bin-win-sycl-x64.zip) |Arc B580/Linux/oneAPI 2025.1<br>LNL Arc GPU/Windows 11/oneAPI 2025.1.1|2025-05-15|
|
||||
|3bcd40b3c593d14261fb2abfabad3c0fb5b9e318|b4040 |[llama-b4040-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b4040/llama-b4040-bin-win-sycl-x64.zip) |Arc A770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1| 2024-11-19|
|
||||
|fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc A770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1||
|
||||
|
||||
|
||||
## News
|
||||
@@ -51,7 +51,7 @@ The following releases are verified and recommended:
|
||||
|-|-|-|-|
|
||||
|PVC 1550|39|73|+87%|
|
||||
|Flex 170|39|50|+28%|
|
||||
|Arc770|42|55|+30%|
|
||||
|Arc A770|42|55|+30%|
|
||||
|MTL|13|16|+23%|
|
||||
|ARL-H|14|17|+21%|
|
||||
|
||||
@@ -62,7 +62,7 @@ The following releases are verified and recommended:
|
||||
- Use oneDNN as the default GEMM library, improve the compatibility for new Intel GPUs.
|
||||
|
||||
- 2024.5
|
||||
- Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc770.
|
||||
- Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc A770.
|
||||
- Arch Linux is verified successfully.
|
||||
|
||||
- 2024.4
|
||||
@@ -111,7 +111,8 @@ On older Intel GPUs, you may try [OpenCL](/docs/backend/OPENCL.md) although the
|
||||
|-------------------------------|---------|---------------------------------------|
|
||||
| Intel Data Center Max Series | Support | Max 1550, 1100 |
|
||||
| Intel Data Center Flex Series | Support | Flex 170 |
|
||||
| Intel Arc Series | Support | Arc 770, 730M, Arc A750, B580 |
|
||||
| Intel Arc A-Series | Support | Arc A770, Arc A730M, Arc A750 |
|
||||
| Intel Arc B-Series | Support | Arc B580 |
|
||||
| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake, Arrow Lake, Lunar Lake |
|
||||
| Intel iGPU | Support | iGPU in 13700k, 13400, i5-1250P, i7-1260P, i7-1165G7 |
|
||||
|
||||
|
||||
+18
-3
@@ -1,5 +1,10 @@
|
||||
{
|
||||
"version": 4,
|
||||
"version": 5,
|
||||
"cmakeMinimumRequired": {
|
||||
"major": 3,
|
||||
"minor": 28,
|
||||
"patch": 0
|
||||
},
|
||||
"configurePresets": [
|
||||
{
|
||||
"name": "arm64-android-snapdragon",
|
||||
@@ -16,7 +21,9 @@
|
||||
"CMAKE_CXX_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||
"CMAKE_C_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
"CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
|
||||
"CMAKE_PREFIX_PATH": "$env{OPENCL_SDK_ROOT}",
|
||||
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
|
||||
"HEXAGON_TOOLS_ROOT": "$env{HEXAGON_TOOLS_ROOT}",
|
||||
"PREBUILT_LIB_DIR": "android_aarch64",
|
||||
"GGML_OPENMP": "OFF",
|
||||
"GGML_LLAMAFILE": "OFF",
|
||||
@@ -31,7 +38,15 @@
|
||||
"name": "arm64-windows-snapdragon",
|
||||
"inherits": [ "base", "arm64-windows-llvm" ],
|
||||
"cacheVariables": {
|
||||
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
|
||||
"CMAKE_C_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -flto -D_GNU_SOURCE",
|
||||
"CMAKE_CXX_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -flto -D_GNU_SOURCE",
|
||||
"CMAKE_C_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||
"CMAKE_CXX_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||
"CMAKE_C_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
"CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
"CMAKE_PREFIX_PATH": "$env{OPENCL_SDK_ROOT}",
|
||||
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
|
||||
"HEXAGON_TOOLS_ROOT": "$env{HEXAGON_TOOLS_ROOT}",
|
||||
"PREBUILT_LIB_DIR": "windows_aarch64",
|
||||
"GGML_OPENMP": "OFF",
|
||||
"GGML_LLAMAFILE": "OFF",
|
||||
@@ -1,6 +1,8 @@
|
||||
# Snapdragon-based Android devices
|
||||
# Snapdragon-based devices
|
||||
|
||||
## How to Build
|
||||
## Setup
|
||||
|
||||
### Android
|
||||
|
||||
The easiest way to build llama.cpp for a Snapdragon-based Android device is using the toolchain Docker image (see github.com/snapdragon-toolchain).
|
||||
This image includes Android NDK, OpenCL SDK, Hexagon SDK, CMake, etc.
|
||||
@@ -12,7 +14,24 @@ This method works on Linux, macOS, and Windows. macOS and Windows users should i
|
||||
[d]/> cd /workspace
|
||||
```
|
||||
|
||||
The rest of the Android build process assumes that you're running inside the toolchain container.
|
||||
Note: The rest of the **Android** build process assumes that you're running inside the toolchain container.
|
||||
|
||||
### Windows On Snapdragon
|
||||
|
||||
Native Windows 11 arm64 builds has the following tools dependencies:
|
||||
- MS Visual Studio 2026 (Community Edition or Pro)
|
||||
- MSVC arm64 standard and runtime libraries
|
||||
- UCRT and Driver Kit
|
||||
- LLVM core libraries and Clang compiler (winget)
|
||||
- CMake, Git, Python (winget)
|
||||
- Hexagon SDK Community Edition 6.4 or later (see windows.md)
|
||||
- OpenCL SDK 2.3 or later (see windows.md)
|
||||
|
||||
Note: The rest of the **Windows** build process assumes that you're running natively in Powershell.
|
||||
Adapt below build commands accordingly.
|
||||
|
||||
## How to Build
|
||||
|
||||
Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets:
|
||||
|
||||
```
|
||||
@@ -49,24 +68,26 @@ Preset CMake variables:
|
||||
To generate an installable "package" simply use cmake --install:
|
||||
|
||||
```
|
||||
[d]/workspace> cmake --install build-snapdragon --prefix pkg-adb/llama.cpp
|
||||
[d]/workspace> cmake --install build-snapdragon --prefix pkg-snapdragon/llama.cpp
|
||||
-- Install configuration: "Release"
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-cpu.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-opencl.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-hexagon.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v73.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v75.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v79.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v81.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-cpu.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-opencl.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-hexagon.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v73.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v75.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v79.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v81.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml.so
|
||||
...
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/bin/llama-bench
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/bin/llama-cli
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/bin/llama-bench
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/bin/llama-cli
|
||||
...
|
||||
```
|
||||
|
||||
## How to Install
|
||||
|
||||
### Android
|
||||
|
||||
For this step, your device needs to be configured for on-device development.
|
||||
Please see https://developer.android.com/studio/debug/dev-options for details.
|
||||
|
||||
@@ -74,10 +95,10 @@ Once ADB is enabled, use `adb push` to install `pkg-snapdragon` on the device.
|
||||
**Note that the toolchain Docker image doesn't have ADB and doesn't set up the ADB bridge. Please use native ADB on the host.**
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ adb push pkg-adb/llama.cpp /data/local/tmp/
|
||||
pkg-adb/llama.cpp/bin/: 67 files pushed, 0 skipped. 190.2 MB/s (919095042 bytes in 4.607s)
|
||||
pkg-adb/llama.cpp/include/: 19 files pushed, 0 skipped. 20.5 MB/s (255173 bytes in 0.012s)
|
||||
pkg-adb/llama.cpp/lib/: 16 files pushed, 0 skipped. 144.4 MB/s (43801382 bytes in 0.289s)
|
||||
~/src/llama.cpp$ adb push pkg-snapdragon/llama.cpp /data/local/tmp/
|
||||
pkg-snapdragon/llama.cpp/bin/: 67 files pushed, 0 skipped. 190.2 MB/s (919095042 bytes in 4.607s)
|
||||
pkg-snapdragon/llama.cpp/include/: 19 files pushed, 0 skipped. 20.5 MB/s (255173 bytes in 0.012s)
|
||||
pkg-snapdragon/llama.cpp/lib/: 16 files pushed, 0 skipped. 144.4 MB/s (43801382 bytes in 0.289s)
|
||||
102 files pushed, 0 skipped. 186.9 MB/s (963151597 bytes in 4.914s)
|
||||
```
|
||||
|
||||
@@ -92,6 +113,11 @@ At this point, you should also install some models:
|
||||
Llama-3.2-1B-Instruct-Q4_0.gguf: 1 file pushed, 0 skipped. 38.3 MB/s (773025920 bytes in 19.250s)
|
||||
```
|
||||
|
||||
### Windows
|
||||
|
||||
All artifacts are already installed in the `pkg-snapdragon` folder.
|
||||
To run, adapt below instructions to use Powershell scrits in `scripts/snapdragon/windows`.
|
||||
|
||||
## How to Run
|
||||
|
||||
The easiest way to run llama.cpp cli tools is using provided wrapper scripts that properly set up all required environment variables.
|
||||
@@ -0,0 +1,161 @@
|
||||
## Overview
|
||||
|
||||
The document covers procedures for installing the latest GPU and NPU drivers, and OpenCL and Hexagon SDKs.
|
||||
|
||||
|
||||
In order to use Hexagon NPU on Snapdragon Windows devices the underlying HTP Ops libraries (e.g libggml-htp-v73.so)
|
||||
must be included in the .cat file digitally signed with a trusted certificate.
|
||||
|
||||
This document covers details on how to generate personal certificate files (.pfx) and how to configure the system
|
||||
to allow for test signatures (aka test-signing).
|
||||
|
||||
## Install the latest Adreno OpenCL SDK
|
||||
|
||||
Either use the trimmed down version (optimized for CI) from
|
||||
|
||||
https://github.com/snapdragon-toolchain/opencl-sdk/releases/download/v2.3.2/adreno-opencl-sdk-v2.3.2-arm64-wos.tar.xz
|
||||
|
||||
Or download the complete official version from
|
||||
|
||||
https://softwarecenter.qualcomm.com/catalog/item/Adreno_OpenCL_SDK?version=2.3.2
|
||||
|
||||
Unzip/untar the archive into
|
||||
```
|
||||
c:\Qualcomm\OpenCL_SDK\2.3.2
|
||||
```
|
||||
|
||||
## Install the latest Hexagon SDK Community Edition
|
||||
|
||||
Either use the trimmed down version (optimized for CI) from
|
||||
|
||||
https://github.com/snapdragon-toolchain/hexagon-sdk/releases/download/v6.4.0.2/hexagon-sdk-v6.4.0.2-arm64-wos.tar.xz
|
||||
|
||||
Or download the complete official version from
|
||||
|
||||
https://softwarecenter.qualcomm.com/catalog/item/Hexagon_SDK?version=6.4.0.2
|
||||
|
||||
Unzip/untar the archive into
|
||||
```
|
||||
c:\Qualcomm\Hexagon_SDK\6.4.0.2
|
||||
```
|
||||
|
||||
## Install the latest Adreno GPU driver
|
||||
|
||||
Download the driver from
|
||||
|
||||
https://softwarecenter.qualcomm.com/catalog/item/Windows_Graphics_Driver
|
||||
|
||||
After the automated installation and reboot please make sure that the GPU device shows up in the `Device Manager` (under 'Display Adapters`)
|
||||
|
||||
## Install the latest Qualcomm NPU driver
|
||||
|
||||
Download the driver from
|
||||
|
||||
https://softwarecenter.qualcomm.com/catalog/item/Qualcomm_HND
|
||||
|
||||
After the automated installation and reboot please make sure that the Hexagon NPU device shows up in the `Device Manager` (under `Neural Processors`).
|
||||
|
||||
If the device is not available you can try installing all components (`qcnspmcdm8380`, `qcnspmcdm8380_ext`) manually.
|
||||
The components are extracted into
|
||||
```
|
||||
c:\QCDrivers\qcnspmcdm...
|
||||
```
|
||||
|
||||
## Enable NPU driver test signatures
|
||||
|
||||
Please note that the following steps are required only for the Hexagon NPU.
|
||||
Adreno GPU backend does not require test signatures.
|
||||
|
||||
### Enable testsigning
|
||||
|
||||
Use `bcdedit` to enable test-signing
|
||||
```
|
||||
> bcdedit /set TESTSIGNING ON
|
||||
```
|
||||
(Secure Boot may need to be disabled for this to work)
|
||||
|
||||
Make sure test-signing is enabled after reboot
|
||||
```
|
||||
> bcdedit /enum
|
||||
...
|
||||
testsigning Yes
|
||||
...
|
||||
```
|
||||
For additional details see Microsoft guide at
|
||||
|
||||
https://learn.microsoft.com/en-us/windows-hardware/drivers/install/the-testsigning-boot-configuration-option
|
||||
|
||||
### Create personal certificate
|
||||
|
||||
The tools required for this procedure are available as part of Windows SDK and Windows Driver Kit which should be
|
||||
installed as part of the MS Visual Studio.
|
||||
They are typically located at
|
||||
```
|
||||
c:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0
|
||||
```
|
||||
(replace 10.0.26100.0 with correct version).
|
||||
|
||||
To create personal self-signed certificate run the following commands (either from cmd or power-shell):
|
||||
```
|
||||
> cd c:\Users\MyUser
|
||||
> mkdir Certs
|
||||
> cd Certs
|
||||
> makecert -r -pe -ss PrivateCertStore -n CN=GGML.HTP.v1 -eku 1.3.6.1.5.5.7.3.3 -sv ggml-htp-v1.pvk ggml-htp-v1.cer
|
||||
> pvk2pfx.exe -pvk ggml-htp-v1.pvk -spc ggml-htp-v1.cer -pfx ggml-htp-v1.pfx
|
||||
```
|
||||
(replace `MyUser` with your username).
|
||||
|
||||
Add this certificate to `Trusted Root Certification Authorities` and `Trusted Publishers` stores.
|
||||
This can be done using `certlm` Certificate Manager tool.
|
||||
Right click on the certificate store, select `All Tasks -> Import` and follow the prompts to import the certificate from the
|
||||
PFX file you created above.
|
||||
|
||||
For additional details see Microsoft guide at
|
||||
|
||||
https://learn.microsoft.com/en-us/windows-hardware/drivers/install/introduction-to-test-signing
|
||||
|
||||
Make sure to save the PFX file, you will need it for the build procedures.
|
||||
Please note that the same certificate can be used for signing any number of builds.
|
||||
|
||||
## Build Hexagon backend with signed HTP ops libraries
|
||||
|
||||
The overall Hexagon backend build procedure for Windows on Snapdragon is the same as for other platforms.
|
||||
However, additional settings are required for generating and signing HTP Ops libraries.
|
||||
```
|
||||
> $env:OPENCL_SDK_ROOT="C:\Qualcomm\OpenCL_SDK\2.3.2"
|
||||
> $env:HEXAGON_SDK_ROOT="C:\Qualcomm\Hexagon_SDK\6.4.0.2"
|
||||
> $env:HEXAGON_TOOLS_ROOT="C:\Qualcomm\Hexagon_SDK\6.4.0.2\tools\HEXAGON_Tools\19.0.04"
|
||||
> $env:HEXAGON_HTP_CERT="c:\Users\MyUsers\Certs\ggml-htp-v1.pfx"
|
||||
> $env:WINDOWS_SDK_BIN="C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\arm64"
|
||||
|
||||
> cmake --preset arm64-windows-snapdragon -B build-wos
|
||||
...
|
||||
> cmake --install build-wos --prefix pkg-snapdragon
|
||||
```
|
||||
|
||||
Once the build is complete HTP ops libraries will be installed like this
|
||||
```
|
||||
> dir pkg-snapdragon/lib
|
||||
...
|
||||
-a---- 1/22/2026 6:01 PM 187656 libggml-htp-v73.so
|
||||
-a---- 1/22/2026 6:01 PM 191752 libggml-htp-v75.so
|
||||
-a---- 1/22/2026 6:01 PM 187656 libggml-htp-v79.so
|
||||
-a---- 1/22/2026 6:01 PM 187656 libggml-htp-v81.so
|
||||
-a---- 1/22/2026 6:01 PM 4139 libggml-htp.cat
|
||||
```
|
||||
|
||||
The .cat file, the signature and proper certicate installation can be verified with
|
||||
|
||||
```
|
||||
> signtool.exe verify /v /pa .\pkg-snapdragon\lib\libggml-htp.cat
|
||||
Verifying: .\pkg-snapdragon\lib\libggml-htp.cat
|
||||
|
||||
Signature Index: 0 (Primary Signature)
|
||||
Hash of file (sha256): 9820C664DA59D5EAE31DBB664127FCDAEF59CDC31502496BC567544EC2F401CF
|
||||
|
||||
Signing Certificate Chain:
|
||||
Issued to: GGML.HTP.v1
|
||||
...
|
||||
Successfully verified: .\pkg-snapdragon\lib\libggml-htp.cat
|
||||
...
|
||||
```
|
||||
+2
-2
@@ -97,7 +97,7 @@ Legend:
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
@@ -114,7 +114,7 @@ Legend:
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
|
||||
+763
-605
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,120 @@
|
||||
# Speculative Decoding
|
||||
|
||||
llama.cpp supports speculative decoding, a technique that can significantly accelerate token generation by predicting multiple tokens ahead of the main model.
|
||||
|
||||
[Speculative decoding](https://en.wikipedia.org/wiki/Transformer_(deep_learning)#Speculative_decoding) leverages the fact that computing n tokens in a batch (as in prompt processing) is more efficient than computing n sequentially (as in response generation). By generating draft tokens quickly and then verifying them with the target model in a single batch, this approach can achieve substantial speedups when the draft predictions are frequently correct.
|
||||
|
||||
## Implementations
|
||||
|
||||
The `llama-server` application supports several implementations of speculative decoding:
|
||||
|
||||
### Draft Model (`draft`)
|
||||
|
||||
A much smaller model (called the _draft model_) generates drafts.
|
||||
A draft model is the most used approach in speculative decoding.
|
||||
|
||||
### n-gram Cache (`ngram-cache`)
|
||||
|
||||
An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences.
|
||||
A draft is computed using probabilities derived from these statistics. External statistics can also be loaded from files for improved accuracy.
|
||||
|
||||
See:
|
||||
|
||||
- #5479, #6828, #6848
|
||||
|
||||
### n-gram Map (`ngram-simple`, `ngram-map-*`)
|
||||
|
||||
These implementations search the token history for patterns and use matching sequences as draft candidates.
|
||||
They require no additional model but rely on patterns that have already appeared in the generated text.
|
||||
An example to use this approach can be the rewriting of source code by a LLM.
|
||||
|
||||
#### n-gram Map (`ngram-simple`)
|
||||
|
||||
This implementation looks for the last n-gram in history that matches the current n-gram and creates a draft using the m tokens following the matched n-gram. It is the simplest self-speculative approach with minimal overhead.
|
||||
|
||||
#### n-gram Map Key (`ngram-map-k`)
|
||||
|
||||
This implementation looks for the current n-gram of size n (called the _key_) in the token history. If the key n-gram is followed by the same m tokens (called the _mgram_) multiple times, it creates a draft using these m tokens. This approach requires a minimum number of occurrences (argument `--spec-ngram-min-hits`) before generating drafts.
|
||||
|
||||
The number of accepted tokens is stored for each used n-gram.
|
||||
|
||||
#### n-gram Map Key-4-Values (`ngram-map-k4v`)
|
||||
|
||||
This experimental implementation looks for the current n-gram of size n (called the _key_) in the token history. For each key, up to four _values_ (n-grams of size m, called _mgrams_) are tracked. An internal statistic counts the occurrences of each mgram after the key n-gram. If one mgram is significantly more frequent than the others, it is used as the draft.
|
||||
|
||||
The number of accepted tokens is stored for each used n-gram.
|
||||
|
||||
**Example:** Server options to be used if there are a lot of longer repetitions.
|
||||
```bash
|
||||
llama-server [...] --spec-type ngram-map-k4v --spec-ngram-size-n 8 --spec-ngram-size-m 8 --spec-ngram-min-hits 2
|
||||
```
|
||||
|
||||
|
||||
## Command-Line Options
|
||||
|
||||
If a draft model is combined with a draftless decoding the draftless decoding has higher precedence.
|
||||
|
||||
```
|
||||
--spec-type [none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v]
|
||||
type of speculative decoding to use when no draft model is provided
|
||||
(default: none)
|
||||
--spec-ngram-size-n N ngram size N for ngram-simple/ngram-map speculative decoding, length
|
||||
of lookup n-gram (default: 12)
|
||||
--spec-ngram-size-m N ngram size M for ngram-simple/ngram-map speculative decoding, length
|
||||
of draft m-gram (default: 48)
|
||||
--spec-ngram-check-rate N ngram check rate for ngram-simple/ngram-map speculative decoding
|
||||
(default: 1)
|
||||
--spec-ngram-min-hits N minimum hits for ngram-map speculative decoding (default: 1)
|
||||
```
|
||||
|
||||
### `--spec-type TYPE`
|
||||
|
||||
Specifies a type of speculative decoding without draft model.
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `none` | No speculative decoding (default) |
|
||||
| `ngram-cache` | Use n-gram cache lookup |
|
||||
| `ngram-simple` | Use simple n-gram pattern matching |
|
||||
| `ngram-map-k` | Use n-gram pattern matching with n-gram-keys |
|
||||
| `ngram-map-k4v` | Use n-gram pattern matching with n-gram-keys and up to four m-gram values (experimental) |
|
||||
|
||||
**Example:** Server-instance used to refactor source code.
|
||||
```bash
|
||||
./llama-server [...] --spec-type ngram-simple
|
||||
```
|
||||
|
||||
### `--spec-ngram-size-n N`
|
||||
|
||||
Sets the size N of the lookup n-gram for n-gram map based speculative decoding.
|
||||
The n-gram size N determines how many tokens in a row to look back when searching for matching patterns.
|
||||
|
||||
### `--spec-ngram-size-m M`
|
||||
|
||||
Sets the size M of the draft m-gram for n-gram map based speculative decoding.
|
||||
The m-gram size determines how many tokens to draft when a match is found.
|
||||
Larger values can provide more speedup but may reduce acceptance rate.
|
||||
|
||||
### `--spec-ngram-check-rate R`
|
||||
|
||||
This option aims at performance if the n-gram lookup in history is to costly. A lookup will be executed at every R tokens (default is 1, every token).
|
||||
|
||||
### `--spec-ngram-min-hits H`
|
||||
|
||||
This option defines how often a key has to appear in the token history to be used as a draft (default is 1).
|
||||
|
||||
## Statistics
|
||||
Each speculative decoding implementation prints statistics.
|
||||
|
||||
```
|
||||
draft acceptance rate = 0.57576 ( 171 accepted / 297 generated)
|
||||
statistics ngram_simple: #calls = 15, #gen drafts = 5, #acc drafts = 5, #gen tokens = 187, #acc tokens = 73
|
||||
statistics draft: #calls = 10, #gen drafts = 10, #acc drafts = 10, #gen tokens = 110, #acc tokens = 98
|
||||
```
|
||||
|
||||
- `#calls`: number of calls of this implementations
|
||||
- `#gen drafts`: number of drafts generated by this implementation
|
||||
- `#acc drafts`: number of drafts accepted (partially) by the main model
|
||||
- `#gen tokens`: number of tokens generated by this implementation (including rejected tokens)
|
||||
- `#acc tokens`: number of tokens accepted by the main model
|
||||
|
||||
@@ -32,9 +32,9 @@ int main(int argc, char ** argv){
|
||||
|
||||
common_ngram_cache ngram_cache;
|
||||
common_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true);
|
||||
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str());
|
||||
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.speculative.lookup_cache_static.c_str());
|
||||
|
||||
common_ngram_cache_save(ngram_cache, params.lookup_cache_static);
|
||||
common_ngram_cache_save(ngram_cache, params.speculative.lookup_cache_static);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -46,18 +46,18 @@ int main(int argc, char ** argv){
|
||||
{
|
||||
const int64_t t_start_draft_us = ggml_time_us();
|
||||
|
||||
if (!params.lookup_cache_static.empty()) {
|
||||
if (!params.speculative.lookup_cache_static.empty()) {
|
||||
try {
|
||||
ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static);
|
||||
ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static);
|
||||
} catch (std::ifstream::failure const &) {
|
||||
LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
|
||||
LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str());
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
if (!params.lookup_cache_dynamic.empty()) {
|
||||
if (!params.speculative.lookup_cache_dynamic.empty()) {
|
||||
try {
|
||||
ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic);
|
||||
ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic);
|
||||
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
|
||||
}
|
||||
|
||||
|
||||
@@ -51,18 +51,18 @@ int main(int argc, char ** argv){
|
||||
const int64_t t_start_draft_us = ggml_time_us();
|
||||
common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false);
|
||||
|
||||
if (!params.lookup_cache_static.empty()) {
|
||||
if (!params.speculative.lookup_cache_static.empty()) {
|
||||
try {
|
||||
ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static);
|
||||
ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static);
|
||||
} catch (std::ifstream::failure const &) {
|
||||
LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
|
||||
LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str());
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
if (!params.lookup_cache_dynamic.empty()) {
|
||||
if (!params.speculative.lookup_cache_dynamic.empty()) {
|
||||
try {
|
||||
ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic);
|
||||
ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic);
|
||||
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
|
||||
}
|
||||
|
||||
@@ -210,7 +210,7 @@ int main(int argc, char ** argv){
|
||||
|
||||
// Update dynamic ngram cache with context ngram cache and save it to disk:
|
||||
common_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context);
|
||||
common_ngram_cache_save(ngram_cache_dynamic, params.lookup_cache_dynamic);
|
||||
common_ngram_cache_save(ngram_cache_dynamic, params.speculative.lookup_cache_dynamic);
|
||||
|
||||
LOG("\n\n");
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
common_init();
|
||||
|
||||
if (params.speculative.model.path.empty()) {
|
||||
if (params.speculative.mparams_dft.path.empty()) {
|
||||
LOG_ERR("%s: --model-draft is required\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@@ -34,10 +34,8 @@ int main(int argc, char ** argv) {
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
llama_model * model_tgt = NULL;
|
||||
//llama_model * model_dft = NULL;
|
||||
|
||||
llama_context * ctx_tgt = NULL;
|
||||
llama_context * ctx_dft = NULL;
|
||||
|
||||
// load the target model
|
||||
auto llama_init_tgt = common_init_from_params(params);
|
||||
@@ -48,26 +46,38 @@ int main(int argc, char ** argv) {
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
|
||||
|
||||
// load the draft model
|
||||
params.devices = params.speculative.devices;
|
||||
params.model = params.speculative.model;
|
||||
params.n_ctx = params.speculative.n_ctx;
|
||||
params.n_batch = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch;
|
||||
params.n_gpu_layers = params.speculative.n_gpu_layers;
|
||||
llama_model_ptr model_dft;
|
||||
|
||||
if (params.speculative.cpuparams.n_threads > 0) {
|
||||
params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
|
||||
}
|
||||
// TODO: simplify this logic
|
||||
{
|
||||
const auto & params_spec = params.speculative;
|
||||
|
||||
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
|
||||
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
|
||||
auto params_dft = params;
|
||||
|
||||
auto llama_init_dft = common_init_from_params(params);
|
||||
params_dft.n_parallel = 1;
|
||||
params_dft.n_ctx = params_spec.n_ctx;
|
||||
params_dft.n_batch = llama_n_ctx_seq(ctx_tgt);
|
||||
params_dft.devices = params_spec.devices;
|
||||
params_dft.model = params_spec.mparams_dft;
|
||||
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
|
||||
|
||||
//model_dft = llama_init_dft->model();
|
||||
ctx_dft = llama_init_dft->context();
|
||||
if (params_spec.cpuparams.n_threads > 0) {
|
||||
params_dft.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
|
||||
params_dft.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
|
||||
}
|
||||
|
||||
if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
|
||||
LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str());
|
||||
params_dft.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
|
||||
|
||||
auto mparams_dft = common_model_params_to_llama(params_dft);
|
||||
|
||||
model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
|
||||
if (model_dft == nullptr) {
|
||||
LOG_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
params.speculative.model_dft = model_dft.get();
|
||||
params.speculative.cparams_dft = common_context_params_to_llama(params_dft);
|
||||
}
|
||||
|
||||
// Tokenize the prompt
|
||||
@@ -92,12 +102,6 @@ int main(int argc, char ** argv) {
|
||||
LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
|
||||
}
|
||||
|
||||
// how many tokens to draft each time
|
||||
int n_draft = params.speculative.n_max;
|
||||
int n_draft_min = params.speculative.n_min;
|
||||
|
||||
float p_min = params.speculative.p_min;
|
||||
|
||||
int n_predict = 0;
|
||||
int n_drafted = 0;
|
||||
int n_accept = 0;
|
||||
@@ -127,15 +131,11 @@ int main(int argc, char ** argv) {
|
||||
int n_past = inp.size() - 1;
|
||||
|
||||
// init the speculator
|
||||
struct common_speculative_params params_spec;
|
||||
params_spec.n_draft = n_draft;
|
||||
params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft;
|
||||
params_spec.p_min = p_min;
|
||||
const auto & params_spec = params.speculative;
|
||||
|
||||
struct common_speculative * spec = common_speculative_init(ctx_tgt, ctx_dft);
|
||||
for (auto &pair : params.speculative.replacements) {
|
||||
common_speculative_add_replacement_tgt_dft(spec, pair.first.c_str(), pair.second.c_str());
|
||||
}
|
||||
struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt);
|
||||
|
||||
common_speculative_begin(spec, prompt_tgt);
|
||||
|
||||
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
|
||||
|
||||
@@ -151,7 +151,7 @@ int main(int argc, char ** argv) {
|
||||
// offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
|
||||
// from a cache or lookup tables.
|
||||
//
|
||||
llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last);
|
||||
llama_tokens draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last);
|
||||
|
||||
//LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
|
||||
|
||||
@@ -162,7 +162,7 @@ int main(int argc, char ** argv) {
|
||||
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
|
||||
{
|
||||
// do not waste time on small drafts
|
||||
if (draft.size() < (size_t) n_draft_min) {
|
||||
if (draft.size() < (size_t) params_spec.n_min) {
|
||||
draft.clear();
|
||||
}
|
||||
|
||||
@@ -240,7 +240,7 @@ int main(int argc, char ** argv) {
|
||||
LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
|
||||
|
||||
LOG_INF("\n");
|
||||
LOG_INF("n_draft = %d\n", n_draft);
|
||||
LOG_INF("n_draft = %d\n", params_spec.n_max);
|
||||
LOG_INF("n_predict = %d\n", n_predict);
|
||||
LOG_INF("n_drafted = %d\n", n_drafted);
|
||||
LOG_INF("n_accept = %d\n", n_accept);
|
||||
@@ -249,8 +249,6 @@ int main(int argc, char ** argv) {
|
||||
LOG_INF("\n");
|
||||
LOG_INF("draft:\n\n");
|
||||
|
||||
llama_perf_context_print(ctx_dft);
|
||||
|
||||
LOG_INF("\n");
|
||||
LOG_INF("target:\n\n");
|
||||
common_perf_print(ctx_tgt, smpl);
|
||||
|
||||
@@ -46,7 +46,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
common_init();
|
||||
|
||||
if (params.speculative.model.path.empty()) {
|
||||
if (params.speculative.mparams_dft.path.empty()) {
|
||||
LOG_ERR("%s: --model-draft is required\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@@ -78,7 +78,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// load the draft model
|
||||
params.devices = params.speculative.devices;
|
||||
params.model = params.speculative.model;
|
||||
params.model = params.speculative.mparams_dft;
|
||||
params.n_gpu_layers = params.speculative.n_gpu_layers;
|
||||
if (params.speculative.cpuparams.n_threads > 0) {
|
||||
params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
|
||||
|
||||
@@ -222,6 +222,7 @@ if (GGML_SCHED_NO_REALLOC)
|
||||
endif()
|
||||
|
||||
add_library(ggml
|
||||
ggml-backend-dl.cpp
|
||||
ggml-backend-reg.cpp)
|
||||
add_library(ggml::ggml ALIAS ggml)
|
||||
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
#include "ggml-backend-dl.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
dl_handle * dl_load_library(const fs::path & path) {
|
||||
// suppress error dialogs for missing DLLs
|
||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
|
||||
HMODULE handle = LoadLibraryW(path.wstring().c_str());
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
|
||||
void * p = (void *) GetProcAddress(handle, name);
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
return p;
|
||||
}
|
||||
|
||||
const char * dl_error() {
|
||||
return "";
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
dl_handle * dl_load_library(const fs::path & path) {
|
||||
dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
|
||||
return handle;
|
||||
}
|
||||
|
||||
void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||
return dlsym(handle, name);
|
||||
}
|
||||
|
||||
const char * dl_error() {
|
||||
const char *rslt = dlerror();
|
||||
return rslt != nullptr ? rslt : "";
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,45 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef _WIN32
|
||||
# define WIN32_LEAN_AND_MEAN
|
||||
# ifndef NOMINMAX
|
||||
# define NOMINMAX
|
||||
# endif
|
||||
# include <windows.h>
|
||||
# include <winevt.h>
|
||||
#else
|
||||
# include <dlfcn.h>
|
||||
# include <unistd.h>
|
||||
#endif
|
||||
#include <filesystem>
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
using dl_handle = std::remove_pointer_t<HMODULE>;
|
||||
|
||||
struct dl_handle_deleter {
|
||||
void operator()(HMODULE handle) {
|
||||
FreeLibrary(handle);
|
||||
}
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
using dl_handle = void;
|
||||
|
||||
struct dl_handle_deleter {
|
||||
void operator()(void * handle) {
|
||||
dlclose(handle);
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
|
||||
|
||||
dl_handle * dl_load_library(const fs::path & path);
|
||||
void * dl_get_sym(dl_handle * handle, const char * name);
|
||||
const char * dl_error();
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-backend-dl.h"
|
||||
#include "ggml-impl.h"
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
@@ -98,72 +99,6 @@ static std::string path_str(const fs::path & path) {
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
using dl_handle = std::remove_pointer_t<HMODULE>;
|
||||
|
||||
struct dl_handle_deleter {
|
||||
void operator()(HMODULE handle) {
|
||||
FreeLibrary(handle);
|
||||
}
|
||||
};
|
||||
|
||||
static dl_handle * dl_load_library(const fs::path & path) {
|
||||
// suppress error dialogs for missing DLLs
|
||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
|
||||
HMODULE handle = LoadLibraryW(path.wstring().c_str());
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
static void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
|
||||
void * p = (void *) GetProcAddress(handle, name);
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
return p;
|
||||
}
|
||||
|
||||
static const char * dl_error() {
|
||||
return "";
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
using dl_handle = void;
|
||||
|
||||
struct dl_handle_deleter {
|
||||
void operator()(void * handle) {
|
||||
dlclose(handle);
|
||||
}
|
||||
};
|
||||
|
||||
static void * dl_load_library(const fs::path & path) {
|
||||
dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
static void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||
return dlsym(handle, name);
|
||||
}
|
||||
|
||||
static const char * dl_error() {
|
||||
const char *rslt = dlerror();
|
||||
return rslt != nullptr ? rslt : "";
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
|
||||
|
||||
struct ggml_backend_reg_entry {
|
||||
ggml_backend_reg_t reg;
|
||||
dl_handle_ptr handle;
|
||||
|
||||
@@ -1122,15 +1122,18 @@ struct ggml_tensor_extra_gpu {
|
||||
#endif
|
||||
|
||||
struct ggml_cuda_graph_node_properties {
|
||||
void * node_address;
|
||||
void * node_data;
|
||||
ggml_op node_op;
|
||||
enum ggml_type node_type;
|
||||
int32_t flags;
|
||||
int64_t ne[GGML_MAX_DIMS];
|
||||
size_t nb[GGML_MAX_DIMS];
|
||||
void * src_address[GGML_MAX_SRC];
|
||||
void * src_data[GGML_MAX_SRC];
|
||||
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
|
||||
};
|
||||
|
||||
static_assert(std::is_trivial<ggml_cuda_graph_node_properties>::value, "ggml_cuda_graph_node_properties must be trivial");
|
||||
|
||||
struct ggml_cuda_graph {
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
~ggml_cuda_graph() {
|
||||
@@ -1150,6 +1153,12 @@ struct ggml_cuda_graph {
|
||||
int number_consecutive_updates = 0;
|
||||
std::vector<ggml_cuda_graph_node_properties> props;
|
||||
|
||||
// these are extra tensors (inputs) that participate in the ggml graph but are not nodes
|
||||
// they properties also have to match in order to be able to safely reuse a CUDA graph
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18583
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/19165
|
||||
std::vector<ggml_cuda_graph_node_properties> extra;
|
||||
|
||||
void record_update(bool use_graph, bool update_required) {
|
||||
if (use_graph && update_required) {
|
||||
number_consecutive_updates++;
|
||||
|
||||
@@ -310,8 +310,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
}
|
||||
}
|
||||
|
||||
const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
|
||||
|
||||
const int cc = ggml_cuda_info().devices[device].cc;
|
||||
|
||||
switch (K->ne[0]) {
|
||||
@@ -334,9 +332,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
if (!gqa_opt_applies) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!V_is_K_view) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
|
||||
+269
-92
@@ -70,17 +70,18 @@
|
||||
#include <condition_variable>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <float.h>
|
||||
#include <cfloat>
|
||||
#include <initializer_list>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <stdarg.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <cstdarg>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_set>
|
||||
|
||||
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||
|
||||
@@ -2916,22 +2917,27 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
}
|
||||
|
||||
static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
|
||||
props->node_address = node->data;
|
||||
memset(props, 0, sizeof(ggml_cuda_graph_node_properties));
|
||||
props->node_data = node->data;
|
||||
props->node_op = node->op;
|
||||
props->node_type = node->type;
|
||||
props->flags = node->flags;
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
props->ne[i] = node->ne[i];
|
||||
props->nb[i] = node->nb[i];
|
||||
}
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
|
||||
if (!node->src[i]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
props->src_data[i] = node->src[i]->data;
|
||||
}
|
||||
memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
|
||||
}
|
||||
|
||||
static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
|
||||
if (node->data != props->node_address &&
|
||||
node->op != GGML_OP_VIEW) {
|
||||
if (node->data != props->node_data && node->op != GGML_OP_VIEW) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -2939,6 +2945,10 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node->type != props->node_type) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (node->ne[i] != props->ne[i]) {
|
||||
return false;
|
||||
@@ -2948,12 +2958,18 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (node->src[i] &&
|
||||
node->src[i]->data != props->src_address[i] &&
|
||||
node->op != GGML_OP_VIEW
|
||||
) {
|
||||
return false;
|
||||
if (node->op != GGML_OP_VIEW) {
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (!node->src[i]) {
|
||||
if (props->src_data[i] != nullptr) {
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node->src[i]->data != props->src_data[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2974,7 +2990,6 @@ static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
|
||||
}
|
||||
|
||||
static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
|
||||
|
||||
bool res = false;
|
||||
|
||||
const void * graph_key = ggml_cuda_graph_get_key(cgraph);
|
||||
@@ -2985,15 +3000,20 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
|
||||
}
|
||||
|
||||
// Check if the graph size has changed
|
||||
if (graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
|
||||
if (graph->props.size() != (size_t)cgraph->n_nodes) {
|
||||
res = true;
|
||||
graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
|
||||
graph->props.resize(cgraph->n_nodes);
|
||||
}
|
||||
|
||||
// Loop over nodes in GGML graph to determine if CUDA graph update is required
|
||||
// and store properties to allow this comparison for the next token
|
||||
std::unordered_set<ggml_tensor *> seen_node;
|
||||
std::vector<ggml_tensor *> srcs_extra;
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
bool props_match = true;
|
||||
|
||||
seen_node.insert(cgraph->nodes[i]);
|
||||
|
||||
if (!res) {
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
|
||||
}
|
||||
@@ -3001,17 +3021,31 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
|
||||
res = true;
|
||||
}
|
||||
ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
|
||||
|
||||
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
|
||||
ggml_tensor * src = cgraph->nodes[i]->src[src_idx];
|
||||
if (src && seen_node.find(src) == seen_node.end()) {
|
||||
srcs_extra.push_back(src);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < cgraph->n_leafs; i++) {
|
||||
if (graph->extra.size() != (size_t) srcs_extra.size()) {
|
||||
res = true;
|
||||
graph->extra.resize(srcs_extra.size());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < srcs_extra.size(); ++i) {
|
||||
bool props_match = true;
|
||||
|
||||
if (!res) {
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &graph->props[cgraph->n_nodes + i]);
|
||||
props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]);
|
||||
}
|
||||
|
||||
if (!props_match) {
|
||||
res = true;
|
||||
}
|
||||
ggml_cuda_graph_node_set_properties(&graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
|
||||
ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]);
|
||||
}
|
||||
|
||||
return res;
|
||||
@@ -3080,63 +3114,166 @@ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
|
||||
static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) {
|
||||
args.sigmoid = false;
|
||||
args.softmax = false;
|
||||
args.delayed_softmax = false;
|
||||
args.prob_bias = false;
|
||||
args.norm = false;
|
||||
|
||||
const int n_nodes = cgraph->n_nodes;
|
||||
ggml_tensor ** nodes = cgraph->nodes;
|
||||
|
||||
if (nodes[node_idx]->op == GGML_OP_SOFT_MAX) {
|
||||
args.softmax = true;
|
||||
}
|
||||
|
||||
if (nodes[node_idx]->op == GGML_OP_UNARY) {
|
||||
if (ggml_get_unary_op(nodes[node_idx]) != GGML_UNARY_OP_SIGMOID) {
|
||||
return false;
|
||||
}
|
||||
args.sigmoid = true;
|
||||
}
|
||||
|
||||
if (nodes[node_idx]->op == GGML_OP_ARGSORT) {
|
||||
args.delayed_softmax = true;
|
||||
}
|
||||
|
||||
node_idx++;
|
||||
|
||||
if (args.sigmoid || args.softmax) {
|
||||
// SOFTMAX -> RESHAPE
|
||||
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_RESHAPE ||
|
||||
nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
|
||||
return false;
|
||||
}
|
||||
ggml_tensor * probs_reshaped = nodes[node_idx];
|
||||
node_idx++;
|
||||
|
||||
if (node_idx >= n_nodes) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// src of bias add is the unreshaped probs (-2 instead of -1)
|
||||
if (nodes[node_idx]->op == GGML_OP_ADD && nodes[node_idx]->src[0] == nodes[node_idx - 2]) {
|
||||
args.prob_bias = true;
|
||||
node_idx++;
|
||||
}
|
||||
// RESHAPE/ADD -> ARGSORT
|
||||
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_ARGSORT) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
|
||||
return false;
|
||||
} else if (!args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 2]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
node_idx++;
|
||||
|
||||
// ARGSORT-> VIEW
|
||||
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
|
||||
nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
|
||||
return false;
|
||||
}
|
||||
node_idx++;
|
||||
|
||||
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_GET_ROWS) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// GET_ROWS
|
||||
if (nodes[node_idx]->src[0] != probs_reshaped || nodes[node_idx]->src[1] != nodes[node_idx - 1]) {
|
||||
return false;
|
||||
}
|
||||
node_idx++;
|
||||
} else if (args.delayed_softmax) {
|
||||
if (node_idx - 2 < 0) {
|
||||
return false;
|
||||
}
|
||||
ggml_tensor * probs_reshaped = nodes[node_idx - 2];
|
||||
|
||||
// VIEW->ARGSORT
|
||||
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
|
||||
nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
|
||||
return false;
|
||||
}
|
||||
node_idx++;
|
||||
|
||||
// GET_ROWS
|
||||
if (node_idx >= n_nodes || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
|
||||
nodes[node_idx]->src[0] != probs_reshaped) {
|
||||
return false;
|
||||
}
|
||||
node_idx++;
|
||||
|
||||
static const std::vector<ggml_op> remaining_ops = { GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
|
||||
|
||||
for (const ggml_op op : remaining_ops) {
|
||||
if (node_idx >= n_nodes || nodes[node_idx]->op != op || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
|
||||
return false;
|
||||
}
|
||||
node_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
// At this point we can check for norm + scale. Everything is now at least valid till the norm
|
||||
if (node_idx >= n_nodes) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (nodes[node_idx]->op == GGML_OP_RESHAPE) {
|
||||
//check RESHAPE->SUM_ROWS->CLAMP->DIV->RESHAPE
|
||||
static const std::vector<ggml_op> norm_ops = { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP };
|
||||
|
||||
args.norm = true;
|
||||
for (const ggml_op op : norm_ops) {
|
||||
if (nodes[node_idx]->op == op && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
|
||||
node_idx++;
|
||||
} else {
|
||||
args.norm = false;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// DIV <- CLAMP, RESHAPE
|
||||
if (nodes[node_idx]->op != GGML_OP_DIV || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
|
||||
nodes[node_idx]->src[0] != nodes[node_idx - 3]) {
|
||||
args.norm = false;
|
||||
return true;
|
||||
}
|
||||
node_idx++;
|
||||
|
||||
if (nodes[node_idx]->op != GGML_OP_RESHAPE || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
|
||||
args.norm = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
node_idx++;
|
||||
}
|
||||
|
||||
if (nodes[node_idx]->op == GGML_OP_SCALE && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
|
||||
args.scale = true;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
|
||||
int node_idx,
|
||||
std::initializer_list<enum ggml_op> ops,
|
||||
std::initializer_list<enum ggml_unary_op> unary_ops) {
|
||||
#ifndef NDEBUG
|
||||
const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
|
||||
GGML_ASSERT(unary_ops.size() == num_unary);
|
||||
#endif
|
||||
|
||||
//TODO: remove special case once ggml_can_fuse can handle empty nodes
|
||||
std::initializer_list<enum ggml_op> topk_moe_ops =
|
||||
ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
|
||||
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
|
||||
ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
|
||||
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
|
||||
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
|
||||
|
||||
const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
|
||||
const std::initializer_list<enum ggml_op> & list2) {
|
||||
return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
|
||||
};
|
||||
|
||||
if (is_equal(topk_moe_ops_with_norm, ops) &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
|
||||
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
|
||||
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
|
||||
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
|
||||
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
|
||||
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
|
||||
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
|
||||
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
|
||||
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
|
||||
ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
|
||||
ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
|
||||
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
|
||||
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
|
||||
std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
|
||||
|
||||
@@ -3398,35 +3535,75 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
||||
// start of fusion operations
|
||||
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
|
||||
if (!disable_fusion) {
|
||||
ggml_cuda_topk_moe_args args;
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
|
||||
ggml_tensor * weights = cgraph->nodes[i + 9];
|
||||
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
|
||||
ggml_tensor * clamp = cgraph->nodes[i + 7];
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
|
||||
/*delayed softmax*/ false, clamp);
|
||||
i += 9;
|
||||
continue;
|
||||
}
|
||||
if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
|
||||
cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
|
||||
const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args);
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
|
||||
ggml_tensor * weights = cgraph->nodes[i + 4];
|
||||
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
|
||||
/*delayed softmax*/ false);
|
||||
i += 4;
|
||||
continue;
|
||||
}
|
||||
std::vector<ggml_op> ops;
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i,
|
||||
ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
|
||||
ggml_tensor * weights = cgraph->nodes[i + 5];
|
||||
ggml_tensor * ids = cgraph->nodes[i + 1];
|
||||
if (can_fuse) {
|
||||
const ggml_tensor * logits = node->src[0];
|
||||
ggml_tensor * weights = nullptr;
|
||||
ggml_tensor * ids = nullptr;
|
||||
const ggml_tensor * bias = nullptr;
|
||||
const ggml_tensor * clamp = nullptr;
|
||||
const ggml_tensor * scale = nullptr;
|
||||
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
|
||||
/*delayed_softmax*/ true);
|
||||
i += 5;
|
||||
continue;
|
||||
if (!args.delayed_softmax) {
|
||||
ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;
|
||||
int out_nodes[2]; // nodes which can't be elided
|
||||
|
||||
if (args.prob_bias) {
|
||||
bias = cgraph->nodes[i + 2]->src[1];
|
||||
ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS });
|
||||
out_nodes[0] = i + 4;
|
||||
ids = cgraph->nodes[i + 4];
|
||||
} else {
|
||||
ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW,
|
||||
GGML_OP_GET_ROWS });
|
||||
out_nodes[0] = i + 3;
|
||||
ids = cgraph->nodes[i + 3];
|
||||
}
|
||||
|
||||
if (args.norm) {
|
||||
ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
|
||||
GGML_OP_DIV, GGML_OP_RESHAPE });
|
||||
clamp = cgraph->nodes[i + ops.size() - 3];
|
||||
}
|
||||
if (args.scale) {
|
||||
ops.insert(ops.end(), { GGML_OP_SCALE });
|
||||
scale = cgraph->nodes[i + ops.size() - 1];
|
||||
}
|
||||
|
||||
weights = cgraph->nodes[i + ops.size() - 1];
|
||||
out_nodes[1] = i + ops.size() - 1;
|
||||
|
||||
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
|
||||
ggml_cuda_should_use_topk_moe(node, logits, weights, ids)) {
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
|
||||
i += ops.size() - 1;
|
||||
continue;
|
||||
}
|
||||
} else if (!args.norm && !args.prob_bias) {
|
||||
//special case gpt-oss, no norm, no bias.
|
||||
ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
|
||||
GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });
|
||||
weights = cgraph->nodes[i + 5];
|
||||
ids = cgraph->nodes[i + 1];
|
||||
const ggml_tensor * softmax = cgraph->nodes[i + 4];
|
||||
|
||||
int out_nodes[2] = { i + 1, i + 5 };
|
||||
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
|
||||
ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids)) {
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
|
||||
i += ops.size() - 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
|
||||
|
||||
@@ -333,7 +333,33 @@ namespace ggml_cuda_mma {
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return 4 * (threadIdx.x / 16) + l;
|
||||
return ne * (threadIdx.x / 16) + l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
static constexpr int ne = I * J / 64;
|
||||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 16 && J == 8) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return threadIdx.x % 16;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return ne * (threadIdx.x / 16) + l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
@@ -391,7 +417,22 @@ namespace ggml_cuda_mma {
|
||||
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
||||
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
static constexpr int ne = I * J / 32;
|
||||
static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
|
||||
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
|
||||
}
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
|
||||
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
@@ -945,6 +986,32 @@ namespace ggml_cuda_mma {
|
||||
#endif // AMPERE_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template <data_layout dl_ab, data_layout dl_d>
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) {
|
||||
#ifdef AMD_MFMA_AVAILABLE
|
||||
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
||||
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
|
||||
#if defined(CDNA3)
|
||||
using floatx2_t = __attribute__((ext_vector_type(2))) float;
|
||||
const floatx2_t& a_frag = reinterpret_cast<const floatx2_t&>(A.x[0]);
|
||||
const floatx2_t& b_frag = reinterpret_cast<const floatx2_t&>(B.x[0]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
#elif defined(CDNA2) || defined(CDNA1)
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(CDNA3)
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // AMD_MFMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
|
||||
const tile<16, 8, int> & A,
|
||||
const tile<8, 8, int> & B,
|
||||
@@ -1054,6 +1121,13 @@ namespace ggml_cuda_mma {
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // RDNA4
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
|
||||
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
||||
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
|
||||
const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]);
|
||||
const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
@@ -1081,11 +1155,31 @@ namespace ggml_cuda_mma {
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // RDNA4
|
||||
#endif // defined(RDNA4)
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
||||
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
|
||||
#if defined(CDNA3) || defined(CDNA2)
|
||||
using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16;
|
||||
const bf16x4_t& a_frag = reinterpret_cast<const bf16x4_t&>(A.x[0]);
|
||||
const bf16x4_t& b_frag = reinterpret_cast<const bf16x4_t&>(B.x[0]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
#elif defined(CDNA1)
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16;
|
||||
const bf16x2_t& a_frag = reinterpret_cast<const bf16x2_t&>(A.x[i]);
|
||||
const bf16x2_t& b_frag = reinterpret_cast<const bf16x2_t&>(B.x[i]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // AMPERE_MMA_AVAILABLE
|
||||
#endif // defined(CDNA3) || defined(CDNA2)
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template <data_layout dl_d, data_layout dl_ab>
|
||||
|
||||
+30
-10
@@ -2,6 +2,13 @@
|
||||
#include "mmf.cuh"
|
||||
#include "mmid.cuh"
|
||||
|
||||
static __forceinline__ int mmf_get_rows_per_block(const int cc) {
|
||||
if (GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||
return MMF_ROWS_PER_BLOCK_CDNA;
|
||||
} else {
|
||||
return MMF_ROWS_PER_BLOCK;
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||||
@@ -89,28 +96,32 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
||||
ids_info_ptr = &ids_info;
|
||||
}
|
||||
|
||||
const int device = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[device].cc;
|
||||
const int rows_per_block = mmf_get_rows_per_block(cc);
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
constexpr int vals_per_T = 1;
|
||||
mul_mat_f_switch_cols_per_block(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
mul_mat_f_switch_rows_per_block<float>(
|
||||
rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
const half2 * src0_d = (const half2 *) src0->data;
|
||||
constexpr int vals_per_T = 2;
|
||||
mul_mat_f_switch_cols_per_block(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
mul_mat_f_switch_rows_per_block<half2>(
|
||||
rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
|
||||
constexpr int vals_per_T = 2;
|
||||
mul_mat_f_switch_cols_per_block(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
mul_mat_f_switch_rows_per_block<nv_bfloat162>(
|
||||
rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
} break;
|
||||
@@ -140,7 +151,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
|
||||
if (src0_ne[1] % mmf_get_rows_per_block(cc) != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (GGML_CUDA_CC_IS_CDNA3(cc) && type == GGML_TYPE_BF16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -153,6 +168,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||
} else {
|
||||
if (GGML_CUDA_CC_IS_RDNA3_0(cc) && src1_ncols > 8) {
|
||||
return false;
|
||||
} else if (GGML_CUDA_CC_IS_CDNA2(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
|
||||
//TODO: truse CDNA2 as CDNA1, tune the perf when CDNA2 is available.
|
||||
return false;
|
||||
} else if (GGML_CUDA_CC_IS_CDNA1(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
|
||||
return false;
|
||||
} else if (src1_ncols > 16) {
|
||||
return false;
|
||||
}
|
||||
@@ -160,11 +180,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||
|
||||
switch (type) {
|
||||
case GGML_TYPE_F32:
|
||||
return ampere_mma_available(cc);
|
||||
return ampere_mma_available(cc) || amd_mfma_available(cc);
|
||||
case GGML_TYPE_F16:
|
||||
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
|
||||
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
|
||||
case GGML_TYPE_BF16:
|
||||
return ampere_mma_available(cc) || amd_wmma_available(cc);
|
||||
return ampere_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
+158
-85
@@ -7,6 +7,31 @@
|
||||
using namespace ggml_cuda_mma;
|
||||
|
||||
#define MMF_ROWS_PER_BLOCK 32
|
||||
#define MMF_ROWS_PER_BLOCK_CDNA 64
|
||||
|
||||
static __forceinline__ int64_t mmf_get_max_block_size(int cc) {
|
||||
if (GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||
return 512;
|
||||
} else {
|
||||
return 256;
|
||||
}
|
||||
}
|
||||
|
||||
static __forceinline__ int mmf_get_padding(int cc) {
|
||||
if (GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||
return 2;
|
||||
} else {
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr __device__ int mmf_get_padding() {
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
return 2;
|
||||
#else
|
||||
return 4;
|
||||
#endif // defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
struct mmf_ids_data {
|
||||
const int32_t * ids_src_compact = nullptr;
|
||||
@@ -29,23 +54,25 @@ static __global__ void mul_mat_f(
|
||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
|
||||
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
||||
constexpr bool is_tf32 = std::is_same_v<T, float>;
|
||||
constexpr int tile_B_I = is_tf32 ? 8 : 16;
|
||||
constexpr int tile_C_J = is_tf32 ? 8 : 16;
|
||||
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
|
||||
typedef tile<16, 8, T, ab_layout> tile_A;
|
||||
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
|
||||
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T, get_input_data_layout()> tile_A;
|
||||
typedef tile<16, 8, T, get_input_data_layout()> tile_B;
|
||||
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
|
||||
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
#else
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
|
||||
if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
|
||||
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
|
||||
#else
|
||||
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
@@ -57,7 +84,7 @@ static __global__ void mul_mat_f(
|
||||
}
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int tile_k_padded = warp_size + 4;
|
||||
constexpr int tile_k_padded = warp_size + mmf_get_padding();
|
||||
constexpr int ntA = rows_per_block / tile_A::I;
|
||||
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
|
||||
|
||||
@@ -198,7 +225,7 @@ static __global__ void mul_mat_f(
|
||||
}
|
||||
|
||||
float * buf_iw = (float *) compute_base;
|
||||
constexpr int kiw = nwarps*rows_per_block + 4;
|
||||
constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
|
||||
|
||||
if (nwarps > 1) {
|
||||
__syncthreads();
|
||||
@@ -228,27 +255,34 @@ static __global__ void mul_mat_f(
|
||||
return;
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
static_assert(rows_per_block == warp_size, "need loop/check");
|
||||
float sum[rows_per_block/warp_size] = {0.0f};
|
||||
static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
|
||||
const int i = i0 + i1*warp_size + threadIdx.x;
|
||||
|
||||
sum += buf_iw[j*kiw + i];
|
||||
sum[i1] += buf_iw[j*kiw + i];
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (!has_ids) {
|
||||
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
|
||||
dst[j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
|
||||
}
|
||||
} else {
|
||||
const int slot = (j < cols_per_block) ? slot_map[j] : -1;
|
||||
if (slot >= 0 && (col_base + j) < ncols_dst_total) {
|
||||
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
|
||||
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
}
|
||||
#endif //VOLTA_MMA_AVAILABLE
|
||||
#else
|
||||
GGML_UNUSED_VARS(x, y, ids, dst,
|
||||
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
@@ -256,7 +290,7 @@ static __global__ void mul_mat_f(
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
//This kernel is for larger batch sizes of mul_mat_id
|
||||
@@ -271,23 +305,25 @@ static __global__ void mul_mat_f_ids(
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
const uint3 sis1_fd, const uint3 nch_fd) {
|
||||
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
|
||||
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
||||
constexpr bool is_tf32 = std::is_same_v<T, float>;
|
||||
constexpr int tile_B_I = is_tf32 ? 8 : 16;
|
||||
constexpr int tile_C_J = is_tf32 ? 8 : 16;
|
||||
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
|
||||
typedef tile<16, 8, T, ab_layout> tile_A;
|
||||
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
|
||||
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T, get_input_data_layout()> tile_A;
|
||||
typedef tile<16, 8, T, get_input_data_layout()> tile_B;
|
||||
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
|
||||
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
#else
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
|
||||
if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
|
||||
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
|
||||
#else
|
||||
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
@@ -300,7 +336,7 @@ static __global__ void mul_mat_f_ids(
|
||||
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int tile_k_padded = warp_size + 4;
|
||||
constexpr int tile_k_padded = warp_size + mmf_get_padding();
|
||||
constexpr int ntA = rows_per_block / tile_A::I;
|
||||
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
|
||||
|
||||
@@ -467,7 +503,7 @@ static __global__ void mul_mat_f_ids(
|
||||
}
|
||||
|
||||
float * buf_iw = (float *) compute_base;
|
||||
constexpr int kiw = nwarps*rows_per_block + 4;
|
||||
constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
|
||||
|
||||
if (nwarps > 1) {
|
||||
__syncthreads();
|
||||
@@ -497,13 +533,16 @@ static __global__ void mul_mat_f_ids(
|
||||
return;
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
static_assert(rows_per_block == warp_size, "need loop/check");
|
||||
float sum[rows_per_block/warp_size] = {0.0f};
|
||||
static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
|
||||
const int i = i0 + i1*warp_size + threadIdx.x;
|
||||
|
||||
sum += buf_iw[j*kiw + i];
|
||||
sum[i1] += buf_iw[j * kiw + i];
|
||||
}
|
||||
}
|
||||
|
||||
const int global_j = col_base + j;
|
||||
@@ -513,23 +552,24 @@ static __global__ void mul_mat_f_ids(
|
||||
const int token = (int) qrm.x;
|
||||
if (token < ncols_dst_total) {
|
||||
const int slot = (int) qrm.y;
|
||||
dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
|
||||
dst[slot * stride_channel_dst + token * stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
}
|
||||
#endif // VOLTA_MMA_AVAILABLE
|
||||
#else
|
||||
GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
|
||||
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template<typename T, int cols_per_block, int nwarps>
|
||||
template<typename T, int rows_per_block, int cols_per_block, int nwarps>
|
||||
static inline void mul_mat_f_switch_ids(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
|
||||
@@ -553,7 +593,7 @@ static inline void mul_mat_f_switch_ids(
|
||||
const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
|
||||
const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst);
|
||||
|
||||
mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||
mul_mat_f_ids<T, rows_per_block, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
|
||||
ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
@@ -564,19 +604,19 @@ static inline void mul_mat_f_switch_ids(
|
||||
dim3 block_nums_ids = block_nums;
|
||||
block_nums_ids.y *= col_tiles;
|
||||
|
||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||
mul_mat_f<T, rows_per_block, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} else {
|
||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
||||
mul_mat_f<T, rows_per_block, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int cols_per_block>
|
||||
template <typename T, int rows_per_block, int cols_per_block>
|
||||
void mul_mat_f_cuda(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
|
||||
@@ -605,7 +645,7 @@ void mul_mat_f_cuda(
|
||||
|
||||
int64_t nwarps_best = 1;
|
||||
int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
|
||||
int64_t max_block_size = 256;
|
||||
int64_t max_block_size = mmf_get_max_block_size(cc);
|
||||
for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
|
||||
const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
|
||||
if (niter < niter_best) {
|
||||
@@ -614,10 +654,9 @@ void mul_mat_f_cuda(
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
|
||||
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
|
||||
const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
|
||||
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
|
||||
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + mmf_get_padding(cc)) * 4;
|
||||
const int nbytes_cols_per_block_pad = (amd_wmma_available(cc) || amd_mfma_available(cc)) ? tile_B_16::I : tile_B_8::I;
|
||||
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + mmf_get_padding(cc)) * 4;
|
||||
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
||||
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
|
||||
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
|
||||
@@ -628,56 +667,56 @@ void mul_mat_f_cuda(
|
||||
|
||||
switch (nwarps_best) {
|
||||
case 1: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 1>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 1>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 2: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 2>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 2>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 3: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 3>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 3>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 4: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 4>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 4>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 5: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 5>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 5>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 6: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 6>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 6>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 7: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 7>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 7>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 8: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 8>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 8>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
@@ -691,7 +730,7 @@ void mul_mat_f_cuda(
|
||||
GGML_UNUSED_VARS(nchannels_y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, int rows_per_block>
|
||||
static void mul_mat_f_switch_cols_per_block(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
|
||||
@@ -708,82 +747,82 @@ static void mul_mat_f_switch_cols_per_block(
|
||||
|
||||
switch (ncols_case) {
|
||||
case 1: {
|
||||
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 2: {
|
||||
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 3: {
|
||||
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 4: {
|
||||
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 5: {
|
||||
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 6: {
|
||||
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 7: {
|
||||
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 8: {
|
||||
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 9: {
|
||||
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 10: {
|
||||
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 11: {
|
||||
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 12: {
|
||||
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 13: {
|
||||
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 14: {
|
||||
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 15: {
|
||||
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 16: {
|
||||
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
@@ -793,8 +832,36 @@ static void mul_mat_f_switch_cols_per_block(
|
||||
}
|
||||
}
|
||||
|
||||
#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
|
||||
template void mul_mat_f_cuda<T, ncols_dst>( \
|
||||
template <typename T>
|
||||
static void mul_mat_f_switch_rows_per_block(
|
||||
const int rows_per_block, const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t stride_col_id, const int stride_row_id,
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
cudaStream_t stream, const mmf_ids_data * ids_data) {
|
||||
switch (rows_per_block) {
|
||||
case MMF_ROWS_PER_BLOCK: {
|
||||
mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK>(
|
||||
x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case MMF_ROWS_PER_BLOCK_CDNA: {
|
||||
mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK_CDNA>(
|
||||
x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("unsupported rows_per_block: %i", rows_per_block);
|
||||
}
|
||||
}
|
||||
|
||||
#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \
|
||||
template void mul_mat_f_cuda<T, nrows_dst, ncols_dst>( \
|
||||
const T * x, const float * y, const int32_t * ids, float * dst, \
|
||||
const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
|
||||
const int64_t stride_col_id, const int64_t stride_row_id, \
|
||||
@@ -803,16 +870,22 @@ static void mul_mat_f_switch_cols_per_block(
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
|
||||
cudaStream_t stream, const mmf_ids_data * ids_data);
|
||||
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
#if !defined(GGML_USE_MUSA)
|
||||
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(float, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
|
||||
extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
|
||||
|
||||
#define DECL_MMF_CASE(ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(float, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(half2, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
|
||||
DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
|
||||
|
||||
DECL_MMF_CASE_EXTERN(1);
|
||||
DECL_MMF_CASE_EXTERN(2);
|
||||
|
||||
+191
-139
@@ -5,6 +5,13 @@
|
||||
#include <cmath>
|
||||
#include <initializer_list>
|
||||
|
||||
// Kernel config struct - passed by value to CUDA kernel
|
||||
struct topk_moe_config {
|
||||
bool use_sigmoid;
|
||||
bool with_norm;
|
||||
bool delayed_softmax;
|
||||
};
|
||||
|
||||
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
|
||||
template <int experts_per_thread, bool use_limit>
|
||||
__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
|
||||
@@ -50,6 +57,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
|
||||
}
|
||||
}
|
||||
|
||||
template <int experts_per_thread, bool use_limit>
|
||||
__device__ void sigmoid_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const int idx = lane + i * WARP_SIZE;
|
||||
const bool active = !use_limit || (idx < limit);
|
||||
vals[i] = active ? 1.f / (1.f + expf(-vals[i])) : -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
This kernel does the following:
|
||||
1. optionally softmax over the logits per token [n_experts, n_tokens]
|
||||
@@ -59,13 +76,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
|
||||
|
||||
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
|
||||
*/
|
||||
template <int n_experts, bool with_norm, bool delayed_softmax = false>
|
||||
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
|
||||
float * weights,
|
||||
int32_t * ids,
|
||||
const int n_rows,
|
||||
const int n_expert_used,
|
||||
const float clamp_val) {
|
||||
template <int n_experts, bool has_bias>
|
||||
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
|
||||
float * weights,
|
||||
int32_t * ids,
|
||||
float * bias,
|
||||
const int n_rows,
|
||||
const int n_expert_used,
|
||||
const float clamp_val,
|
||||
const float scale_val,
|
||||
const topk_moe_config config) {
|
||||
const int row = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
if (row >= n_rows) {
|
||||
return;
|
||||
@@ -79,14 +99,41 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||
|
||||
float wt[experts_per_thread];
|
||||
|
||||
// Initialize all slots to -INFINITY
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
wt[i] = -INFINITY;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_experts; i += WARP_SIZE) {
|
||||
const int expert = i + threadIdx.x;
|
||||
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
|
||||
}
|
||||
|
||||
if constexpr (!delayed_softmax) {
|
||||
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
|
||||
if (!config.delayed_softmax) {
|
||||
if (config.use_sigmoid) {
|
||||
sigmoid_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
|
||||
} else {
|
||||
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
|
||||
}
|
||||
}
|
||||
|
||||
// selection_wt is only needed when bias is present (selection uses wt + bias)
|
||||
// when no bias, we use wt directly for both selection and weight values
|
||||
float selection_wt[has_bias ? experts_per_thread : 1];
|
||||
|
||||
if constexpr (has_bias) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
selection_wt[i] = -INFINITY;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_experts; i += WARP_SIZE) {
|
||||
const int expert = i + threadIdx.x;
|
||||
selection_wt[i / WARP_SIZE] =
|
||||
(n_experts % WARP_SIZE == 0 || expert < n_experts) ? wt[i / WARP_SIZE] + bias[expert] : -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
//at this point, each thread holds either a portion of the softmax distribution
|
||||
@@ -106,22 +153,56 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||
float max_val = wt[0];
|
||||
int max_expert = threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 1; i < experts_per_thread; i++) {
|
||||
const int expert = threadIdx.x + i * WARP_SIZE;
|
||||
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
|
||||
max_val = wt[i];
|
||||
max_expert = expert;
|
||||
}
|
||||
}
|
||||
if constexpr (has_bias) {
|
||||
float max_val_s = selection_wt[0];
|
||||
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
|
||||
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
|
||||
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
|
||||
if (val > max_val || (val == max_val && expert < max_expert)) {
|
||||
max_val = val;
|
||||
max_expert = expert;
|
||||
for (int i = 1; i < experts_per_thread; i++) {
|
||||
const int expert = threadIdx.x + i * WARP_SIZE;
|
||||
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_wt[i] > max_val_s) {
|
||||
max_val = wt[i];
|
||||
max_val_s = selection_wt[i];
|
||||
max_expert = expert;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
|
||||
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
|
||||
const float val_s = __shfl_xor_sync(0xFFFFFFFF, max_val_s, mask, WARP_SIZE);
|
||||
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
|
||||
if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
|
||||
max_val = val;
|
||||
max_val_s = val_s;
|
||||
max_expert = expert;
|
||||
}
|
||||
}
|
||||
|
||||
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
|
||||
selection_wt[max_expert / WARP_SIZE] = -INFINITY;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 1; i < experts_per_thread; i++) {
|
||||
const int expert = threadIdx.x + i * WARP_SIZE;
|
||||
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
|
||||
max_val = wt[i];
|
||||
max_expert = expert;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
|
||||
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
|
||||
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
|
||||
if (val > max_val || (val == max_val && expert < max_expert)) {
|
||||
max_val = val;
|
||||
max_expert = expert;
|
||||
}
|
||||
}
|
||||
|
||||
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
|
||||
wt[max_expert / WARP_SIZE] = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,16 +211,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||
}
|
||||
|
||||
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
|
||||
wt[max_expert / WARP_SIZE] = -INFINITY;
|
||||
|
||||
ids[k] = max_expert;
|
||||
if constexpr (with_norm) {
|
||||
if (config.with_norm) {
|
||||
wt_sum += max_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (with_norm) {
|
||||
if (config.with_norm) {
|
||||
wt_sum = warp_reduce_sum(wt_sum);
|
||||
wt_sum = max(wt_sum, clamp_val);
|
||||
const float inv_sum = 1.0f / wt_sum;
|
||||
@@ -149,7 +228,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (delayed_softmax) {
|
||||
if (config.delayed_softmax) {
|
||||
softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
|
||||
}
|
||||
|
||||
@@ -157,25 +236,25 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const int idx = i * WARP_SIZE + threadIdx.x;
|
||||
if (idx < n_expert_used) {
|
||||
weights[idx] = output_weights[i];
|
||||
weights[idx] = output_weights[i] * scale_val;
|
||||
}
|
||||
}
|
||||
|
||||
if (!with_norm) {
|
||||
GGML_UNUSED(clamp_val);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool with_norm, bool delayed_softmax = false>
|
||||
template<bool has_bias>
|
||||
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
||||
const float * logits,
|
||||
float * weights,
|
||||
int32_t * ids,
|
||||
float * bias,
|
||||
const int n_rows,
|
||||
const int n_expert,
|
||||
const int n_expert_used,
|
||||
const float clamp_val) {
|
||||
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
|
||||
const float clamp_val,
|
||||
const float scale_val,
|
||||
const topk_moe_config config) {
|
||||
GGML_ASSERT(!(config.with_norm && config.delayed_softmax) &&
|
||||
"delayed softmax is not supported with weight normalization");
|
||||
const int rows_per_block = 4;
|
||||
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
|
||||
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
|
||||
@@ -183,44 +262,48 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
||||
|
||||
switch (n_expert) {
|
||||
case 1:
|
||||
topk_moe_cuda<1, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
topk_moe_cuda<1, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
case 2:
|
||||
topk_moe_cuda<2, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
topk_moe_cuda<2, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
case 4:
|
||||
topk_moe_cuda<4, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
topk_moe_cuda<4, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
case 8:
|
||||
topk_moe_cuda<8, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
topk_moe_cuda<8, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
case 16:
|
||||
topk_moe_cuda<16, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
topk_moe_cuda<16, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
case 32:
|
||||
topk_moe_cuda<32, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
topk_moe_cuda<32, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
case 64:
|
||||
topk_moe_cuda<64, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
topk_moe_cuda<64, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
case 128:
|
||||
topk_moe_cuda<128, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
topk_moe_cuda<128, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
case 256:
|
||||
topk_moe_cuda<256, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
topk_moe_cuda<256, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
case 512:
|
||||
topk_moe_cuda<512, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
topk_moe_cuda<512, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
case 576:
|
||||
topk_moe_cuda<576, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
|
||||
clamp_val, scale_val, config);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "fatal error");
|
||||
@@ -228,13 +311,14 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * logits,
|
||||
ggml_tensor * weights,
|
||||
ggml_tensor * ids,
|
||||
const bool with_norm,
|
||||
const bool delayed_softmax,
|
||||
ggml_tensor * clamp) {
|
||||
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * logits,
|
||||
ggml_tensor * weights,
|
||||
ggml_tensor * ids,
|
||||
const ggml_tensor * clamp,
|
||||
const ggml_tensor * scale,
|
||||
const ggml_tensor * bias,
|
||||
const ggml_cuda_topk_moe_args & args) {
|
||||
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
@@ -245,107 +329,75 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||
const float * logits_d = (const float *) logits->data;
|
||||
float * weights_d = (float *) weights->data;
|
||||
int32_t * ids_d = (int32_t *) ids->data;
|
||||
float * bias_d = bias ? (float *) bias->data : nullptr;
|
||||
|
||||
float scale_val = scale ? ggml_get_op_params_f32(scale, 0) : 1.0f;
|
||||
|
||||
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
|
||||
|
||||
const int n_expert_used = weights->ne[1];
|
||||
|
||||
const bool with_norm = clamp != nullptr;
|
||||
|
||||
float clamp_val = -INFINITY;
|
||||
if (with_norm) {
|
||||
if (clamp) {
|
||||
clamp_val = ggml_get_op_params_f32(clamp, 0);
|
||||
}
|
||||
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
|
||||
if (clamp) {
|
||||
clamp_val = ggml_get_op_params_f32(clamp, 0);
|
||||
}
|
||||
|
||||
topk_moe_config config;
|
||||
config.use_sigmoid = args.sigmoid;
|
||||
config.with_norm = with_norm;
|
||||
config.delayed_softmax = args.delayed_softmax;
|
||||
|
||||
if (bias) {
|
||||
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
|
||||
scale_val, config);
|
||||
} else {
|
||||
GGML_ASSERT(clamp == nullptr);
|
||||
if (delayed_softmax) {
|
||||
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
|
||||
clamp_val);
|
||||
} else {
|
||||
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
|
||||
clamp_val);
|
||||
}
|
||||
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
|
||||
scale_val, config);
|
||||
}
|
||||
}
|
||||
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
|
||||
const ggml_tensor * weights,
|
||||
const ggml_tensor * get_rows,
|
||||
const ggml_tensor * argsort,
|
||||
const ggml_tensor * clamp,
|
||||
int n_expert) {
|
||||
ggml_tensor * probs = get_rows->src[0];
|
||||
if (probs->op != GGML_OP_RESHAPE) {
|
||||
return false;
|
||||
}
|
||||
probs = probs->src[0];
|
||||
ggml_tensor * selection_probs = argsort->src[0];
|
||||
|
||||
if (probs != selection_probs) {
|
||||
const ggml_tensor * logits,
|
||||
const ggml_tensor * ids) {
|
||||
const int n_expert = ids->nb[1] / ids->nb[0];
|
||||
if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) {
|
||||
return false;
|
||||
}
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
|
||||
|
||||
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
|
||||
if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(logits)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (scale != 1.0f || max_bias != 0.0f) {
|
||||
return false;
|
||||
}
|
||||
if (gating_op->op == GGML_OP_SOFT_MAX) {
|
||||
const ggml_tensor * softmax = gating_op;
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
// don't fuse when masks or sinks are present
|
||||
if (softmax->src[1] || softmax->src[2]) {
|
||||
return false;
|
||||
}
|
||||
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
|
||||
|
||||
// n_expert must be a power of 2
|
||||
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (clamp) {
|
||||
if (clamp->op != GGML_OP_CLAMP) {
|
||||
if (!ggml_is_contiguous(softmax->src[0])) {
|
||||
return false;
|
||||
}
|
||||
float max_val = ggml_get_op_params_f32(clamp, 1);
|
||||
|
||||
if (max_val != INFINITY) {
|
||||
if (scale != 1.0f || max_bias != 0.0f) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// don't fuse when masks or sinks are present
|
||||
if (softmax->src[1] || softmax->src[2]) {
|
||||
return false;
|
||||
}
|
||||
} else if (gating_op->op == GGML_OP_UNARY) {
|
||||
ggml_unary_op op = ggml_get_unary_op(gating_op);
|
||||
|
||||
if (op != GGML_UNARY_OP_SIGMOID) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
|
||||
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
|
||||
GGML_OP_RESHAPE };
|
||||
|
||||
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
||||
|
||||
static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW,
|
||||
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
|
||||
|
||||
GGML_ASSERT(!norm || !delayed_softmax);
|
||||
|
||||
if (delayed_softmax) {
|
||||
return delayed_softmax_ops;
|
||||
}
|
||||
|
||||
if (norm) {
|
||||
return norm_ops;
|
||||
}
|
||||
|
||||
return no_norm_ops;
|
||||
}
|
||||
|
||||
@@ -3,19 +3,25 @@
|
||||
|
||||
#include <initializer_list>
|
||||
|
||||
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * logits,
|
||||
ggml_tensor * weights,
|
||||
ggml_tensor * ids,
|
||||
const bool with_norm,
|
||||
const bool delayed_softmax = false,
|
||||
ggml_tensor * weight_clamp = nullptr);
|
||||
struct ggml_cuda_topk_moe_args {
|
||||
bool sigmoid{};
|
||||
bool softmax{};
|
||||
bool delayed_softmax{};
|
||||
bool prob_bias{};
|
||||
bool norm{};
|
||||
bool scale{};
|
||||
};
|
||||
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
|
||||
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * logits,
|
||||
ggml_tensor * weights,
|
||||
ggml_tensor * ids,
|
||||
const ggml_tensor * clamp,
|
||||
const ggml_tensor * scale,
|
||||
const ggml_tensor * bias,
|
||||
const ggml_cuda_topk_moe_args & args);
|
||||
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
|
||||
const ggml_tensor * weights,
|
||||
const ggml_tensor * get_rows,
|
||||
const ggml_tensor * argsort,
|
||||
const ggml_tensor * clamp,
|
||||
int n_expert);
|
||||
|
||||
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
|
||||
const ggml_tensor * logits,
|
||||
const ggml_tensor * ids);
|
||||
|
||||
@@ -1,7 +1,17 @@
|
||||
file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT)
|
||||
file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
|
||||
|
||||
if (NOT IS_DIRECTORY "${HEXAGON_SDK_ROOT}" OR NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}")
|
||||
message(FATAL_ERROR "Make sure HEXAGON_SDK_ROOT and HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.")
|
||||
endif()
|
||||
|
||||
message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for building libggml-htp skels")
|
||||
|
||||
include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
|
||||
include(ExternalProject)
|
||||
|
||||
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
|
||||
set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
|
||||
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)")
|
||||
|
||||
add_library(htp_iface OBJECT
|
||||
@@ -25,56 +35,71 @@ else()
|
||||
target_link_options(htp_iface PUBLIC -ldl)
|
||||
endif()
|
||||
|
||||
link_custom_library(htp_iface cdsprpc)
|
||||
link_custom_library(htp_iface rpcmem)
|
||||
|
||||
set(TARGET_NAME ggml-hexagon)
|
||||
ggml_add_backend_library(${TARGET_NAME}
|
||||
ggml-hexagon.cpp htp-utils.c htp-utils.h ../../include/ggml-hexagon.h)
|
||||
ggml-hexagon.cpp
|
||||
htp-drv.cpp
|
||||
htp-drv.h
|
||||
libdl.h
|
||||
../../include/ggml-hexagon.h)
|
||||
|
||||
target_link_libraries(${TARGET_NAME} PRIVATE htp_iface)
|
||||
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/htp ${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
# Build HTP bits
|
||||
set(HTP_CMAKE_ARGS
|
||||
-DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR}
|
||||
-DHEXAGON_SDK_ROOT=$ENV{HEXAGON_SDK_ROOT}
|
||||
-DHEXAGON_TOOLS_ROOT=$ENV{HEXAGON_TOOLS_ROOT}
|
||||
-DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}
|
||||
-DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
|
||||
# Build HTP skels
|
||||
set(HTP_SKELS)
|
||||
function(build_htp_skel V)
|
||||
ExternalProject_Add(htp-${V}
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
BUILD_BYPRODUCTS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so
|
||||
CMAKE_ARGS
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake
|
||||
-DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR}
|
||||
-DHEXAGON_SDK_ROOT=${HEXAGON_SDK_ROOT}
|
||||
-DHEXAGON_TOOLS_ROOT=${HEXAGON_TOOLS_ROOT}
|
||||
-DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}
|
||||
-DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}
|
||||
-DDSP_VERSION=${V}
|
||||
-DPREBUILT_LIB_DIR="toolv19_${V}")
|
||||
list(APPEND HTP_SKELS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so)
|
||||
set(HTP_SKELS ${HTP_SKELS} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
ExternalProject_Add(htp-v68
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v68 -DPREBUILT_LIB_DIR="toolv19_v68")
|
||||
|
||||
ExternalProject_Add(htp-v69
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v69 -DPREBUILT_LIB_DIR="toolv19_v69")
|
||||
|
||||
ExternalProject_Add(htp-v73
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v73 -DPREBUILT_LIB_DIR="toolv19_v73")
|
||||
|
||||
ExternalProject_Add(htp-v75
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v75 -DPREBUILT_LIB_DIR="toolv19_v75")
|
||||
|
||||
ExternalProject_Add(htp-v79
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v79 -DPREBUILT_LIB_DIR="toolv19_v79")
|
||||
|
||||
ExternalProject_Add(htp-v81
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v81 -DPREBUILT_LIB_DIR="toolv19_v81")
|
||||
build_htp_skel(v68)
|
||||
build_htp_skel(v69)
|
||||
build_htp_skel(v73)
|
||||
build_htp_skel(v75)
|
||||
build_htp_skel(v79)
|
||||
build_htp_skel(v81)
|
||||
|
||||
# Install Hexagon skels required at runtime
|
||||
install(FILES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v68.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v69.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v73.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v75.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v79.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v81.so
|
||||
TYPE LIB)
|
||||
install(FILES ${HTP_SKELS} TYPE LIB)
|
||||
|
||||
if (CMAKE_SYSTEM_NAME MATCHES Windows AND GGML_HEXAGON_HTP_CERT)
|
||||
file(TO_CMAKE_PATH "$ENV{WINDOWS_SDK_BIN}/arm64" WINSDK_BIN0_ARM64)
|
||||
file(TO_CMAKE_PATH "$ENV{WINDOWS_SDK_BIN}/x86" WINSDK_BIN0_X86)
|
||||
file(TO_CMAKE_PATH "$ENV{WindowsSdkVerBinPath}/arm64" WINSDK_BIN1_ARM64)
|
||||
file(TO_CMAKE_PATH "$ENV{WindowsSdkVerBinPath}/x86" WINSDK_BIN1_X86)
|
||||
|
||||
set(WINSDK_PATHS ${WINSDK_BIN0_ARM64} ${WINSDK_BIN0_X86} ${WINSDK_BIN1_ARM64} ${WINSDK_BIN1_X86})
|
||||
|
||||
find_program(INF2CAT NAMES inf2cat.exe PATHS ${WINSDK_PATHS} REQUIRED)
|
||||
find_program(SIGNTOOL NAMES signtool.exe PATHS ${WINSDK_PATHS} REQUIRED)
|
||||
|
||||
message(STATUS "hexagon: using ${GGML_HEXAGON_HTP_CERT} to sign libggml-htp skels")
|
||||
|
||||
set(LIBGGML_HTP_CAT ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp.cat)
|
||||
add_custom_target(libggml-htp-cat
|
||||
BYPRODUCTS ${LIBGGML_HTP_CAT}
|
||||
DEPENDS libggml-htp.inf ${HTP_SKELS}
|
||||
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/libggml-htp.inf ${CMAKE_CURRENT_BINARY_DIR}
|
||||
COMMAND ${INF2CAT} /driver:${CMAKE_CURRENT_BINARY_DIR} /os:10_25H2_ARM64
|
||||
COMMAND ${SIGNTOOL} sign /fd sha256 /f ${GGML_HEXAGON_HTP_CERT} ${LIBGGML_HTP_CAT}
|
||||
COMMENT "generating and signing libggml-htp.cat file"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
add_dependencies(${TARGET_NAME} libggml-htp-cat)
|
||||
install(FILES ${LIBGGML_HTP_CAT} TYPE LIB)
|
||||
endif()
|
||||
|
||||
@@ -14,9 +14,6 @@
|
||||
|
||||
#ifdef _WIN32
|
||||
# include <sal.h>
|
||||
# ifndef _WINDOWS
|
||||
# define _WINDOWS
|
||||
# endif
|
||||
#else
|
||||
# include <semaphore.h>
|
||||
# include <unistd.h>
|
||||
@@ -25,8 +22,6 @@
|
||||
#pragma clang diagnostic ignored "-Wnested-anon-types"
|
||||
#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
|
||||
|
||||
#include "htp-utils.h"
|
||||
|
||||
#include <AEEStdErr.h>
|
||||
#include <dspqueue.h>
|
||||
#include <rpcmem.h>
|
||||
@@ -40,6 +35,7 @@
|
||||
#include "op-desc.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp_iface.h"
|
||||
#include "htp-drv.h"
|
||||
|
||||
static size_t opt_ndev = 1;
|
||||
static size_t opt_nhvx = 0; // use all
|
||||
@@ -150,9 +146,9 @@ void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_
|
||||
0, // flags - the framework will autoset this
|
||||
n_bufs, // number of buffers
|
||||
bufs, // buffer references
|
||||
sizeof(req),
|
||||
sizeof(req), // Message length
|
||||
(const uint8_t *) &req, // Message
|
||||
1000000 // Timeout
|
||||
DSPQUEUE_TIMEOUT // Timeout
|
||||
);
|
||||
|
||||
if (err != 0) {
|
||||
@@ -182,13 +178,13 @@ void ggml_hexagon_session::flush() {
|
||||
|
||||
// Read response packet from queue
|
||||
int err = dspqueue_read(q, &flags,
|
||||
HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references
|
||||
&n_bufs, // Number of buffer references
|
||||
bufs, // Buffer references
|
||||
sizeof(rsp), // Max message length
|
||||
&rsp_size, // Message length
|
||||
(uint8_t *) &rsp,
|
||||
1000000); // Timeout
|
||||
HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references
|
||||
&n_bufs, // Number of buffer references
|
||||
bufs, // Buffer references
|
||||
sizeof(rsp), // Max message length
|
||||
&rsp_size, // Message length
|
||||
(uint8_t *) &rsp, // Message
|
||||
DSPQUEUE_TIMEOUT); // Timeout
|
||||
|
||||
if (err == AEE_EEXPIRED) {
|
||||
// TODO: might need to bail out if the HTP is stuck on something
|
||||
@@ -269,13 +265,7 @@ struct ggml_backend_hexagon_buffer_context {
|
||||
ggml_backend_hexagon_buffer_context(ggml_hexagon_session * sess, size_t size, bool repack) {
|
||||
size += 4 * 1024; // extra page for padding
|
||||
|
||||
if (rpcmem_alloc2) {
|
||||
this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
|
||||
} else {
|
||||
GGML_LOG_INFO("ggml-hex: %s rpcmem_alloc2 not found, falling back to rpcmem_alloc\n", sess->name.c_str());
|
||||
this->base = (uint8_t *) rpcmem_alloc(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
|
||||
}
|
||||
|
||||
this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
|
||||
if (!this->base) {
|
||||
GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->name.c_str(), size);
|
||||
throw std::runtime_error("ggml-hex: rpcmem_alloc failed (see log for details)");
|
||||
@@ -2461,12 +2451,12 @@ static void ggml_backend_hexagon_free(ggml_backend_t backend) {
|
||||
}
|
||||
|
||||
static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) {
|
||||
return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type) && ggml_is_quantized(op1->src[1]->type));
|
||||
return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type));
|
||||
}
|
||||
|
||||
static inline bool is_compute_op(ggml_tensor *node)
|
||||
{
|
||||
return !(ggml_op_is_empty(node->op) || ggml_is_empty(node));
|
||||
return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE);
|
||||
}
|
||||
|
||||
// scan the graph and figure out last compute op index
|
||||
@@ -2488,7 +2478,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
|
||||
const int last = last_compute_op(graph);
|
||||
|
||||
const struct ggml_tensor * prev_quant_op = nullptr; // prev executed op with quantizer
|
||||
const struct ggml_tensor * prev_op = nullptr; // prev executed op
|
||||
|
||||
for (int i = 0; i < graph->n_nodes; ++i) {
|
||||
ggml_tensor * node = graph->nodes[i];
|
||||
@@ -2497,17 +2487,15 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
continue;
|
||||
}
|
||||
|
||||
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t flags = 0;
|
||||
|
||||
// skip quantizer if src1 is reused
|
||||
if (op_reuse_src1(node, prev_quant_op)) {
|
||||
if (op_reuse_src1(node, prev_op)) {
|
||||
flags |= HTP_OPFLAGS_SKIP_QUANTIZE;
|
||||
}
|
||||
|
||||
prev_op = node;
|
||||
|
||||
// ask for early notification for the last Op
|
||||
if (i == last) {
|
||||
flags |= HTP_OPFLAGS_EARLY_WAKEUP;
|
||||
@@ -2520,7 +2508,6 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
} else {
|
||||
ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);
|
||||
}
|
||||
prev_quant_op = node;
|
||||
break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
if (ggml_is_quantized(node->src[0]->type)) {
|
||||
@@ -2528,7 +2515,6 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
} else {
|
||||
ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags);
|
||||
}
|
||||
prev_quant_op = node;
|
||||
break;
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_ADD:
|
||||
@@ -2670,7 +2656,7 @@ static std::vector<int> ggml_hexagon_graph_optimize_reorder(const std::vector<no
|
||||
}
|
||||
|
||||
// that many nodes forward to search for stackable nodes that can reuse VTCM
|
||||
constexpr int N_FORWARD = 8;
|
||||
constexpr int N_FORWARD = 16;
|
||||
|
||||
for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
|
||||
if (used[i1]) {
|
||||
@@ -3056,10 +3042,12 @@ ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) {
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
if (opt_arch < 75) {
|
||||
opt_ndev = 1;
|
||||
GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
GGML_LOG_INFO("ggml-hex: Hexagon Arch version v%d\n", opt_arch);
|
||||
|
||||
@@ -3156,6 +3144,8 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
||||
opt_arch = strtoul(str_arch, NULL, 0);
|
||||
}
|
||||
|
||||
opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : 1;
|
||||
|
||||
reg->context = new ggml_hexagon_registry(reg);
|
||||
|
||||
HEX_VERBOSE("ggml-hex: size-of-general-req %zu size-of-general-rsp %zu\n", sizeof(struct htp_general_req),
|
||||
@@ -3180,6 +3170,11 @@ ggml_backend_reg_t ggml_backend_hexagon_reg(void) {
|
||||
static std::mutex mutex;
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
if (!initialized) {
|
||||
auto nErr = htpdrv_init();
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
ggml_hexagon_init(®);
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,418 @@
|
||||
// sample drv interface
|
||||
|
||||
#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
|
||||
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||
#pragma clang diagnostic ignored "-Wsign-compare"
|
||||
|
||||
#include <filesystem>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#ifdef _WIN32
|
||||
# define WIN32_LEAN_AND_MEAN
|
||||
# ifndef NOMINMAX
|
||||
# define NOMINMAX
|
||||
# endif
|
||||
# include <windows.h>
|
||||
# include <winevt.h>
|
||||
#else
|
||||
# include <dlfcn.h>
|
||||
# include <unistd.h>
|
||||
#endif
|
||||
#include "ggml-impl.h"
|
||||
#include "htp-drv.h"
|
||||
#include "libdl.h"
|
||||
|
||||
#include <domain.h>
|
||||
|
||||
//
|
||||
// Driver API types
|
||||
//
|
||||
|
||||
typedef void * (*rpcmem_alloc_pfn_t)(int heapid, uint32_t flags, int size);
|
||||
typedef void * (*rpcmem_alloc2_pfn_t)(int heapid, uint32_t flags, size_t size);
|
||||
typedef void (*rpcmem_free_pfn_t)(void * po);
|
||||
typedef int (*rpcmem_to_fd_pfn_t)(void * po);
|
||||
|
||||
typedef AEEResult (*dspqueue_create_pfn_t)(int domain,
|
||||
uint32_t flags,
|
||||
uint32_t req_queue_size,
|
||||
uint32_t resp_queue_size,
|
||||
dspqueue_callback_t packet_callback,
|
||||
dspqueue_callback_t error_callback,
|
||||
void * callback_context,
|
||||
dspqueue_t * queue);
|
||||
typedef AEEResult (*dspqueue_close_pfn_t)(dspqueue_t queue);
|
||||
typedef AEEResult (*dspqueue_export_pfn_t)(dspqueue_t queue, uint64_t *queue_id);
|
||||
typedef AEEResult (*dspqueue_write_pfn_t)(dspqueue_t queue, uint32_t flags,
|
||||
uint32_t num_buffers,
|
||||
struct dspqueue_buffer *buffers,
|
||||
uint32_t message_length,
|
||||
const uint8_t *message,
|
||||
uint32_t timeout_us);
|
||||
typedef AEEResult (*dspqueue_read_pfn_t)(dspqueue_t queue, uint32_t *flags,
|
||||
uint32_t max_buffers, uint32_t *num_buffers,
|
||||
struct dspqueue_buffer *buffers,
|
||||
uint32_t max_message_length,
|
||||
uint32_t *message_length, uint8_t *message,
|
||||
uint32_t timeout_us);
|
||||
|
||||
typedef int (*fastrpc_mmap_pfn_t)(int domain, int fd, void *addr, int offset, size_t length, enum fastrpc_map_flags flags);
|
||||
typedef int (*fastrpc_munmap_pfn_t)(int domain, int fd, void *addr, size_t length);
|
||||
|
||||
typedef int (*remote_handle64_open_pfn_t)(const char* name, remote_handle64 *ph);
|
||||
typedef int (*remote_handle64_invoke_pfn_t)(remote_handle64 h, uint32_t dwScalars, remote_arg *pra);
|
||||
typedef int (*remote_handle64_close_pfn_t)(remote_handle h);
|
||||
typedef int (*remote_handle_control_pfn_t)(uint32_t req, void* data, uint32_t datalen);
|
||||
typedef int (*remote_handle64_control_pfn_t)(remote_handle64 h, uint32_t req, void* data, uint32_t datalen);
|
||||
typedef int (*remote_session_control_pfn_t)(uint32_t req, void *data, uint32_t datalen);
|
||||
|
||||
//
|
||||
// Driver API pfns
|
||||
//
|
||||
|
||||
rpcmem_alloc_pfn_t rpcmem_alloc_pfn = nullptr;
|
||||
rpcmem_alloc2_pfn_t rpcmem_alloc2_pfn = nullptr;
|
||||
rpcmem_free_pfn_t rpcmem_free_pfn = nullptr;
|
||||
rpcmem_to_fd_pfn_t rpcmem_to_fd_pfn = nullptr;
|
||||
|
||||
fastrpc_mmap_pfn_t fastrpc_mmap_pfn = nullptr;
|
||||
fastrpc_munmap_pfn_t fastrpc_munmap_pfn = nullptr;
|
||||
|
||||
dspqueue_create_pfn_t dspqueue_create_pfn = nullptr;
|
||||
dspqueue_close_pfn_t dspqueue_close_pfn = nullptr;
|
||||
dspqueue_export_pfn_t dspqueue_export_pfn = nullptr;
|
||||
dspqueue_write_pfn_t dspqueue_write_pfn = nullptr;
|
||||
dspqueue_read_pfn_t dspqueue_read_pfn = nullptr;
|
||||
|
||||
remote_handle64_open_pfn_t remote_handle64_open_pfn = nullptr;
|
||||
remote_handle64_invoke_pfn_t remote_handle64_invoke_pfn = nullptr;
|
||||
remote_handle64_close_pfn_t remote_handle64_close_pfn = nullptr;
|
||||
remote_handle_control_pfn_t remote_handle_control_pfn = nullptr;
|
||||
remote_handle64_control_pfn_t remote_handle64_control_pfn = nullptr;
|
||||
remote_session_control_pfn_t remote_session_control_pfn = nullptr;
|
||||
|
||||
//
|
||||
// Driver API
|
||||
//
|
||||
|
||||
void * rpcmem_alloc(int heapid, uint32_t flags, int size) {
|
||||
return rpcmem_alloc_pfn(heapid, flags, size);
|
||||
}
|
||||
|
||||
void * rpcmem_alloc2(int heapid, uint32_t flags, size_t size) {
|
||||
if (rpcmem_alloc2_pfn) {
|
||||
return rpcmem_alloc2_pfn(heapid, flags, size);
|
||||
} else {
|
||||
GGML_LOG_INFO("ggml-hex: rpcmem_alloc2 not found, falling back to rpcmem_alloc\n");
|
||||
return rpcmem_alloc_pfn(heapid, flags, size);
|
||||
}
|
||||
}
|
||||
|
||||
void rpcmem_free(void * po) {
|
||||
return rpcmem_free_pfn(po);
|
||||
}
|
||||
|
||||
int rpcmem_to_fd(void * po) {
|
||||
return rpcmem_to_fd_pfn(po);
|
||||
}
|
||||
|
||||
HTPDRV_API int fastrpc_mmap(int domain, int fd, void * addr, int offset, size_t length, enum fastrpc_map_flags flags) {
|
||||
return fastrpc_mmap_pfn(domain, fd, addr, offset, length, flags);
|
||||
}
|
||||
|
||||
HTPDRV_API int fastrpc_munmap(int domain, int fd, void * addr, size_t length) {
|
||||
return fastrpc_munmap_pfn(domain, fd, addr, length);
|
||||
}
|
||||
|
||||
AEEResult dspqueue_create(int domain,
|
||||
uint32_t flags,
|
||||
uint32_t req_queue_size,
|
||||
uint32_t resp_queue_size,
|
||||
dspqueue_callback_t packet_callback,
|
||||
dspqueue_callback_t error_callback,
|
||||
void * callback_context,
|
||||
dspqueue_t * queue) {
|
||||
return dspqueue_create_pfn(domain, flags, req_queue_size, resp_queue_size, packet_callback, error_callback,
|
||||
callback_context, queue);
|
||||
}
|
||||
|
||||
AEEResult dspqueue_close(dspqueue_t queue) {
|
||||
return dspqueue_close_pfn(queue);
|
||||
}
|
||||
|
||||
AEEResult dspqueue_export(dspqueue_t queue, uint64_t * queue_id) {
|
||||
return dspqueue_export_pfn(queue, queue_id);
|
||||
}
|
||||
|
||||
AEEResult dspqueue_write(dspqueue_t queue,
|
||||
uint32_t flags,
|
||||
uint32_t num_buffers,
|
||||
struct dspqueue_buffer * buffers,
|
||||
uint32_t message_length,
|
||||
const uint8_t * message,
|
||||
uint32_t timeout_us) {
|
||||
return dspqueue_write_pfn(queue, flags, num_buffers, buffers, message_length, message, timeout_us);
|
||||
}
|
||||
|
||||
AEEResult dspqueue_read(dspqueue_t queue,
|
||||
uint32_t * flags,
|
||||
uint32_t max_buffers,
|
||||
uint32_t * num_buffers,
|
||||
struct dspqueue_buffer * buffers,
|
||||
uint32_t max_message_length,
|
||||
uint32_t * message_length,
|
||||
uint8_t * message,
|
||||
uint32_t timeout_us) {
|
||||
return dspqueue_read_pfn(queue, flags, max_buffers, num_buffers, buffers, max_message_length, message_length,
|
||||
message, timeout_us);
|
||||
}
|
||||
|
||||
HTPDRV_API int remote_handle64_open(const char * name, remote_handle64 * ph) {
|
||||
return remote_handle64_open_pfn(name, ph);
|
||||
}
|
||||
|
||||
HTPDRV_API int remote_handle64_invoke(remote_handle64 h, uint32_t dwScalars, remote_arg * pra) {
|
||||
return remote_handle64_invoke_pfn(h, dwScalars, pra);
|
||||
}
|
||||
|
||||
HTPDRV_API int remote_handle64_close(remote_handle64 h) {
|
||||
return remote_handle64_close_pfn(h);
|
||||
}
|
||||
|
||||
HTPDRV_API int remote_handle_control(uint32_t req, void * data, uint32_t datalen) {
|
||||
return remote_handle_control_pfn(req, data, datalen);
|
||||
}
|
||||
|
||||
HTPDRV_API int remote_handle64_control(remote_handle64 h, uint32_t req, void * data, uint32_t datalen) {
|
||||
return remote_handle64_control_pfn(h, req, data, datalen);
|
||||
}
|
||||
|
||||
HTPDRV_API int remote_session_control(uint32_t req, void * data, uint32_t datalen) {
|
||||
return remote_session_control_pfn(req, data, datalen);
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
static std::string wstr_to_str(std::wstring_view wstr) {
|
||||
std::string result;
|
||||
if (wstr.empty()) {
|
||||
return result;
|
||||
}
|
||||
auto bytes_needed = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS,
|
||||
wstr.data(), (int) wstr.size(),
|
||||
nullptr, 0, nullptr, nullptr);
|
||||
if (bytes_needed == 0) {
|
||||
GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError());
|
||||
throw std::runtime_error("Invalid wstring input");
|
||||
}
|
||||
|
||||
result.resize(bytes_needed, '\0');
|
||||
int bytes_written = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS,
|
||||
wstr.data(), (int) wstr.size(),
|
||||
result.data(), bytes_needed,
|
||||
nullptr, nullptr);
|
||||
if (bytes_written == 0) {
|
||||
GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError());
|
||||
throw std::runtime_error("Wstring conversion failed");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::string get_driver_path() {
|
||||
std::wstring serviceName = L"qcnspmcdm";
|
||||
std::string result;
|
||||
|
||||
// Get a handle to the SCM database.
|
||||
SC_HANDLE schSCManager = OpenSCManagerW(NULL, NULL, STANDARD_RIGHTS_READ);
|
||||
if (nullptr == schSCManager) {
|
||||
GGML_LOG_ERROR("ggml-hex: Failed to open SCManager. Error: %lu\n", GetLastError());
|
||||
return result;
|
||||
}
|
||||
|
||||
// Get a handle to the service.
|
||||
SC_HANDLE schService = OpenServiceW(schSCManager, // SCM database
|
||||
serviceName.c_str(), // name of service
|
||||
SERVICE_QUERY_CONFIG); // need query config access
|
||||
|
||||
if (nullptr == schService) {
|
||||
GGML_LOG_ERROR("ggml-hex: Failed to open qcnspmcdm service. Error: %lu\n", GetLastError());
|
||||
CloseServiceHandle(schSCManager);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Store the size of buffer used as an output.
|
||||
DWORD bufferSize;
|
||||
if (!QueryServiceConfigW(schService, NULL, 0, &bufferSize) &&
|
||||
(GetLastError() != ERROR_INSUFFICIENT_BUFFER)) {
|
||||
GGML_LOG_ERROR("ggml-hex: Failed to query service config. Error: %lu\n", GetLastError());
|
||||
CloseServiceHandle(schService);
|
||||
CloseServiceHandle(schSCManager);
|
||||
return result;
|
||||
}
|
||||
// Get the configuration of the service.
|
||||
LPQUERY_SERVICE_CONFIGW serviceConfig =
|
||||
static_cast<LPQUERY_SERVICE_CONFIGW>(LocalAlloc(LMEM_FIXED, bufferSize));
|
||||
if (!QueryServiceConfigW(schService, serviceConfig, bufferSize, &bufferSize)) {
|
||||
fprintf(stderr, "ggml-hex: Failed to query service config. Error: %lu\n", GetLastError());
|
||||
LocalFree(serviceConfig);
|
||||
CloseServiceHandle(schService);
|
||||
CloseServiceHandle(schSCManager);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Read the driver file path get its parent directory
|
||||
std::wstring driverPath = std::wstring(serviceConfig->lpBinaryPathName);
|
||||
driverPath = driverPath.substr(0, driverPath.find_last_of(L"\\"));
|
||||
|
||||
// Clean up resources
|
||||
LocalFree(serviceConfig);
|
||||
CloseServiceHandle(schService);
|
||||
CloseServiceHandle(schSCManager);
|
||||
|
||||
// Driver path would contain invalid path string, like:
|
||||
// \SystemRoot\System32\DriverStore\FileRepository\qcadsprpc8280.inf_arm64_c2b9460c9a072f37
|
||||
// "\SystemRoot" should be replace with a correct one (e.g. C:\Windows)
|
||||
const std::wstring systemRootPlaceholder = L"\\SystemRoot";
|
||||
if (0 != driverPath.compare(0, systemRootPlaceholder.length(), systemRootPlaceholder)) {
|
||||
GGML_LOG_ERROR("ggml-hex: String pattern not found in driver path.\n");
|
||||
return result;
|
||||
}
|
||||
|
||||
// Replace \SystemRoot with an absolute path from system ENV windir
|
||||
const std::wstring systemRootEnv = L"windir";
|
||||
|
||||
// Query the number of wide charactors this variable requires
|
||||
DWORD numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), NULL, 0);
|
||||
if (numWords == 0) {
|
||||
GGML_LOG_ERROR("ggml-hex: Failed get systemRoot environment variable\n");
|
||||
return result;
|
||||
}
|
||||
|
||||
// Query the actual system root name from environment variable
|
||||
std::vector<wchar_t> systemRoot(numWords + 1);
|
||||
numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), systemRoot.data(), numWords + 1);
|
||||
if (numWords == 0) {
|
||||
GGML_LOG_ERROR("ggml-hex: Failed to read windir environment variable\n");
|
||||
return result;
|
||||
}
|
||||
driverPath.replace(0, systemRootPlaceholder.length(), std::wstring(systemRoot.data()));
|
||||
|
||||
return wstr_to_str(driverPath);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
|
||||
|
||||
int htpdrv_init() {
|
||||
static dl_handle_ptr lib_cdsp_rpc_handle = nullptr;
|
||||
static bool initialized = false;
|
||||
#ifdef _WIN32
|
||||
std::string drv_path = get_driver_path() + "\\" + "libcdsprpc.dll";
|
||||
#else
|
||||
std::string drv_path = "libcdsprpc.so";
|
||||
#endif
|
||||
if (initialized) {
|
||||
GGML_LOG_INFO("ggml-hex: Driver already loaded\n");
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
GGML_LOG_INFO("ggml-hex: Loading driver %s\n", drv_path.c_str());
|
||||
|
||||
fs::path path{ drv_path.c_str() };
|
||||
dl_handle_ptr handle { dl_load_library(path) };
|
||||
if (!handle) {
|
||||
GGML_LOG_ERROR("ggml-hex: failed to load %s: %s\n", path.u8string().c_str(), dl_error());
|
||||
return AEE_EUNABLETOLOAD;
|
||||
}
|
||||
|
||||
#define dlsym(drv, type, pfn, symbol, ignore) \
|
||||
do { \
|
||||
pfn = (type) dl_get_sym(drv, #symbol); \
|
||||
if (!ignore && nullptr == pfn) { \
|
||||
GGML_LOG_ERROR("ggml-hex: failed to dlsym %s\n", #symbol); \
|
||||
return AEE_EUNABLETOLOAD; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
dlsym(handle.get(), rpcmem_alloc_pfn_t, rpcmem_alloc_pfn, rpcmem_alloc, false);
|
||||
dlsym(handle.get(), rpcmem_alloc2_pfn_t, rpcmem_alloc2_pfn, rpcmem_alloc2, true);
|
||||
dlsym(handle.get(), rpcmem_free_pfn_t, rpcmem_free_pfn, rpcmem_free, false);
|
||||
dlsym(handle.get(), rpcmem_to_fd_pfn_t, rpcmem_to_fd_pfn, rpcmem_to_fd, false);
|
||||
dlsym(handle.get(), fastrpc_mmap_pfn_t, fastrpc_mmap_pfn, fastrpc_mmap, false);
|
||||
dlsym(handle.get(), fastrpc_munmap_pfn_t, fastrpc_munmap_pfn, fastrpc_munmap, false);
|
||||
dlsym(handle.get(), dspqueue_create_pfn_t, dspqueue_create_pfn, dspqueue_create, false);
|
||||
dlsym(handle.get(), dspqueue_close_pfn_t, dspqueue_close_pfn, dspqueue_close, false);
|
||||
dlsym(handle.get(), dspqueue_export_pfn_t, dspqueue_export_pfn, dspqueue_export, false);
|
||||
dlsym(handle.get(), dspqueue_write_pfn_t, dspqueue_write_pfn, dspqueue_write, false);
|
||||
dlsym(handle.get(), dspqueue_read_pfn_t, dspqueue_read_pfn, dspqueue_read, false);
|
||||
dlsym(handle.get(), remote_handle64_open_pfn_t, remote_handle64_open_pfn, remote_handle64_open, false);
|
||||
dlsym(handle.get(), remote_handle64_invoke_pfn_t, remote_handle64_invoke_pfn, remote_handle64_invoke, false);
|
||||
dlsym(handle.get(), remote_handle_control_pfn_t, remote_handle_control_pfn, remote_handle_control, false);
|
||||
dlsym(handle.get(), remote_handle64_control_pfn_t, remote_handle64_control_pfn, remote_handle64_control, false);
|
||||
dlsym(handle.get(), remote_session_control_pfn_t, remote_session_control_pfn, remote_session_control, false);
|
||||
dlsym(handle.get(), remote_handle64_close_pfn_t, remote_handle64_close_pfn, remote_handle64_close, false);
|
||||
|
||||
lib_cdsp_rpc_handle = std::move(handle);
|
||||
initialized = true;
|
||||
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
domain * get_domain(int domain_id) {
|
||||
int i = 0;
|
||||
int size = sizeof(supported_domains) / sizeof(domain);
|
||||
|
||||
for (i = 0; i < size; i++) {
|
||||
if (supported_domains[i].id == domain_id) {
|
||||
return &supported_domains[i];
|
||||
}
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
int get_hex_arch_ver(int domain, int * arch) {
|
||||
if (!remote_handle_control_pfn) {
|
||||
GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n");
|
||||
return AEE_EUNSUPPORTEDAPI;
|
||||
}
|
||||
|
||||
struct remote_dsp_capability arch_ver;
|
||||
arch_ver.domain = (uint32_t) domain;
|
||||
arch_ver.attribute_ID = ARCH_VER;
|
||||
arch_ver.capability = (uint32_t) 0;
|
||||
|
||||
int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver));
|
||||
if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) {
|
||||
GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n");
|
||||
return AEE_EUNSUPPORTEDAPI;
|
||||
}
|
||||
|
||||
if (err != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err);
|
||||
return err;
|
||||
}
|
||||
|
||||
switch (arch_ver.capability & 0xff) {
|
||||
case 0x68:
|
||||
*arch = 68;
|
||||
return 0;
|
||||
case 0x69:
|
||||
*arch = 69;
|
||||
return 0;
|
||||
case 0x73:
|
||||
*arch = 73;
|
||||
return 0;
|
||||
case 0x75:
|
||||
*arch = 75;
|
||||
return 0;
|
||||
case 0x79:
|
||||
*arch = 79;
|
||||
return 0;
|
||||
case 0x81:
|
||||
*arch = 81;
|
||||
return 0;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
# pragma clang diagnostic ignored "-Wignored-attributes"
|
||||
#endif
|
||||
|
||||
#include <AEEStdErr.h>
|
||||
#include <rpcmem.h>
|
||||
#include <remote.h>
|
||||
#include <dspqueue.h>
|
||||
|
||||
#if defined(_WIN32) && !defined(__MINGW32__)
|
||||
# ifdef GGML_BACKEND_BUILD
|
||||
# define HTPDRV_API __declspec(dllexport) extern
|
||||
# else
|
||||
# define HTPDRV_API __declspec(dllimport) extern
|
||||
# endif
|
||||
#else
|
||||
# define HTPDRV_API __attribute__ ((visibility ("default"))) extern
|
||||
#endif
|
||||
|
||||
/* Offset to differentiate HLOS and Hexagon error codes.
|
||||
Stores the value of AEE_EOFFSET for Hexagon. */
|
||||
#ifndef DSP_OFFSET
|
||||
# define DSP_OFFSET 0x80000400
|
||||
#endif
|
||||
|
||||
/* Errno for connection reset by peer. */
|
||||
#ifndef ECONNRESET
|
||||
# ifdef __hexagon__
|
||||
# define ECONNRESET 104
|
||||
# endif
|
||||
#endif
|
||||
|
||||
/* Abstraction of different OS specific sleep APIs.
|
||||
SLEEP accepts input in seconds. */
|
||||
#ifndef SLEEP
|
||||
# ifdef __hexagon__
|
||||
# define SLEEP(x) \
|
||||
{ /* Do nothing for simulator. */ \
|
||||
}
|
||||
# else
|
||||
# ifdef _WIN32
|
||||
# define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */
|
||||
# else
|
||||
# define SLEEP(x) sleep(x) /* sleep accepts input in seconds. */
|
||||
# endif
|
||||
# endif
|
||||
#endif
|
||||
|
||||
/* Include windows specific header files. */
|
||||
#ifdef _WIN32
|
||||
# include <windows.h>
|
||||
# include <sysinfoapi.h>
|
||||
# define _CRT_SECURE_NO_WARNINGS 1
|
||||
# define _WINSOCK_DEPRECATED_NO_WARNINGS 1
|
||||
#endif
|
||||
|
||||
/* Includes and defines for all HLOS except windows */
|
||||
#if !defined(__hexagon__) && !defined(_WIN32)
|
||||
# include "unistd.h"
|
||||
|
||||
# include <sys/time.h>
|
||||
#endif
|
||||
|
||||
/* Includes and defines for Hexagon and all HLOS except Windows. */
|
||||
#if !defined(_WIN32)
|
||||
/* Weak reference to remote symbol for compilation. */
|
||||
# pragma weak remote_session_control
|
||||
# pragma weak remote_handle_control
|
||||
# pragma weak remote_handle64_control
|
||||
# pragma weak fastrpc_mmap
|
||||
# pragma weak fastrpc_munmap
|
||||
# pragma weak rpcmem_alloc2
|
||||
#endif
|
||||
|
||||
#if !defined(_WIN32)
|
||||
# pragma weak remote_system_request
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
# define DSPQUEUE_TIMEOUT DSPQUEUE_TIMEOUT_NONE
|
||||
#else
|
||||
# define DSPQUEUE_TIMEOUT 1000000
|
||||
#endif
|
||||
|
||||
/**
|
||||
* htpdrv_init API: driver interface entry point
|
||||
*
|
||||
* @return Return AEE error codes as defined in Hexagon SDK.
|
||||
*/
|
||||
HTPDRV_API int htpdrv_init(void);
|
||||
|
||||
/**
|
||||
* get_domain API: get domain struct from domain value.
|
||||
*
|
||||
* @param[in] domain value of a domain
|
||||
* @return Returns domain struct of the domain if it is supported or else
|
||||
* returns NULL.
|
||||
*
|
||||
*/
|
||||
HTPDRV_API domain * get_domain(int domain_id);
|
||||
|
||||
/**
|
||||
* get_hex_arch_ver API: query the Hexagon processor architecture version information
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @param[out] Arch version (73, 75, ...)
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
*/
|
||||
HTPDRV_API int get_hex_arch_ver(int domain, int * arch);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -1,454 +0,0 @@
|
||||
|
||||
#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
|
||||
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||
#pragma clang diagnostic ignored "-Wsign-compare"
|
||||
|
||||
#define GGML_COMMON_IMPL_C
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-common.h"
|
||||
#include "ggml-hexagon.h"
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#include "htp-utils.h"
|
||||
|
||||
#include <domain.h>
|
||||
#include <remote.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
domain * get_domain(int domain_id) {
|
||||
int i = 0;
|
||||
int size = sizeof(supported_domains) / sizeof(domain);
|
||||
|
||||
for (i = 0; i < size; i++) {
|
||||
if (supported_domains[i].id == domain_id) {
|
||||
return &supported_domains[i];
|
||||
}
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
bool is_valid_domain_id(int domain_id, int compute_only) {
|
||||
int i = 0;
|
||||
int size = sizeof(supported_domains) / sizeof(domain);
|
||||
|
||||
if (compute_only) {
|
||||
return is_CDSP(domain_id);
|
||||
}
|
||||
|
||||
for (i = 0; i < size; i++) {
|
||||
if (supported_domains[i].id == domain_id) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
int ss_info = 0;
|
||||
if (domain_type != NULL) {
|
||||
if (strcmp(domain_type, "LPASS") == 0) {
|
||||
ss_info = FASTRPC_LPASS;
|
||||
} else if (strcmp(domain_type, "HPASS") == 0) {
|
||||
ss_info = FASTRPC_HPASS;
|
||||
} else {
|
||||
ss_info = FASTRPC_NSP;
|
||||
}
|
||||
}
|
||||
system_req_payload req = { 0 };
|
||||
req.id = FASTRPC_GET_DOMAINS;
|
||||
req.sys.domains = NULL;
|
||||
fastrpc_domain * domain = NULL;
|
||||
if (ss_info != 0) {
|
||||
req.sys.flags = DOMAINS_LIST_FLAGS_SET_TYPE(req.sys.flags, ss_info);
|
||||
} else {
|
||||
req.sys.flags = 0;
|
||||
}
|
||||
#ifdef _WIN32
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
goto bail;
|
||||
#endif
|
||||
if (remote_system_request) {
|
||||
nErr = remote_system_request(&req);
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
// Allocate memory for domain-info array
|
||||
req.sys.max_domains = req.sys.num_domains;
|
||||
if ((req.sys.domains = calloc(req.sys.num_domains, sizeof(fastrpc_domain))) == NULL) {
|
||||
nErr = AEE_ENOMEMORY;
|
||||
GGML_LOG_ERROR("Unable to allocate memory for req.sys.domains");
|
||||
goto bail;
|
||||
}
|
||||
|
||||
nErr = remote_system_request(&req);
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
|
||||
for (int i = 0; i < req.sys.num_domains; i++) {
|
||||
// Verify that only requested type domains were returned
|
||||
domain = &req.sys.domains[i];
|
||||
if (domain->type != ss_info && domain_type != NULL) {
|
||||
nErr = -1;
|
||||
GGML_LOG_ERROR("Incorrect data received from remote_system_request.\n");
|
||||
goto bail;
|
||||
}
|
||||
}
|
||||
*domains_info = req.sys.domains;
|
||||
*num_domains = req.sys.num_domains;
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
goto bail;
|
||||
}
|
||||
bail:
|
||||
if (nErr && !req.sys.domains) {
|
||||
free(req.sys.domains);
|
||||
}
|
||||
return nErr;
|
||||
}
|
||||
|
||||
int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id) {
|
||||
int err = 0;
|
||||
remote_rpc_effective_domain_id_t sess = { 0 };
|
||||
|
||||
sess.domain_name = domain_name;
|
||||
sess.domain_name_len = strlen(domain_name);
|
||||
sess.session_id = session_id;
|
||||
|
||||
err = remote_session_control(FASTRPC_GET_EFFECTIVE_DOMAIN_ID, &sess, sizeof(sess));
|
||||
if (err) {
|
||||
GGML_LOG_ERROR("Error 0x%x: failed to get effective domain id for %s, session id %d\n", err, sess.domain_name,
|
||||
session_id);
|
||||
return err;
|
||||
}
|
||||
|
||||
*effec_domain_id = sess.effective_domain_id;
|
||||
return err;
|
||||
}
|
||||
|
||||
int get_dsp_support(int * domain) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
*domain = CDSP_DOMAIN_ID; // DSP domain default value is CDSP_DOMAIN_ID
|
||||
|
||||
if (remote_handle_control) {
|
||||
struct remote_dsp_capability dsp_capability_domain = { CDSP_DOMAIN_ID, DOMAIN_SUPPORT, 0 };
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
goto bail;
|
||||
}
|
||||
|
||||
if (dsp_capability_domain.capability == 0) {
|
||||
dsp_capability_domain.domain = ADSP_DOMAIN_ID; // Check for ADSP support.
|
||||
dsp_capability_domain.attribute_ID = DOMAIN_SUPPORT;
|
||||
dsp_capability_domain.capability = 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if (dsp_capability_domain.capability) {
|
||||
*domain = ADSP_DOMAIN_ID; // For targets like Agatti (not having cDSP), domain is ADSP_DOMAIN_ID
|
||||
}
|
||||
}
|
||||
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("\nget_dsp_support failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return nErr;
|
||||
}
|
||||
|
||||
int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
*capability = 0;
|
||||
|
||||
if (attr == VTCM_PAGE || attr == VTCM_COUNT) {
|
||||
} else {
|
||||
nErr = AEE_EBADPARM;
|
||||
GGML_LOG_ERROR("Unsupported attr. Only VTCM_PAGE and VTCM_COUNT supported\n");
|
||||
goto bail;
|
||||
}
|
||||
if (remote_handle_control) {
|
||||
if (domain == ADSP_DOMAIN_ID || domain == CDSP_DOMAIN_ID) {
|
||||
/*
|
||||
* Query the DSP for VTCM information
|
||||
* Since the ADSP does not have a dedicated VTCM, we expect the output to be 0
|
||||
*/
|
||||
struct remote_dsp_capability dsp_capability_vtcm_dsp;
|
||||
dsp_capability_vtcm_dsp.domain = (uint32_t) domain;
|
||||
dsp_capability_vtcm_dsp.attribute_ID = attr;
|
||||
dsp_capability_vtcm_dsp.capability = (uint32_t) 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_vtcm_dsp,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||
nErr = AEE_SUCCESS;
|
||||
goto bail;
|
||||
} else if (nErr == AEE_SUCCESS) {
|
||||
*capability = dsp_capability_vtcm_dsp.capability;
|
||||
} else {
|
||||
GGML_LOG_ERROR("\nget_vtcm_info failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
GGML_LOG_ERROR("Unsupported domain %d\n", domain);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return nErr;
|
||||
}
|
||||
|
||||
bool is_unsignedpd_supported(int domain_id) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
if (remote_handle_control) {
|
||||
struct remote_dsp_capability dsp_capability_domain = { domain_id, UNSIGNED_PD_SUPPORT, 0 };
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device. Falling back to signed pd.\n");
|
||||
return false;
|
||||
}
|
||||
if (nErr) {
|
||||
GGML_LOG_ERROR("\nERROR 0x%x: FastRPC Capability API failed. Falling back to signed pd.", nErr);
|
||||
return false;
|
||||
}
|
||||
if (dsp_capability_domain.capability == 1) {
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device. Falling back to signed pd.\n");
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool get_unsignedpd_support(void) {
|
||||
return is_unsignedpd_supported(CDSP_DOMAIN_ID);
|
||||
}
|
||||
|
||||
bool is_async_fastrpc_supported(int domain) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
if (remote_handle_control) {
|
||||
if (domain == CDSP_DOMAIN_ID) {
|
||||
/*
|
||||
* Query the DSP for ASYNC_FASTRPC_SUPPORT information
|
||||
* Async fastrpc is supported only on CDSP
|
||||
*/
|
||||
struct remote_dsp_capability dsp_capability_async_support;
|
||||
dsp_capability_async_support.domain = (uint32_t) domain;
|
||||
dsp_capability_async_support.attribute_ID = ASYNC_FASTRPC_SUPPORT;
|
||||
dsp_capability_async_support.capability = (uint32_t) 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_async_support,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||
nErr = AEE_SUCCESS;
|
||||
goto bail;
|
||||
} else if (dsp_capability_async_support.capability == 1) {
|
||||
return true;
|
||||
}
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("\nis_async_fastrpc_supported failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
GGML_LOG_ERROR("Async fastrpc is not supported on domain %d\n", domain);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_status_notification_supported(int domain) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
|
||||
if (remote_handle_control) {
|
||||
/*
|
||||
* Query the DSP for STATUS_NOTIFICATION_SUPPORT information
|
||||
* DSP User PD status notification Support
|
||||
*/
|
||||
struct remote_dsp_capability dsp_capability_status_notification_support;
|
||||
dsp_capability_status_notification_support.domain = (uint32_t) domain;
|
||||
dsp_capability_status_notification_support.attribute_ID = STATUS_NOTIFICATION_SUPPORT;
|
||||
dsp_capability_status_notification_support.capability = (uint32_t) 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_status_notification_support,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||
nErr = AEE_SUCCESS;
|
||||
goto bail;
|
||||
} else if (dsp_capability_status_notification_support.capability == 1) {
|
||||
return true;
|
||||
}
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("\nis_status_notification_supported failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return false;
|
||||
}
|
||||
|
||||
int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
*capability = 0;
|
||||
|
||||
if (attr != HMX_SUPPORT_SPATIAL && attr != HMX_SUPPORT_DEPTH) {
|
||||
nErr = AEE_EBADPARM;
|
||||
GGML_LOG_ERROR("Unsupported attr. Only HMX_SUPPORT_SPATIAL and HMX_SUPPORT_DEPTH supported\n");
|
||||
goto bail;
|
||||
}
|
||||
if (remote_handle_control) {
|
||||
if (domain == CDSP_DOMAIN_ID) {
|
||||
/*
|
||||
* Query the DSP for HMX SUPPORT information
|
||||
* HMX is supported on CDSP only
|
||||
*/
|
||||
struct remote_dsp_capability dsp_capability_hmx_dsp;
|
||||
dsp_capability_hmx_dsp.domain = (uint32_t) domain;
|
||||
dsp_capability_hmx_dsp.attribute_ID = attr;
|
||||
dsp_capability_hmx_dsp.capability = (uint32_t) 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hmx_dsp,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||
nErr = AEE_SUCCESS;
|
||||
goto bail;
|
||||
} else if (nErr == AEE_SUCCESS) {
|
||||
*capability = dsp_capability_hmx_dsp.capability;
|
||||
} else {
|
||||
GGML_LOG_ERROR("\nget_hmx_support_info failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
GGML_LOG_ERROR("HMX support is not there for domain %d\n", domain);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return nErr;
|
||||
}
|
||||
|
||||
int get_hex_arch_ver(int domain, int * arch) {
|
||||
if (!remote_handle_control) {
|
||||
GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n");
|
||||
return AEE_EUNSUPPORTEDAPI;
|
||||
}
|
||||
|
||||
struct remote_dsp_capability arch_ver;
|
||||
arch_ver.domain = (uint32_t) domain;
|
||||
arch_ver.attribute_ID = ARCH_VER;
|
||||
arch_ver.capability = (uint32_t) 0;
|
||||
|
||||
int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver));
|
||||
if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) {
|
||||
GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n");
|
||||
return AEE_EUNSUPPORTEDAPI;
|
||||
}
|
||||
|
||||
if (err != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err);
|
||||
return err;
|
||||
}
|
||||
|
||||
switch (arch_ver.capability & 0xff) {
|
||||
case 0x68:
|
||||
*arch = 68;
|
||||
return 0;
|
||||
case 0x69:
|
||||
*arch = 69;
|
||||
return 0;
|
||||
case 0x73:
|
||||
*arch = 73;
|
||||
return 0;
|
||||
case 0x75:
|
||||
*arch = 75;
|
||||
return 0;
|
||||
case 0x79:
|
||||
*arch = 79;
|
||||
return 0;
|
||||
case 0x81:
|
||||
*arch = 81;
|
||||
return 0;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
*capability = 0;
|
||||
|
||||
if (remote_handle_control) {
|
||||
if (domain == CDSP_DOMAIN_ID) {
|
||||
/*
|
||||
* Query the DSP for HVX SUPPORT information
|
||||
* HVX is supported on CDSP only
|
||||
*/
|
||||
struct remote_dsp_capability dsp_capability_hvx_dsp;
|
||||
dsp_capability_hvx_dsp.domain = (uint32_t) domain;
|
||||
dsp_capability_hvx_dsp.attribute_ID = attr;
|
||||
dsp_capability_hvx_dsp.capability = (uint32_t) 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hvx_dsp,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||
nErr = AEE_SUCCESS;
|
||||
goto bail;
|
||||
} else if (nErr == AEE_SUCCESS) {
|
||||
*capability = dsp_capability_hvx_dsp.capability;
|
||||
} else {
|
||||
GGML_LOG_ERROR("\nget_hvx_support_info failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
GGML_LOG_ERROR("HVX support is not available on domain %d\n", domain);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return nErr;
|
||||
}
|
||||
@@ -1,221 +0,0 @@
|
||||
#ifndef HTP_UTILS_H
|
||||
#define HTP_UTILS_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#include <AEEStdErr.h>
|
||||
#include <inttypes.h>
|
||||
#include <remote.h>
|
||||
#include <rpcmem.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
/* Offset to differentiate HLOS and Hexagon error codes.
|
||||
Stores the value of AEE_EOFFSET for Hexagon. */
|
||||
#ifndef DSP_OFFSET
|
||||
# define DSP_OFFSET 0x80000400
|
||||
#endif
|
||||
|
||||
/* Errno for connection reset by peer. */
|
||||
#ifndef ECONNRESET
|
||||
# ifdef __hexagon__
|
||||
# define ECONNRESET 104
|
||||
# endif
|
||||
#endif
|
||||
|
||||
/* Abstraction of different OS specific sleep APIs.
|
||||
SLEEP accepts input in seconds. */
|
||||
#ifndef SLEEP
|
||||
# ifdef __hexagon__
|
||||
# define SLEEP(x) \
|
||||
{ /* Do nothing for simulator. */ \
|
||||
}
|
||||
# else
|
||||
# ifdef _WINDOWS
|
||||
# define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */
|
||||
# else
|
||||
# define SLEEP(x) sleep(x) /* sleep accepts input in seconds. */
|
||||
# endif
|
||||
# endif
|
||||
#endif
|
||||
|
||||
/* Include windows specific header files. */
|
||||
#ifdef _WINDOWS
|
||||
# include <sysinfoapi.h>
|
||||
# include <windows.h>
|
||||
# define _CRT_SECURE_NO_WARNINGS 1
|
||||
# define _WINSOCK_DEPRECATED_NO_WARNINGS 1
|
||||
/* Including this file for custom implementation of getopt function. */
|
||||
# include "getopt_custom.h"
|
||||
#endif
|
||||
|
||||
/* Includes and defines for all HLOS except windows */
|
||||
#if !defined(__hexagon__) && !defined(_WINDOWS)
|
||||
# include "unistd.h"
|
||||
|
||||
# include <sys/time.h>
|
||||
#endif
|
||||
|
||||
/* Includes and defines for Hexagon and all HLOS except Windows. */
|
||||
#if !defined(_WINDOWS)
|
||||
/* Weak reference to remote symbol for compilation. */
|
||||
# pragma weak remote_session_control
|
||||
# pragma weak remote_handle_control
|
||||
# pragma weak remote_handle64_control
|
||||
# pragma weak fastrpc_mmap
|
||||
# pragma weak fastrpc_munmap
|
||||
# pragma weak rpcmem_alloc2
|
||||
#endif
|
||||
|
||||
#if !defined(_WINDOWS)
|
||||
# pragma weak remote_system_request
|
||||
#endif
|
||||
/**
|
||||
* Wrapper for FastRPC Capability API: query DSP support.
|
||||
*
|
||||
* @param[out] domain pointer to supported domain.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*/
|
||||
int get_dsp_support(int * domain);
|
||||
|
||||
/**
|
||||
* Wrapper for FastRPC Capability API: query VTCM information.
|
||||
*
|
||||
* @param[in] domain value of domain in the queried.
|
||||
* @param[out] capability capability value of the attribute queried.
|
||||
* @param[in] attr value of the attribute to the queried.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*/
|
||||
int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr);
|
||||
|
||||
/**
|
||||
* Wrapper for FastRPC Capability API: query unsigned pd support on CDSP domain.
|
||||
*
|
||||
* @return true if unsigned pd is supported.
|
||||
* false if unsigned pd is not supported, capability query failed.
|
||||
*/
|
||||
|
||||
bool get_unsignedpd_support(void);
|
||||
|
||||
/**
|
||||
* Wrapper for FastRPC Capability API: query unsigned pd support.
|
||||
*
|
||||
* @param[in] domain value of domain in the queried.
|
||||
* @return true if unsigned pd is supported.
|
||||
* false if unsigned pd is not supported, capability query failed.
|
||||
*/
|
||||
|
||||
bool is_unsignedpd_supported(int domain_id);
|
||||
|
||||
/**
|
||||
* is_valid_domain_id API: query a domain id is valid.
|
||||
*
|
||||
* @param[in] domain value of domain in the queried.
|
||||
* @param[in] compute_only value of domain is only compared with CDSP domains supported by the target when enabled.
|
||||
* @return true if value of domain is valid.
|
||||
* false if value of domain is not valid.
|
||||
*/
|
||||
|
||||
bool is_valid_domain_id(int domain_id, int compute_only);
|
||||
|
||||
/**
|
||||
* get_domain API: get domain struct from domain value.
|
||||
*
|
||||
* @param[in] domain value of a domain
|
||||
* @return Returns domain struct of the domain if it is supported or else
|
||||
* returns NULL.
|
||||
*
|
||||
*/
|
||||
|
||||
domain * get_domain(int domain_id);
|
||||
|
||||
/**
|
||||
* get_domains_info API: get information for all the domains available on the device
|
||||
*
|
||||
* @param[in] domain_type pointer to domain type
|
||||
* @param[in] num_domains pointer to number of domains
|
||||
* @param[in] domains_info pointer to save discovered domains information.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
* It is user's responsibility to free the memory used to store the domains info whose address is present in domains_info before closing the application.
|
||||
*
|
||||
*/
|
||||
|
||||
int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info);
|
||||
|
||||
/**
|
||||
* get_effective_domain_id API: get effective domain id for given session id
|
||||
*
|
||||
* @param[in] domain_name pointer to domain name
|
||||
* @param[in] session_id
|
||||
* @param[in] effec_domain_id pointer to save obtained effective domain id.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
*/
|
||||
|
||||
int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id);
|
||||
|
||||
/**
|
||||
* is_async_fastrpc_supported API: query a domain id has async fastrpc supported or not
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @return Returns true or false stating support of Async FastRPC
|
||||
*
|
||||
*/
|
||||
|
||||
bool is_async_fastrpc_supported(int domain_id);
|
||||
|
||||
/**
|
||||
* is_status_notification_supported API: query the DSP for STATUS_NOTIFICATION_SUPPORT information
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @return Returns true or false stating status notification support information
|
||||
*
|
||||
*/
|
||||
bool is_status_notification_supported(int domain_id);
|
||||
|
||||
/**
|
||||
* get_hmx_support_info API: query the DSP for HMX SUPPORT information
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @param[out] capability capability value of the attribute queried.
|
||||
* @param[in] attr value of the attribute to the queried.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
*/
|
||||
int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr);
|
||||
|
||||
/**
|
||||
* get_hex_arch_ver API: query the Hexagon processor architecture version information
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @param[out] Arch version (73, 75, ...)
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
*/
|
||||
int get_hex_arch_ver(int domain, int * arch);
|
||||
|
||||
/**
|
||||
* get_hvx_support_info API: query the DSP for HVX SUPPORT information
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @param[out] capability capability value of the attribute queried.
|
||||
* @param[in] attr value of the attribute to the queried.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
*/
|
||||
int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif //DSP_CAPABILITIES_UTILS_H
|
||||
@@ -0,0 +1,79 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef _WIN32
|
||||
# define WIN32_LEAN_AND_MEAN
|
||||
# ifndef NOMINMAX
|
||||
# define NOMINMAX
|
||||
# endif
|
||||
# include <windows.h>
|
||||
# include <winevt.h>
|
||||
#else
|
||||
# include <dlfcn.h>
|
||||
# include <unistd.h>
|
||||
#endif
|
||||
#include <filesystem>
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
using dl_handle = std::remove_pointer_t<HMODULE>;
|
||||
|
||||
struct dl_handle_deleter {
|
||||
void operator()(HMODULE handle) {
|
||||
FreeLibrary(handle);
|
||||
}
|
||||
};
|
||||
|
||||
static inline dl_handle * dl_load_library(const fs::path & path) {
|
||||
// suppress error dialogs for missing DLLs
|
||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
|
||||
HMODULE handle = LoadLibraryW(path.wstring().c_str());
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
static inline void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
|
||||
void * p = (void *) GetProcAddress(handle, name);
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
return p;
|
||||
}
|
||||
|
||||
static inline const char * dl_error() {
|
||||
return "";
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
using dl_handle = void;
|
||||
|
||||
struct dl_handle_deleter {
|
||||
void operator()(void * handle) {
|
||||
dlclose(handle);
|
||||
}
|
||||
};
|
||||
|
||||
static inline dl_handle * dl_load_library(const fs::path & path) {
|
||||
dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
|
||||
return handle;
|
||||
}
|
||||
|
||||
static inline void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||
return dlsym(handle, name);
|
||||
}
|
||||
|
||||
static inline const char * dl_error() {
|
||||
const char *rslt = dlerror();
|
||||
return rslt != nullptr ? rslt : "";
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,38 @@
|
||||
[Version]
|
||||
Signature = "$WINDOWS NT$"
|
||||
Class = ComputeAccelerator
|
||||
ClassGuid = {F01A9D53-3FF6-48D2-9F97-C8A7004BE10C}
|
||||
Provider = %GGML%
|
||||
DriverVer = 01/01/2026,1.0.0.0
|
||||
CatalogFile = libggml-htp.cat
|
||||
PnpLockDown = 1
|
||||
|
||||
[DestinationDirs]
|
||||
Drivers_Dir = 6
|
||||
|
||||
[SourceDisksNames]
|
||||
1 = %DiskId%
|
||||
|
||||
[SourceDisksFiles]
|
||||
libggml-htp-v68.so = 1
|
||||
libggml-htp-v69.so = 1
|
||||
libggml-htp-v73.so = 1
|
||||
libggml-htp-v75.so = 1
|
||||
libggml-htp-v81.so = 1
|
||||
|
||||
[ControlFlags]
|
||||
ExcludeFromSelect = *
|
||||
|
||||
[DefaultInstall.NTarm64]
|
||||
CopyFiles=Drivers_Dir
|
||||
|
||||
[Drivers_Dir]
|
||||
libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
libggml-htp-v81.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
|
||||
[Strings]
|
||||
GGML = 'GGML'
|
||||
DiskId = 'GGML HTP library'
|
||||
@@ -62,6 +62,8 @@ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
|
||||
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
|
||||
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/mmf*.cu")
|
||||
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
||||
|
||||
if (GGML_CUDA_FA_ALL_QUANTS)
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
|
||||
#include <sycl/sycl.hpp>
|
||||
#include <sycl/half_type.hpp>
|
||||
#include <syclcompat/math.hpp>
|
||||
#include <map>
|
||||
|
||||
#ifdef GGML_SYCL_USE_INTEL_ONEMKL
|
||||
|
||||
@@ -123,6 +123,15 @@ static __dpct_inline__ T op_log(T x) {
|
||||
return sycl::log(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_softplus(T x) {
|
||||
const float xf = (float) x;
|
||||
const float ax = sycl::fabs(xf);
|
||||
const float m = sycl::fmax(xf, 0.0f);
|
||||
const float y = m + sycl::log1p(sycl::exp(-ax));
|
||||
return (T) y;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_neg(T x) {
|
||||
return -x;
|
||||
@@ -695,6 +704,12 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_softplus(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_neg(x);
|
||||
@@ -1101,6 +1116,11 @@ void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_log(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_softplus(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_neg(ctx, dst);
|
||||
|
||||
@@ -61,6 +61,8 @@ void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -2263,6 +2263,65 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_ten
|
||||
diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
|
||||
}
|
||||
|
||||
static void tri_f32_sycl(
|
||||
const float * src,
|
||||
float * dst,
|
||||
const int64_t ne0,
|
||||
const int64_t ne1,
|
||||
const int64_t ne2,
|
||||
const int64_t ne3,
|
||||
const ggml_tri_type ttype,
|
||||
dpct::queue_ptr main_stream
|
||||
) {
|
||||
const size_t total = (size_t) ne0 * (size_t) ne1 * (size_t) ne2 * (size_t) ne3;
|
||||
|
||||
main_stream->parallel_for(sycl::range<1>(total), [=](sycl::id<1> tid) {
|
||||
const int64_t idx = (int64_t) tid[0];
|
||||
|
||||
const int64_t i0 = idx % ne0;
|
||||
const int64_t t1 = idx / ne0;
|
||||
const int64_t i1 = t1 % ne1;
|
||||
|
||||
bool keep = false;
|
||||
switch (ttype) {
|
||||
case GGML_TRI_TYPE_LOWER: keep = (i0 < i1); break;
|
||||
case GGML_TRI_TYPE_LOWER_DIAG: keep = (i0 <= i1); break;
|
||||
case GGML_TRI_TYPE_UPPER: keep = (i0 > i1); break;
|
||||
case GGML_TRI_TYPE_UPPER_DIAG: keep = (i0 >= i1); break;
|
||||
default: keep = false; break;
|
||||
}
|
||||
|
||||
dst[idx] = keep ? src[idx] : 0.0f;
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_sycl_op_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
GGML_ASSERT(src0);
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
const float * src0_dd = static_cast<const float *>(src0->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
const int64_t ne0 = src0->ne[0];
|
||||
const int64_t ne1 = src0->ne[1];
|
||||
const int64_t ne2 = src0->ne[2];
|
||||
const int64_t ne3 = src0->ne[3];
|
||||
|
||||
tri_f32_sycl(src0_dd, dst_dd, ne0, ne1, ne2, ne3, ttype, main_stream);
|
||||
}
|
||||
|
||||
|
||||
inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
@@ -3786,6 +3845,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||
case GGML_UNARY_OP_EXP:
|
||||
ggml_sycl_exp(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
ggml_sycl_softplus(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_SGN:
|
||||
ggml_sycl_sgn(ctx, dst);
|
||||
break;
|
||||
@@ -3912,6 +3974,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||
case GGML_OP_TRANSPOSE:
|
||||
GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__);
|
||||
break;
|
||||
case GGML_OP_TRI:
|
||||
ggml_sycl_op_tri(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
ggml_sycl_diag_mask_inf(ctx, dst);
|
||||
break;
|
||||
@@ -4404,6 +4469,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_GELU_ERF:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
return true;
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
@@ -4606,18 +4672,23 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
|
||||
#endif
|
||||
case GGML_OP_NORM:
|
||||
return true;
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_RMS_NORM:
|
||||
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
|
||||
return true;
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_SCALE:
|
||||
return true;
|
||||
case GGML_OP_CONT:
|
||||
return op->src[0]->type != GGML_TYPE_BF16;
|
||||
case GGML_OP_TRI:
|
||||
{
|
||||
const ggml_tensor * src0 = op->src[0];
|
||||
return src0 &&
|
||||
op->type == GGML_TYPE_F32 &&
|
||||
ggml_is_contiguous(src0);
|
||||
}
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
return true;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
|
||||
@@ -251,7 +251,6 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
|
||||
const float eps, queue_ptr stream, int device) {
|
||||
|
||||
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
if (ncols < 1024) {
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
@@ -334,7 +333,6 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
||||
|
||||
static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
||||
|
||||
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
|
||||
@@ -374,7 +372,6 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
|
||||
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
const int nrows, const float eps,
|
||||
queue_ptr stream, int device) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
||||
if (ncols < 1024) {
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
|
||||
@@ -3162,17 +3162,31 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
// For scalar, use 128 (arbitrary)
|
||||
// The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
|
||||
const uint32_t D = (hsk|hsv);
|
||||
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
|
||||
? scalar_flash_attention_workgroup_size
|
||||
: ((small_rows && (D % 32) == 0) ? 256 : 128);
|
||||
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache);
|
||||
|
||||
uint32_t wg_size;
|
||||
switch (path) {
|
||||
case FA_COOPMAT2:
|
||||
wg_size = ((small_rows && (D % 32) == 0) ? 256 : 128);
|
||||
break;
|
||||
case FA_COOPMAT1:
|
||||
wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
|
||||
break;
|
||||
default:
|
||||
wg_size = scalar_flash_attention_workgroup_size;
|
||||
break;
|
||||
}
|
||||
|
||||
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
||||
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
||||
const uint32_t D_lsb = D ^ (D & (D-1));
|
||||
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
|
||||
|
||||
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
|
||||
// Nvidia prefers shared memory use to load large tiles of K
|
||||
// AMD prefers loading K directly from global memory
|
||||
const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA ? 1 : 0;
|
||||
|
||||
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem};
|
||||
};
|
||||
|
||||
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
||||
@@ -3187,15 +3201,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
if (path == FAPATH) { \
|
||||
if (aligned) { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
|
||||
} \
|
||||
} else { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
@@ -8344,41 +8358,49 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
|
||||
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
|
||||
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
|
||||
|
||||
return supported;
|
||||
}
|
||||
|
||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
|
||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
|
||||
// Needs to be kept up to date on shader changes
|
||||
GGML_UNUSED(hsv);
|
||||
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
||||
const uint32_t Br = coopmat1_flash_attention_num_large_rows;
|
||||
const uint32_t Bc = scalar_flash_attention_Bc;
|
||||
const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, false, false);
|
||||
const uint32_t Br = rows_cols[0];
|
||||
const uint32_t Bc = rows_cols[1];
|
||||
|
||||
const uint32_t MatBr = 16, MatBc = 16;
|
||||
|
||||
const uint32_t row_split = Bc / MatBc;
|
||||
|
||||
const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16);
|
||||
|
||||
const uint32_t acctype = f32acc ? 4 : 2;
|
||||
const uint32_t f16vec4 = 8;
|
||||
|
||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||
const uint32_t tmpshv4 = wg_size * 4 * acctype;
|
||||
const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);
|
||||
|
||||
const uint32_t qstride = hsk_pad / 4 + 2;
|
||||
const uint32_t Qf = Br * qstride * f16vec4;
|
||||
|
||||
const uint32_t psh_stride = Br / 4 + 2;
|
||||
const uint32_t Psh = Bc * psh_stride * f16vec4;
|
||||
|
||||
const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
|
||||
const uint32_t sfsh = Bc * sfshstride * acctype;
|
||||
|
||||
const uint32_t kshstride = hsk_pad / 4 + 2;
|
||||
const uint32_t ksh = Bc * kshstride * f16vec4;
|
||||
const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA;
|
||||
const uint32_t kshstride = (k_load_shmem ? hsk_pad : MatBr) / 4 + 2;
|
||||
const uint32_t vsh_stride = MatBc / 4 * row_split;
|
||||
const uint32_t ksh = ((kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)) * f16vec4;
|
||||
|
||||
const uint32_t slope = Br * sizeof(float);
|
||||
const uint32_t slope = Br * acctype;
|
||||
|
||||
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
|
||||
const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope;
|
||||
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported);
|
||||
|
||||
return supported;
|
||||
}
|
||||
@@ -8442,7 +8464,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
|
||||
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
|
||||
|
||||
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
|
||||
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32, k->type);
|
||||
|
||||
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
|
||||
path = FA_SCALAR;
|
||||
|
||||
@@ -8,6 +8,8 @@ layout (constant_id = 3) const uint32_t HSK = 32;
|
||||
layout (constant_id = 4) const uint32_t HSV = 32;
|
||||
layout (constant_id = 5) const uint32_t Clamp = 0;
|
||||
layout (constant_id = 6) const uint32_t D_split = 16;
|
||||
layout (constant_id = 7) const uint32_t SubGroupSize = 32;
|
||||
layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0;
|
||||
|
||||
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
|
||||
const uint32_t HSK_pad = (HSK + 15) & ~15;
|
||||
@@ -74,6 +76,10 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16
|
||||
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
|
||||
#endif
|
||||
|
||||
#ifndef BLOCK_SIZE
|
||||
#define BLOCK_SIZE 1
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_F32)
|
||||
#undef BLOCK_SIZE
|
||||
#define BLOCK_SIZE 4
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_vote : enable
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
@@ -14,12 +15,13 @@
|
||||
#include "types.glsl"
|
||||
#include "flash_attn_base.glsl"
|
||||
|
||||
const uint32_t HSK_per_thread = HSK / D_split;
|
||||
const uint32_t HSV_per_thread = HSV / D_split;
|
||||
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
|
||||
const uint32_t MatBr = 16;
|
||||
const uint32_t MatBc = 16;
|
||||
|
||||
const uint32_t row_split = 4;
|
||||
const uint32_t row_split = Bc / MatBc;
|
||||
const uint32_t rows_per_thread = Br / row_split;
|
||||
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
|
||||
const uint32_t cols_per_iter = gl_WorkGroupSize.x / row_split;
|
||||
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
||||
|
||||
|
||||
@@ -40,24 +42,24 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
|
||||
return elem;
|
||||
}
|
||||
|
||||
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
|
||||
const uint32_t MatBr = 16;
|
||||
const uint32_t MatBc = 16;
|
||||
|
||||
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
||||
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
|
||||
shared float tmpsh[row_split];
|
||||
|
||||
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 Qf[Br * qstride];
|
||||
|
||||
const uint psh_stride = Br / 4 + 2;
|
||||
shared f16vec4 Psh[Bc * psh_stride];
|
||||
|
||||
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
|
||||
const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
|
||||
shared ACC_TYPE sfsh[Bc * sfshstride];
|
||||
const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4;
|
||||
shared ACC_TYPEV4 sfsh[Bc * sfshstride];
|
||||
|
||||
const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 ksh[Bc * kshstride];
|
||||
const uint32_t kshstride = (K_LOAD_SHMEM != 0 ? HSK_pad : MatBr) / 4 + 2; // in units of f16vec4
|
||||
const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups
|
||||
const uint vsh_stride = v_cols;
|
||||
shared f16vec4 ksh[(kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)];
|
||||
|
||||
shared float slope[Br];
|
||||
shared ACC_TYPE slope[Br];
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
@@ -69,9 +71,9 @@ void main() {
|
||||
const uint32_t tid = gl_LocalInvocationIndex;
|
||||
|
||||
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
|
||||
const uint32_t d_per_thread = (HSV/4 + threads_per_rowgroup - 1) / threads_per_rowgroup;
|
||||
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
|
||||
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
|
||||
const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
|
||||
const uint32_t col_tid = gl_LocalInvocationIndex % threads_per_rowgroup;
|
||||
|
||||
#define tile_row(r) (row_tid * rows_per_thread + (r))
|
||||
|
||||
@@ -102,9 +104,9 @@ void main() {
|
||||
}
|
||||
barrier();
|
||||
|
||||
ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
ACC_TYPEV4 Of[rows_per_thread][d_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
[[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) {
|
||||
Of[r][d] = ACC_TYPEV4(0.0);
|
||||
}
|
||||
}
|
||||
@@ -125,13 +127,11 @@ void main() {
|
||||
uint r = tid;
|
||||
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
|
||||
}
|
||||
barrier();
|
||||
} else {
|
||||
if (tid < Br) {
|
||||
uint r = tid;
|
||||
slope[r] = 1.0;
|
||||
slope[r] = ACC_TYPE(1.0);
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
#if BLOCK_SIZE > 1
|
||||
@@ -149,19 +149,45 @@ void main() {
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
|
||||
float mask_cache[Bc * Br / WorkGroupSize];
|
||||
f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
float max_mask = NEG_FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) / (Br / 4);
|
||||
uint32_t r = (idx + tid) % (Br / 4);
|
||||
if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV)) {
|
||||
f16vec4 m;
|
||||
if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) {
|
||||
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
|
||||
data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
|
||||
data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
|
||||
data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]);
|
||||
max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3]));
|
||||
} else if (i * Br + r * 4 + 2 < p.nem1) {
|
||||
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
|
||||
data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
|
||||
data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
|
||||
0.0);
|
||||
max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2]));
|
||||
} else if (i * Br + r * 4 + 1 < p.nem1) {
|
||||
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
|
||||
data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
|
||||
0.0,
|
||||
0.0);
|
||||
max_mask = max(max(max_mask, float(m[0])), float(m[1]));
|
||||
} else if (i * Br + r * 4 < p.nem1) {
|
||||
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
|
||||
0.0,
|
||||
0.0,
|
||||
0.0);
|
||||
max_mask = max(max_mask, float(m[0]));
|
||||
} else {
|
||||
m = f16vec4(0.0);
|
||||
}
|
||||
mask_cache[idx / WorkGroupSize] = m;
|
||||
max_mask = max(max_mask, m);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -180,26 +206,28 @@ void main() {
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t c = (idx + tid) / (HSK / 4);
|
||||
if (c < Bc && d < HSK / 4) {
|
||||
f16vec4 K_Tf = f16vec4(0);
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
if (K_LOAD_SHMEM != 0) {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t c = (idx + tid) / (HSK / 4);
|
||||
if (c < Bc && d < HSK / 4) {
|
||||
f16vec4 K_Tf = f16vec4(0);
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
|
||||
#else
|
||||
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
ksh[c * kshstride + d] = K_Tf;
|
||||
ksh[c * kshstride + d] = K_Tf;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
barrier();
|
||||
|
||||
// K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br
|
||||
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
|
||||
@@ -208,11 +236,55 @@ void main() {
|
||||
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
|
||||
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
||||
|
||||
for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
|
||||
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
[[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
|
||||
if (K_LOAD_SHMEM == 0) {
|
||||
#if BLOCK_SIZE == 1
|
||||
if (KV_bounds_check || d * 16 + 16 > HSK) {
|
||||
#endif
|
||||
barrier();
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t col_vec = (idx + tid) % (MatBr / 4);
|
||||
uint32_t row = (idx + tid) / (MatBr / 4);
|
||||
if (idx + tid < Bc * MatBr / 4) {
|
||||
f16vec4 K_Tf = f16vec4(0);
|
||||
if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
|
||||
#else
|
||||
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
|
||||
#endif
|
||||
}
|
||||
|
||||
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
|
||||
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
ksh[row * kshstride + col_vec] = K_Tf;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
#if BLOCK_SIZE == 1
|
||||
}
|
||||
#endif
|
||||
|
||||
#if BLOCK_SIZE == 1
|
||||
if (KV_bounds_check || d * 16 + 16 > HSK)
|
||||
#endif
|
||||
{
|
||||
uint coord = (gl_SubgroupID * MatBc) * kshstride;
|
||||
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
#if BLOCK_SIZE == 1
|
||||
else {
|
||||
const uint coord = k_offset / 4 + (j * Bc + gl_SubgroupID * MatBc) * k_stride / 4 + d * 16 / 4;
|
||||
coopMatLoad(KMat, data_kv4, coord, k_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
|
||||
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
|
||||
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
|
||||
}
|
||||
@@ -222,26 +294,26 @@ void main() {
|
||||
barrier();
|
||||
|
||||
if (p.logit_softcap != 0.0f) {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) / Br;
|
||||
uint32_t r = (idx + tid) % Br;
|
||||
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
||||
sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) / (Br / 4);
|
||||
uint32_t r = (idx + tid) % (Br / 4);
|
||||
if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
|
||||
sfsh[c * sfshstride + r] = ACC_TYPEV4(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
float f = mask_cache[idx / WorkGroupSize];
|
||||
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f);
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) / (Br / 4);
|
||||
uint32_t r = (idx + tid) % (Br / 4);
|
||||
if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
// Mask nem1 bounds check is handled when loading masks
|
||||
ACC_TYPEV4 masks = ACC_TYPEV4(mask_cache[idx / WorkGroupSize]);
|
||||
ACC_TYPEV4 slopes = ACC_TYPEV4(slope[r * 4], slope[r * 4 + 1], slope[r * 4 + 2], slope[r * 4 + 3]);
|
||||
sfsh[c * sfshstride + r] += slopes * masks;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -250,51 +322,145 @@ void main() {
|
||||
|
||||
float eMf[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
const uint r_vec = tile_row(r) / 4;
|
||||
const uint r_comp = tile_row(r) % 4;
|
||||
|
||||
float rowmaxf = NEG_FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
|
||||
rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp]));
|
||||
}
|
||||
float Moldf = Mf[r];
|
||||
|
||||
// Compute max across the row
|
||||
rowmaxf = subgroupMax(rowmaxf);
|
||||
|
||||
// M = max(rowmax, Mold)
|
||||
// P = e^(S - M)
|
||||
// eM = e^(Mold - M)
|
||||
Mf[r] = max(rowmaxf, Moldf);
|
||||
eMf[r] = exp(Moldf - Mf[r]);
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Lf[r] = eMf[r]*Lf[r];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
float Pf[rows_per_thread];
|
||||
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
|
||||
const uint d_local = d0 / threads_per_rowgroup;
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
|
||||
Lf[r] += Pf[r];
|
||||
Of[r][d_local] = ACC_TYPE(eMf[r]) * Of[r][d_local];
|
||||
}
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
#else
|
||||
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
|
||||
}
|
||||
|
||||
// Calculate and store Pf in Psh
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
const uint col = c * cols_per_iter + col_tid;
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; r += 4) {
|
||||
const uint row = tile_row(r);
|
||||
if (KV_bounds_check && j * Bc + col >= KV) {
|
||||
Psh[col * psh_stride + row / 4] = f16vec4(0.0f);
|
||||
} else {
|
||||
const vec4 mfvec = vec4(Mf[r], Mf[r + 1], Mf[r + 2], Mf[r + 3]);
|
||||
const f16vec4 Pf = f16vec4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec));
|
||||
[[unroll]] for (uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) {
|
||||
Lf[r + vec_idx] += Pf[vec_idx];
|
||||
}
|
||||
Psh[col * psh_stride + row / 4] = Pf;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up
|
||||
|
||||
// Each subgroup handles HSV/4 columns
|
||||
[[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) {
|
||||
const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16;
|
||||
|
||||
SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
||||
|
||||
// Preload V tiles for [Bc, 16 * num subgroups]
|
||||
const uint v_rows = Bc;
|
||||
const uint v_total = v_rows * v_cols;
|
||||
const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x;
|
||||
|
||||
#if BLOCK_SIZE == 1
|
||||
// For f16, only preload if not aligned
|
||||
if (KV_bounds_check) {
|
||||
#endif
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf);
|
||||
[[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) {
|
||||
const uint idx = i * gl_WorkGroupSize.x + tid;
|
||||
const uint row = idx / v_cols;
|
||||
const uint col = idx % v_cols;
|
||||
|
||||
const uint v_row = j * Bc + row;
|
||||
const uint v_col = hsv_tile * MatBc * row_split + col * 4;
|
||||
|
||||
const uint coord = v_row * v_stride * BLOCK_SIZE + v_col;
|
||||
const uint ib = coord / BLOCK_SIZE;
|
||||
const uint iqs = coord % BLOCK_SIZE;
|
||||
|
||||
if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
|
||||
#if BLOCK_SIZE > 1
|
||||
ksh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V));
|
||||
#else
|
||||
ksh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
|
||||
#endif
|
||||
} else {
|
||||
ksh[row * vsh_stride + col] = f16vec4(0.0f);
|
||||
}
|
||||
}
|
||||
#if BLOCK_SIZE == 1
|
||||
}
|
||||
#endif
|
||||
|
||||
barrier();
|
||||
|
||||
[[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) {
|
||||
coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
#if BLOCK_SIZE == 1
|
||||
if (!KV_bounds_check) {
|
||||
// F16 values can be loaded directly from global memory
|
||||
const uint v_tile_row = j * Bc + bc_chunk * MatBc;
|
||||
const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
|
||||
coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
|
||||
coopMatLoad(QMat, ksh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
|
||||
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
|
||||
}
|
||||
|
||||
// Store SfMat to sfsh and load into Of
|
||||
const uint osh_stride = row_split * MatBc / 4;
|
||||
const uint o_offset = gl_SubgroupID * MatBc / 4;
|
||||
coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
|
||||
barrier();
|
||||
|
||||
const uint hsv_per_tile = row_split * MatBc;
|
||||
const uint hsv_base = hsv_tile * hsv_per_tile;
|
||||
const uint d_values_per_tile = hsv_per_tile / 4;
|
||||
|
||||
const uint d_start = hsv_tile * d_values_per_tile;
|
||||
const uint d_end = min(d_start + d_values_per_tile, HSV / 4);
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
const uint row = tile_row(r);
|
||||
|
||||
[[unroll]] for (uint32_t d_local = 0; d_local < d_per_thread; ++d_local) {
|
||||
const uint d = d_local * threads_per_rowgroup + col_tid;
|
||||
const uint hsv_col = 4 * d;
|
||||
|
||||
if (hsv_col >= hsv_base && hsv_col < hsv_base + hsv_per_tile && hsv_col < HSV) {
|
||||
const uint local_hsv = (hsv_col - hsv_base) / 4;
|
||||
Of[r][d_local] += ACC_TYPEV4(sfsh[row * osh_stride + local_hsv]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -302,69 +468,8 @@ void main() {
|
||||
barrier();
|
||||
}
|
||||
|
||||
// prevent race on tmpsh
|
||||
barrier();
|
||||
|
||||
// reduce across threads
|
||||
|
||||
float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
FLOAT_TYPE M = Mf[r];
|
||||
tmpsh[tid] = M;
|
||||
// Compute max across the row
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
|
||||
M = max(M, tmpsh[tid ^ s]);
|
||||
barrier();
|
||||
tmpsh[tid] = M;
|
||||
barrier();
|
||||
}
|
||||
rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
|
||||
barrier();
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Moldf[r] = Mf[r];
|
||||
|
||||
// M = max(rowmax, Mold)
|
||||
// eM = e^(Mold - M)
|
||||
Mf[r] = max(rowmaxf[r], Moldf[r]);
|
||||
eMf[r] = exp(Moldf[r] - Mf[r]);
|
||||
|
||||
Lf[r] = eMf[r]*Lf[r];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
FLOAT_TYPE L = Lf[r];
|
||||
tmpsh[tid] = L;
|
||||
// Compute sum across the row
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
|
||||
L += tmpsh[tid ^ s];
|
||||
barrier();
|
||||
tmpsh[tid] = L;
|
||||
barrier();
|
||||
}
|
||||
Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
|
||||
barrier();
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
|
||||
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
|
||||
tmpshv4[tid] = Of[r][d];
|
||||
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
|
||||
Of[r][d] += tmpshv4[tid ^ s];
|
||||
barrier();
|
||||
tmpshv4[tid] = Of[r][d];
|
||||
barrier();
|
||||
}
|
||||
Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup];
|
||||
barrier();
|
||||
}
|
||||
Lf[r] = subgroupAdd(Lf[r]);
|
||||
}
|
||||
|
||||
// If there is split_k, then the split_k resolve shader does the final
|
||||
@@ -375,9 +480,12 @@ void main() {
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
|
||||
const uint d = d0 + col_tid;
|
||||
if (d >= HSV/4) break;
|
||||
const uint d_local = d0 / threads_per_rowgroup;
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
||||
perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -404,8 +512,9 @@ void main() {
|
||||
if (sink > Mf[r]) {
|
||||
ms = exp(Mf[r] - sink);
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
Of[r][d] *= ACC_TYPE(ms);
|
||||
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
|
||||
const uint d_local = d0 / threads_per_rowgroup;
|
||||
Of[r][d_local] *= ACC_TYPE(ms);
|
||||
}
|
||||
} else {
|
||||
vs = exp(sink - Mf[r]);
|
||||
@@ -420,11 +529,12 @@ void main() {
|
||||
Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
|
||||
const uint d_local = d0 / threads_per_rowgroup;
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] *= ACC_TYPE(Lfrcp[r]);
|
||||
Of[r][d_local] *= ACC_TYPE(Lfrcp[r]);
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
||||
Of[r][d_local] = clamp(Of[r][d_local], -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -434,9 +544,12 @@ void main() {
|
||||
if (p.gqa_ratio > 1) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
|
||||
const uint d = d0 + col_tid;
|
||||
if (d >= HSV / 4) break;
|
||||
const uint d_local = d0 / threads_per_rowgroup;
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
||||
perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -444,9 +557,12 @@ void main() {
|
||||
} else {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (i * Br + tile_row(r) < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
|
||||
const uint d = d0 + col_tid;
|
||||
if (d >= HSV / 4) break;
|
||||
const uint d_local = d0 / threads_per_rowgroup;
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
||||
data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4 * d + comp] = D_TYPE(Of[r][d_local][comp]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
|
||||
return max(elem0, elem1);
|
||||
}
|
||||
|
||||
#if defined(BLOCK_SIZE)
|
||||
#if BLOCK_SIZE > 1
|
||||
#define DECODEFUNC , DEQUANTFUNC
|
||||
#else
|
||||
#define DECODEFUNC
|
||||
@@ -85,7 +85,7 @@ void main() {
|
||||
|
||||
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
|
||||
|
||||
#if defined(BLOCK_SIZE)
|
||||
#if BLOCK_SIZE > 1
|
||||
tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
|
||||
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
|
||||
#endif
|
||||
@@ -98,7 +98,7 @@ void main() {
|
||||
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
||||
{
|
||||
q_stride &= ~7;
|
||||
#if !defined(BLOCK_SIZE)
|
||||
#if BLOCK_SIZE == 1
|
||||
k_stride &= ~7;
|
||||
v_stride &= ~7;
|
||||
#endif
|
||||
|
||||
@@ -114,7 +114,7 @@ struct Params {
|
||||
#define PARAMS_BINDING 4
|
||||
#endif
|
||||
|
||||
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<f32>;
|
||||
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
|
||||
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
|
||||
|
||||
// Just a very small float value.
|
||||
@@ -160,14 +160,21 @@ fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32) -> f32 {
|
||||
return v;
|
||||
}
|
||||
|
||||
fn load_f32x4(buf: ptr<storage, array<vec4<f32>>, read_write>, scalar_index: u32) -> vec4<f32> {
|
||||
return (*buf)[scalar_index >> 2u];
|
||||
}
|
||||
|
||||
fn load_kvx4(buf: ptr<storage, array<vec4<KV_TYPE>>, read_write>, scalar_index: u32) -> vec4<KV_TYPE> {
|
||||
return (*buf)[scalar_index >> 2u];
|
||||
}
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(subgroup_id) subgroup_id: u32,
|
||||
@builtin(subgroup_size) subgroup_size: u32,
|
||||
@builtin(num_subgroups) num_subgroups: u32,
|
||||
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(subgroup_id) subgroup_id: u32,
|
||||
@builtin(subgroup_size) subgroup_size: u32,
|
||||
@builtin(num_subgroups) num_subgroups: u32,
|
||||
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
|
||||
|
||||
// initialize row max for online softmax
|
||||
for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {
|
||||
@@ -231,9 +238,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
|
||||
for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) {
|
||||
// clear inter_shmem to ensure zero-initialized accumulators
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
|
||||
inter_shmem[elem_idx] = 0.0;
|
||||
}
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
|
||||
inter_shmem[elem_idx] = 0.0;
|
||||
}
|
||||
|
||||
// load k tile into shared memory
|
||||
#if defined(KV_Q4_0)
|
||||
@@ -309,48 +316,77 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
|
||||
// accumulate q block * k block into registers across the entire KV tile
|
||||
// TODO: this loop seems to be the current largest bottleneck
|
||||
for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
|
||||
let inter_offset = kv_block * SG_MAT_N;
|
||||
var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<
|
||||
subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE);
|
||||
// this bracket exists to scope the lifetime of variables, reducing register pressure
|
||||
{
|
||||
#ifdef KV_DIRECT
|
||||
let k_block_row = kv_tile + kv_block * SG_MAT_N;
|
||||
let k_global_offset = k_head_offset + k_block_row * params.stride_k1;
|
||||
let k_block_row = kv_tile + subgroup_id * SG_MAT_N;
|
||||
var k_global_offset = k_head_offset + k_block_row * params.stride_k1;
|
||||
#else
|
||||
let k_block_offset = kv_block * SG_MAT_N * HEAD_DIM_QK;
|
||||
var k_block_offset = subgroup_id * SG_MAT_N * HEAD_DIM_QK;
|
||||
#endif
|
||||
for (var head_dim_block = 0u; head_dim_block < HEAD_DIM_QK; head_dim_block += SG_MAT_K) {
|
||||
// load q submatrix from shared memory
|
||||
var q_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
|
||||
&q_shmem,
|
||||
head_dim_block,
|
||||
false,
|
||||
HEAD_DIM_QK
|
||||
);
|
||||
for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
|
||||
let inter_offset = kv_block * SG_MAT_N;
|
||||
var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE);
|
||||
|
||||
var q_cur = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, 0u, false, HEAD_DIM_QK);
|
||||
|
||||
// load k submatrix from device or shared memory
|
||||
#ifdef KV_DIRECT
|
||||
var k_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
|
||||
&K,
|
||||
k_global_offset + head_dim_block,
|
||||
true,
|
||||
params.stride_k1
|
||||
);
|
||||
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + 0u, true, params.stride_k1);
|
||||
#else
|
||||
var k_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
|
||||
&kv_shmem,
|
||||
k_block_offset + head_dim_block,
|
||||
true,
|
||||
HEAD_DIM_QK
|
||||
);
|
||||
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK);
|
||||
#endif
|
||||
acc = subgroupMatrixMultiplyAccumulate(q_sg_mat, k_sg_mat, acc);
|
||||
|
||||
var t: u32 = 1u;
|
||||
for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) {
|
||||
let h0 = t * SG_MAT_K;
|
||||
var q0 = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h0, false, HEAD_DIM_QK);
|
||||
#ifdef KV_DIRECT
|
||||
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h0, true, params.stride_k1);
|
||||
#else
|
||||
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK);
|
||||
#endif
|
||||
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
|
||||
q_cur = q0;
|
||||
k_cur = k0;
|
||||
|
||||
let h1 = (t + 1u) * SG_MAT_K;
|
||||
var q1g = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h1, false, HEAD_DIM_QK);
|
||||
#ifdef KV_DIRECT
|
||||
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h1, true, params.stride_k1);
|
||||
#else
|
||||
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK);
|
||||
#endif
|
||||
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
|
||||
q_cur = q1g;
|
||||
k_cur = k1g;
|
||||
}
|
||||
|
||||
// handle odd tail
|
||||
if (t < HEAD_DIM_QK / SG_MAT_K) {
|
||||
let h = t * SG_MAT_K;
|
||||
var qn = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h, false, HEAD_DIM_QK);
|
||||
#ifdef KV_DIRECT
|
||||
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h, true, params.stride_k1);
|
||||
#else
|
||||
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK);
|
||||
#endif
|
||||
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
|
||||
q_cur = qn;
|
||||
k_cur = kn;
|
||||
}
|
||||
|
||||
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
|
||||
|
||||
#ifdef KV_DIRECT
|
||||
k_global_offset += num_subgroups * SG_MAT_N * params.stride_k1;
|
||||
#else
|
||||
k_block_offset += num_subgroups * SG_MAT_N * HEAD_DIM_QK;
|
||||
#endif
|
||||
subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE);
|
||||
}
|
||||
|
||||
// store acc to shared memory for softmax (S matrix from paper)
|
||||
subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE);
|
||||
}
|
||||
|
||||
|
||||
#ifdef MASK
|
||||
// load mask tile into shared memory for this KV block
|
||||
// TODO: optimize and skip if mask is -INF for the entire tile
|
||||
@@ -495,7 +531,6 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
false,
|
||||
HEAD_DIM_V
|
||||
);
|
||||
|
||||
for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) {
|
||||
let p_offset = kv_block * SG_MAT_N;
|
||||
var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
|
||||
@@ -527,11 +562,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
// O += P * V
|
||||
o_sg_mat = subgroupMatrixMultiplyAccumulate(p_sg_mat, v_sg_mat, o_sg_mat);
|
||||
}
|
||||
|
||||
// store O back to shared memory
|
||||
subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, HEAD_DIM_V);
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
@@ -566,26 +599,38 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
o_shmem[idx] = f16(val);
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
#endif
|
||||
|
||||
// write output back to global memory
|
||||
for (var q_tile_row = subgroup_id;
|
||||
q_tile_row < Q_TILE;
|
||||
q_tile_row += num_subgroups) {
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) {
|
||||
break;
|
||||
}
|
||||
q_tile_row < Q_TILE;
|
||||
q_tile_row += num_subgroups) {
|
||||
|
||||
let exp_sum = exp_sum_shmem[q_tile_row];
|
||||
let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0);
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) { break; }
|
||||
|
||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||
let o_val = o_shmem[q_tile_row * HEAD_DIM_V + elem_idx];
|
||||
let scaled = f32(o_val) * scale;
|
||||
dst[dst_global_offset + q_tile_row * dst2_stride + elem_idx] = scaled;
|
||||
}
|
||||
let exp_sum = exp_sum_shmem[q_tile_row];
|
||||
let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
|
||||
|
||||
let row_base: u32 = dst_global_offset + q_tile_row * dst2_stride;
|
||||
|
||||
for (var elem_base = sg_inv_id * 4u;
|
||||
elem_base < HEAD_DIM_V;
|
||||
elem_base += subgroup_size * 4u) {
|
||||
|
||||
let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
|
||||
let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
|
||||
let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
|
||||
let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
|
||||
|
||||
let v = vec4<f32>(
|
||||
f32(o_shmem[i0]) * scale,
|
||||
f32(o_shmem[i1]) * scale,
|
||||
f32(o_shmem[i2]) * scale,
|
||||
f32(o_shmem[i3]) * scale
|
||||
);
|
||||
|
||||
let dst_vec_index: u32 = (row_base + elem_base) >> 2u;
|
||||
dst[dst_vec_index] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-cpu.h"
|
||||
#include "zendnnl.hpp"
|
||||
|
||||
#include <cstring>
|
||||
@@ -122,8 +121,8 @@ static void ggml_zendnn_compute_forward_mul_mat(
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
ggml_type const vec_dot_type = ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
|
||||
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(vec_dot_type)->from_float;
|
||||
ggml_type const vec_dot_type = src0->type;
|
||||
ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float_ref;
|
||||
|
||||
GGML_ASSERT(ne0 == ne01);
|
||||
GGML_ASSERT(ne1 == ne11);
|
||||
|
||||
@@ -0,0 +1,156 @@
|
||||
{#- ======== Template Parameters ======== #}
|
||||
{%- set add_generation_prompt = add_generation_prompt if add_generation_prompt is defined else true %}
|
||||
{%- set default_system_prompt = default_system_prompt if default_system_prompt is defined else true %}
|
||||
{%- set reasoning_effort = reasoning_effort if reasoning_effort is defined else "high" %}
|
||||
{%- set think_render_option = think_render_option if think_render_option is defined else "lastthink" %}
|
||||
|
||||
{#- ======== System Block State ======== #}
|
||||
{%- set sys_ns = namespace(is_first_block=true) -%}
|
||||
|
||||
{#- ======== Find last user message index ======== #}
|
||||
{%- set last_user_idx = namespace(value=-1) -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message.role == 'user' -%}
|
||||
{%- set last_user_idx.value = loop.index0 -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{#- ======== System messages renderers ======== #}
|
||||
{%- macro render_system_message(user_system_messages) %}
|
||||
{%- if default_system_prompt %}
|
||||
{%- if not sys_ns.is_first_block %}{{- "\n\n" }}{%- endif %}
|
||||
{%- set sys_ns.is_first_block = false %}
|
||||
{{- "## Provider System Prompt\n\nYou are Solar Open 100B, a large language model trained by Upstage AI, a Korean startup. Your knowledge cutoff is 2025-07. The current date is " + strftime_now("%Y-%m-%d") + "." }}
|
||||
{%- endif -%}
|
||||
{%- if user_system_messages %}
|
||||
{%- if not sys_ns.is_first_block %}{{- "\n\n" }}{%- endif %}
|
||||
{%- set sys_ns.is_first_block = false %}
|
||||
{{- "## System Prompt" }}
|
||||
{%- for system_message in user_system_messages %}
|
||||
{{- "\n\n" }}
|
||||
{{- system_message }}
|
||||
{%- endfor %}
|
||||
{%- endif -%}
|
||||
{%- endmacro %}
|
||||
|
||||
{%- macro render_tool_instruction(tools) %}
|
||||
{%- if not sys_ns.is_first_block %}{{- "\n\n" }}{%- endif %}
|
||||
{%- set sys_ns.is_first_block = false %}
|
||||
{{- "## Tools\n\n### Tool Call Instruction" }}
|
||||
{{- "\nYou may invoke one or more tools to assist with the user's query. Available tools are provided in JSON Schema format: <|tools:begin|><|tool:begin|><tools-json-object><|tool:end|>...<|tools:end|>\n" }}
|
||||
{{- "\n### Available Tools\n" }}
|
||||
{{- "<|tools:begin|>" }}
|
||||
{%- for tool in tools %}
|
||||
{{- "<|tool:begin|>" }}
|
||||
{{- tool.function | tojson }}
|
||||
{{- "<|tool:end|>" }}
|
||||
{%- endfor %}
|
||||
{{- "<|tools:end|>\n" }}
|
||||
{{- "\n### Tool Call Format\n" }}
|
||||
{{- "For each tool call, return a JSON object with the following structure, enclosed within <|tool_call:begin|> and <|tool_call:end|> tags: \n<|tool_call:begin|><tool-call-id><|tool_call:name|><tool-name><|tool_call:args|><args-json-object><|tool_call:end|>\n" }}
|
||||
{{- "- The <tool-call-id> must be a randomly generated string consisting of 10 lowercase letters (a-z) and/or digits (0-9) (e.g., a1b2c3d4e5)\n" }}
|
||||
{{- "\n### Tool Response Format\n" }}
|
||||
{{- "Each tool is responded by `tool` with the following structure:\n<|tool_response:id|><tool-call-id><|tool_response:name|><tool-name><|tool_response:result|><results><|tool_response:end|>\n" }}
|
||||
{{- "- Ensure the <tool-call-id> matches the corresponding tool call" -}}
|
||||
{%- endmacro %}
|
||||
|
||||
{%- macro render_json_response_format_instruction(response_format) %}
|
||||
{%- if not sys_ns.is_first_block %}{{- "\n\n" }}{%- endif %}
|
||||
{%- set sys_ns.is_first_block = false %}
|
||||
{{- "## Output Format Constraint" }}
|
||||
{{- "\n\nYour final response should follow the JSON schema: \n[Start of schema]" }}
|
||||
{{- response_format }}
|
||||
{{- "\n[End of schema]\nPlease ensure your answers adhere to this format and do not contain any unnecessary text." }}
|
||||
{%- endmacro %}
|
||||
|
||||
{%- macro get_tool_name(messages, tool_call_id) %}
|
||||
{%- for msg in messages -%}
|
||||
{%- if msg.role == 'assistant' and msg.tool_calls -%}
|
||||
{%- for tool_call in msg.tool_calls -%}
|
||||
{%- if tool_call.id == tool_call_id -%}
|
||||
{{- tool_call.function.name }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endmacro %}
|
||||
|
||||
{%- macro render_tool_arguments(tool_arguments) %}
|
||||
{%- if tool_arguments is mapping -%}
|
||||
{{- tool_arguments | tojson }}
|
||||
{%- else -%}
|
||||
{{- tool_arguments }}
|
||||
{%- endif -%}
|
||||
{%- endmacro %}
|
||||
|
||||
{#- ======== Render system message ======== #}
|
||||
{%- set ns = namespace(system_messages=[]) -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message.role == 'system' -%}
|
||||
{%- set ns.system_messages = ns.system_messages + [message.content] -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if ns.system_messages or default_system_prompt or tools or response_format -%}
|
||||
{{- "<|begin|>system<|content|>" }}
|
||||
{{- render_system_message(ns.system_messages) }}
|
||||
{%- if tools -%}
|
||||
{{- render_tool_instruction(tools) }}
|
||||
{%- endif %}
|
||||
{%- if response_format -%}
|
||||
{{- render_json_response_format_instruction(response_format) }}
|
||||
{%- endif %}
|
||||
{{- "<|end|>" }}
|
||||
{%- endif -%}
|
||||
|
||||
{#- ======== Render main messages ======== #}
|
||||
{%- for message in messages -%}
|
||||
{%- if message.role == 'user' -%}
|
||||
{{- "<|begin|>user<|content|>" + message.content + "<|end|>" }}
|
||||
{%- elif message.role == 'tool' -%}
|
||||
{%- set prev_is_tool = loop.index0 > 0 and messages[loop.index0 - 1].role == 'tool' -%}
|
||||
{%- set next_is_tool = loop.index0 < (messages | length - 1) and messages[loop.index0 + 1].role == 'tool' -%}
|
||||
{%- if not prev_is_tool -%}
|
||||
{{- "<|begin|>tool<|tool_response|>" }}
|
||||
{%- endif -%}
|
||||
{{- "<|tool_response:begin|>" + message.tool_call_id + "<|tool_response:name|>" }}
|
||||
{{- get_tool_name(messages, message.tool_call_id) }}
|
||||
{{- "<|tool_response:result|>" }}
|
||||
{{- message.content }}
|
||||
{{- "<|tool_response:end|>" }}
|
||||
{%- if not next_is_tool -%}
|
||||
{{- "<|end|>" }}
|
||||
{%- endif -%}
|
||||
{%- elif message.role == 'assistant' -%}
|
||||
{#- ======== Assistant Thinking ======== #}
|
||||
{%- if think_render_option == "all" -%}
|
||||
{%- if message.reasoning -%}
|
||||
{{- "<|begin|>assistant<|think|>" + message.reasoning + "<|end|>" }}
|
||||
{%- endif -%}
|
||||
{%- elif think_render_option == "lastthink" -%}
|
||||
{%- if message.reasoning and loop.index0 > last_user_idx.value -%}
|
||||
{{- "<|begin|>assistant<|think|>" + message.reasoning + "<|end|>" }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
|
||||
{#- ======== Assistant Messages ======== #}
|
||||
{%- if message.tool_calls -%}
|
||||
{{- "<|begin|>assistant<|tool_calls|>" }}
|
||||
{%- for tool_call in message.tool_calls -%}
|
||||
{{- "<|tool_call:begin|>" + tool_call.id +"<|tool_call:name|>" + tool_call.function.name + "<|tool_call:args|>" }}
|
||||
{{- render_tool_arguments(tool_call.function.arguments) }}
|
||||
{{- "<|tool_call:end|>" }}
|
||||
{%- endfor -%}
|
||||
{{- "<|calls|>" }}
|
||||
{%- else -%}
|
||||
{{- "<|begin|>assistant<|content|>" + message.content + "<|end|>" }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if add_generation_prompt -%}
|
||||
{%- if reasoning_effort in ["low", "minimal"] -%}
|
||||
{{- "<|begin|>assistant<|think|><|end|>" }}
|
||||
{%- endif -%}
|
||||
{{- "<|begin|>assistant" }}
|
||||
{%- endif -%}
|
||||
@@ -0,0 +1,40 @@
|
||||
|
||||
#!/usr/bin/env pwsh
|
||||
|
||||
# Basedir on device
|
||||
$basedir=".\pkg-snapdragon"
|
||||
|
||||
$cli_opts=$args
|
||||
|
||||
$model="Llama-3.2-3B-Instruct-Q4_0.gguf"
|
||||
if ($null -ne $env:M) {
|
||||
$model=$env:M
|
||||
}
|
||||
|
||||
$device="HTP0"
|
||||
if ($null -ne $env:D) {
|
||||
$device=$env:D
|
||||
}
|
||||
|
||||
if ($null -ne $env:V) {
|
||||
$env:GGML_HEXAGON_VERBOSE=$env:V
|
||||
}
|
||||
|
||||
if ($null -ne $env:OPMASK) {
|
||||
$env:GGML_HEXAGON_OPMASK=$env:OPMASK
|
||||
}
|
||||
|
||||
if ($null -ne $env:NHVX) {
|
||||
$env:GGML_HEXAGON_NHVX=$env:NHVX
|
||||
}
|
||||
|
||||
if ($null -ne $env:NDEV) {
|
||||
$env:GGML_HEXAGON_NDEV=$env:NDEV
|
||||
}
|
||||
|
||||
$env:ADSP_LIBRARY_PATH="$basedir\lib"
|
||||
|
||||
& "$basedir\bin\llama-bench.exe" `
|
||||
--mmap 0 -m $basedir\..\..\gguf\$model `
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 `
|
||||
--batch-size 128 -ngl 99 --device $device $cli_opts
|
||||
@@ -0,0 +1,53 @@
|
||||
|
||||
#!/usr/bin/env pwsh
|
||||
|
||||
# Basedir on device
|
||||
$basedir=".\pkg-snapdragon"
|
||||
|
||||
$cli_opts=$args
|
||||
|
||||
$model="Llama-3.2-3B-Instruct-Q4_0.gguf"
|
||||
if ($null -ne $env:M) {
|
||||
$model=$env:M
|
||||
}
|
||||
|
||||
$device="HTP0"
|
||||
if ($null -ne $env:D) {
|
||||
$device=$env:D
|
||||
}
|
||||
|
||||
if ($null -ne $env:V) {
|
||||
$env:GGML_HEXAGON_VERBOSE=$env:V
|
||||
}
|
||||
|
||||
if ($null -ne $env:E) {
|
||||
$env:GGML_HEXAGON_EXPERIMENTAL=$env:E
|
||||
}
|
||||
|
||||
if ($null -ne $env:SCHED) {
|
||||
$env:GGML_SCHED_DEBUG=$env:SCHED; $cli_opts="$cli_opts -v"
|
||||
}
|
||||
|
||||
if ($null -ne $env:PROF) {
|
||||
$env:GGML_HEXAGON_PROFILE=$env:PROF; $env:GGML_HEXAGON_OPSYNC=1
|
||||
}
|
||||
|
||||
if ($null -ne $env:OPMASK) {
|
||||
$env:GGML_HEXAGON_OPMASK=$env:OPMASK
|
||||
}
|
||||
|
||||
if ($null -ne $env:NHVX) {
|
||||
$env:GGML_HEXAGON_NHVX=$env:NHVX
|
||||
}
|
||||
|
||||
if ($null -ne $env:NDEV) {
|
||||
$env:GGML_HEXAGON_NDEV=$env:NDEV
|
||||
}
|
||||
|
||||
$env:ADSP_LIBRARY_PATH="$basedir\lib"
|
||||
|
||||
& "$basedir\bin\llama-completion.exe" `
|
||||
--no-mmap -no-cnv -m $basedir\..\..\gguf\$model `
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 `
|
||||
--ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on `
|
||||
-ngl 99 --device $device $cli_opts
|
||||
@@ -0,0 +1,56 @@
|
||||
|
||||
#!/usr/bin/env pwsh
|
||||
|
||||
# Basedir on device
|
||||
$basedir=".\pkg-snapdragon"
|
||||
|
||||
if ($args.Count -eq 0) {
|
||||
Write-Host "No arguments provided.Expected the tool and argument to run."
|
||||
exit -1
|
||||
}
|
||||
|
||||
$tool=$args[0]
|
||||
$cli_opts=@()
|
||||
|
||||
if ($args.Count -gt 1) {
|
||||
$cli_opts=$args[1..($args.Count - 1)]
|
||||
$remainingArgs = $args[1..($args.Count - 1)]
|
||||
}
|
||||
|
||||
$device="HTP0"
|
||||
if ($null -ne $env:D) {
|
||||
$device=$env:D
|
||||
}
|
||||
|
||||
if ($null -ne $env:V) {
|
||||
$env:GGML_HEXAGON_VERBOSE=$env:V
|
||||
}
|
||||
|
||||
if ($null -ne $env:E) {
|
||||
$env:GGML_HEXAGON_EXPERIMENTAL=$env:E
|
||||
}
|
||||
|
||||
if ($null -ne $env:SCHED) {
|
||||
$env:GGML_SCHED_DEBUG=$env:SCHED; $cli_opts="$cli_opts -v"
|
||||
}
|
||||
|
||||
if ($null -ne $env:PROF) {
|
||||
$env:GGML_HEXAGON_PROFILE=$env:PROF; $env:GGML_HEXAGON_OPSYNC=1
|
||||
}
|
||||
|
||||
if ($null -ne $env:OPMASK) {
|
||||
$env:GGML_HEXAGON_OPMASK=$env:OPMASK
|
||||
}
|
||||
|
||||
if ($null -ne $env:NHVX) {
|
||||
$env:GGML_HEXAGON_NHVX=$env:NHVX
|
||||
}
|
||||
|
||||
if ($null -ne $env:NDEV) {
|
||||
$env:GGML_HEXAGON_NDEV=$env:NDEV
|
||||
}
|
||||
|
||||
$env:ADSP_LIBRARY_PATH="$basedir\lib"
|
||||
|
||||
& "$basedir\bin\$tool" `
|
||||
$cli_opts
|
||||
@@ -0,0 +1,105 @@
|
||||
# Requires Run as Administrator is NOT strictly necessary for User-scope env vars,
|
||||
# but recommended for creating directories in C:\ root if permissions are restricted.
|
||||
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
# --- Configuration ---
|
||||
$BaseDir = "C:\Qualcomm"
|
||||
|
||||
# SDK 1: Hexagon
|
||||
$HexagonUrl = "https://github.com/snapdragon-toolchain/hexagon-sdk/releases/download/v6.4.0.2/hexagon-sdk-v6.4.0.2-arm64-wos.tar.xz"
|
||||
$HexagonParent = Join-Path $BaseDir "Hexagon_SDK"
|
||||
$HexagonSdkVersion = "6.4.0.2"
|
||||
$HexagonToolsVersion = "19.0.04"
|
||||
$HexagonSdkTarget = Join-Path $HexagonParent $HexagonSdkVersion
|
||||
$HexagonToolsTarget = Join-Path $HexagonSdkTarget "\tools\HEXAGON_Tools\$HexagonToolsVersion"
|
||||
|
||||
# SDK 2: OpenCL
|
||||
$OpenCLUrl = "https://github.com/snapdragon-toolchain/opencl-sdk/releases/download/v2.3.2/adreno-opencl-sdk-v2.3.2-arm64-wos.tar.xz"
|
||||
$OpenCLParent = Join-Path $BaseDir "OpenCL_SDK"
|
||||
$OpenCLVersion = "2.3.2"
|
||||
$OpenCLTarget = Join-Path $OpenCLParent $OpenCLVersion
|
||||
|
||||
# --- Helper Function ---
|
||||
function Install-QualcommSDK {
|
||||
param (
|
||||
[string]$Url,
|
||||
[string]$ParentDir,
|
||||
[string]$TargetDir,
|
||||
[string]$Name
|
||||
)
|
||||
|
||||
# 1. Create Parent Directory
|
||||
if (-not (Test-Path -Path $ParentDir)) {
|
||||
Write-Host "Creating directory: $ParentDir" -ForegroundColor Cyan
|
||||
New-Item -Path $ParentDir -ItemType Directory -Force | Out-Null
|
||||
}
|
||||
|
||||
# 2. Check for Specific Version Directory
|
||||
if (Test-Path -Path $TargetDir) {
|
||||
Write-Host "$Name ($TargetDir) already exists. Skipping download." -ForegroundColor Green
|
||||
}
|
||||
else {
|
||||
Write-Host "$Name not found. preparing to download..." -ForegroundColor Yellow
|
||||
|
||||
# Create the target directory to extract into
|
||||
New-Item -Path $TargetDir -ItemType Directory -Force | Out-Null
|
||||
|
||||
# Define temporary archive path
|
||||
$TempFile = Join-Path $ParentDir "temp_sdk.tar.xz"
|
||||
|
||||
try {
|
||||
# Download
|
||||
Write-Host "Downloading from: $Url"
|
||||
Invoke-WebRequest -Uri $Url -OutFile $TempFile
|
||||
|
||||
# Untar
|
||||
# Note: We assume Windows includes tar.exe (Win 10 build 17063+)
|
||||
Write-Host "Extracting archive to $TargetDir..."
|
||||
|
||||
# We use -C to extract contents INTO the target directory created above
|
||||
tar -xJvf $TempFile -C $TargetDir\..
|
||||
|
||||
Write-Host "Extraction complete." -ForegroundColor Green
|
||||
}
|
||||
catch {
|
||||
Write-Error "Failed to download or extract $Name. Error: $_"
|
||||
# Cleanup target dir if failed so script tries again next time
|
||||
Remove-Item -Path $TargetDir -Recurse -Force -ErrorAction SilentlyContinue
|
||||
}
|
||||
finally {
|
||||
# Cleanup Archive
|
||||
if (Test-Path $TempFile) { Remove-Item $TempFile -Force }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# --- Execution ---
|
||||
|
||||
# 1. Ensure Base C:\Qualcomm exists
|
||||
if (-not (Test-Path $BaseDir)) {
|
||||
New-Item -Path $BaseDir -ItemType Directory -Force | Out-Null
|
||||
}
|
||||
|
||||
# 2. Run Install Logic
|
||||
Install-QualcommSDK -Url $HexagonUrl -ParentDir $HexagonParent -TargetDir $HexagonSdkTarget -Name "Hexagon SDK"
|
||||
Install-QualcommSDK -Url $OpenCLUrl -ParentDir $OpenCLParent -TargetDir $OpenCLTarget -Name "OpenCL SDK"
|
||||
|
||||
# --- Environment Variables ---
|
||||
|
||||
Write-Host "`nSetting Environment Variables..." -ForegroundColor Cyan
|
||||
|
||||
# Set OPENCL_SDK_ROOT
|
||||
[System.Environment]::SetEnvironmentVariable('OPENCL_SDK_ROOT', $OpenCLTarget, [System.EnvironmentVariableTarget]::User)
|
||||
$env:OPENCL_SDK_ROOT = $OpenCLTarget # Set for current session as well
|
||||
Write-Host "OPENCL_SDK_ROOT set to: $OpenCLTarget"
|
||||
|
||||
# Set HEXAGON_SDK_ROOT
|
||||
[System.Environment]::SetEnvironmentVariable('HEXAGON_SDK_ROOT', $HexagonSdkTarget, [System.EnvironmentVariableTarget]::User)
|
||||
$env:HEXAGON_SDK_ROOT = $HexagonSdkTarget # Set for current session as well
|
||||
Write-Host "HEXAGON_SDK_ROOT set to: $HexagonSdkTarget"
|
||||
|
||||
# Set HEXAGON_SDK_ROOT
|
||||
[System.Environment]::SetEnvironmentVariable('HEXAGON_TOOLS_ROOT', $HexagonToolsTarget, [System.EnvironmentVariableTarget]::User)
|
||||
$env:HEXAGON_TOOLS_ROOT = $HexagonToolsTarget # Set for current session as well
|
||||
Write-Host "HEXAGON_TOOLS_ROOT set to: $HexagonToolsTarget"
|
||||
@@ -1630,11 +1630,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
|
||||
|
||||
if (!cparams.offload_kqv) {
|
||||
// all nodes between the KV store and the attention output are run on the CPU
|
||||
ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
|
||||
}
|
||||
|
||||
ggml_flash_attn_ext_add_sinks(cur, sinks);
|
||||
ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
|
||||
|
||||
|
||||
@@ -54,7 +54,6 @@ std::string DEFAULT_JSON = R"({
|
||||
],
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"tools": [],
|
||||
"add_generation_prompt": true
|
||||
})";
|
||||
|
||||
@@ -481,7 +480,7 @@ int main_automated_tests(void) {
|
||||
/* .name= */ "Mistral-Large-Instruct-2407 (mistralai 'v3' template; modified to have system prompt at start)",
|
||||
/* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
|
||||
/* .expected_output= */ "[INST] You are a helpful assistant\n\nHello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[INST] Another question[/INST]",
|
||||
/* .expected_output_jinja= */ "[INST] Hello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[AVAILABLE_TOOLS] [[/AVAILABLE_TOOLS][INST] You are a helpful assistant\n\nAnother question[/INST]",
|
||||
/* .expected_output_jinja= */ "[INST] Hello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[INST] You are a helpful assistant\n\nAnother question[/INST]",
|
||||
/* .bos_token= */ "",
|
||||
/* .eos_token= */ "</s>",
|
||||
},
|
||||
@@ -489,7 +488,7 @@ int main_automated_tests(void) {
|
||||
/* .name= */ "Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template; modified to have system prompt at start)",
|
||||
/* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS][\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif (message.tool_calls is defined and message.tool_calls is not none) %}\n {{- \"[TOOL_CALLS][\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
|
||||
/* .expected_output= */ "[INST]You are a helpful assistant\n\nHello[/INST]Hi there</s>[INST]Who are you[/INST] I am an assistant </s>[INST]Another question[/INST]",
|
||||
/* .expected_output_jinja= */ "[INST]Hello[/INST]Hi there</s>[INST]Who are you[/INST] I am an assistant </s>[AVAILABLE_TOOLS][[/AVAILABLE_TOOLS][INST]You are a helpful assistant\n\nAnother question[/INST]",
|
||||
/* .expected_output_jinja= */ "[INST]Hello[/INST]Hi there</s>[INST]Who are you[/INST] I am an assistant </s>[INST]You are a helpful assistant\n\nAnother question[/INST]",
|
||||
/* .bos_token= */ "",
|
||||
/* .eos_token= */ "</s>",
|
||||
},
|
||||
|
||||
+129
-1
@@ -592,7 +592,7 @@ static void test_peg_parser(common_chat_templates * tmpls, const std::function<v
|
||||
}
|
||||
if (diff.tool_call_index != std::string::npos) {
|
||||
if (!diff.tool_call_delta.name.empty()) {
|
||||
msg_accum.tool_calls.push_back({diff.tool_call_delta.name, "", ""});
|
||||
msg_accum.tool_calls.push_back({diff.tool_call_delta.name, "", diff.tool_call_delta.id});
|
||||
}
|
||||
if (!diff.tool_call_delta.arguments.empty()) {
|
||||
msg_accum.tool_calls.back().arguments += diff.tool_call_delta.arguments;
|
||||
@@ -3799,6 +3799,134 @@ static void test_template_output_peg_parsers() {
|
||||
});
|
||||
}
|
||||
|
||||
{
|
||||
// Solar-Open-100B
|
||||
auto tmpls = read_templates("models/templates/upstage-Solar-Open-100B.jinja");
|
||||
|
||||
// Test basic message
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input = "<|content|>Hello, world!\nWhat's up?";
|
||||
t.expect = message_assist;
|
||||
});
|
||||
|
||||
// Test basic message and reasoning
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input = "<|think|>I'm\nthinking<|end|><|begin|>assistant<|content|>Hello, world!\nWhat's up?";
|
||||
t.expect = message_assist_thoughts;
|
||||
});
|
||||
|
||||
// Test basic message and reasoning_effort = low
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input = "<|content|>Hello, world!\nWhat's up?";
|
||||
t.params.chat_template_kwargs["reasoning_effort"] = "\"low\"";
|
||||
t.expect = message_assist;
|
||||
});
|
||||
|
||||
// Test tool call
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input = "<|tool_calls|>"
|
||||
"<|tool_call:begin|>123456789"
|
||||
"<|tool_call:name|>special_function"
|
||||
"<|tool_call:args|>{\"arg1\":1}"
|
||||
"<|tool_call:end|>";
|
||||
|
||||
t.params.chat_template_kwargs["reasoning_effort"] = "\"low\"";
|
||||
t.params.tools = {special_function_tool};
|
||||
t.expect = message_assist_call_id;
|
||||
});
|
||||
|
||||
// Test tool call with reasoning
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input = "<|think|>I'm\nthinking<|end|>"
|
||||
"<|begin|>assistant<|tool_calls|>"
|
||||
"<|tool_call:begin|>0"
|
||||
"<|tool_call:name|>special_function"
|
||||
"<|tool_call:args|>{\"arg1\":1}"
|
||||
"<|tool_call:end|>";
|
||||
|
||||
t.params.tools = {special_function_tool};
|
||||
t.expect = message_assist_thoughts_call_idx;
|
||||
});
|
||||
|
||||
// Test tool call with reasoning and tool_choice = required
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input = "<|think|>I'm\nthinking<|end|>"
|
||||
"<|begin|>assistant<|tool_calls|>"
|
||||
"<|tool_call:begin|>0"
|
||||
"<|tool_call:name|>special_function"
|
||||
"<|tool_call:args|>{\"arg1\":1}"
|
||||
"<|tool_call:end|>";
|
||||
|
||||
t.params.tools = {special_function_tool};
|
||||
t.params.tool_choice = COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
t.expect = message_assist_thoughts_call_idx;
|
||||
});
|
||||
|
||||
// Test tool call without reasoning and tool_choice = required
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input = "<|tool_calls|>"
|
||||
"<|tool_call:begin|>0"
|
||||
"<|tool_call:name|>special_function"
|
||||
"<|tool_call:args|>{\"arg1\":1}"
|
||||
"<|tool_call:end|>";
|
||||
|
||||
t.params.tools = {special_function_tool};
|
||||
t.params.tool_choice = COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
t.params.chat_template_kwargs["reasoning_effort"] = "\"low\"";
|
||||
t.expect = message_assist_call_idx;
|
||||
});
|
||||
|
||||
// Test parallel tool calls
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input = "<|think|>I'm\nthinking<|end|>"
|
||||
"<|begin|>assistant<|tool_calls|>"
|
||||
"<|tool_call:begin|>0"
|
||||
"<|tool_call:name|>special_function"
|
||||
"<|tool_call:args|>{\"arg1\":1}"
|
||||
"<|tool_call:end|>"
|
||||
"<|tool_call:begin|>1"
|
||||
"<|tool_call:name|>special_function_with_opt"
|
||||
"<|tool_call:args|>{\"arg1\": 1, \"arg2\": 2}"
|
||||
"<|tool_call:end|>";
|
||||
|
||||
t.params.parallel_tool_calls = true;
|
||||
t.params.tools = {special_function_tool, special_function_tool_with_optional_param};
|
||||
|
||||
t.expect.reasoning_content = "I'm\nthinking";
|
||||
t.expect.tool_calls = {{
|
||||
/* .name = */ "special_function",
|
||||
/* .arguments = */ R"({"arg1": 1})",
|
||||
/* .id = */ "0",
|
||||
}, {
|
||||
/* .name = */ "special_function_with_opt",
|
||||
/* .arguments = */ R"({"arg1": 1, "arg2": 2})",
|
||||
/* .id = */ "1",
|
||||
}};
|
||||
});
|
||||
|
||||
// Test response format
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input = "<|think|>I need to output the invoice details in JSON<|end|>"
|
||||
"<|begin|>assistant<|content|>"
|
||||
R"({"amount": 123.45, "date": "2025-12-03"})";
|
||||
|
||||
t.params.json_schema = invoice_schema;
|
||||
|
||||
t.expect.reasoning_content = "I need to output the invoice details in JSON";
|
||||
t.expect.content =R"({"amount": 123.45, "date": "2025-12-03"})";
|
||||
});
|
||||
|
||||
// Test response format no reasoning
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input = "<|content|>"
|
||||
R"({"amount": 123.45, "date": "2025-12-03"})";
|
||||
|
||||
t.params.chat_template_kwargs["reasoning_effort"] = "\"low\"";
|
||||
t.params.json_schema = invoice_schema;
|
||||
|
||||
t.expect.content =R"({"amount": 123.45, "date": "2025-12-03"})";
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static void test_msg_diffs_compute() {
|
||||
|
||||
Binary file not shown.
@@ -48,11 +48,8 @@ enum server_state {
|
||||
struct server_slot {
|
||||
int id;
|
||||
|
||||
llama_batch batch_spec = {};
|
||||
|
||||
// TODO: change to unique_ptrs for consistency:
|
||||
llama_context * ctx = nullptr;
|
||||
llama_context * ctx_dft = nullptr;
|
||||
|
||||
// multimodal
|
||||
mtmd_context * mctx = nullptr;
|
||||
@@ -259,7 +256,7 @@ struct server_slot {
|
||||
}
|
||||
|
||||
bool can_speculate() const {
|
||||
return ctx_dft;
|
||||
return !!spec;
|
||||
}
|
||||
|
||||
void add_token(const completion_token_output & token) {
|
||||
@@ -295,6 +292,7 @@ struct server_slot {
|
||||
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min);
|
||||
n_draft_max = 0;
|
||||
}
|
||||
|
||||
return n_draft_max;
|
||||
}
|
||||
|
||||
@@ -397,6 +395,8 @@ struct server_slot {
|
||||
draft_ratio, n_draft_accepted, n_draft_total
|
||||
);
|
||||
}
|
||||
|
||||
common_speculative_print_stats(spec);
|
||||
}
|
||||
|
||||
json to_json(bool only_metrics = false) const {
|
||||
@@ -553,18 +553,13 @@ private:
|
||||
|
||||
// note: keep these alive - they determine the lifetime of the model, context, etc.
|
||||
common_init_result_ptr llama_init;
|
||||
common_init_result_ptr llama_init_dft;
|
||||
|
||||
llama_context * ctx = nullptr;
|
||||
|
||||
bool vocab_dft_compatible = true;
|
||||
|
||||
llama_model * model_dft = nullptr;
|
||||
|
||||
llama_context_params cparams_dft;
|
||||
|
||||
llama_batch batch {};
|
||||
|
||||
llama_model_ptr model_dft;
|
||||
|
||||
bool add_bos_token = true;
|
||||
|
||||
int32_t n_ctx; // total context for all clients / slots
|
||||
@@ -597,13 +592,8 @@ private:
|
||||
|
||||
// Clear any sampling context
|
||||
for (server_slot & slot : slots) {
|
||||
llama_free(slot.ctx_dft);
|
||||
slot.ctx_dft = nullptr;
|
||||
|
||||
common_speculative_free(slot.spec);
|
||||
slot.spec = nullptr;
|
||||
|
||||
llama_batch_free(slot.batch_spec);
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
@@ -648,44 +638,39 @@ private:
|
||||
|
||||
add_bos_token = llama_vocab_get_add_bos(vocab);
|
||||
|
||||
if (params_base.has_speculative()) {
|
||||
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());
|
||||
if (params_base.speculative.has_dft()) {
|
||||
SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
|
||||
|
||||
const auto & params_spec = params_base.speculative;
|
||||
|
||||
auto params_dft = params_base;
|
||||
|
||||
params_dft.devices = params_base.speculative.devices;
|
||||
params_dft.model = params_base.speculative.model;
|
||||
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx;
|
||||
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
||||
params_dft.n_parallel = 1;
|
||||
params_dft.cache_type_k = params_base.speculative.cache_type_k;
|
||||
params_dft.cache_type_v = params_base.speculative.cache_type_v;
|
||||
params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx;
|
||||
params_dft.n_batch = llama_n_ctx_seq(ctx);
|
||||
params_dft.devices = params_spec.devices;
|
||||
params_dft.model = params_spec.mparams_dft;
|
||||
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
|
||||
params_dft.cache_type_k = params_spec.cache_type_k;
|
||||
params_dft.cache_type_v = params_spec.cache_type_v;
|
||||
|
||||
params_dft.cpuparams.n_threads = params_base.speculative.cpuparams.n_threads;
|
||||
params_dft.cpuparams_batch.n_threads = params_base.speculative.cpuparams_batch.n_threads;
|
||||
params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides;
|
||||
if (params_spec.cpuparams.n_threads > 0) {
|
||||
params_dft.cpuparams.n_threads = params_spec.cpuparams.n_threads;
|
||||
params_dft.cpuparams_batch.n_threads = params_spec.cpuparams_batch.n_threads;
|
||||
}
|
||||
|
||||
llama_init_dft = common_init_from_params(params_dft);
|
||||
params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides;
|
||||
|
||||
model_dft = llama_init_dft->model();
|
||||
auto mparams_dft = common_model_params_to_llama(params_dft);
|
||||
|
||||
model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
|
||||
if (model_dft == nullptr) {
|
||||
SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str());
|
||||
SRV_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft->context());
|
||||
if (!vocab_dft_compatible) {
|
||||
SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
|
||||
}
|
||||
|
||||
const int n_ctx_dft = llama_n_ctx(llama_init_dft->context());
|
||||
|
||||
cparams_dft = common_context_params_to_llama(params_dft);
|
||||
cparams_dft.n_batch = n_ctx_dft;
|
||||
|
||||
// the context is not needed - we will create one for each slot
|
||||
llama_init_dft->free_context();
|
||||
params_base.speculative.model_dft = model_dft.get();
|
||||
params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft);
|
||||
}
|
||||
|
||||
std::string & mmproj_path = params_base.mmproj.path;
|
||||
@@ -695,6 +680,7 @@ private:
|
||||
}
|
||||
|
||||
mtmd_context_params mparams = mtmd_context_params_default();
|
||||
|
||||
mparams.use_gpu = params_base.mmproj_use_gpu;
|
||||
mparams.print_timings = false;
|
||||
mparams.n_threads = params_base.cpuparams.n_threads;
|
||||
@@ -702,6 +688,7 @@ private:
|
||||
mparams.warmup = params_base.warmup;
|
||||
mparams.image_min_tokens = params_base.image_min_tokens;
|
||||
mparams.image_max_tokens = params_base.image_max_tokens;
|
||||
|
||||
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
|
||||
if (mctx == nullptr) {
|
||||
SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
|
||||
@@ -718,11 +705,6 @@ private:
|
||||
params_base.n_cache_reuse = 0;
|
||||
SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
|
||||
}
|
||||
|
||||
if (params_base.has_speculative()) {
|
||||
SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!llama_memory_can_shift(llama_get_memory(ctx))) {
|
||||
@@ -757,29 +739,24 @@ private:
|
||||
for (int i = 0; i < params_base.n_parallel; i++) {
|
||||
server_slot slot;
|
||||
|
||||
slot.id = i;
|
||||
slot.ctx = ctx;
|
||||
slot.id = i;
|
||||
slot.ctx = ctx;
|
||||
slot.n_ctx = n_ctx_slot;
|
||||
slot.mctx = mctx;
|
||||
|
||||
slot.mctx = mctx;
|
||||
slot.prompt.tokens.has_mtmd = mctx != nullptr;
|
||||
|
||||
if (model_dft) {
|
||||
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
|
||||
|
||||
// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
|
||||
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
|
||||
if (slot.ctx_dft == nullptr) {
|
||||
SRV_ERR("%s", "failed to create draft context\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft);
|
||||
if (slot.spec == nullptr) {
|
||||
SRV_ERR("%s", "failed to create speculator\n");
|
||||
return false;
|
||||
}
|
||||
for (auto & pair : params_base.speculative.replacements) {
|
||||
common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
|
||||
// try speculative decoding
|
||||
{
|
||||
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
|
||||
if (slot.spec) {
|
||||
if (mctx) {
|
||||
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
|
||||
return false;
|
||||
}
|
||||
SRV_WRN("%s", "speculative decoding context initialized\n");
|
||||
} else {
|
||||
SRV_WRN("%s", "speculative decoding context not initialized\n");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1059,7 +1036,7 @@ private:
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<common_adapter_lora_info> construct_lora_list(const std::map<int, float> & config) {
|
||||
std::vector<common_adapter_lora_info> construct_lora_list(const std::map<int, float> & config) const {
|
||||
std::vector<common_adapter_lora_info> output = params_base.lora_adapters; // copy
|
||||
for (size_t i = 0; i < output.size(); ++i) {
|
||||
auto it = config.find(i);
|
||||
@@ -1162,7 +1139,7 @@ private:
|
||||
backend_sampling &= task.params.sampling.backend_sampling;
|
||||
|
||||
// TODO: speculative decoding requires multiple samples per batch - not supported yet
|
||||
backend_sampling &= !(slot.ctx_dft && task.params.speculative.n_max > 0);
|
||||
backend_sampling &= !(slot.spec && task.params.speculative.n_max > 0);
|
||||
|
||||
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
|
||||
backend_sampling &= !need_logits;
|
||||
@@ -1179,14 +1156,6 @@ private:
|
||||
slot.smpl.reset();
|
||||
}
|
||||
|
||||
// initialize draft batch
|
||||
// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
|
||||
if (slot.ctx_dft) {
|
||||
llama_batch_free(slot.batch_spec);
|
||||
|
||||
slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1);
|
||||
}
|
||||
|
||||
slot.task = std::make_unique<const server_task>(std::move(task));
|
||||
|
||||
slot.state = slot.task->is_child()
|
||||
@@ -2059,19 +2028,23 @@ private:
|
||||
// generate draft tokens in speculative decoding mode
|
||||
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
|
||||
// perform the speculative drafting for all sequences at the same time in a single batch
|
||||
int n_draft_max = slot.get_n_draft_max();
|
||||
const int n_draft_max = slot.get_n_draft_max();
|
||||
if (n_draft_max > 0) {
|
||||
if (mctx) {
|
||||
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
|
||||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
|
||||
struct common_speculative_params params_spec;
|
||||
params_spec.n_draft = n_draft_max;
|
||||
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
|
||||
params_spec.p_min = slot.task->params.speculative.p_min;
|
||||
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
|
||||
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
|
||||
|
||||
const auto & params_spec = slot.task->params.speculative;
|
||||
|
||||
llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
|
||||
|
||||
if (draft.size() > (size_t) n_draft_max) {
|
||||
SLT_WRN(slot, "draft size %d exceeds max %d, truncating\n", (int) draft.size(), n_draft_max);
|
||||
draft.resize(n_draft_max);
|
||||
}
|
||||
|
||||
// add the sampled token to the batch
|
||||
slot.i_batch_dft.push_back(batch.n_tokens);
|
||||
@@ -2742,6 +2715,10 @@ private:
|
||||
|
||||
// prompt evaluated for next-token prediction
|
||||
slot.state = SLOT_STATE_GENERATING;
|
||||
|
||||
if (slot.can_speculate()) {
|
||||
common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens());
|
||||
}
|
||||
} else if (slot.state != SLOT_STATE_GENERATING) {
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
@@ -2813,6 +2790,9 @@ private:
|
||||
// update how many tokens out of those tested were accepted
|
||||
slot.n_draft_accepted += ids.size() - 1;
|
||||
|
||||
// inform the speculative decoding about the number of accepted tokens
|
||||
common_speculative_accept(slot.spec, ids.size() - 1);
|
||||
|
||||
// rollback to the state before sampling the draft tokens
|
||||
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "llama.h"
|
||||
#include "chat.h"
|
||||
#include "sampling.h"
|
||||
#include "speculative.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
@@ -76,6 +77,11 @@ json task_params::to_json(bool only_metrics) const {
|
||||
{"speculative.n_max", speculative.n_max},
|
||||
{"speculative.n_min", speculative.n_min},
|
||||
{"speculative.p_min", speculative.p_min},
|
||||
{"speculative.type", common_speculative_type_to_str(speculative.type)},
|
||||
{"speculative.ngram_size_n", speculative.ngram_size_n},
|
||||
{"speculative.ngram_size_m", speculative.ngram_size_m},
|
||||
{"speculative.ngram_c_rate", speculative.ngram_check_rate},
|
||||
{"speculative.ngram_m_hits", speculative.ngram_min_hits},
|
||||
{"timings_per_token", timings_per_token},
|
||||
{"post_sampling_probs", post_sampling_probs},
|
||||
{"backend_sampling", sampling.backend_sampling},
|
||||
@@ -135,6 +141,11 @@ json task_params::to_json(bool only_metrics) const {
|
||||
{"speculative.n_max", speculative.n_max},
|
||||
{"speculative.n_min", speculative.n_min},
|
||||
{"speculative.p_min", speculative.p_min},
|
||||
{"speculative.type", common_speculative_type_to_str(speculative.type)},
|
||||
{"speculative.ngram_size_n", speculative.ngram_size_n},
|
||||
{"speculative.ngram_size_m", speculative.ngram_size_m},
|
||||
{"speculative.ngram_c_rate", speculative.ngram_check_rate},
|
||||
{"speculative.ngram_m_hits", speculative.ngram_min_hits},
|
||||
{"timings_per_token", timings_per_token},
|
||||
{"post_sampling_probs", post_sampling_probs},
|
||||
{"backend_sampling", sampling.backend_sampling},
|
||||
@@ -242,6 +253,18 @@ task_params server_task::params_from_json_cmpl(
|
||||
params.speculative.n_min = std::max(params.speculative.n_min, 0);
|
||||
params.speculative.n_max = std::max(params.speculative.n_max, 0);
|
||||
|
||||
params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type)));
|
||||
|
||||
params.speculative.ngram_size_n = json_value(data, "speculative.ngram_size_n", defaults.speculative.ngram_size_n);
|
||||
params.speculative.ngram_size_m = json_value(data, "speculative.ngram_size_m", defaults.speculative.ngram_size_m);
|
||||
params.speculative.ngram_check_rate = json_value(data, "speculative.ngram_c_rate", defaults.speculative.ngram_check_rate);
|
||||
params.speculative.ngram_min_hits = json_value(data, "speculative.ngram_m_hits", defaults.speculative.ngram_min_hits);
|
||||
|
||||
params.speculative.ngram_size_n = std::max(std::min(1, (int) params.speculative.ngram_size_n), 1024);
|
||||
params.speculative.ngram_size_m = std::max(std::min(1, (int) params.speculative.ngram_size_m), 1024);
|
||||
params.speculative.ngram_check_rate = std::max(std::min(1, (int) params.speculative.ngram_check_rate), 1024);
|
||||
params.speculative.ngram_min_hits = std::max(std::min(1, (int) params.speculative.ngram_min_hits), 1024);
|
||||
|
||||
// Use OpenAI API logprobs only if n_probs wasn't provided
|
||||
if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
|
||||
params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);
|
||||
|
||||
Generated
+42
-17
@@ -61,7 +61,7 @@
|
||||
"remark-math": "^6.0.0",
|
||||
"sass": "^1.93.3",
|
||||
"storybook": "^10.0.7",
|
||||
"svelte": "^5.0.0",
|
||||
"svelte": "^5.38.2",
|
||||
"svelte-check": "^4.0.0",
|
||||
"tailwind-merge": "^3.3.1",
|
||||
"tailwind-variants": "^3.2.2",
|
||||
@@ -88,6 +88,7 @@
|
||||
"version": "2.3.0",
|
||||
"resolved": "https://registry.npmjs.org/@ampproject/remapping/-/remapping-2.3.0.tgz",
|
||||
"integrity": "sha512-30iZtAPgz+LTIYoeivqYo853f02jBYSd5uGnGpkFV0M3xOt9aN73erkgYAmZU43x4VfqcnLxW9Kpg3R5LC4YYw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"@jridgewell/gen-mapping": "^0.3.5",
|
||||
@@ -867,6 +868,7 @@
|
||||
"integrity": "sha512-oJrXtQiAXLvT9clCf1K4kxp3eKsQhIaZqxEyowkBcsvZDdZkbWrVmnGknxs5flTD0VGsxrxKgBCZty1EzoiMzA==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@swc/helpers": "^0.5.0"
|
||||
}
|
||||
@@ -898,7 +900,6 @@
|
||||
"version": "2.3.5",
|
||||
"resolved": "https://registry.npmjs.org/@jridgewell/remapping/-/remapping-2.3.5.tgz",
|
||||
"integrity": "sha512-LI9u/+laYG4Ds1TDKSJW2YPrIlcVYOwi2fUC6xB43lueCjgxV4lffOCZCtYFiH6TNOX+tQKXx97T4IKHbhyHEQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@jridgewell/gen-mapping": "^0.3.5",
|
||||
@@ -2031,6 +2032,7 @@
|
||||
"integrity": "sha512-rO+YQhHucy47Vh67z318pALmd6x+K1Kj30Fb4a6oOEw4xn4zCo9KTmkMWs24c4oduEXD/eJu3badlRmsVXzyfA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"ts-dedent": "^2.0.0",
|
||||
"type-fest": "~2.19"
|
||||
@@ -2114,6 +2116,7 @@
|
||||
"integrity": "sha512-Vp3zX/qlwerQmHMP6x0Ry1oY7eKKRcOWGc2P59srOp4zcqyn+etJyQpELgOi4+ZSUgteX8Y387NuwruLgGXLUQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -2153,6 +2156,7 @@
|
||||
"integrity": "sha512-YZs/OSKOQAQCnJvM/P+F1URotNnYNeU3P2s4oIpzm1uFaqUEqRxUB0g5ejMjEb5Gjb9/PiBI5Ktrq4rUUF8UVQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^5.0.0",
|
||||
"debug": "^4.4.1",
|
||||
@@ -2568,6 +2572,7 @@
|
||||
"integrity": "sha512-pemlzrSESWbdAloYml3bAJMEfNh1Z7EduzqPKprCH5S341frlpYnUEW0H72dLxa6IsYr+mPno20GiSm+h9dEdQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@babel/code-frame": "^7.10.4",
|
||||
"@babel/runtime": "^7.12.5",
|
||||
@@ -2735,6 +2740,7 @@
|
||||
"integrity": "sha512-bJFoMATwIGaxxx8VJPeM8TonI8t579oRvgAuT8zFugJsJZgzqv0Fu8Mhp68iecjzG7cnN3mO2dJQ5uUM2EFrgQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -2802,6 +2808,7 @@
|
||||
"integrity": "sha512-kVIaQE9vrN9RLCQMQ3iyRlVJpTiDUY6woHGb30JDkfJErqrQEmtdWH3gV0PBAfGZgQXoqzXOO0T3K6ioApbbAA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@typescript-eslint/scope-manager": "8.37.0",
|
||||
"@typescript-eslint/types": "8.37.0",
|
||||
@@ -3026,6 +3033,7 @@
|
||||
"integrity": "sha512-tJxiPrWmzH8a+w9nLKlQMzAKX/7VjFs50MWgcAj7p9XQ7AQ9/35fByFYptgPELyLw+0aixTnC4pUWV+APcZ/kw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@testing-library/dom": "^10.4.0",
|
||||
"@testing-library/user-event": "^14.6.1",
|
||||
@@ -3129,6 +3137,7 @@
|
||||
"integrity": "sha512-oukfKT9Mk41LreEW09vt45f8wx7DordoWUZMYdY/cyAk7w5TWkTRCNZYF7sX7n2wB7jyGAl74OxgwhPgKaqDMQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@vitest/utils": "3.2.4",
|
||||
"pathe": "^2.0.3",
|
||||
@@ -3186,6 +3195,7 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -3738,8 +3748,7 @@
|
||||
"resolved": "https://registry.npmjs.org/csstype/-/csstype-3.1.3.tgz",
|
||||
"integrity": "sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/debug": {
|
||||
"version": "4.4.1",
|
||||
@@ -3840,10 +3849,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/devalue": {
|
||||
"version": "5.3.2",
|
||||
"resolved": "https://registry.npmjs.org/devalue/-/devalue-5.3.2.tgz",
|
||||
"integrity": "sha512-UDsjUbpQn9kvm68slnrs+mfxwFkIflOhkanmyabZ8zOYk8SMEIbJ3TK+88g70hSIeytu4y18f0z/hYHMTrXIWw==",
|
||||
"dev": true,
|
||||
"version": "5.6.2",
|
||||
"resolved": "https://registry.npmjs.org/devalue/-/devalue-5.6.2.tgz",
|
||||
"integrity": "sha512-nPRkjWzzDQlsejL1WVifk5rvcFi/y1onBRxjaFMjZeR9mFpqu2gmAZ9xUB9/IEanEP/vBtGeGganC/GO1fmufg==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/devlop": {
|
||||
@@ -3973,6 +3981,7 @@
|
||||
"dev": true,
|
||||
"hasInstallScript": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"esbuild": "bin/esbuild"
|
||||
},
|
||||
@@ -4027,6 +4036,7 @@
|
||||
"integrity": "sha512-QldCVh/ztyKJJZLr4jXNUByx3gR+TDYZCRXEktiZoUR3PGy4qCmSbkxcIle8GEwGpb5JBZazlaJ/CxLidXdEbQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@eslint-community/eslint-utils": "^4.2.0",
|
||||
"@eslint-community/regexpp": "^4.12.1",
|
||||
@@ -6939,6 +6949,7 @@
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"nanoid": "^3.3.11",
|
||||
"picocolors": "^1.1.1",
|
||||
@@ -7072,6 +7083,7 @@
|
||||
"integrity": "sha512-I7AIg5boAr5R0FFtJ6rCfD+LFsWHp81dolrFD8S79U9tb8Az2nGrJncnMSnys+bpQJfRUzqs9hnA81OAA3hCuQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"prettier": "bin/prettier.cjs"
|
||||
},
|
||||
@@ -7088,6 +7100,7 @@
|
||||
"integrity": "sha512-pn1ra/0mPObzqoIQn/vUTR3ZZI6UuZ0sHqMK5x2jMLGrs53h0sXhkVuDcrlssHwIMk7FYrMjHBPoUSyyEEDlBQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"peerDependencies": {
|
||||
"prettier": "^3.0.0",
|
||||
"svelte": "^3.2.0 || ^4.0.0-next.0 || ^5.0.0-next.0"
|
||||
@@ -7312,6 +7325,7 @@
|
||||
"integrity": "sha512-FS+XFBNvn3GTAWq26joslQgWNoFu08F4kl0J4CgdNKADkdSGXQyTCnKteIAJy96Br6YbpEU1LSzV5dYtjMkMDg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
@@ -7322,6 +7336,7 @@
|
||||
"integrity": "sha512-Xs1hdnE+DyKgeHJeJznQmYMIBG3TKIHJJT95Q58nHLSrElKlGQqDTR2HQ9fx5CN/Gk6Vh/kupBTDLU11/nDk/g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"scheduler": "^0.26.0"
|
||||
},
|
||||
@@ -7598,6 +7613,7 @@
|
||||
"integrity": "sha512-4iya7Jb76fVpQyLoiVpzUrsjQ12r3dM7fIVz+4NwoYvZOShknRmiv+iu9CClZml5ZLGb0XMcYLutK6w9tgxHDw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@types/estree": "1.0.8"
|
||||
},
|
||||
@@ -7704,6 +7720,7 @@
|
||||
"integrity": "sha512-elOcIZRTM76dvxNAjqYrucTSI0teAF/L2Lv0s6f6b7FOwcwIuA357bIE871580AjHJuSvLIRUosgV+lIWx6Rgg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"chokidar": "^4.0.0",
|
||||
"immutable": "^5.0.2",
|
||||
@@ -7938,6 +7955,7 @@
|
||||
"integrity": "sha512-7smAu0o+kdm378Q2uIddk32pn0UdIbrtTVU+rXRVtTVTCrK/P2cCui2y4JH+Bl3NgEq1bbBQpCAF/HKrDjk2Qw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@storybook/global": "^5.0.0",
|
||||
"@storybook/icons": "^1.6.0",
|
||||
@@ -8079,12 +8097,13 @@
|
||||
}
|
||||
},
|
||||
"node_modules/svelte": {
|
||||
"version": "5.36.12",
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.36.12.tgz",
|
||||
"integrity": "sha512-c3mWT+b0yBLl3gPGSHiy4pdSQCsPNTjLC0tVoOhrGJ6PPfCzD/RQpAmAfJtQZ304CAae2ph+L3C4aqds3R3seQ==",
|
||||
"version": "5.48.3",
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.48.3.tgz",
|
||||
"integrity": "sha512-w7QZ398cdNherTdiQ/v3SYLLGOO4948Jgjh04PYqtTYVohmBvbmFwLmo7pp8gp4/1tceRWfSTjHgjtfpCVNJmQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@ampproject/remapping": "^2.3.0",
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
"@types/estree": "^1.0.5",
|
||||
@@ -8092,8 +8111,9 @@
|
||||
"aria-query": "^5.3.1",
|
||||
"axobject-query": "^4.1.0",
|
||||
"clsx": "^2.1.1",
|
||||
"devalue": "^5.6.2",
|
||||
"esm-env": "^1.2.1",
|
||||
"esrap": "^2.1.0",
|
||||
"esrap": "^2.2.1",
|
||||
"is-reference": "^3.0.3",
|
||||
"locate-character": "^3.0.0",
|
||||
"magic-string": "^0.30.11",
|
||||
@@ -8281,9 +8301,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/svelte/node_modules/esrap": {
|
||||
"version": "2.1.0",
|
||||
"resolved": "https://registry.npmjs.org/esrap/-/esrap-2.1.0.tgz",
|
||||
"integrity": "sha512-yzmPNpl7TBbMRC5Lj2JlJZNPml0tzqoqP5B1JXycNUwtqma9AKCO0M2wHrdgsHcy1WRW7S9rJknAMtByg3usgA==",
|
||||
"version": "2.2.2",
|
||||
"resolved": "https://registry.npmjs.org/esrap/-/esrap-2.2.2.tgz",
|
||||
"integrity": "sha512-zA6497ha+qKvoWIK+WM9NAh5ni17sKZKhbS5B3PoYbBvaYHZWoS33zmFybmyqpn07RLUxSmn+RCls2/XF+d0oQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@jridgewell/sourcemap-codec": "^1.4.15"
|
||||
@@ -8326,6 +8346,7 @@
|
||||
"integrity": "sha512-gBXpgUm/3rp1lMZZrM/w7D8GKqshif0zAymAhbCyIt8KMe+0v9DQ7cdYLR4FHH/cKpdTXb+A/tKKU3eolfsI+g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/dcastil"
|
||||
@@ -8356,7 +8377,8 @@
|
||||
"resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-4.1.11.tgz",
|
||||
"integrity": "sha512-2E9TBm6MDD/xKYe+dvJZAmg3yxIEDNRc0jwlNyDg/4Fil2QcSLjFKGVff0lAf1jjeaArlG/M75Ey/EYr/OJtBA==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
"license": "MIT",
|
||||
"peer": true
|
||||
},
|
||||
"node_modules/tapable": {
|
||||
"version": "2.2.2",
|
||||
@@ -8569,6 +8591,7 @@
|
||||
"integrity": "sha512-p1diW6TqL9L07nNxvRMM7hMMw4c5XOo/1ibL4aAIGmSAt9slTE1Xgw5KWuof2uTOvCg9BY7ZRi+GaF+7sfgPeQ==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -8934,6 +8957,7 @@
|
||||
"integrity": "sha512-BxAKBWmIbrDgrokdGZH1IgkIk/5mMHDreLDmCJ0qpyJaAteP8NvMhkwr/ZCQNqNH97bw/dANTE9PDzqwJghfMQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.5.0",
|
||||
@@ -9094,6 +9118,7 @@
|
||||
"integrity": "sha512-LUCP5ev3GURDysTWiP47wRRUpLKMOfPh+yKTx3kVIEiu5KOMeqzpnYNsKyOoVrULivR8tLcks4+lga33Whn90A==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@types/chai": "^5.2.2",
|
||||
"@vitest/expect": "3.2.4",
|
||||
|
||||
@@ -62,7 +62,7 @@
|
||||
"remark-math": "^6.0.0",
|
||||
"sass": "^1.93.3",
|
||||
"storybook": "^10.0.7",
|
||||
"svelte": "^5.0.0",
|
||||
"svelte": "^5.38.2",
|
||||
"svelte-check": "^4.0.0",
|
||||
"tailwind-merge": "^3.3.1",
|
||||
"tailwind-variants": "^3.2.2",
|
||||
|
||||
Reference in New Issue
Block a user