mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-07-01 10:07:44 +02:00
Compare commits
54 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e3af5563bd | |||
| 10fcc41290 | |||
| bcf5bda6f5 | |||
| 3eb2be1ca5 | |||
| e41bcce8f0 | |||
| 144a4ce824 | |||
| f549b0007d | |||
| 9a3ea685b9 | |||
| 338074c383 | |||
| 851553ea6b | |||
| 85a7d8677b | |||
| a8ca18b4b8 | |||
| 8284efc35c | |||
| 1c1409e131 | |||
| 7a0e900e36 | |||
| 280d97be96 | |||
| 3479efd112 | |||
| 463bbf20bf | |||
| ad8d36beff | |||
| c053e18a66 | |||
| e1ab084803 | |||
| 5a4ff43e7d | |||
| 10640e31aa | |||
| 80d28f104c | |||
| c55d53acec | |||
| 945501f5ea | |||
| 75cbdd3fce | |||
| 2b9bd9bf4e | |||
| 59fc1ec8e8 | |||
| 75d33b9302 | |||
| 3470a5c891 | |||
| bd562fe4f7 | |||
| bbac6a26b2 | |||
| 73a48c9790 | |||
| f696428ce8 | |||
| 7cce4f8158 | |||
| 8d8862829c | |||
| f77c13b91f | |||
| 3cfa9c3f12 | |||
| 5d195f17bc | |||
| 226f295f4d | |||
| f90b4a8efe | |||
| 8423d01931 | |||
| 5cca2542ac | |||
| 55945d2ef5 | |||
| 0bcb40b48c | |||
| 69e9ff0103 | |||
| 5a91109a5d | |||
| f8f071fadd | |||
| 0bf47a1dbb | |||
| dd62dcfab9 | |||
| d0660f237a | |||
| fe6a9882ac | |||
| 061f0eff02 |
+1
-1
@@ -65,7 +65,7 @@
|
||||
/ggml/src/ggml-impl.h @ggerganov @slaren
|
||||
/ggml/src/ggml-metal/ @ggerganov
|
||||
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky
|
||||
/ggml/src/ggml-hexagon/ @max-krasnyansky
|
||||
/ggml/src/ggml-hexagon/ @max-krasnyansky @lhez
|
||||
/ggml/src/ggml-opt.cpp @JohannesGaessler
|
||||
/ggml/src/ggml-quants.* @ggerganov
|
||||
/ggml/src/ggml-rpc/ @rgerganov
|
||||
|
||||
@@ -84,6 +84,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [X] [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1)
|
||||
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
|
||||
- [x] [DBRX](https://huggingface.co/databricks/dbrx-instruct)
|
||||
- [x] [Jamba](https://huggingface.co/ai21labs)
|
||||
- [X] [Falcon](https://huggingface.co/models?search=tiiuae/falcon)
|
||||
- [X] [Chinese LLaMA / Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) and [Chinese LLaMA-2 / Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2)
|
||||
- [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne)
|
||||
|
||||
+2
-2
@@ -3248,7 +3248,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
|
||||
add_opt(common_arg(
|
||||
{"--embd-output-format"}, "FORMAT",
|
||||
"empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix",
|
||||
"empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.embd_out = value;
|
||||
}
|
||||
@@ -3435,7 +3435,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params) {
|
||||
params.use_jinja = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA"));
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_JINJA"));
|
||||
add_opt(common_arg(
|
||||
{"--reasoning-format"}, "FORMAT",
|
||||
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
|
||||
|
||||
+198
@@ -9,8 +9,11 @@
|
||||
#include <minja/chat-template.hpp>
|
||||
#include <minja/minja.hpp>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <cctype>
|
||||
#include <exception>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <optional>
|
||||
#include <stdexcept>
|
||||
@@ -640,6 +643,7 @@ const char * common_chat_format_name(common_chat_format format) {
|
||||
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
|
||||
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
|
||||
case COMMON_CHAT_FORMAT_APERTUS: return "Apertus";
|
||||
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools";
|
||||
default:
|
||||
throw std::runtime_error("Unknown chat format");
|
||||
}
|
||||
@@ -986,6 +990,126 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
|
||||
return data;
|
||||
}
|
||||
|
||||
|
||||
// Case-insensitive find
|
||||
static size_t ifind_string(const std::string & haystack, const std::string & needle, size_t pos = 0) {
|
||||
auto it = std::search(
|
||||
haystack.begin() + pos, haystack.end(),
|
||||
needle.begin(), needle.end(),
|
||||
[](char a, char b) { return std::tolower(a) == std::tolower(b); }
|
||||
);
|
||||
return (it == haystack.end()) ? std::string::npos : std::distance(haystack.begin(), it);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
const auto is_json_schema_provided = !inputs.json_schema.is_null();
|
||||
const auto is_grammar_provided = !inputs.grammar.empty();
|
||||
const auto are_tools_provided = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
|
||||
// the logic requires potentially modifying the messages
|
||||
auto tweaked_messages = inputs.messages;
|
||||
|
||||
auto replace_json_schema_marker = [](json & messages) -> bool {
|
||||
static std::string marker1 = "force json schema.\n";
|
||||
static std::string marker2 = "force json schema.";
|
||||
|
||||
if (messages.empty() || messages.at(0).at("role") != "system") {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string content = messages.at(0).at("content");
|
||||
|
||||
for (const auto & marker : {marker1, marker2}) {
|
||||
const auto pos = ifind_string(content, marker);
|
||||
if (pos != std::string::npos) {
|
||||
content.replace(pos, marker.length(), "");
|
||||
// inject modified content back into the messages
|
||||
messages.at(0).at("content") = content;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
// Lfm2 model does not natively work with json, but can generally understand the tools structure
|
||||
//
|
||||
// Example of the pytorch dialog structure:
|
||||
// <|startoftext|><|im_start|>system
|
||||
// List of tools: <|tool_list_start|>[{"name": "get_candidate_status", "description": "Retrieves the current status of a candidate in the recruitment process", "parameters": {"type": "object", "properties": {"candidate_id": {"type": "string", "description": "Unique identifier for the candidate"}}, "required": ["candidate_id"]}}]<|tool_list_end|><|im_end|>
|
||||
// <|im_start|>user
|
||||
// What is the current status of candidate ID 12345?<|im_end|>
|
||||
// <|im_start|>assistant
|
||||
// <|tool_call_start|>[get_candidate_status(candidate_id="12345")]<|tool_call_end|>Checking the current status of candidate ID 12345.<|im_end|>
|
||||
// <|im_start|>tool
|
||||
// <|tool_response_start|>{"candidate_id": "12345", "status": "Interview Scheduled", "position": "Clinical Research Associate", "date": "2023-11-20"}<|tool_response_end|><|im_end|>
|
||||
// <|im_start|>assistant
|
||||
// The candidate with ID 12345 is currently in the "Interview Scheduled" stage for the position of Clinical Research Associate, with an interview date set for 2023-11-20.<|im_end|>
|
||||
//
|
||||
// For the llama server compatibility with json tools semantic,
|
||||
// the client can add "Follow json schema." line into the system message prompt to force the json output.
|
||||
//
|
||||
if (are_tools_provided && (is_json_schema_provided || is_grammar_provided)) {
|
||||
// server/utils.hpp prohibits that branch for the custom grammar anyways
|
||||
throw std::runtime_error("Tools call must not use \"json_schema\" or \"grammar\", use non-tool invocation if you want to use custom grammar");
|
||||
} else if (are_tools_provided && replace_json_schema_marker(tweaked_messages)) {
|
||||
LOG_INF("%s: Using tools to build a grammar\n", __func__);
|
||||
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"name", {
|
||||
{"type", "string"},
|
||||
{"const", function.at("name")},
|
||||
}},
|
||||
{"arguments", function.at("parameters")},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments", "id"})},
|
||||
});
|
||||
});
|
||||
auto schema = json {
|
||||
{"type", "array"},
|
||||
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
||||
{"minItems", 1},
|
||||
};
|
||||
if (!inputs.parallel_tool_calls) {
|
||||
schema["maxItems"] = 1;
|
||||
}
|
||||
|
||||
builder.add_rule("root", "\"<|tool_call_start|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tool_call_end|>\"");
|
||||
});
|
||||
// model has no concept of tool selection mode choice,
|
||||
// if the system prompt rendered correctly it will produce a tool call
|
||||
// the grammar goes inside the tool call body
|
||||
data.grammar_lazy = true;
|
||||
data.grammar_triggers = {{COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, "\\s*<\\|tool_call_start\\|>\\s*\\["}};
|
||||
data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"};
|
||||
data.format = COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS;
|
||||
} else if (are_tools_provided && (!is_json_schema_provided && !is_grammar_provided)) {
|
||||
LOG_INF("%s: Using tools without json schema or grammar\n", __func__);
|
||||
// output those tokens
|
||||
data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"};
|
||||
} else if (is_json_schema_provided) {
|
||||
LOG_INF("%s: Using provided json schema to build a grammar\n", __func__);
|
||||
data.grammar = json_schema_to_grammar(inputs.json_schema);
|
||||
} else if (is_grammar_provided) {
|
||||
LOG_INF("%s: Using provided grammar\n", __func__);
|
||||
data.grammar = inputs.grammar;
|
||||
} else {
|
||||
LOG_INF("%s: Using content relying on the template\n", __func__);
|
||||
}
|
||||
|
||||
data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages);
|
||||
LOG_DBG("%s: Prompt: %s\n", __func__, data.prompt.c_str());
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs);
|
||||
@@ -2499,6 +2623,71 @@ static void common_chat_parse_apertus(common_chat_msg_parser & builder) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
|
||||
static void common_chat_parse_lfm2(common_chat_msg_parser & builder) {
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
// LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|>
|
||||
static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>"));
|
||||
static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>"));
|
||||
|
||||
// Loop through all tool calls
|
||||
while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) {
|
||||
builder.move_to(res->groups[0].end);
|
||||
|
||||
// Parse JSON array format: [{"name": "...", "arguments": {...}}]
|
||||
auto tool_calls_data = builder.consume_json();
|
||||
|
||||
// Consume end marker
|
||||
builder.consume_spaces();
|
||||
if (!builder.try_consume_regex(tool_call_end_regex)) {
|
||||
throw common_chat_msg_partial_exception("Expected <|tool_call_end|>");
|
||||
}
|
||||
|
||||
// Process each tool call in the array
|
||||
if (tool_calls_data.json.is_array()) {
|
||||
for (const auto & tool_call : tool_calls_data.json) {
|
||||
if (!tool_call.is_object()) {
|
||||
throw common_chat_msg_partial_exception("Tool call must be an object");
|
||||
}
|
||||
|
||||
if (!tool_call.contains("name")) {
|
||||
throw common_chat_msg_partial_exception("Tool call missing 'name' field");
|
||||
}
|
||||
|
||||
std::string function_name = tool_call.at("name");
|
||||
std::string arguments = "{}";
|
||||
|
||||
if (tool_call.contains("arguments")) {
|
||||
if (tool_call.at("arguments").is_object()) {
|
||||
arguments = tool_call.at("arguments").dump();
|
||||
} else if (tool_call.at("arguments").is_string()) {
|
||||
arguments = tool_call.at("arguments");
|
||||
}
|
||||
}
|
||||
|
||||
if (!builder.add_tool_call(function_name, "", arguments)) {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Expected JSON array for tool calls");
|
||||
}
|
||||
|
||||
// Consume any trailing whitespace after this tool call
|
||||
builder.consume_spaces();
|
||||
}
|
||||
|
||||
// Consume any remaining content after all tool calls
|
||||
auto remaining = builder.consume_rest();
|
||||
if (!string_strip(remaining).empty()) {
|
||||
builder.add_content(remaining);
|
||||
}
|
||||
}
|
||||
|
||||
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
|
||||
// Parse thinking tags first - this handles the main reasoning content
|
||||
builder.try_parse_reasoning("<seed:think>", "</seed:think>");
|
||||
@@ -2748,6 +2937,12 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
return common_chat_params_init_apertus(tmpl, params);
|
||||
}
|
||||
|
||||
// LFM2 (w/ tools)
|
||||
if (src.find("List of tools: <|tool_list_start|>[") != std::string::npos &&
|
||||
src.find("]<|tool_list_end|>") != std::string::npos) {
|
||||
return common_chat_params_init_lfm2(tmpl, params);
|
||||
}
|
||||
|
||||
// Use generic handler when mixing tools + JSON schema.
|
||||
// TODO: support that mix in handlers below.
|
||||
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
||||
@@ -2926,6 +3121,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||
case COMMON_CHAT_FORMAT_APERTUS:
|
||||
common_chat_parse_apertus(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS:
|
||||
common_chat_parse_lfm2(builder);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
|
||||
}
|
||||
|
||||
@@ -116,6 +116,7 @@ enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_SEED_OSS,
|
||||
COMMON_CHAT_FORMAT_NEMOTRON_V2,
|
||||
COMMON_CHAT_FORMAT_APERTUS,
|
||||
COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
||||
@@ -601,7 +601,10 @@ private:
|
||||
}
|
||||
|
||||
std::string _resolve_ref(const std::string & ref) {
|
||||
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
|
||||
auto it = ref.find('#');
|
||||
std::string ref_fragment = it != std::string::npos ? ref.substr(it + 1) : ref;
|
||||
static const std::regex nonalphanumeric_regex(R"([^a-zA-Z0-9-]+)");
|
||||
std::string ref_name = "ref" + std::regex_replace(ref_fragment, nonalphanumeric_regex, "-");
|
||||
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
|
||||
_refs_being_resolved.insert(ref);
|
||||
json resolved = _refs[ref];
|
||||
@@ -774,11 +777,24 @@ public:
|
||||
std::vector<std::string> tokens = string_split(pointer, "/");
|
||||
for (size_t i = 1; i < tokens.size(); ++i) {
|
||||
std::string sel = tokens[i];
|
||||
if (target.is_null() || !target.contains(sel)) {
|
||||
if (target.is_object() && target.contains(sel)) {
|
||||
target = target[sel];
|
||||
} else if (target.is_array()) {
|
||||
size_t sel_index;
|
||||
try {
|
||||
sel_index = std::stoul(sel);
|
||||
} catch (const std::invalid_argument & e) {
|
||||
sel_index = target.size();
|
||||
}
|
||||
if (sel_index >= target.size()) {
|
||||
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
|
||||
return;
|
||||
}
|
||||
target = target[sel_index];
|
||||
} else {
|
||||
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
|
||||
return;
|
||||
}
|
||||
target = target[sel];
|
||||
}
|
||||
_refs[ref] = target;
|
||||
}
|
||||
|
||||
+259
-84
@@ -29,12 +29,29 @@ if 'NO_LOCAL_GGUF' not in os.environ:
|
||||
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
||||
import gguf
|
||||
from gguf.vocab import MistralTokenizerType, MistralVocab
|
||||
from mistral_common.tokens.tokenizers.base import TokenizerVersion
|
||||
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN, DATASET_STD
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
|
||||
try:
|
||||
from mistral_common.tokens.tokenizers.base import TokenizerVersion # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
|
||||
_mistral_common_installed = True
|
||||
_mistral_import_error_msg = ""
|
||||
except ImportError:
|
||||
_MISTRAL_COMMON_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
||||
_MISTRAL_COMMON_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
||||
|
||||
_mistral_common_installed = False
|
||||
TokenizerVersion = None
|
||||
Tekkenizer = None
|
||||
SentencePieceTokenizer = None
|
||||
_mistral_import_error_msg = (
|
||||
"Mistral format requires `mistral-common` to be installed. Please run "
|
||||
"`pip install mistral-common[image,audio]` to install it."
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger("hf-to-gguf")
|
||||
@@ -73,10 +90,8 @@ class ModelBase:
|
||||
use_temp_file: bool
|
||||
lazy: bool
|
||||
dry_run: bool
|
||||
part_names: list[str]
|
||||
is_safetensors: bool
|
||||
hparams: dict[str, Any]
|
||||
tensor_names: set[str] | None
|
||||
model_tensors: dict[str, Callable[[], Tensor]]
|
||||
gguf_writer: gguf.GGUFWriter
|
||||
model_name: str | None
|
||||
metadata_override: Path | None
|
||||
@@ -107,6 +122,9 @@ class ModelBase:
|
||||
type(self) is MmprojModel:
|
||||
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
||||
|
||||
if self.is_mistral_format and not _mistral_common_installed:
|
||||
raise ImportError(_mistral_import_error_msg)
|
||||
|
||||
self.dir_model = dir_model
|
||||
self.ftype = ftype
|
||||
self.fname_out = fname_out
|
||||
@@ -117,25 +135,8 @@ class ModelBase:
|
||||
self.dry_run = dry_run
|
||||
self.remote_hf_model_id = remote_hf_model_id
|
||||
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
|
||||
if remote_hf_model_id is not None:
|
||||
self.is_safetensors = True
|
||||
|
||||
def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
|
||||
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
|
||||
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
|
||||
self.tensor_names = set(name for name in remote_tensors.keys())
|
||||
for name, remote_tensor in remote_tensors.items():
|
||||
yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor))
|
||||
|
||||
self.get_tensors = get_remote_tensors
|
||||
else:
|
||||
prefix = "model" if not self.is_mistral_format else "consolidated"
|
||||
self.part_names = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
|
||||
self.is_safetensors = len(self.part_names) > 0
|
||||
if not self.is_safetensors:
|
||||
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
|
||||
self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
|
||||
self.tensor_names = None
|
||||
self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
|
||||
self.metadata_override = metadata_override
|
||||
self.model_name = model_name
|
||||
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
|
||||
@@ -151,6 +152,8 @@ class ModelBase:
|
||||
logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_BF16
|
||||
|
||||
self.dequant_model()
|
||||
|
||||
# Configure GGUF Writer
|
||||
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
|
||||
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
|
||||
@@ -172,67 +175,215 @@ class ModelBase:
|
||||
return None
|
||||
raise KeyError(f"could not find any of: {keys}")
|
||||
|
||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||
tensor_names_from_parts: set[str] = set()
|
||||
def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]:
|
||||
tensors: dict[str, Callable[[], Tensor]] = {}
|
||||
|
||||
if remote_hf_model_id is not None:
|
||||
is_safetensors = True
|
||||
|
||||
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
|
||||
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
|
||||
for name, remote_tensor in remote_tensors.items():
|
||||
tensors[name] = lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r)
|
||||
|
||||
return tensors
|
||||
|
||||
prefix = "model" if not self.is_mistral_format else "consolidated"
|
||||
part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
|
||||
is_safetensors: bool = len(part_names) > 0
|
||||
if not is_safetensors:
|
||||
part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
|
||||
|
||||
tensor_names_from_index: set[str] = set()
|
||||
|
||||
if not self.is_mistral_format:
|
||||
index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin"
|
||||
index_name = "model.safetensors" if is_safetensors else "pytorch_model.bin"
|
||||
index_name += ".index.json"
|
||||
index_file = self.dir_model / index_name
|
||||
|
||||
if index_file.is_file():
|
||||
self.tensor_names = set()
|
||||
logger.info(f"gguf: loading model weight map from '{index_name}'")
|
||||
with open(index_file, "r", encoding="utf-8") as f:
|
||||
index: dict[str, Any] = json.load(f)
|
||||
weight_map = index.get("weight_map")
|
||||
if weight_map is None or not isinstance(weight_map, dict):
|
||||
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
|
||||
self.tensor_names.update(weight_map.keys())
|
||||
tensor_names_from_index.update(weight_map.keys())
|
||||
else:
|
||||
self.tensor_names = tensor_names_from_parts
|
||||
weight_map = {}
|
||||
else:
|
||||
self.tensor_names = tensor_names_from_parts
|
||||
weight_map = {}
|
||||
|
||||
for part_name in self.part_names:
|
||||
logger.info(f"gguf: loading model part '{part_name}'")
|
||||
for part_name in part_names:
|
||||
logger.info(f"gguf: indexing model part '{part_name}'")
|
||||
ctx: ContextManager[Any]
|
||||
if self.is_safetensors:
|
||||
if is_safetensors:
|
||||
from safetensors import safe_open
|
||||
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
|
||||
else:
|
||||
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
|
||||
|
||||
with ctx as model_part:
|
||||
tensor_names_from_parts.update(model_part.keys())
|
||||
assert model_part is not None
|
||||
|
||||
for name in model_part.keys():
|
||||
if self.is_safetensors:
|
||||
if is_safetensors:
|
||||
if self.lazy:
|
||||
data = model_part.get_slice(name)
|
||||
data = LazyTorchTensor.from_safetensors_slice(data)
|
||||
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731
|
||||
else:
|
||||
data = model_part.get_tensor(name)
|
||||
data_gen = lambda data=data: data # noqa: E731
|
||||
else:
|
||||
data = model_part[name]
|
||||
if self.lazy:
|
||||
data = LazyTorchTensor.from_eager(data)
|
||||
yield name, data
|
||||
data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731
|
||||
else:
|
||||
data_gen = lambda data=data: data # noqa: E731
|
||||
tensors[name] = data_gen
|
||||
|
||||
# verify tensor name presence and identify potentially missing files
|
||||
if len(tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0:
|
||||
missing = sorted(self.tensor_names.difference(tensor_names_from_parts))
|
||||
extra = sorted(tensor_names_from_parts.difference(self.tensor_names))
|
||||
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
|
||||
if len(extra) == 0 and len(missing_files) > 0:
|
||||
raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
|
||||
f"Missing tensors: {missing}")
|
||||
if len(tensor_names_from_index) > 0:
|
||||
tensor_names_from_parts = set(tensors.keys())
|
||||
if len(tensor_names_from_parts.symmetric_difference(tensor_names_from_index)) > 0:
|
||||
missing = sorted(tensor_names_from_index.difference(tensor_names_from_parts))
|
||||
extra = sorted(tensor_names_from_parts.difference(tensor_names_from_index))
|
||||
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
|
||||
if len(extra) == 0 and len(missing_files) > 0:
|
||||
raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
|
||||
f"Missing tensors: {missing}")
|
||||
else:
|
||||
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
|
||||
f"Missing tensors: {missing}\n"
|
||||
f"Extra tensors: {extra}")
|
||||
|
||||
return tensors
|
||||
|
||||
def dequant_model(self):
|
||||
tensors_to_remove: list[str] = []
|
||||
new_tensors: dict[str, Callable[[], Tensor]] = {}
|
||||
|
||||
if (quant_config := self.hparams.get("quantization_config")) and isinstance(quant_config, dict):
|
||||
quant_method = quant_config.get("quant_method")
|
||||
|
||||
def dequant_bitnet(weight: Tensor, scale: Tensor) -> Tensor:
|
||||
weight = weight.view(torch.uint8)
|
||||
orig_shape = weight.shape
|
||||
|
||||
shift = torch.tensor([0, 2, 4, 6], dtype=torch.uint8).reshape((4, *(1 for _ in range(len(orig_shape)))))
|
||||
data = weight.unsqueeze(0).expand((4, *orig_shape)) >> shift
|
||||
data = data & 3
|
||||
data = (data.float() - 1).reshape((orig_shape[0] * 4, *orig_shape[1:]))
|
||||
|
||||
# The scale is inverted
|
||||
return data / scale.float()
|
||||
|
||||
def dequant_simple(weight: Tensor, scale: Tensor) -> Tensor:
|
||||
scale = scale.float()
|
||||
|
||||
if (weight_block_size := quant_config.get("weight_block_size")):
|
||||
# TODO: make sure it's a list of integers
|
||||
for i, size in enumerate(weight_block_size):
|
||||
scale = scale.repeat_interleave(size, i)
|
||||
# unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
|
||||
scale = scale[tuple(slice(0, size) for size in weight.shape)]
|
||||
|
||||
return weight.float() * scale
|
||||
|
||||
# ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476
|
||||
def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor) -> Tensor:
|
||||
bits = quant_config["bits"]
|
||||
assert bits in (2, 3, 4, 8)
|
||||
assert qweight.dtype == qzeros.dtype
|
||||
maxq = (2 ** bits) - 1
|
||||
weight = None
|
||||
zeros = None
|
||||
pack_dtype_bits = qweight.dtype.itemsize * 8
|
||||
|
||||
if bits in [2, 4, 8]:
|
||||
pack_factor = pack_dtype_bits // bits
|
||||
wf = torch.tensor(list(range(0, pack_dtype_bits, bits)), dtype=torch.int32).unsqueeze(0)
|
||||
if self.lazy:
|
||||
wf = LazyTorchTensor.from_eager(wf)
|
||||
|
||||
zeros = torch.bitwise_right_shift(
|
||||
qzeros.unsqueeze(2).expand(-1, -1, pack_factor),
|
||||
wf.unsqueeze(0)
|
||||
).to(torch.int16 if bits == 8 else torch.int8)
|
||||
zeros = torch.bitwise_and(zeros, maxq).reshape(scales.shape)
|
||||
|
||||
weight = torch.bitwise_and(
|
||||
torch.bitwise_right_shift(
|
||||
qweight.unsqueeze(1).expand(-1, pack_factor, -1),
|
||||
wf.unsqueeze(-1)
|
||||
).to(torch.int16 if bits == 8 else torch.int8),
|
||||
maxq
|
||||
)
|
||||
elif bits == 3:
|
||||
raise NotImplementedError("3-bit gptq dequantization is not yet implemented")
|
||||
|
||||
assert weight is not None
|
||||
assert zeros is not None
|
||||
|
||||
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
|
||||
|
||||
# gptq_v2 doesn't need to offset zeros
|
||||
if quant_config.get("checkpoint_format", "gptq") == "gptq":
|
||||
zeros += 1
|
||||
|
||||
return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T
|
||||
|
||||
if quant_method == "bitnet":
|
||||
for name in self.model_tensors.keys():
|
||||
if name.endswith(".weight_scale"):
|
||||
weight_name = name.removesuffix("_scale")
|
||||
w = self.model_tensors[weight_name]
|
||||
s = self.model_tensors[name]
|
||||
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s())
|
||||
tensors_to_remove.append(name)
|
||||
elif quant_method == "fp8":
|
||||
for name in self.model_tensors.keys():
|
||||
if name.endswith(".weight_scale_inv"):
|
||||
weight_name = name.removesuffix("_scale_inv")
|
||||
w = self.model_tensors[weight_name]
|
||||
s = self.model_tensors[name]
|
||||
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s())
|
||||
tensors_to_remove.append(name)
|
||||
elif quant_method == "gptq":
|
||||
for name in self.model_tensors.keys():
|
||||
if name.endswith(".qweight"):
|
||||
base_name = name.removesuffix(".qweight")
|
||||
g_idx = self.model_tensors[base_name + ".g_idx"]
|
||||
qweight = self.model_tensors[base_name + ".qweight"]
|
||||
qzeros = self.model_tensors[base_name + ".qzeros"]
|
||||
scales = self.model_tensors[base_name + ".scales"]
|
||||
new_tensors[base_name + ".weight"] = (
|
||||
lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq(
|
||||
g(), w(), z(), s()
|
||||
)
|
||||
)
|
||||
tensors_to_remove += [
|
||||
base_name + n
|
||||
for n in (
|
||||
".g_idx",
|
||||
".qzeros",
|
||||
".qweight",
|
||||
".scales",
|
||||
)
|
||||
]
|
||||
else:
|
||||
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
|
||||
f"Missing tensors: {missing}\n"
|
||||
f"Extra tensors: {extra}")
|
||||
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
|
||||
|
||||
for name in tensors_to_remove:
|
||||
if name in self.model_tensors:
|
||||
del self.model_tensors[name]
|
||||
|
||||
for name, value in new_tensors.items():
|
||||
self.model_tensors[name] = value
|
||||
|
||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||
for name, gen in self.model_tensors.items():
|
||||
yield name, gen()
|
||||
|
||||
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
|
||||
if key not in gguf.MODEL_TENSORS[self.model_arch]:
|
||||
@@ -591,6 +742,12 @@ class TextModel(ModelBase):
|
||||
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
|
||||
self.gguf_writer.add_expert_used_count(n_experts_used)
|
||||
logger.info(f"gguf: experts used count = {n_experts_used}")
|
||||
if (n_expert_groups := self.hparams.get("n_group")) is not None:
|
||||
self.gguf_writer.add_expert_group_count(n_expert_groups)
|
||||
logger.info(f"gguf: expert groups count = {n_expert_groups}")
|
||||
if (n_group_used := self.hparams.get("topk_group")) is not None:
|
||||
self.gguf_writer.add_expert_group_used_count(n_group_used)
|
||||
logger.info(f"gguf: expert groups used count = {n_group_used}")
|
||||
|
||||
if (head_dim := self.hparams.get("head_dim")) is not None:
|
||||
self.gguf_writer.add_key_length(head_dim)
|
||||
@@ -1346,6 +1503,17 @@ class MmprojModel(ModelBase):
|
||||
def set_type(self):
|
||||
self.gguf_writer.add_type(gguf.GGUFType.MMPROJ)
|
||||
|
||||
def prepare_metadata(self, vocab_only: bool):
|
||||
super().prepare_metadata(vocab_only=vocab_only)
|
||||
|
||||
output_type: str = self.ftype.name.partition("_")[2]
|
||||
|
||||
if self.fname_out.is_dir():
|
||||
fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, size_label=None, output_type=output_type, model_type=None)
|
||||
self.fname_out = self.fname_out / f"mmproj-{fname_default}.gguf"
|
||||
else:
|
||||
self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
@@ -1363,8 +1531,8 @@ class MmprojModel(ModelBase):
|
||||
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"]))
|
||||
|
||||
# preprocessor config
|
||||
image_mean = DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
|
||||
image_std = DATASET_STD if self.is_mistral_format else self.preprocessor_config["image_std"]
|
||||
image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
|
||||
image_std = _MISTRAL_COMMON_DATASET_STD if self.is_mistral_format else self.preprocessor_config["image_std"]
|
||||
|
||||
self.gguf_writer.add_vision_image_mean(image_mean)
|
||||
self.gguf_writer.add_vision_image_std(image_std)
|
||||
@@ -2033,6 +2201,9 @@ class LlamaModel(TextModel):
|
||||
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
|
||||
|
||||
def _set_vocab_mistral(self):
|
||||
if not _mistral_common_installed:
|
||||
raise ImportError(_mistral_import_error_msg)
|
||||
|
||||
vocab = MistralVocab(self.dir_model)
|
||||
logger.info(
|
||||
f"Converting tokenizer {vocab.tokenizer_type} of size {vocab.vocab_size}."
|
||||
@@ -2289,18 +2460,21 @@ class ArceeModel(LlamaModel):
|
||||
)
|
||||
class LlavaVisionModel(MmprojModel):
|
||||
img_break_tok_id = -1
|
||||
use_break_tok = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.hparams.get("model_type") == "pixtral":
|
||||
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
|
||||
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
|
||||
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
|
||||
if self.use_break_tok:
|
||||
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
|
||||
elif self.is_mistral_format:
|
||||
# hparams is already vision config here so norm_eps is only defined in global_config.
|
||||
self.hparams["norm_eps"] = self.global_config.get("norm_eps", None)
|
||||
assert self.hparams["norm_eps"] is not None, "norm_eps not found in params.json"
|
||||
self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
|
||||
if self.use_break_tok:
|
||||
self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
|
||||
logger.info(f"Image break token id: {self.img_break_tok_id}")
|
||||
@@ -3791,6 +3965,10 @@ class Qwen3Model(Qwen2Model):
|
||||
return torch.stack([true_row, false_row], dim=0)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if "model.vision_" in name:
|
||||
# skip multimodal tensors
|
||||
return []
|
||||
|
||||
if self.is_rerank:
|
||||
is_tied_head = self.is_tied_embeddings and "embed_tokens" in name
|
||||
is_real_head = not self.is_tied_embeddings and "lm_head" in name
|
||||
@@ -4358,27 +4536,6 @@ class CodeShellModel(TextModel):
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||
self.gguf_writer.add_rope_scaling_factor(1.0)
|
||||
|
||||
_has_tok_embd = False
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
|
||||
tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD)
|
||||
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
||||
# assuming token_embd.weight is seen before output.weight
|
||||
if not self._has_tok_embd and new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
|
||||
# even though the tensor file(s) does not contain the word embeddings they are still in the weight map
|
||||
if self.tensor_names and "transformer.wte.weight" in self.tensor_names:
|
||||
logger.debug(f"{tok_embd_name} not found before {output_name}, assuming they are tied")
|
||||
self.tensor_names.remove("transformer.wte.weight")
|
||||
elif new_name == tok_embd_name:
|
||||
self._has_tok_embd = True
|
||||
|
||||
return [(new_name, data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register("InternLM2ForCausalLM")
|
||||
class InternLM2Model(TextModel):
|
||||
@@ -8089,8 +8246,6 @@ class BailingMoeV2Model(TextModel):
|
||||
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
|
||||
self.gguf_writer.add_expert_count(hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
|
||||
self.gguf_writer.add_expert_group_count(hparams["n_group"])
|
||||
self.gguf_writer.add_expert_group_used_count(hparams["topk_group"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
|
||||
if hparams["score_function"] == "sigmoid":
|
||||
@@ -8810,6 +8965,13 @@ class SmolLM3Model(LlamaModel):
|
||||
class GptOssModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.GPT_OSS
|
||||
|
||||
# TODO: remove once MXFP4 is supported more generally
|
||||
def dequant_model(self):
|
||||
quant_config = self.hparams.get("quantization_config")
|
||||
if quant_config is not None and quant_config.get("quant_method") == "mxfp4":
|
||||
return
|
||||
return super().dequant_model()
|
||||
|
||||
def transform_nibble_layout(self, tensor):
|
||||
assert tensor.dtype == torch.uint8
|
||||
assert tensor.shape[-1] == 16
|
||||
@@ -9212,7 +9374,7 @@ class MistralModel(LlamaModel):
|
||||
|
||||
@staticmethod
|
||||
def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool):
|
||||
assert TokenizerVersion is not None, "mistral_common is not installed"
|
||||
assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg
|
||||
assert isinstance(vocab.tokenizer, (Tekkenizer, SentencePieceTokenizer)), (
|
||||
f"Expected Tekkenizer or SentencePieceTokenizer, got {type(vocab.tokenizer)}"
|
||||
)
|
||||
@@ -9280,6 +9442,21 @@ class PixtralModel(LlavaVisionModel):
|
||||
return super().map_tensor_name(name, try_suffixes)
|
||||
|
||||
|
||||
@ModelBase.register("LightOnOCRForConditionalGeneration")
|
||||
class LightOnOCRVisionModel(LlavaVisionModel):
|
||||
is_mistral_format = False
|
||||
use_break_tok = False
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LIGHTONOCR)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
name = name.replace("model.vision_encoder.", "vision_tower.")
|
||||
name = name.replace("model.vision_projection.", "multi_modal_projector.")
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("KimiVLForConditionalGeneration")
|
||||
class KimiVLModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -9589,11 +9766,9 @@ def main() -> None:
|
||||
|
||||
logger.info(f"Loading model: {dir_model.name}")
|
||||
|
||||
if args.mmproj:
|
||||
if "mmproj" not in fname_out.name:
|
||||
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
|
||||
|
||||
is_mistral_format = args.mistral_format
|
||||
if is_mistral_format and not _mistral_common_installed:
|
||||
raise ImportError(_mistral_import_error_msg)
|
||||
disable_mistral_community_chat_template = args.disable_mistral_community_chat_template
|
||||
|
||||
with torch.inference_mode():
|
||||
|
||||
+6
-4
@@ -261,10 +261,12 @@ You can download it from your Linux distro's package manager or from here: [ROCm
|
||||
- Using `CMake` for Linux (assuming a gfx1030-compatible AMD GPU):
|
||||
```bash
|
||||
HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \
|
||||
cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
|
||||
cmake -S . -B build -DGGML_HIP=ON -DGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
|
||||
&& cmake --build build --config Release -- -j 16
|
||||
```
|
||||
|
||||
Note: `GPU_TARGETS` is optional, omitting it will build the code for all GPUs in the current system.
|
||||
|
||||
To enhance flash attention performance on RDNA3+ or CDNA architectures, you can utilize the rocWMMA library by enabling the `-DGGML_HIP_ROCWMMA_FATTN=ON` option. This requires rocWMMA headers to be installed on the build system.
|
||||
|
||||
The rocWMMA library is included by default when installing the ROCm SDK using the `rocm` meta package provided by AMD. Alternatively, if you are not using the meta package, you can install the library using the `rocwmma-dev` or `rocwmma-devel` package, depending on your system's package manager.
|
||||
@@ -282,17 +284,17 @@ You can download it from your Linux distro's package manager or from here: [ROCm
|
||||
```bash
|
||||
HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -p)" \
|
||||
HIP_DEVICE_LIB_PATH=<directory-you-just-found> \
|
||||
cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
|
||||
cmake -S . -B build -DGGML_HIP=ON -DGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
|
||||
&& cmake --build build -- -j 16
|
||||
```
|
||||
|
||||
- Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS, and assuming a gfx1100-compatible AMD GPU):
|
||||
```bash
|
||||
set PATH=%HIP_PATH%\bin;%PATH%
|
||||
cmake -S . -B build -G Ninja -DAMDGPU_TARGETS=gfx1100 -DGGML_HIP=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release
|
||||
cmake -S . -B build -G Ninja -DGPU_TARGETS=gfx1100 -DGGML_HIP=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build
|
||||
```
|
||||
Make sure that `AMDGPU_TARGETS` is set to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors)
|
||||
If necessary, adapt `GPU_TARGETS` to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors)
|
||||
Find your gpu version string by matching the most significant version information from `rocminfo | grep gfx | head -1 | awk '{print $2}'` with the list of processors, e.g. `gfx1035` maps to `gfx1030`.
|
||||
|
||||
|
||||
|
||||
+1
-1
@@ -79,7 +79,7 @@ Legend:
|
||||
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_MUL_ADD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
|
||||
+4
-4
@@ -5637,25 +5637,25 @@
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000000,inplace=0","support","1","yes","SYCL"
|
||||
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000,inplace=0","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000000","support","0","no","SYCL"
|
||||
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000000","support","1","yes","SYCL"
|
||||
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
|
||||
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=0","support","1","yes","SYCL"
|
||||
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001,inplace=0","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000001","support","0","no","SYCL"
|
||||
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000001","support","1","yes","SYCL"
|
||||
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
|
||||
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100,inplace=0","support","1","yes","SYCL"
|
||||
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100,inplace=0","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000100","support","0","no","SYCL"
|
||||
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000100","support","1","yes","SYCL"
|
||||
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
|
||||
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000,inplace=0","support","1","yes","SYCL"
|
||||
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000,inplace=0","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.100000","support","0","no","SYCL"
|
||||
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.100000","support","1","yes","SYCL"
|
||||
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=1","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=0.000000,broadcast=0,multi_add=0","support","1","yes","SYCL"
|
||||
|
||||
|
Can't render this file because it is too large.
|
@@ -38,6 +38,7 @@ The above command will output space-separated float values.
|
||||
| | multiple embeddings | $[[x_1,...,x_n],[x_1,...,x_n],...,[x_1,...,x_n]]$
|
||||
| 'json' | openai style |
|
||||
| 'json+' | add cosine similarity matrix |
|
||||
| 'raw' | plain text output |
|
||||
|
||||
### --embd-separator $"string"$
|
||||
| $"string"$ | |
|
||||
|
||||
@@ -70,6 +70,29 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
||||
}
|
||||
}
|
||||
|
||||
// plain, pipe-friendly output: one embedding per line
|
||||
static void print_raw_embeddings(const float * emb,
|
||||
int n_embd_count,
|
||||
int n_embd,
|
||||
const llama_model * model,
|
||||
enum llama_pooling_type pooling_type,
|
||||
int embd_normalize) {
|
||||
const uint32_t n_cls_out = llama_model_n_cls_out(model);
|
||||
const bool is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK);
|
||||
const int cols = is_rank ? std::min<int>(n_embd, (int) n_cls_out) : n_embd;
|
||||
|
||||
for (int j = 0; j < n_embd_count; ++j) {
|
||||
for (int i = 0; i < cols; ++i) {
|
||||
if (embd_normalize == 0) {
|
||||
LOG("%1.0f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
|
||||
} else {
|
||||
LOG("%1.7f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
|
||||
}
|
||||
}
|
||||
LOG("\n");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
@@ -372,6 +395,8 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
if (notArray) LOG("\n}\n");
|
||||
} else if (params.embd_out == "raw") {
|
||||
print_raw_embeddings(emb, n_embd_count, n_embd, model, pooling_type, params.embd_normalize);
|
||||
}
|
||||
|
||||
LOG("\n");
|
||||
|
||||
@@ -371,8 +371,17 @@ class SchemaConverter:
|
||||
raise ValueError(f'Unsupported ref {ref}')
|
||||
|
||||
for sel in ref.split('#')[-1].split('/')[1:]:
|
||||
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
|
||||
target = target[sel]
|
||||
assert target is not None, f'Error resolving ref {ref}: {sel} not in {target}'
|
||||
if isinstance(target, list):
|
||||
try:
|
||||
sel_index = int(sel)
|
||||
except ValueError:
|
||||
raise ValueError(f'Error resolving ref {ref}: {sel} not in {target}')
|
||||
assert 0 <= sel_index < len(target), f'Error resolving ref {ref}: {sel} not in {target}'
|
||||
target = target[sel_index]
|
||||
else:
|
||||
assert sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
|
||||
target = target[sel]
|
||||
|
||||
self._refs[ref] = target
|
||||
else:
|
||||
@@ -547,7 +556,8 @@ class SchemaConverter:
|
||||
|
||||
|
||||
def _resolve_ref(self, ref):
|
||||
ref_name = ref.split('/')[-1]
|
||||
ref_fragment = ref.split('#')[-1]
|
||||
ref_name = 'ref' + re.sub(r'[^a-zA-Z0-9-]+', '-', ref_fragment)
|
||||
if ref_name not in self._rules and ref not in self._refs_being_resolved:
|
||||
self._refs_being_resolved.add(ref)
|
||||
resolved = self._refs[ref]
|
||||
|
||||
@@ -138,7 +138,7 @@ if model_path is None:
|
||||
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
print("Model type: ", config.model_type)
|
||||
print("Vocab size: ", config.vocab_size)
|
||||
@@ -148,8 +148,8 @@ print("BOS token id: ", config.bos_token_id)
|
||||
print("EOS token id: ", config.eos_token_id)
|
||||
|
||||
print("Loading model and tokenizer using AutoTokenizer:", model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
if unreleased_model_name:
|
||||
model_name_lower = unreleased_model_name.lower()
|
||||
@@ -171,7 +171,7 @@ if unreleased_model_name:
|
||||
exit(1)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, device_map="auto", offload_folder="offload"
|
||||
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True
|
||||
)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
|
||||
+11
-4
@@ -226,16 +226,23 @@ static struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * al
|
||||
}
|
||||
|
||||
if (best_fit_block == -1) {
|
||||
// no suitable block found, try the last block (this will grow a chunks size)
|
||||
// no suitable block found, try the last block (this may grow a chunks size)
|
||||
int64_t best_reuse = INT64_MIN;
|
||||
for (int c = 0; c < alloc->n_chunks; ++c) {
|
||||
struct tallocr_chunk * chunk = alloc->chunks[c];
|
||||
if (chunk->n_free_blocks > 0) {
|
||||
struct free_block * block = &chunk->free_blocks[chunk->n_free_blocks - 1];
|
||||
max_avail = MAX(max_avail, block->size);
|
||||
if (block->size >= size) {
|
||||
int64_t reuse_factor = chunk->max_size - block->offset - size;
|
||||
// reuse_factor < 0 : amount of extra memory that needs to be allocated
|
||||
// reuse_factor = 0 : allocated free space exactly matches tensor size
|
||||
// reuse_factor > 0 : superfluous memory that will remain unused
|
||||
bool better_reuse = best_reuse < 0 && reuse_factor > best_reuse;
|
||||
bool better_fit = reuse_factor >= 0 && reuse_factor < best_reuse;
|
||||
if (block->size >= size && (better_reuse || better_fit)) {
|
||||
best_fit_chunk = c;
|
||||
best_fit_block = chunk->n_free_blocks - 1;
|
||||
break;
|
||||
best_reuse = reuse_factor;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -268,7 +275,7 @@ static struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * al
|
||||
#ifdef GGML_ALLOCATOR_DEBUG
|
||||
add_allocated_tensor(alloc, addr, tensor);
|
||||
size_t cur_max = addr.offset + size;
|
||||
if (cur_max > alloc->max_size[addr.chunk]) {
|
||||
if (cur_max > chunk->max_size) {
|
||||
// sort allocated_tensors by chunk/offset
|
||||
for (int i = 0; i < 1024; i++) {
|
||||
for (int j = i + 1; j < 1024; j++) {
|
||||
|
||||
@@ -2234,7 +2234,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
|
||||
ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
|
||||
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
|
||||
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
|
||||
theta_scale_ne, theta_scale_nb, 1);
|
||||
|
||||
float start = 0;
|
||||
float step = 1;
|
||||
@@ -2251,7 +2251,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
|
||||
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
|
||||
void * yarn_ramp_buffer = yarn_ramp_allocator.get();
|
||||
acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne,
|
||||
theta_scale_nb, GGML_MAX_DIMS);
|
||||
theta_scale_nb, 1);
|
||||
float zero_value = 0, one_value = 1;
|
||||
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
|
||||
aclScalar * low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT);
|
||||
|
||||
@@ -67,19 +67,30 @@
|
||||
GGML_ABORT("CANN error");
|
||||
}
|
||||
|
||||
// Thread-local variable to record the current device of this thread.
|
||||
thread_local int g_current_cann_device = -1;
|
||||
|
||||
/**
|
||||
* @brief Sets the device to be used by CANN.
|
||||
* @brief Set the CANN device to be used.
|
||||
*
|
||||
* @param device The device ID to set.
|
||||
* @param device The target device ID to set.
|
||||
*/
|
||||
void ggml_cann_set_device(const int32_t device) {
|
||||
int current_device = -1;
|
||||
aclrtGetDevice(¤t_device);
|
||||
// int current_device = -1;
|
||||
// Note: In some CANN versions, if no device has been set yet,
|
||||
// aclrtGetDevice(¤t_device) may return 0 by default.
|
||||
// aclrtGetDevice(¤t_device);
|
||||
|
||||
if (device == current_device) {
|
||||
// If the current device is already the target one, no need to switch.
|
||||
if (device == g_current_cann_device) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Switch to the new device.
|
||||
ACL_CHECK(aclrtSetDevice(device));
|
||||
|
||||
// Update the global device record.
|
||||
g_current_cann_device = device;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -7519,8 +7519,8 @@ static void ggml_compute_forward_upscale_f32(
|
||||
float pixel_offset = 0.5f;
|
||||
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
||||
pixel_offset = 0.0f;
|
||||
sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
|
||||
sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
|
||||
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
|
||||
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
|
||||
}
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
|
||||
@@ -1,5 +1,81 @@
|
||||
#include "argsort.cuh"
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
# include <cub/cub.cuh>
|
||||
using namespace cub;
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
|
||||
static __global__ void init_indices(int * indices, const int ncols, const int nrows) {
|
||||
const int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int row = blockIdx.y;
|
||||
|
||||
if (col < ncols && row < nrows) {
|
||||
indices[row * ncols + col] = col;
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx <= nrows) {
|
||||
offsets[idx] = idx * ncols;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
ggml_sort_order order,
|
||||
cudaStream_t stream) {
|
||||
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
|
||||
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
|
||||
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
|
||||
|
||||
int * temp_indices = temp_indices_alloc.get();
|
||||
float * temp_keys = temp_keys_alloc.get();
|
||||
int * d_offsets = offsets_alloc.get();
|
||||
|
||||
static const int block_size = 256;
|
||||
const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
|
||||
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
|
||||
|
||||
const dim3 offset_grid((nrows + block_size - 1) / block_size);
|
||||
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
|
||||
|
||||
cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream);
|
||||
|
||||
size_t temp_storage_bytes = 0;
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits
|
||||
stream);
|
||||
} else {
|
||||
DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
|
||||
sizeof(float) * 8, stream);
|
||||
}
|
||||
|
||||
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
|
||||
void * d_temp_storage = temp_storage_alloc.get();
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||
ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
|
||||
stream);
|
||||
} else {
|
||||
DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
|
||||
0, sizeof(float) * 8, stream);
|
||||
}
|
||||
}
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
|
||||
// Bitonic sort implementation
|
||||
template<typename T>
|
||||
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
|
||||
T tmp = a;
|
||||
@@ -65,7 +141,12 @@ static int next_power_of_2(int x) {
|
||||
return n;
|
||||
}
|
||||
|
||||
static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
|
||||
static void argsort_f32_i32_cuda_bitonic(const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
ggml_sort_order order,
|
||||
cudaStream_t stream) {
|
||||
// bitonic sort requires ncols to be power of 2
|
||||
const int ncols_pad = next_power_of_2(ncols);
|
||||
|
||||
@@ -77,9 +158,11 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
|
||||
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>
|
||||
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||
} else if (order == GGML_SORT_ORDER_DESC) {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>
|
||||
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
@@ -100,5 +183,18 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
||||
|
||||
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
const int ncols_pad = next_power_of_2(ncols);
|
||||
const size_t shared_mem = ncols_pad * sizeof(int);
|
||||
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
|
||||
|
||||
if (shared_mem > max_shared_mem || ncols > 1024) {
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||
} else {
|
||||
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||
}
|
||||
#else
|
||||
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -272,7 +272,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
|
||||
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
|
||||
|
||||
if (block_nums.z > 65535) {
|
||||
if (block_nums.z > 65535 || block_nums.y > 65535) {
|
||||
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
|
||||
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
|
||||
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
|
||||
|
||||
@@ -625,8 +625,11 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
|
||||
// and a shift:
|
||||
//
|
||||
// n/d = (mulhi(n, mp) + n) >> L;
|
||||
static const uint3 init_fastdiv_values(uint32_t d) {
|
||||
GGML_ASSERT(d != 0);
|
||||
static const uint3 init_fastdiv_values(uint64_t d_64) {
|
||||
GGML_ASSERT(d_64 != 0);
|
||||
GGML_ASSERT(d_64 <= std::numeric_limits<uint32_t>::max());
|
||||
|
||||
uint32_t d = (uint32_t)d_64;
|
||||
|
||||
// compute L = ceil(log2(d));
|
||||
uint32_t L = 0;
|
||||
@@ -1005,3 +1008,16 @@ struct ggml_backend_cuda_context {
|
||||
return pool(device);
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_cuda_mm_fusion_args_host {
|
||||
const ggml_tensor * x_bias = nullptr;
|
||||
const ggml_tensor * gate = nullptr;
|
||||
const ggml_tensor * gate_bias = nullptr;
|
||||
ggml_glu_op glu_op;
|
||||
};
|
||||
struct ggml_cuda_mm_fusion_args_device {
|
||||
const void * x_bias = nullptr;
|
||||
const void * gate = nullptr;
|
||||
const void * gate_bias = nullptr;
|
||||
ggml_glu_op glu_op;
|
||||
};
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#pragma once
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||
|
||||
+69
-11
@@ -112,6 +112,30 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
|
||||
cpy_blck(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const int64_t ne) {
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const src_t * x = (const src_t *) cx;
|
||||
dst_t * dst = (dst_t *) cdst;
|
||||
|
||||
dst[i] = ggml_cuda_cast<dst_t>(x[i]);
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static void ggml_cpy_flt_contiguous_cuda(
|
||||
const char * cx, char * cdst, const int64_t ne,
|
||||
cudaStream_t stream) {
|
||||
|
||||
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_flt_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne);
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static void ggml_cpy_flt_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
@@ -285,7 +309,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
char * src0_ddc = (char *) src0->data;
|
||||
char * src1_ddc = (char *) src1->data;
|
||||
|
||||
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
|
||||
|
||||
if (src0->type == src1->type && contiguous_srcs) {
|
||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
||||
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
|
||||
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
|
||||
@@ -296,11 +322,19 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
||||
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<float, half> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
||||
@@ -327,21 +361,45 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
||||
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<half, float> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
||||
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
|
||||
+399
-15
@@ -50,6 +50,7 @@
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/wkv.cuh"
|
||||
#include "ggml-cuda/gla.cuh"
|
||||
#include "ggml-cuda/set.cuh"
|
||||
#include "ggml-cuda/set-rows.cuh"
|
||||
#include "ggml-cuda/pad_reflect_1d.cuh"
|
||||
#include "ggml.h"
|
||||
@@ -1957,8 +1958,15 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
|
||||
|
||||
size_t src1_stride_size = sizeof(cuda_t);
|
||||
|
||||
dim3 block_dims(ne13, ne12);
|
||||
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
||||
const int threads_x = 16;
|
||||
const int threads_y = 16;
|
||||
dim3 block_dims(threads_x, threads_y);
|
||||
|
||||
dim3 grid_dims(
|
||||
(ne13 + threads_x - 1) / threads_x,
|
||||
(ne12 + threads_y - 1) / threads_y
|
||||
);
|
||||
k_compute_batched_ptrs<<<grid_dims, block_dims, 0, main_stream>>>(
|
||||
src0_ptr, src1_ptr, dst_t,
|
||||
ptrs_src.get(), ptrs_dst.get(),
|
||||
ne12, ne13,
|
||||
@@ -2007,6 +2015,147 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up,
|
||||
const ggml_tensor * ffn_gate,
|
||||
const ggml_tensor * glu,
|
||||
const ggml_tensor * ffn_up_bias = nullptr,
|
||||
const ggml_tensor * ffn_gate_bias = nullptr) {
|
||||
const bool has_bias = ffn_up_bias != nullptr || ffn_gate_bias != nullptr;
|
||||
|
||||
if (has_bias && (!ffn_up_bias || !ffn_gate_bias)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool is_mul_mat = ffn_up->op == GGML_OP_MUL_MAT && ffn_gate->op == GGML_OP_MUL_MAT && glu->op == GGML_OP_GLU;
|
||||
const bool is_mul_mat_id = ffn_up->op == GGML_OP_MUL_MAT_ID && ffn_gate->op == GGML_OP_MUL_MAT_ID && glu->op == GGML_OP_GLU;
|
||||
|
||||
GGML_ASSERT(ffn_up && ffn_gate && glu);
|
||||
|
||||
if (!is_mul_mat && !is_mul_mat_id) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ggml_op expected_bias_op = is_mul_mat ? GGML_OP_ADD : GGML_OP_ADD_ID;
|
||||
|
||||
if (has_bias) {
|
||||
if (ffn_up_bias->op != expected_bias_op || ffn_gate_bias->op != expected_bias_op) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (glu->src[0] != ffn_gate_bias || glu->src[1] != ffn_up_bias) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (expected_bias_op == GGML_OP_ADD) {
|
||||
const bool up_has_mul = ffn_up_bias->src[0] == ffn_up || ffn_up_bias->src[1] == ffn_up;
|
||||
const bool gate_has_mul = ffn_gate_bias->src[0] == ffn_gate || ffn_gate_bias->src[1] == ffn_gate;
|
||||
if (!up_has_mul || !gate_has_mul) {
|
||||
return false;
|
||||
}
|
||||
} else { // GGML_OP_ADD_ID
|
||||
if (ffn_up_bias->src[0] != ffn_up || ffn_gate_bias->src[0] != ffn_gate) {
|
||||
return false;
|
||||
}
|
||||
if (ffn_up_bias->src[2] != ffn_up->src[2] || ffn_gate_bias->src[2] != ffn_gate->src[2]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (glu->src[0] != ffn_gate && glu->src[1] != ffn_up) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (ffn_up->src[0]->type != ffn_gate->src[0]->type || !ggml_are_same_shape(ffn_up->src[0], ffn_gate->src[0]) ||
|
||||
!ggml_are_same_stride(ffn_up->src[0], ffn_gate->src[0])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ffn_up->src[1] != ffn_gate->src[1]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ffn_up->src[2] && (ffn_up->src[2] != ffn_gate->src[2])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static constexpr std::array<ggml_glu_op, 3> valid_glu_ops = { GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU, GGML_GLU_OP_SWIGLU_OAI };
|
||||
|
||||
if (std::find(valid_glu_ops.begin(), valid_glu_ops.end(), ggml_get_glu_op(glu)) == valid_glu_ops.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool split = ggml_backend_buft_is_cuda_split(ffn_up->src[0]->buffer->buft) ||
|
||||
ggml_backend_buft_is_cuda_split(ffn_gate->src[0]->buffer->buft);
|
||||
|
||||
//TODO: add support for fusion for split buffers
|
||||
if (split) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
|
||||
ggml_tensor * src0 = tensor->src[0];
|
||||
ggml_tensor * src1 = tensor->src[1];
|
||||
const ggml_tensor * dst = tensor;
|
||||
|
||||
const bool is_mul_mat_id = tensor->op == GGML_OP_MUL_MAT_ID;
|
||||
|
||||
bool use_mul_mat_vec_f =
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) &&
|
||||
src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
||||
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2] : src1->ne[1]);
|
||||
|
||||
//we only support fusion for ncols_dst = 1
|
||||
if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
return use_mul_mat_vec_f;
|
||||
}
|
||||
|
||||
static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
|
||||
ggml_tensor * src0 = tensor->src[0];
|
||||
ggml_tensor * src1 = tensor->src[1];
|
||||
const ggml_tensor * dst = tensor;
|
||||
|
||||
const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE &&
|
||||
ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) &&
|
||||
src0->view_src;
|
||||
|
||||
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 &&
|
||||
dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||||
|
||||
// fusion is not universally faster on Pascal
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
if (cc <= GGML_CUDA_CC_PASCAL) {
|
||||
return false;
|
||||
}
|
||||
//we only support fusion for ncols_dst = 1
|
||||
if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return use_mul_mat_vec_q;
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
|
||||
|
||||
@@ -2268,6 +2417,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_SET_ROWS:
|
||||
ggml_cuda_op_set_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SET:
|
||||
ggml_cuda_op_set(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_DUP:
|
||||
ggml_cuda_dup(ctx, dst);
|
||||
break;
|
||||
@@ -2745,7 +2897,7 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
|
||||
}
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_SCALE &&
|
||||
if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
|
||||
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
|
||||
return false;
|
||||
}
|
||||
@@ -2826,9 +2978,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
|
||||
|
||||
if (ops.size() == topk_moe_ops_with_norm.size() &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_with_norm, { node_idx + 3, node_idx + 8 })) {
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx+8];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
|
||||
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
|
||||
return true;
|
||||
@@ -2836,16 +2988,16 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
}
|
||||
|
||||
if (ops.size() == topk_moe_ops.size() &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops, { node_idx + 3, node_idx + 4 })) {
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx+4];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_delayed_softmax, { node_idx + 2, node_idx + 5 })) {
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
|
||||
|
||||
@@ -2854,6 +3006,38 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
}
|
||||
}
|
||||
|
||||
std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
|
||||
std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
|
||||
|
||||
std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
|
||||
std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
|
||||
|
||||
if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) ||
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) {
|
||||
|
||||
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1];
|
||||
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2];
|
||||
const ggml_tensor * ffn_up_bias = cgraph->nodes[node_idx + 3];
|
||||
const ggml_tensor * glu = cgraph->nodes[node_idx + 4];
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (ops.size() == 3 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}) ||
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}))) {
|
||||
|
||||
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1];
|
||||
const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
||||
return false;
|
||||
}
|
||||
@@ -2934,9 +3118,20 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
// With the use of CUDA graphs, the execution will be performed by the graph launch.
|
||||
if (!use_cuda_graph || cuda_graph_update_required) {
|
||||
|
||||
[[maybe_unused]] int prev_i = 0;
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
|
||||
#ifdef GGML_CUDA_DEBUG
|
||||
const int nodes_fused = i - prev_i - 1;
|
||||
prev_i = i;
|
||||
if (nodes_fused > 0) {
|
||||
GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||
continue;
|
||||
}
|
||||
@@ -2945,17 +3140,18 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
if (!disable_fusion) {
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
|
||||
ggml_tensor * weights = cgraph->nodes[i+8];
|
||||
ggml_tensor * selected_experts = cgraph->nodes[i+3];
|
||||
ggml_tensor * weights = cgraph->nodes[i + 9];
|
||||
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
|
||||
ggml_tensor * clamp = cgraph->nodes[i + 7];
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
|
||||
/*delayed softmax*/ false);
|
||||
i += 8;
|
||||
/*delayed softmax*/ false, clamp);
|
||||
i += 9;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
|
||||
ggml_tensor * weights = cgraph->nodes[i+4];
|
||||
ggml_tensor * selected_experts = cgraph->nodes[i+3];
|
||||
ggml_tensor * weights = cgraph->nodes[i + 4];
|
||||
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
|
||||
/*delayed softmax*/ false);
|
||||
i += 4;
|
||||
@@ -3004,6 +3200,184 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
}
|
||||
}
|
||||
|
||||
bool fused_mul_mat_vec = false;
|
||||
int fused_node_count = 0;
|
||||
|
||||
for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
|
||||
const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) {
|
||||
ggml_tensor * glu = cgraph->nodes[i + 4];
|
||||
ggml_tensor * gate_bias_n = glu->src[0];
|
||||
ggml_tensor * up_bias_n = glu->src[1];
|
||||
|
||||
//we don't assume the order for {gate, up}. Instead infer it from the bias tensor
|
||||
ggml_tensor * gate_n = nullptr;
|
||||
ggml_tensor * up_n = nullptr;
|
||||
|
||||
if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) {
|
||||
gate_n = cgraph->nodes[i];
|
||||
up_n = cgraph->nodes[i + 2];
|
||||
} else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) {
|
||||
gate_n = cgraph->nodes[i + 2];
|
||||
up_n = cgraph->nodes[i];
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) {
|
||||
if (op_bias == GGML_OP_ADD) {
|
||||
if (bias_node->src[0] == mul_node) {
|
||||
return bias_node->src[1];
|
||||
}
|
||||
if (bias_node->src[1] == mul_node) {
|
||||
return bias_node->src[0];
|
||||
}
|
||||
return (ggml_tensor *) nullptr;
|
||||
}
|
||||
GGML_ASSERT(op_bias == GGML_OP_ADD_ID);
|
||||
GGML_ASSERT(bias_node->src[0] == mul_node);
|
||||
return bias_node->src[1];
|
||||
};
|
||||
|
||||
ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op);
|
||||
ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op);
|
||||
|
||||
if (!up_bias_tensor || !gate_bias_tensor) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const ggml_tensor * src0 = up_n->src[0];
|
||||
const ggml_tensor * src1 = up_n->src[1];
|
||||
const ggml_tensor * ids = up_n->src[2];
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) {
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.gate = gate_n->src[0];
|
||||
fusion_data.x_bias = up_bias_tensor;
|
||||
fusion_data.gate_bias = gate_bias_tensor;
|
||||
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||
|
||||
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 5;
|
||||
break;
|
||||
}
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) {
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.gate = gate_n->src[0];
|
||||
fusion_data.x_bias = up_bias_tensor;
|
||||
fusion_data.gate_bias = gate_bias_tensor;
|
||||
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||
|
||||
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 5;
|
||||
break;
|
||||
}
|
||||
} else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
|
||||
ggml_tensor * glu = cgraph->nodes[i + 2];
|
||||
ggml_tensor * gate = glu->src[0];
|
||||
ggml_tensor * up = glu->src[1];
|
||||
|
||||
bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1])
|
||||
|| (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]);
|
||||
|
||||
if (!ok) continue;
|
||||
|
||||
const ggml_tensor * src0 = up->src[0];
|
||||
const ggml_tensor * src1 = up->src[1];
|
||||
const ggml_tensor * ids = up->src[2];
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.gate = gate->src[0];
|
||||
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||
|
||||
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 3;
|
||||
break;
|
||||
}
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.gate = gate->src[0];
|
||||
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||
|
||||
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 3;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (fused_mul_mat_vec) {
|
||||
i += fused_node_count - 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
fused_mul_mat_vec = false;
|
||||
fused_node_count = 0;
|
||||
|
||||
for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
|
||||
const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
|
||||
|
||||
if (!ggml_can_fuse(cgraph, i, { op, bias_op })) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_tensor * mm_node = cgraph->nodes[i];
|
||||
ggml_tensor * bias_node = cgraph->nodes[i + 1];
|
||||
|
||||
ggml_tensor * bias_tensor = nullptr;
|
||||
if (bias_op == GGML_OP_ADD) {
|
||||
if (bias_node->src[0] == mm_node) {
|
||||
bias_tensor = bias_node->src[1];
|
||||
} else if (bias_node->src[1] == mm_node) {
|
||||
bias_tensor = bias_node->src[0];
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (bias_node->src[0] != mm_node) {
|
||||
continue;
|
||||
}
|
||||
bias_tensor = bias_node->src[1];
|
||||
}
|
||||
|
||||
const ggml_tensor * src0 = mm_node->src[0];
|
||||
const ggml_tensor * src1 = mm_node->src[1];
|
||||
const ggml_tensor * ids = mm_node->src[2];
|
||||
|
||||
if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.x_bias = bias_tensor;
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) {
|
||||
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 2;
|
||||
break;
|
||||
}
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) {
|
||||
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 2;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (fused_mul_mat_vec) {
|
||||
i += fused_node_count - 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
|
||||
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
|
||||
@@ -3483,6 +3857,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
op->src[0]->type == GGML_TYPE_F32 &&
|
||||
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
|
||||
} break;
|
||||
case GGML_OP_SET:
|
||||
{
|
||||
const ggml_type t = op->type;
|
||||
return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) &&
|
||||
t == op->src[0]->type &&
|
||||
t == op->src[1]->type;
|
||||
} break;
|
||||
case GGML_OP_CPY:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
@@ -3642,8 +4023,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_SUM:
|
||||
return ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_ARGSORT:
|
||||
// TODO: Support arbitrary column width
|
||||
#ifndef GGML_CUDA_USE_CUB
|
||||
return op->src[0]->ne[0] <= 1024;
|
||||
#else
|
||||
return true;
|
||||
#endif
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
|
||||
+321
-57
@@ -1,11 +1,12 @@
|
||||
#include "ggml.h"
|
||||
#include "common.cuh"
|
||||
#include "convert.cuh"
|
||||
#include "unary.cuh"
|
||||
#include "mmvf.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
template <typename T, typename type_acc, int ncols_dst, int block_size>
|
||||
template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false>
|
||||
static __global__ void mul_mat_vec_f(
|
||||
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
|
||||
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
|
||||
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
@@ -24,58 +25,164 @@ static __global__ void mul_mat_vec_f(
|
||||
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
|
||||
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
|
||||
|
||||
bool use_gate = false;
|
||||
bool use_bias = false;
|
||||
bool use_gate_bias = false;
|
||||
ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU;
|
||||
const T * gate_x = nullptr;
|
||||
const float * x_bias = nullptr;
|
||||
const float * gate_bias = nullptr;
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
use_gate = fusion.gate != nullptr;
|
||||
use_bias = fusion.x_bias != nullptr;
|
||||
use_gate_bias = fusion.gate_bias != nullptr;
|
||||
glu_op = fusion.glu_op;
|
||||
|
||||
if (use_gate) {
|
||||
gate_x = static_cast<const T *>(fusion.gate);
|
||||
}
|
||||
if (use_bias) {
|
||||
x_bias = static_cast<const float *>(fusion.x_bias);
|
||||
}
|
||||
if (use_gate_bias) {
|
||||
gate_bias = static_cast<const float *>(fusion.gate_bias);
|
||||
use_gate_bias = use_gate;
|
||||
} else {
|
||||
use_gate_bias = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (use_gate) {
|
||||
gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
|
||||
}
|
||||
if constexpr (has_fusion) {
|
||||
const int channel_bias = ids ? channel_x : channel_dst;
|
||||
if (use_bias) {
|
||||
x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
|
||||
}
|
||||
if (use_gate_bias) {
|
||||
gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
|
||||
}
|
||||
}
|
||||
|
||||
const float2 * y2 = (const float2 *) y;
|
||||
|
||||
extern __shared__ char data_mmv[];
|
||||
float * buf_iw = (float *) data_mmv;
|
||||
float * buf_iw_gate = nullptr;
|
||||
if constexpr (has_fusion) {
|
||||
buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float));
|
||||
}
|
||||
|
||||
if (block_size > warp_size) {
|
||||
if (tid < warp_size) {
|
||||
buf_iw[tid] = 0.0f;
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
buf_iw_gate[tid] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
float sumf[ncols_dst] = {0.0f};
|
||||
float sumf_gate[ncols_dst];
|
||||
if constexpr (has_fusion) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
sumf_gate[j] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
const float2 * x2 = (const float2 *) x;
|
||||
const float2 * gate_x2 = nullptr;
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
gate_x2 = (const float2 *) gate_x;
|
||||
}
|
||||
}
|
||||
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const float2 tmpx = x2[col2];
|
||||
float2 tmpx_gate = make_float2(0.0f, 0.0f);
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmpx_gate = gate_x2[col2];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if constexpr (std::is_same_v<T, half>) {
|
||||
const half2 * x2 = (const half2 *) x;
|
||||
const half2 * gate_x2 = nullptr;
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
gate_x2 = (const half2 *) gate_x;
|
||||
}
|
||||
}
|
||||
|
||||
if (std::is_same_v<type_acc, float>) {
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const float2 tmpx = __half22float2(x2[col2]);
|
||||
|
||||
float2 tmpx_gate = make_float2(0.0f, 0.0f);
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmpx_gate = __half22float2(gate_x2[col2]);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#ifdef FP16_AVAILABLE
|
||||
half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
|
||||
half2 sumh2_gate[ncols_dst] = {{0.0f, 0.0f}};
|
||||
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const half2 tmpx = x2[col2];
|
||||
|
||||
half2 tmpx_gate = make_half2(0.0f, 0.0f);
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmpx_gate = gate_x2[col2];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
sumh2_gate[j] += tmpx_gate * make_half2(tmpy.x, tmpy.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,6 +190,15 @@ static __global__ void mul_mat_vec_f(
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
|
||||
}
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
sumf_gate[j] = __low2float(sumh2_gate[j]) + __high2float(sumh2_gate[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FP16_AVAILABLE
|
||||
@@ -91,8 +207,20 @@ static __global__ void mul_mat_vec_f(
|
||||
//TODO: add support for ggml_cuda_mad for hip_bfloat162
|
||||
#if defined(GGML_USE_HIP)
|
||||
const int * x2 = (const int *) x;
|
||||
const int * gate_x2 = nullptr;
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
gate_x2 = (const int *) gate_x;
|
||||
}
|
||||
}
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const int tmpx = x2[col2];
|
||||
int tmpx_gate = 0;
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmpx_gate = gate_x2[col2];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
@@ -100,17 +228,45 @@ static __global__ void mul_mat_vec_f(
|
||||
const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
|
||||
ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
|
||||
ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
const float tmpx0_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[0]);
|
||||
const float tmpx1_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[1]);
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx0_gate, tmpy.x);
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx1_gate, tmpy.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
|
||||
const nv_bfloat162 * gate_x2 = nullptr;
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
gate_x2 = (const nv_bfloat162 *) gate_x;
|
||||
}
|
||||
}
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const nv_bfloat162 tmpx = x2[col2];
|
||||
nv_bfloat162 tmpx_gate;
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmpx_gate = gate_x2[col2];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -122,13 +278,31 @@ static __global__ void mul_mat_vec_f(
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
|
||||
}
|
||||
}
|
||||
|
||||
if (block_size > warp_size) {
|
||||
buf_iw[tid/warp_size] = sumf[j];
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
buf_iw_gate[tid/warp_size] = sumf_gate[j];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid < warp_size) {
|
||||
sumf[j] = buf_iw[tid];
|
||||
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
sumf_gate[j] = buf_iw_gate[tid];
|
||||
sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (j < ncols_dst) {
|
||||
__syncthreads();
|
||||
}
|
||||
@@ -139,12 +313,74 @@ static __global__ void mul_mat_vec_f(
|
||||
return;
|
||||
}
|
||||
|
||||
dst[tid*stride_col_dst + row] = sumf[tid];
|
||||
float value = sumf[tid];
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_bias) {
|
||||
value += x_bias[tid*stride_col_dst + row];
|
||||
}
|
||||
|
||||
if (use_gate) {
|
||||
float gate_value = sumf_gate[tid];
|
||||
if (use_gate_bias) {
|
||||
gate_value += gate_bias[tid*stride_col_dst + row];
|
||||
}
|
||||
switch (glu_op) {
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
value *= ggml_cuda_op_silu_single(gate_value);
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
value *= ggml_cuda_op_gelu_single(gate_value);
|
||||
break;
|
||||
case GGML_GLU_OP_SWIGLU_OAI: {
|
||||
value = ggml_cuda_op_swiglu_oai_single(gate_value, value);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dst[tid*stride_col_dst + row] = value;
|
||||
|
||||
if constexpr (!has_fusion) {
|
||||
GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, glu_op, gate_x, x_bias, gate_bias, sumf_gate);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename type_acc, int ncols_dst, int block_size>
|
||||
static void mul_mat_vec_f_switch_fusion(
|
||||
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const int64_t ncols, const int64_t nrows,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) {
|
||||
|
||||
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
||||
if constexpr (ncols_dst == 1) {
|
||||
if (has_fusion) {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
|
||||
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, block_size><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
|
||||
}
|
||||
|
||||
template <typename T, typename type_acc, int ncols_dst>
|
||||
static void launch_mul_mat_vec_f_cuda(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
void launch_mul_mat_vec_f_cuda(
|
||||
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const int64_t ncols, const int64_t nrows,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
@@ -176,57 +412,59 @@ static void launch_mul_mat_vec_f_cuda(
|
||||
}
|
||||
}
|
||||
|
||||
const int nbytes_shared = warp_size*sizeof(float);
|
||||
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
||||
|
||||
const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0);
|
||||
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
|
||||
const dim3 block_dims(block_size_best, 1, 1);
|
||||
switch (block_size_best) {
|
||||
case 32: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
case 64: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
case 96: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
case 128: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
case 160: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
case 192: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
case 224: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
case 256: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -236,7 +474,7 @@ static void launch_mul_mat_vec_f_cuda(
|
||||
|
||||
template <typename T, typename type_acc>
|
||||
static void mul_mat_vec_f_cuda_switch_ncols_dst(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
@@ -246,49 +484,49 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
|
||||
switch (ncols_dst) {
|
||||
case 1:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 1>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 2:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 2>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 3:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 3>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 4:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 4>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 5:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 5>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 6:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 6>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 7:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 7>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 8:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 8>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
@@ -300,29 +538,31 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
|
||||
|
||||
template<typename T>
|
||||
static void mul_mat_vec_f_cuda(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
enum ggml_prec prec, cudaStream_t stream) {
|
||||
|
||||
if constexpr(std::is_same_v<T, half>) {
|
||||
if (prec == GGML_PREC_DEFAULT) {
|
||||
mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
|
||||
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
(x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
return;
|
||||
}
|
||||
}
|
||||
mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
|
||||
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
(x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
|
||||
const ggml_cuda_mm_fusion_args_host * fusion) {
|
||||
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
@@ -348,6 +588,30 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
ggml_cuda_mm_fusion_args_device fusion_local{};
|
||||
|
||||
if (fusion) {
|
||||
GGML_ASSERT( !ids || dst->ne[2] == 1);
|
||||
GGML_ASSERT( ids || dst->ne[1] == 1);
|
||||
if (fusion->x_bias) {
|
||||
GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
|
||||
GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
|
||||
fusion_local.x_bias = fusion->x_bias->data;
|
||||
}
|
||||
if (fusion->gate) {
|
||||
GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
|
||||
fusion_local.gate = fusion->gate->data;
|
||||
}
|
||||
if (fusion->gate_bias) {
|
||||
GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
|
||||
GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
|
||||
fusion_local.gate_bias = fusion->gate_bias->data;
|
||||
}
|
||||
fusion_local.glu_op = fusion->glu_op;
|
||||
}
|
||||
|
||||
const int64_t s01 = src0->nb[1] / ts_src0;
|
||||
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||
const int64_t s1 = dst->nb[1] / ts_dst;
|
||||
@@ -370,19 +634,19 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
const half * src0_d = (const half *) src0->data;
|
||||
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
|
||||
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
||||
} break;
|
||||
@@ -409,7 +673,6 @@ void ggml_cuda_op_mul_mat_vec_f(
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
||||
|
||||
|
||||
// ggml_cuda_op provides single, contiguous matrices
|
||||
const int64_t stride_row = ne00;
|
||||
const int64_t stride_col_y = ne10;
|
||||
@@ -426,22 +689,23 @@ void ggml_cuda_op_mul_mat_vec_f(
|
||||
const int64_t stride_sample_y = 0;
|
||||
const int64_t stride_sample_dst = 0;
|
||||
|
||||
ggml_cuda_mm_fusion_args_device empty{};
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0_dd_i;
|
||||
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
const half * src0_d = (const half *) src0_dd_i;
|
||||
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
|
||||
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||
} break;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
||||
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
|
||||
const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
|
||||
|
||||
void ggml_cuda_op_mul_mat_vec_f(
|
||||
ggml_backend_cuda_context & ctx,
|
||||
|
||||
+223
-95
@@ -1,5 +1,6 @@
|
||||
#include "mmvq.cuh"
|
||||
#include "quantize.cuh"
|
||||
#include "unary.cuh"
|
||||
#include "vecdotq.cuh"
|
||||
|
||||
#include <cstdint>
|
||||
@@ -82,7 +83,7 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
|
||||
return MMVQ_PARAMETERS_GENERIC;
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
|
||||
static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
|
||||
if (table_id == MMVQ_PARAMETERS_GENERIC) {
|
||||
switch (ncols_dst) {
|
||||
case 1:
|
||||
@@ -136,11 +137,11 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <ggml_type type, int ncols_dst>
|
||||
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
||||
template <ggml_type type, int ncols_dst, bool has_fusion>
|
||||
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
|
||||
static __global__ void mul_mat_vec_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
|
||||
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
|
||||
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
|
||||
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
|
||||
@@ -169,8 +170,38 @@ static __global__ void mul_mat_vec_q(
|
||||
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
|
||||
const uint32_t sample_y = sample_dst;
|
||||
|
||||
bool use_gate = false;
|
||||
bool use_bias = false;
|
||||
bool use_gate_bias = false;
|
||||
const void * vgate = nullptr;
|
||||
const float * x_bias = nullptr;
|
||||
const float * gate_bias = nullptr;
|
||||
ggml_glu_op active_glu;
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
use_gate = fusion.gate != nullptr;
|
||||
use_bias = fusion.x_bias != nullptr;
|
||||
use_gate_bias = fusion.gate_bias != nullptr && use_gate;
|
||||
vgate = fusion.gate;
|
||||
x_bias = (const float *) fusion.x_bias;
|
||||
gate_bias = (const float *) fusion.gate_bias;
|
||||
active_glu = fusion.glu_op;
|
||||
}
|
||||
|
||||
const uint32_t channel_bias = ids ? channel_x : channel_dst;
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_bias) {
|
||||
x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
|
||||
}
|
||||
if (use_gate_bias) {
|
||||
gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
|
||||
}
|
||||
}
|
||||
|
||||
// partial sum for each thread
|
||||
float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
|
||||
float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
|
||||
|
||||
const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
|
||||
const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
|
||||
@@ -187,17 +218,35 @@ static __global__ void mul_mat_vec_q(
|
||||
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
||||
tmp[j][i] += vec_dot_q_cuda(
|
||||
vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmp_gate[j][i] += vec_dot_q_cuda(
|
||||
vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
|
||||
__shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
|
||||
if constexpr (!has_fusion) {
|
||||
(void) tmp_shared_gate;
|
||||
} else if (!use_gate) {
|
||||
(void) tmp_shared_gate;
|
||||
}
|
||||
|
||||
if (threadIdx.y > 0) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
||||
tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmp_shared_gate[threadIdx.y-1][j][i][threadIdx.x] = tmp_gate[j][i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -216,14 +265,55 @@ static __global__ void mul_mat_vec_q(
|
||||
#pragma unroll
|
||||
for (int l = 0; l < nwarps-1; ++l) {
|
||||
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmp_gate[j][i] += tmp_shared_gate[l][j][i][threadIdx.x];
|
||||
}
|
||||
}
|
||||
}
|
||||
tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmp_gate[j][i] = warp_reduce_sum<warp_size>(tmp_gate[j][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
|
||||
dst[j*stride_col_dst + threadIdx.x] = tmp[j][threadIdx.x];
|
||||
float result = tmp[j][threadIdx.x];
|
||||
if constexpr (has_fusion) {
|
||||
if (use_bias) {
|
||||
result += x_bias[j*stride_col_dst + threadIdx.x];
|
||||
}
|
||||
if (use_gate) {
|
||||
float gate_value = tmp_gate[j][threadIdx.x];
|
||||
if (use_gate_bias) {
|
||||
gate_value += gate_bias[j*stride_col_dst + threadIdx.x];
|
||||
}
|
||||
switch (active_glu) {
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
result *= ggml_cuda_op_silu_single(gate_value);
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
result *= ggml_cuda_op_gelu_single(gate_value);
|
||||
break;
|
||||
case GGML_GLU_OP_SWIGLU_OAI: {
|
||||
result = ggml_cuda_op_swiglu_oai_single(gate_value, result);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
result = result * gate_value;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
dst[j*stride_col_dst + threadIdx.x] = result;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (!has_fusion) {
|
||||
GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, active_glu, gate_bias, x_bias, tmp_gate);
|
||||
}
|
||||
}
|
||||
|
||||
static std::pair<dim3, dim3> calc_launch_params(
|
||||
@@ -235,9 +325,37 @@ static std::pair<dim3, dim3> calc_launch_params(
|
||||
return {block_nums, block_dims};
|
||||
}
|
||||
|
||||
template<ggml_type type, int c_ncols_dst>
|
||||
static void mul_mat_vec_q_switch_fusion(
|
||||
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
|
||||
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
|
||||
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
|
||||
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
|
||||
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) {
|
||||
|
||||
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
||||
if constexpr (c_ncols_dst == 1) {
|
||||
if (has_fusion) {
|
||||
mul_mat_vec_q<type, c_ncols_dst, true><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
|
||||
|
||||
mul_mat_vec_q<type, c_ncols_dst, false><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
}
|
||||
|
||||
template <ggml_type type>
|
||||
static void mul_mat_vec_q_switch_ncols_dst(
|
||||
const void * vx, const void * vy, const int32_t * ids, float * dst,
|
||||
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_dst,
|
||||
const int stride_row_x, const int stride_col_y, const int stride_col_dst,
|
||||
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
|
||||
@@ -256,80 +374,83 @@ static void mul_mat_vec_q_switch_ncols_dst(
|
||||
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
|
||||
|
||||
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
||||
|
||||
GGML_ASSERT(!ids || ncols_dst == 1);
|
||||
switch (ncols_dst) {
|
||||
case 1: {
|
||||
constexpr int c_ncols_dst = 1;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
case 2: {
|
||||
constexpr int c_ncols_dst = 2;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
case 3: {
|
||||
constexpr int c_ncols_dst = 3;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
case 4: {
|
||||
constexpr int c_ncols_dst = 4;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
case 5: {
|
||||
constexpr int c_ncols_dst = 5;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
case 6: {
|
||||
constexpr int c_ncols_dst = 6;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
case 7: {
|
||||
constexpr int c_ncols_dst = 7;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
case 8: {
|
||||
constexpr int c_ncols_dst = 8;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
GGML_UNUSED(has_fusion);
|
||||
}
|
||||
static void mul_mat_vec_q_switch_type(
|
||||
const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, float * dst,
|
||||
const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_dst,
|
||||
const int stride_row_x, const int stride_col_y, const int stride_col_dst,
|
||||
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
|
||||
@@ -339,143 +460,123 @@ static void mul_mat_vec_q_switch_type(
|
||||
switch (type_x) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_MXFP4:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_M:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -484,7 +585,8 @@ static void mul_mat_vec_q_switch_type(
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat_vec_q(
|
||||
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
|
||||
const ggml_cuda_mm_fusion_args_host * fusion) {
|
||||
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
|
||||
@@ -508,6 +610,31 @@ void ggml_cuda_mul_mat_vec_q(
|
||||
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
ggml_cuda_mm_fusion_args_device fusion_local{};
|
||||
|
||||
if (fusion) {
|
||||
GGML_ASSERT( !ids || dst->ne[2] == 1);
|
||||
GGML_ASSERT( ids || dst->ne[1] == 1);
|
||||
|
||||
if (fusion->x_bias) {
|
||||
GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
|
||||
GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
|
||||
fusion_local.x_bias = fusion->x_bias->data;
|
||||
}
|
||||
if (fusion->gate) {
|
||||
GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
|
||||
fusion_local.gate = fusion->gate->data;
|
||||
}
|
||||
if (fusion->gate_bias) {
|
||||
GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
|
||||
GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
|
||||
fusion_local.gate_bias = fusion->gate_bias->data;
|
||||
}
|
||||
fusion_local.glu_op = fusion->glu_op;
|
||||
}
|
||||
|
||||
// If src0 is a temporary compute buffer, clear any potential padding.
|
||||
if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
|
||||
const size_t size_data = ggml_nbytes(src0);
|
||||
@@ -549,10 +676,10 @@ void ggml_cuda_mul_mat_vec_q(
|
||||
const int64_t stride_channel_y = ids ? s11 : s12;
|
||||
|
||||
mul_mat_vec_q_switch_type(
|
||||
src0->data, src0->type, src1_q8_1.get(), ids_d, dst_d, ne00,
|
||||
src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,
|
||||
ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
|
||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03, s13, s3, stream);
|
||||
ne03, ne3, s03, s13, s3, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_mul_mat_vec_q(
|
||||
@@ -578,8 +705,9 @@ void ggml_cuda_op_mul_mat_vec_q(
|
||||
const int stride_row_x = ne00 / ggml_blck_size(src0->type);
|
||||
const int stride_col_y = src1_padded_row_size / QK8_1;
|
||||
|
||||
ggml_cuda_mm_fusion_args_device fusion_local{};
|
||||
mul_mat_vec_q_switch_type(
|
||||
src0_dd_i, src0->type, src1_ddq_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
|
||||
src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
|
||||
|
||||
GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
|
||||
|
||||
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
|
||||
|
||||
void ggml_cuda_op_mul_mat_vec_q(
|
||||
ggml_backend_cuda_context & ctx,
|
||||
|
||||
+101
-47
@@ -4,30 +4,53 @@
|
||||
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
|
||||
|
||||
// Generic quantized set_rows kernel template
|
||||
template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
|
||||
static __global__ void k_set_rows_quant(
|
||||
const float * __restrict__ src0, const idx_t * __restrict__ src1, block_type * __restrict__ dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t s10, const int64_t s11, const int64_t s12,
|
||||
const int64_t s1, const int64_t s2, const int64_t s3) {
|
||||
|
||||
template <typename idx_t, typename block_type, int qk, void (*quantize_func)(const float *, block_type *)>
|
||||
static __global__ void k_set_rows_quant(const float * __restrict__ src0,
|
||||
const idx_t * __restrict__ src1,
|
||||
block_type * __restrict__ dst,
|
||||
const int64_t ne_total,
|
||||
const int64_t ne10,
|
||||
const int64_t ne11,
|
||||
const int64_t ne12,
|
||||
const int64_t ne13,
|
||||
const int64_t s01,
|
||||
const int64_t s02,
|
||||
const int64_t s03,
|
||||
const int64_t s10,
|
||||
const int64_t s11,
|
||||
const int64_t s12,
|
||||
const int64_t s1,
|
||||
const int64_t s2,
|
||||
const int64_t s3,
|
||||
const uint3 ne00,
|
||||
const uint3 ne01,
|
||||
const uint3 ne02,
|
||||
const uint3 ne11_fd,
|
||||
const uint3 ne12_fd) {
|
||||
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
|
||||
const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
|
||||
|
||||
if (i >= ne_total) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i_base = i * qk;
|
||||
const int64_t i03 = i_base / (ne00 * ne01 * ne02);
|
||||
const int64_t i02 = (i_base - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
|
||||
const int64_t i01 = (i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
|
||||
const int64_t i00 = i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
|
||||
uint32_t tmp = (uint32_t) i_base;
|
||||
uint2 div_mod;
|
||||
|
||||
const int64_t i12 = i03 % ne12;
|
||||
const int64_t i11 = i02 % ne11;
|
||||
div_mod = fast_div_modulo(tmp, ne00);
|
||||
const int64_t i00 = div_mod.y;
|
||||
tmp = div_mod.x;
|
||||
|
||||
div_mod = fast_div_modulo(tmp, ne01);
|
||||
const int64_t i01 = div_mod.y;
|
||||
tmp = div_mod.x;
|
||||
|
||||
div_mod = fast_div_modulo(tmp, ne02);
|
||||
const int64_t i02 = div_mod.y;
|
||||
const int64_t i03 = div_mod.x;
|
||||
|
||||
const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
|
||||
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
|
||||
const int64_t i10 = i01;
|
||||
|
||||
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
|
||||
@@ -41,6 +64,8 @@ static __global__ void k_set_rows_quant(
|
||||
quantize_func(src_block, dst_block);
|
||||
|
||||
GGML_UNUSED(ne10);
|
||||
GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12);
|
||||
GGML_UNUSED(ne13);
|
||||
}
|
||||
|
||||
@@ -71,40 +96,65 @@ static void set_rows_cuda_quant(
|
||||
const int64_t s2 = nb2;
|
||||
const int64_t s3 = nb3;
|
||||
|
||||
if (ne_total > 0) {
|
||||
if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
|
||||
const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
|
||||
const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
|
||||
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
|
||||
const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
|
||||
const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
|
||||
|
||||
k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
s01, s02, s03,
|
||||
s10, s11, s12,
|
||||
s1, s2, s3);
|
||||
src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd,
|
||||
ne01_fd, ne02_fd, ne11_fd, ne12_fd);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename src_t, typename idx_t, typename dst_t>
|
||||
static __global__ void k_set_rows(
|
||||
const src_t * __restrict__ src0, const idx_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t s10, const int64_t s11, const int64_t s12,
|
||||
const int64_t s1, const int64_t s2, const int64_t s3) {
|
||||
|
||||
template <typename src_t, typename idx_t, typename dst_t>
|
||||
static __global__ void k_set_rows(const src_t * __restrict__ src0,
|
||||
const idx_t * __restrict__ src1,
|
||||
dst_t * __restrict__ dst,
|
||||
const int64_t ne_total,
|
||||
const int64_t ne10,
|
||||
const int64_t ne11,
|
||||
const int64_t ne12,
|
||||
const int64_t ne13,
|
||||
const int64_t s01,
|
||||
const int64_t s02,
|
||||
const int64_t s03,
|
||||
const int64_t s10,
|
||||
const int64_t s11,
|
||||
const int64_t s12,
|
||||
const int64_t s1,
|
||||
const int64_t s2,
|
||||
const int64_t s3,
|
||||
const uint3 ne00,
|
||||
const uint3 ne01,
|
||||
const uint3 ne02,
|
||||
const uint3 ne11_fd,
|
||||
const uint3 ne12_fd) {
|
||||
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
|
||||
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
|
||||
|
||||
if (i >= ne_total) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i03 = i / (ne00 * ne01 * ne02);
|
||||
const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
|
||||
const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
|
||||
const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
|
||||
uint32_t tmp = (uint32_t) i;
|
||||
uint2 div_mod;
|
||||
|
||||
const int64_t i12 = i03 % ne12;
|
||||
const int64_t i11 = i02 % ne11;
|
||||
div_mod = fast_div_modulo(tmp, ne00);
|
||||
const int64_t i00 = div_mod.y;
|
||||
tmp = div_mod.x;
|
||||
|
||||
div_mod = fast_div_modulo(tmp, ne01);
|
||||
const int64_t i01 = div_mod.y;
|
||||
tmp = div_mod.x;
|
||||
|
||||
div_mod = fast_div_modulo(tmp, ne02);
|
||||
const int64_t i02 = div_mod.y;
|
||||
const int64_t i03 = div_mod.x;
|
||||
|
||||
const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
|
||||
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
|
||||
const int64_t i10 = i01;
|
||||
|
||||
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
|
||||
@@ -115,6 +165,8 @@ static __global__ void k_set_rows(
|
||||
dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
|
||||
|
||||
GGML_UNUSED(ne10);
|
||||
GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12);
|
||||
GGML_UNUSED(ne13);
|
||||
}
|
||||
|
||||
@@ -144,14 +196,16 @@ static void set_rows_cuda(
|
||||
const int64_t s2 = nb2/sizeof(dst_t);
|
||||
const int64_t s3 = nb3/sizeof(dst_t);
|
||||
|
||||
if (ne_total > 0) {
|
||||
k_set_rows<<<grid_size, block_size, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
s01, s02, s03,
|
||||
s10, s11, s12,
|
||||
s1, s2, s3);
|
||||
if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
|
||||
const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
|
||||
const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
|
||||
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
|
||||
const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
|
||||
const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
|
||||
|
||||
k_set_rows<<<grid_size, block_size, 0, stream>>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01,
|
||||
s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd,
|
||||
ne11_fd, ne12_fd);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
#include "set.cuh"
|
||||
#include "cpy.cuh"
|
||||
|
||||
void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32));
|
||||
GGML_ASSERT(src1->type == src0->type);
|
||||
GGML_ASSERT(dst ->type == src0->type);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
|
||||
const size_t nb1 = ((int32_t *) dst->op_params)[0];
|
||||
const size_t nb2 = ((int32_t *) dst->op_params)[1];
|
||||
const size_t nb3 = ((int32_t *) dst->op_params)[2];
|
||||
const size_t offset = ((int32_t *) dst->op_params)[3];
|
||||
const bool inplace= (bool) ((int32_t *) dst->op_params)[4];
|
||||
|
||||
if (!inplace) {
|
||||
ggml_cuda_cpy(ctx, src0, dst);
|
||||
}
|
||||
|
||||
ggml_tensor dst_view = *dst;
|
||||
dst_view.data = (void *)((char *)dst->data + offset);
|
||||
dst_view.ne[0] = src1->ne[0];
|
||||
dst_view.ne[1] = src1->ne[1];
|
||||
dst_view.ne[2] = src1->ne[2];
|
||||
dst_view.ne[3] = src1->ne[3];
|
||||
|
||||
dst_view.nb[0] = ggml_element_size(dst);
|
||||
dst_view.nb[1] = nb1;
|
||||
dst_view.nb[2] = nb2;
|
||||
dst_view.nb[3] = nb3;
|
||||
|
||||
ggml_cuda_cpy(ctx, src1, &dst_view);
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_SET_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "ggml.h"
|
||||
#include "topk-moe.cuh"
|
||||
|
||||
#include <cmath>
|
||||
#include <initializer_list>
|
||||
|
||||
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
|
||||
@@ -63,7 +64,8 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||
float * weights,
|
||||
int32_t * ids,
|
||||
const int n_rows,
|
||||
const int n_expert_used) {
|
||||
const int n_expert_used,
|
||||
const float clamp_val) {
|
||||
const int row = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
if (row >= n_rows) {
|
||||
return;
|
||||
@@ -139,6 +141,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||
|
||||
if constexpr (with_norm) {
|
||||
wt_sum = warp_reduce_sum(wt_sum);
|
||||
wt_sum = max(wt_sum, clamp_val);
|
||||
const float inv_sum = 1.0f / wt_sum;
|
||||
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
@@ -157,6 +160,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||
weights[idx] = output_weights[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (!with_norm) {
|
||||
GGML_UNUSED(clamp_val);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool with_norm, bool delayed_softmax = false>
|
||||
@@ -166,9 +173,9 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
||||
int32_t * ids,
|
||||
const int n_rows,
|
||||
const int n_expert,
|
||||
const int n_expert_used) {
|
||||
const int n_expert_used,
|
||||
const float clamp_val) {
|
||||
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
|
||||
|
||||
const int rows_per_block = 4;
|
||||
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
|
||||
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
|
||||
@@ -177,43 +184,43 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
||||
switch (n_expert) {
|
||||
case 1:
|
||||
topk_moe_cuda<1, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 2:
|
||||
topk_moe_cuda<2, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 4:
|
||||
topk_moe_cuda<4, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 8:
|
||||
topk_moe_cuda<8, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 16:
|
||||
topk_moe_cuda<16, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 32:
|
||||
topk_moe_cuda<32, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 64:
|
||||
topk_moe_cuda<64, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 128:
|
||||
topk_moe_cuda<128, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 256:
|
||||
topk_moe_cuda<256, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 512:
|
||||
topk_moe_cuda<512, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "fatal error");
|
||||
@@ -226,7 +233,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||
ggml_tensor * weights,
|
||||
ggml_tensor * ids,
|
||||
const bool with_norm,
|
||||
const bool delayed_softmax) {
|
||||
const bool delayed_softmax,
|
||||
ggml_tensor * clamp) {
|
||||
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
@@ -242,18 +250,25 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||
|
||||
const int n_expert_used = weights->ne[1];
|
||||
|
||||
float clamp_val = -INFINITY;
|
||||
if (with_norm) {
|
||||
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
||||
if (clamp) {
|
||||
clamp_val = ggml_get_op_params_f32(clamp, 0);
|
||||
}
|
||||
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
|
||||
} else {
|
||||
GGML_ASSERT(clamp == nullptr);
|
||||
if (delayed_softmax) {
|
||||
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
||||
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
|
||||
clamp_val);
|
||||
} else {
|
||||
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
||||
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
|
||||
clamp_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) {
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
@@ -279,13 +294,26 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
|
||||
return false;
|
||||
}
|
||||
|
||||
if (clamp) {
|
||||
if (clamp->op != GGML_OP_CLAMP) {
|
||||
return false;
|
||||
}
|
||||
float max_val = ggml_get_op_params_f32(clamp, 1);
|
||||
|
||||
if (max_val != INFINITY) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
|
||||
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
|
||||
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
|
||||
GGML_OP_RESHAPE };
|
||||
|
||||
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
||||
|
||||
@@ -8,8 +8,9 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||
ggml_tensor * weights,
|
||||
ggml_tensor * ids,
|
||||
const bool with_norm,
|
||||
const bool delayed_softmax = false);
|
||||
const bool delayed_softmax = false,
|
||||
ggml_tensor * weight_clamp = nullptr);
|
||||
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr);
|
||||
|
||||
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
|
||||
|
||||
@@ -18,10 +18,7 @@ static __device__ __forceinline__ float op_step(float x) {
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_gelu(float x) {
|
||||
const float GELU_COEF_A = 0.044715f;
|
||||
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
return ggml_cuda_op_gelu_single(x);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_gelu_erf(float x) {
|
||||
@@ -37,7 +34,7 @@ static __device__ __forceinline__ float op_gelu_quick(float x) {
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_silu(float x) {
|
||||
return x / (1.0f + expf(-x));
|
||||
return ggml_cuda_op_silu_single(x);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_tanh(float x) {
|
||||
@@ -317,13 +314,8 @@ static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, cons
|
||||
|
||||
float xi = x[j0];
|
||||
float gi = g[j1];
|
||||
xi = fminf(xi, limit);
|
||||
gi = fmaxf(fminf(gi, limit), -limit);
|
||||
|
||||
float out_glu = xi / (1.0f + expf(-xi * alpha));
|
||||
out_glu = out_glu * (1.0f + gi);
|
||||
|
||||
dst[i] = out_glu;
|
||||
dst[i] = ggml_cuda_op_swiglu_oai_single(xi, gi, alpha, limit);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#pragma once
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_NEG_BLOCK_SIZE 256
|
||||
@@ -75,3 +76,23 @@ void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
__device__ __forceinline__ float ggml_cuda_op_silu_single(float x) {
|
||||
return x / (1.0f + expf(-x));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float ggml_cuda_op_gelu_single(float x) {
|
||||
const float GELU_COEF_A = 0.044715f;
|
||||
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
return 0.5f * x * (1.0f + tanhf(SQRT_2_OVER_PI * x * (1.0f + GELU_COEF_A * x * x)));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float ggml_cuda_op_swiglu_oai_single(float x, float g, float alpha = 1.702f, float limit = 7.0f) {
|
||||
x = fminf(x, limit);
|
||||
g = fmaxf(fminf(g, limit), -limit);
|
||||
|
||||
float out_glu = x / (1.0f + expf(-x * alpha));
|
||||
out_glu = out_glu * (1.0f + g);
|
||||
return out_glu;
|
||||
}
|
||||
|
||||
@@ -126,8 +126,8 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
||||
float pixel_offset = 0.5f;
|
||||
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
||||
sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
|
||||
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
|
||||
sf0 = dst->ne[0] > 1 && src0->ne[0] > 1 ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) : sf0;
|
||||
sf1 = dst->ne[1] > 1 && src0->ne[1] > 1 ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) : sf1;
|
||||
pixel_offset = 0.0f;
|
||||
}
|
||||
upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
|
||||
@@ -211,12 +211,15 @@ static inline void hex_format_op_names(char * str, const struct ggml_tensor * t)
|
||||
// ** backend sessions
|
||||
|
||||
struct ggml_hexagon_session {
|
||||
ggml_hexagon_session(int dev_id) noexcept(false);
|
||||
ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false);
|
||||
~ggml_hexagon_session() noexcept(true);
|
||||
|
||||
void allocate(int dev_id) noexcept(false);
|
||||
void release() noexcept(true);
|
||||
|
||||
void enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync = false);
|
||||
void flush();
|
||||
|
||||
ggml_backend_buffer_type buffer_type;
|
||||
ggml_backend_buffer_type repack_buffer_type;
|
||||
|
||||
@@ -237,15 +240,37 @@ struct ggml_hexagon_session {
|
||||
uint32_t prof_pkts;
|
||||
};
|
||||
|
||||
// Packet callback
|
||||
static void htp_packet_callback(dspqueue_t queue, AEEResult error, void * context) {
|
||||
auto sess = static_cast<ggml_hexagon_session *>(context);
|
||||
void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync) {
|
||||
// Bump pending flag (cleared in the session::flush once we get the responce)
|
||||
this->op_pending++; // atomic inc
|
||||
|
||||
int err = dspqueue_write(this->queue,
|
||||
0, // flags - the framework will autoset this
|
||||
n_bufs, // number of buffers
|
||||
bufs, // buffer references
|
||||
sizeof(req),
|
||||
(const uint8_t *) &req, // Message
|
||||
1000000 // Timeout
|
||||
);
|
||||
|
||||
if (err != 0) {
|
||||
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->name.c_str(), (unsigned) err);
|
||||
}
|
||||
|
||||
if (sync) {
|
||||
flush();
|
||||
}
|
||||
}
|
||||
|
||||
// Flush HTP response queue i.e wait for all outstanding requests to complete
|
||||
void ggml_hexagon_session::flush() {
|
||||
dspqueue_t q = this->queue;
|
||||
|
||||
// Repeatedly read packets from the queue until it's empty. We don't
|
||||
// necessarily get a separate callback for each packet, and new packets
|
||||
// may arrive while we're processing the previous one.
|
||||
|
||||
while (1) {
|
||||
while (this->op_pending) {
|
||||
struct htp_general_rsp rsp;
|
||||
uint32_t rsp_size;
|
||||
uint32_t flags;
|
||||
@@ -253,22 +278,23 @@ static void htp_packet_callback(dspqueue_t queue, AEEResult error, void * contex
|
||||
struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
uint32_t n_bufs;
|
||||
|
||||
// Read packet from queue
|
||||
int err = dspqueue_read_noblock(queue, &flags,
|
||||
HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references
|
||||
&n_bufs, // Number of buffer references
|
||||
bufs, // Buffer references
|
||||
sizeof(rsp), // Max message length
|
||||
&rsp_size, // Message length
|
||||
(uint8_t *) &rsp);
|
||||
// Read response packet from queue
|
||||
int err = dspqueue_read(q, &flags,
|
||||
HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references
|
||||
&n_bufs, // Number of buffer references
|
||||
bufs, // Buffer references
|
||||
sizeof(rsp), // Max message length
|
||||
&rsp_size, // Message length
|
||||
(uint8_t *) &rsp,
|
||||
1000000); // Timeout
|
||||
|
||||
if (err == AEE_EWOULDBLOCK) {
|
||||
// Consumed all packets available for now
|
||||
return;
|
||||
if (err == AEE_EEXPIRED) {
|
||||
// TODO: might need to bail out if the HTP is stuck on something
|
||||
continue;
|
||||
}
|
||||
|
||||
if (err != 0) {
|
||||
GGML_ABORT("ggml-hex: dspqueue_read_noblock failed: 0x%08x\n", (unsigned) err);
|
||||
GGML_ABORT("ggml-hex: dspqueue_read failed: 0x%08x\n", (unsigned) err);
|
||||
}
|
||||
|
||||
// Basic sanity checks
|
||||
@@ -281,21 +307,15 @@ static void htp_packet_callback(dspqueue_t queue, AEEResult error, void * contex
|
||||
// TODO: handle errors
|
||||
}
|
||||
|
||||
// FIXME: update profiling implementation
|
||||
sess->prof_usecs = rsp.prof_usecs;
|
||||
sess->prof_cycles = rsp.prof_cycles;
|
||||
sess->prof_pkts = rsp.prof_pkts;
|
||||
// TODO: update profiling implementation, currently only works for opt_opsync mode
|
||||
this->prof_usecs = rsp.prof_usecs;
|
||||
this->prof_cycles = rsp.prof_cycles;
|
||||
this->prof_pkts = rsp.prof_pkts;
|
||||
|
||||
sess->op_pending--; // atomic dec
|
||||
this->op_pending--; // atomic dec
|
||||
}
|
||||
}
|
||||
|
||||
// Error callback - simply terminates with an error. Used where we don't
|
||||
// expect errors.
|
||||
[[noreturn]] static void htp_error_callback(dspqueue_t queue, AEEResult error, void * context) {
|
||||
GGML_ABORT("ggml-hex: dspcall general error 0x%x: for queue %p\n", error, (void *) queue);
|
||||
}
|
||||
|
||||
// ** backend buffers
|
||||
|
||||
struct ggml_backend_hexagon_buffer_type_context {
|
||||
@@ -1564,7 +1584,8 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
|
||||
0, // Flags
|
||||
128 * 1024, // Request queue size (in bytes)
|
||||
64 * 1024, // Response queue size (in bytes)
|
||||
htp_packet_callback, htp_error_callback,
|
||||
nullptr, // Read packet callback (we handle reads explicitly)
|
||||
nullptr, // Error callback (we handle errors during reads)
|
||||
(void *) this, // Callback context
|
||||
&queue);
|
||||
if (err != 0) {
|
||||
@@ -1631,10 +1652,13 @@ void ggml_hexagon_session::release() noexcept(true) {
|
||||
}
|
||||
}
|
||||
|
||||
ggml_hexagon_session::ggml_hexagon_session(int dev_id) noexcept(false) {
|
||||
ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false) {
|
||||
buffer_type.context = nullptr;
|
||||
repack_buffer_type.context = nullptr;
|
||||
|
||||
buffer_type.device = dev;
|
||||
repack_buffer_type.device = dev;
|
||||
|
||||
try {
|
||||
allocate(dev_id);
|
||||
|
||||
@@ -2202,7 +2226,7 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags)
|
||||
bufs[0].ptr = src0->data;
|
||||
bufs[0].offset = (uint8_t *) src0->data - src0_buf->base;
|
||||
bufs[0].size = ggml_nbytes(src0);
|
||||
bufs[0].flags = DSPQUEUE_BUFFER_FLAG_REF;
|
||||
bufs[0].flags = 0;
|
||||
|
||||
// Second buffer Input Activations. This is a buffer that the CPU
|
||||
// writes and the DSP reads, so we'll need to flush CPU caches and
|
||||
@@ -2212,8 +2236,7 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags)
|
||||
bufs[1].ptr = src1->data;
|
||||
bufs[1].offset = (uint8_t *) src1->data - src1_buf->base;
|
||||
bufs[1].size = ggml_nbytes(src1);
|
||||
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
|
||||
// Third buffer Output Activations. We'll handle DSP
|
||||
@@ -2224,7 +2247,7 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags)
|
||||
bufs[2].ptr = dst->data;
|
||||
bufs[2].offset = (uint8_t *) dst->data - dst_buf->base;
|
||||
bufs[2].size = ggml_nbytes(dst);
|
||||
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
|
||||
// Primary DSP session from the src0 (normally weight) tensor
|
||||
auto sess = src0_buf->sess;
|
||||
@@ -2252,27 +2275,7 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags)
|
||||
}
|
||||
|
||||
if ((opt_opmask & HTP_OPMASK_QUEUE)) {
|
||||
// Bump pending flag (cleared in the callback once we get the responce)
|
||||
sess->op_pending++; // atomic inc
|
||||
|
||||
int err = dspqueue_write(sess->queue,
|
||||
0, // flags - the framework will autoset this
|
||||
3, // number of buffers
|
||||
bufs, // buffer references
|
||||
sizeof(req),
|
||||
(const uint8_t *) &req, // Message
|
||||
1000000 // Timeout
|
||||
);
|
||||
|
||||
if (err != 0) {
|
||||
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
|
||||
}
|
||||
}
|
||||
|
||||
if (opt_opsync) {
|
||||
while (sess->op_pending) {
|
||||
;
|
||||
}
|
||||
sess->enqueue(req, bufs, 3, opt_opsync);
|
||||
}
|
||||
|
||||
t2 = ggml_time_us();
|
||||
@@ -2328,7 +2331,7 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag
|
||||
bufs[0].ptr = src0->data;
|
||||
bufs[0].offset = (uint8_t *) src0->data - src0_buf->base;
|
||||
bufs[0].size = ggml_nbytes(src0);
|
||||
bufs[0].flags = DSPQUEUE_BUFFER_FLAG_REF;
|
||||
bufs[0].flags = 0;
|
||||
|
||||
// Second buffer Input Activations. This is a buffer that the CPU
|
||||
// writes and the DSP reads, so we'll need to flush CPU caches and
|
||||
@@ -2338,8 +2341,7 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag
|
||||
bufs[1].ptr = src1->data;
|
||||
bufs[1].offset = (uint8_t *) src1->data - src1_buf->base;
|
||||
bufs[1].size = ggml_nbytes(src1);
|
||||
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
|
||||
// Third buffer expert IDs. This is a buffer that the CPU
|
||||
@@ -2350,8 +2352,7 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag
|
||||
bufs[2].ptr = src2->data;
|
||||
bufs[2].offset = (uint8_t *) src2->data - src2_buf->base;
|
||||
bufs[2].size = ggml_nbytes(src2);
|
||||
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
|
||||
// Forth buffer Output Activations. We'll handle DSP
|
||||
@@ -2362,7 +2363,7 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag
|
||||
bufs[3].ptr = dst->data;
|
||||
bufs[3].offset = (uint8_t *) dst->data - dst_buf->base;
|
||||
bufs[3].size = ggml_nbytes(dst);
|
||||
bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
|
||||
// Primary DSP session from the src0 (normally weight) tensor
|
||||
auto sess = src0_buf->sess;
|
||||
@@ -2391,27 +2392,7 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag
|
||||
}
|
||||
|
||||
if ((opt_opmask & HTP_OPMASK_QUEUE)) {
|
||||
// Bump pending flag (cleared in the callback once we get the responce)
|
||||
sess->op_pending++; // atomic inc
|
||||
|
||||
int err = dspqueue_write(sess->queue,
|
||||
0, // flags - the framework will autoset this
|
||||
4, // number of buffers
|
||||
bufs, // buffer references
|
||||
sizeof(req),
|
||||
(const uint8_t *) &req, // Message
|
||||
1000000 // Timeout
|
||||
);
|
||||
|
||||
if (err != 0) {
|
||||
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
|
||||
}
|
||||
}
|
||||
|
||||
if (opt_opsync) {
|
||||
while (sess->op_pending) {
|
||||
;
|
||||
}
|
||||
sess->enqueue(req, bufs, 4, opt_opsync);
|
||||
}
|
||||
|
||||
t2 = ggml_time_us();
|
||||
@@ -2484,8 +2465,7 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[0].ptr = src0->data;
|
||||
bufs[0].offset = (uint8_t *) src0->data - src0_buf->base;
|
||||
bufs[0].size = ggml_nbytes(src0);
|
||||
bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP;
|
||||
|
||||
// Second buffer = Second Operand of Binary op
|
||||
@@ -2497,8 +2477,7 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[1].ptr = src1->data;
|
||||
bufs[1].offset = (uint8_t *) src1->data - src1_buf->base;
|
||||
bufs[1].size = ggml_nbytes(src1);
|
||||
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
|
||||
// Third buffer = Output Activations. We'll handle DSP
|
||||
@@ -2509,7 +2488,7 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[2].ptr = dst->data;
|
||||
bufs[2].offset = (uint8_t *) dst->data - dst_buf->base;
|
||||
bufs[2].size = ggml_nbytes(dst);
|
||||
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
|
||||
// Primary DSP session from the src0 tensor
|
||||
ggml_hexagon_session * sess = src0_buf->sess;
|
||||
@@ -2537,26 +2516,7 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) {
|
||||
}
|
||||
|
||||
if ((opt_opmask & HTP_OPMASK_QUEUE)) {
|
||||
// Bump pending flag (cleared in the callback once we get the responce)
|
||||
sess->op_pending++; // atomic inc
|
||||
|
||||
int err = dspqueue_write(sess->queue,
|
||||
0, // flags - the framework will autoset this
|
||||
3, // number of buffers
|
||||
bufs, // buffer references
|
||||
sizeof(req),
|
||||
(const uint8_t *) &req, // Message
|
||||
1000000); // Timeout
|
||||
|
||||
if (0 != err) {
|
||||
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
|
||||
}
|
||||
}
|
||||
|
||||
if (opt_opsync) {
|
||||
while (sess->op_pending) {
|
||||
;
|
||||
}
|
||||
sess->enqueue(req, bufs, 3, opt_opsync);
|
||||
}
|
||||
|
||||
t2 = ggml_time_us();
|
||||
@@ -2621,8 +2581,7 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[0].ptr = src0->data;
|
||||
bufs[0].offset = (uint8_t *) src0->data - src0_buf->base;
|
||||
bufs[0].size = ggml_nbytes(src0);
|
||||
bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP;
|
||||
|
||||
// Second buffer = experts bias
|
||||
@@ -2630,8 +2589,7 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[1].ptr = src1->data;
|
||||
bufs[1].offset = (uint8_t *) src1->data - src1_buf->base;
|
||||
bufs[1].size = ggml_nbytes(src1);
|
||||
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
|
||||
// Third buffer = activated experts
|
||||
@@ -2639,8 +2597,7 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[2].ptr = src2->data;
|
||||
bufs[2].offset = (uint8_t *) src2->data - src2_buf->base;
|
||||
bufs[2].size = ggml_nbytes(src2);
|
||||
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
|
||||
// Forth buffer = output activations
|
||||
@@ -2648,7 +2605,7 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[3].ptr = dst->data;
|
||||
bufs[3].offset = (uint8_t *) dst->data - dst_buf->base;
|
||||
bufs[3].size = ggml_nbytes(dst);
|
||||
bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
|
||||
// Primary DSP session from the src0 tensor
|
||||
ggml_hexagon_session * sess = src0_buf->sess;
|
||||
@@ -2678,26 +2635,7 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) {
|
||||
}
|
||||
|
||||
if ((opt_opmask & HTP_OPMASK_QUEUE)) {
|
||||
// Bump pending flag (cleared in the callback once we get the responce)
|
||||
sess->op_pending++; // atomic inc
|
||||
|
||||
int err = dspqueue_write(sess->queue,
|
||||
0, // flags - the framework will autoset this
|
||||
4, // number of buffers
|
||||
bufs, // buffer references
|
||||
sizeof(req),
|
||||
(const uint8_t *) &req, // Message
|
||||
1000000); // Timeout
|
||||
|
||||
if (0 != err) {
|
||||
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
|
||||
}
|
||||
}
|
||||
|
||||
if (opt_opsync) {
|
||||
while (sess->op_pending) {
|
||||
;
|
||||
}
|
||||
sess->enqueue(req, bufs, 4, opt_opsync);
|
||||
}
|
||||
|
||||
t2 = ggml_time_us();
|
||||
@@ -2795,8 +2733,7 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[n_bufs].ptr = src0->data;
|
||||
bufs[n_bufs].offset = (uint8_t *) src0->data - src0_buf->base;
|
||||
bufs[n_bufs].size = ggml_nbytes(src0);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP;
|
||||
++n_bufs;
|
||||
|
||||
@@ -2811,8 +2748,7 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[n_bufs].ptr = src1->data;
|
||||
bufs[n_bufs].offset = (uint8_t *) src1->data - src1_buf->base;
|
||||
bufs[n_bufs].size = ggml_nbytes(src1);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
++n_bufs;
|
||||
}
|
||||
@@ -2827,7 +2763,7 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[n_bufs].ptr = dst->data;
|
||||
bufs[n_bufs].offset = (uint8_t *) dst->data - dst_buf->base;
|
||||
bufs[n_bufs].size = ggml_nbytes(dst);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
++n_bufs;
|
||||
|
||||
// Primary DSP session from the src0 tensor
|
||||
@@ -2860,26 +2796,7 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
|
||||
}
|
||||
|
||||
if ((opt_opmask & HTP_OPMASK_QUEUE)) {
|
||||
// Bump pending flag (cleared in the callback once we get the responce)
|
||||
sess->op_pending++; // atomic inc
|
||||
|
||||
int err = dspqueue_write(sess->queue,
|
||||
0, // flags - the framework will autoset this
|
||||
n_bufs, // number of buffers
|
||||
bufs, // buffer references
|
||||
sizeof(req),
|
||||
(const uint8_t *) &req, // Message
|
||||
1000000); // Timeout
|
||||
|
||||
if (0 != err) {
|
||||
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
|
||||
}
|
||||
}
|
||||
|
||||
if (opt_opsync) {
|
||||
while (sess->op_pending) {
|
||||
;
|
||||
}
|
||||
sess->enqueue(req, bufs, n_bufs, opt_opsync);
|
||||
}
|
||||
|
||||
t2 = ggml_time_us();
|
||||
@@ -2953,8 +2870,7 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[n_bufs].ptr = src0->data;
|
||||
bufs[n_bufs].offset = (uint8_t *) src0->data - src0_buf->base;
|
||||
bufs[n_bufs].size = ggml_nbytes(src0);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP;
|
||||
++n_bufs;
|
||||
|
||||
@@ -2968,8 +2884,7 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[n_bufs].ptr = src1->data;
|
||||
bufs[n_bufs].offset = (uint8_t *) src1->data - src1_buf->base;
|
||||
bufs[n_bufs].size = ggml_nbytes(src1);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
++n_bufs;
|
||||
|
||||
@@ -2984,8 +2899,7 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[n_bufs].ptr = src2->data;
|
||||
bufs[n_bufs].offset = (uint8_t *) src2->data - src2_buf->base;
|
||||
bufs[n_bufs].size = ggml_nbytes(src2);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
++n_bufs;
|
||||
}
|
||||
@@ -3000,7 +2914,7 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[n_bufs].ptr = dst->data;
|
||||
bufs[n_bufs].offset = (uint8_t *) dst->data - dst_buf->base;
|
||||
bufs[n_bufs].size = ggml_nbytes(dst);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
++n_bufs;
|
||||
|
||||
// Primary DSP session from the src0 tensor
|
||||
@@ -3033,26 +2947,7 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) {
|
||||
}
|
||||
|
||||
if ((opt_opmask & HTP_OPMASK_QUEUE)) {
|
||||
// Bump pending flag (cleared in the callback once we get the responce)
|
||||
sess->op_pending++; // atomic inc
|
||||
|
||||
int err = dspqueue_write(sess->queue,
|
||||
0, // flags - the framework will autoset this
|
||||
n_bufs, // number of buffers
|
||||
bufs, // buffer references
|
||||
sizeof(req),
|
||||
(const uint8_t *) &req, // Message
|
||||
1000000); // Timeout
|
||||
|
||||
if (0 != err) {
|
||||
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
|
||||
}
|
||||
}
|
||||
|
||||
if (opt_opsync) {
|
||||
while (sess->op_pending) {
|
||||
;
|
||||
}
|
||||
sess->enqueue(req, bufs, n_bufs, opt_opsync);
|
||||
}
|
||||
|
||||
t2 = ggml_time_us();
|
||||
@@ -3197,9 +3092,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
}
|
||||
|
||||
// Wait until all pending ops complete
|
||||
while (sess->op_pending) {
|
||||
;
|
||||
}
|
||||
sess->flush();
|
||||
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
@@ -3210,9 +3103,7 @@ static void ggml_backend_hexagon_synchronize(ggml_backend_t backend) {
|
||||
HEX_VERBOSE("ggml-hex: %s synchronize\n", sess->name.c_str());
|
||||
|
||||
// Wait until all pending ops complete
|
||||
while (sess->op_pending) {
|
||||
;
|
||||
}
|
||||
sess->flush();
|
||||
}
|
||||
|
||||
struct node_info {
|
||||
@@ -3628,7 +3519,7 @@ ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) {
|
||||
devices[i].iface = ggml_backend_hexagon_device_i;
|
||||
devices[i].reg = reg;
|
||||
try {
|
||||
devices[i].context = new ggml_hexagon_session(i);
|
||||
devices[i].context = new ggml_hexagon_session(i, &devices[i]);
|
||||
} catch (std::exception const &exc) {
|
||||
GGML_LOG_ERROR("ggml-hex: failed to create device/session %zu\n", i);
|
||||
devices[i].context = nullptr;
|
||||
|
||||
@@ -395,28 +395,14 @@ static void proc_matmul_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
size_t n_bufs) {
|
||||
// Prep response buffer structs (needed for error responses, etc)
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||
rsp_bufs[0].fd = bufs[0].fd;
|
||||
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||
rsp_bufs[0].size = bufs[0].size;
|
||||
rsp_bufs[0].offset = bufs[0].offset;
|
||||
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
rsp_bufs[1].fd = bufs[1].fd;
|
||||
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||
rsp_bufs[1].size = bufs[1].size;
|
||||
rsp_bufs[1].offset = bufs[1].offset;
|
||||
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[2].fd = bufs[2].fd;
|
||||
rsp_bufs[2].ptr = bufs[2].ptr;
|
||||
rsp_bufs[2].size = bufs[2].size;
|
||||
rsp_bufs[2].offset = bufs[2].offset;
|
||||
rsp_bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||
rsp_bufs[0].fd = bufs[2].fd;
|
||||
rsp_bufs[0].ptr = bufs[2].ptr;
|
||||
rsp_bufs[0].size = bufs[2].size;
|
||||
rsp_bufs[0].offset = bufs[2].offset;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
@@ -444,41 +430,21 @@ static void proc_matmul_req(struct htp_context * ctx,
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 3, &prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_matmul_id_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
size_t n_bufs) {
|
||||
// Prep response buffer structs (needed for error responses, etc)
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||
rsp_bufs[0].fd = bufs[0].fd;
|
||||
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||
rsp_bufs[0].size = bufs[0].size;
|
||||
rsp_bufs[0].offset = bufs[0].offset;
|
||||
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
rsp_bufs[1].fd = bufs[1].fd;
|
||||
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||
rsp_bufs[1].size = bufs[1].size;
|
||||
rsp_bufs[1].offset = bufs[1].offset;
|
||||
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
rsp_bufs[2].fd = bufs[2].fd;
|
||||
rsp_bufs[2].ptr = bufs[2].ptr;
|
||||
rsp_bufs[2].size = bufs[2].size;
|
||||
rsp_bufs[2].offset = bufs[2].offset;
|
||||
rsp_bufs[2].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[3].fd = bufs[3].fd;
|
||||
rsp_bufs[3].ptr = bufs[3].ptr;
|
||||
rsp_bufs[3].size = bufs[3].size;
|
||||
rsp_bufs[3].offset = bufs[3].offset;
|
||||
rsp_bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||
rsp_bufs[0].fd = bufs[3].fd;
|
||||
rsp_bufs[0].ptr = bufs[3].ptr;
|
||||
rsp_bufs[0].size = bufs[3].size;
|
||||
rsp_bufs[0].offset = bufs[3].offset;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
@@ -508,32 +474,18 @@ static void proc_matmul_id_req(struct htp_context * ctx,
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 4, &prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_binary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||
|
||||
rsp_bufs[0].fd = bufs[0].fd;
|
||||
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||
rsp_bufs[0].offset = bufs[0].offset;
|
||||
rsp_bufs[0].size = bufs[0].size;
|
||||
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
rsp_bufs[1].fd = bufs[1].fd;
|
||||
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||
rsp_bufs[1].offset = bufs[1].offset;
|
||||
rsp_bufs[1].size = bufs[1].size;
|
||||
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[2].fd = bufs[2].fd;
|
||||
rsp_bufs[2].ptr = bufs[2].ptr;
|
||||
rsp_bufs[2].offset = bufs[2].offset;
|
||||
rsp_bufs[2].size = bufs[2].size;
|
||||
rsp_bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||
rsp_bufs[0].fd = bufs[2].fd;
|
||||
rsp_bufs[0].ptr = bufs[2].ptr;
|
||||
rsp_bufs[0].offset = bufs[2].offset;
|
||||
rsp_bufs[0].size = bufs[2].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
@@ -561,38 +513,18 @@ static void proc_binary_req(struct htp_context * ctx, struct htp_general_req * r
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 3, &prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_add_id_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||
|
||||
rsp_bufs[0].fd = bufs[0].fd;
|
||||
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||
rsp_bufs[0].offset = bufs[0].offset;
|
||||
rsp_bufs[0].size = bufs[0].size;
|
||||
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
rsp_bufs[1].fd = bufs[1].fd;
|
||||
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||
rsp_bufs[1].offset = bufs[1].offset;
|
||||
rsp_bufs[1].size = bufs[1].size;
|
||||
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
rsp_bufs[2].fd = bufs[2].fd;
|
||||
rsp_bufs[2].ptr = bufs[2].ptr;
|
||||
rsp_bufs[2].offset = bufs[2].offset;
|
||||
rsp_bufs[2].size = bufs[2].size;
|
||||
rsp_bufs[2].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[3].fd = bufs[3].fd;
|
||||
rsp_bufs[3].ptr = bufs[3].ptr;
|
||||
rsp_bufs[3].offset = bufs[3].offset;
|
||||
rsp_bufs[3].size = bufs[3].size;
|
||||
rsp_bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||
rsp_bufs[0].fd = bufs[3].fd;
|
||||
rsp_bufs[0].ptr = bufs[3].ptr;
|
||||
rsp_bufs[0].offset = bufs[3].offset;
|
||||
rsp_bufs[0].size = bufs[3].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
@@ -622,26 +554,18 @@ static void proc_add_id_req(struct htp_context * ctx, struct htp_general_req * r
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 4, &prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||
|
||||
rsp_bufs[0].fd = bufs[0].fd;
|
||||
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||
rsp_bufs[0].offset = bufs[0].offset;
|
||||
rsp_bufs[0].size = bufs[0].size;
|
||||
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[1].fd = bufs[1].fd;
|
||||
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||
rsp_bufs[1].offset = bufs[1].offset;
|
||||
rsp_bufs[1].size = bufs[1].size;
|
||||
rsp_bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||
rsp_bufs[0].fd = bufs[1].fd;
|
||||
rsp_bufs[0].ptr = bufs[1].ptr;
|
||||
rsp_bufs[0].offset = bufs[1].offset;
|
||||
rsp_bufs[0].size = bufs[1].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
@@ -669,7 +593,7 @@ static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * re
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 2, &prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_activations_req(struct htp_context * ctx,
|
||||
@@ -677,33 +601,16 @@ static void proc_activations_req(struct htp_context * ctx,
|
||||
struct dspqueue_buffer * bufs,
|
||||
uint32_t n_bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||
|
||||
rsp_bufs[0].fd = bufs[0].fd;
|
||||
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||
rsp_bufs[0].offset = bufs[0].offset;
|
||||
rsp_bufs[0].size = bufs[0].size;
|
||||
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
int write_idx = 1;
|
||||
if (3 == n_bufs) {
|
||||
rsp_bufs[1].fd = bufs[1].fd;
|
||||
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||
rsp_bufs[1].offset = bufs[1].offset;
|
||||
rsp_bufs[1].size = bufs[1].size;
|
||||
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
write_idx = 2;
|
||||
}
|
||||
int write_idx = (n_bufs == 3) ? 2 : 1;
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[write_idx].fd = bufs[write_idx].fd;
|
||||
rsp_bufs[write_idx].ptr = bufs[write_idx].ptr;
|
||||
rsp_bufs[write_idx].offset = bufs[write_idx].offset;
|
||||
rsp_bufs[write_idx].size = bufs[write_idx].size;
|
||||
rsp_bufs[write_idx].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
rsp_bufs[0].fd = bufs[write_idx].fd;
|
||||
rsp_bufs[0].ptr = bufs[write_idx].ptr;
|
||||
rsp_bufs[0].offset = bufs[write_idx].offset;
|
||||
rsp_bufs[0].size = bufs[write_idx].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
@@ -742,7 +649,7 @@ static void proc_activations_req(struct htp_context * ctx,
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, n_bufs, &prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_rope_req(struct htp_context * ctx,
|
||||
@@ -750,39 +657,16 @@ static void proc_rope_req(struct htp_context * ctx,
|
||||
struct dspqueue_buffer * bufs,
|
||||
uint32_t n_bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||
|
||||
rsp_bufs[0].fd = bufs[0].fd;
|
||||
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||
rsp_bufs[0].offset = bufs[0].offset;
|
||||
rsp_bufs[0].size = bufs[0].size;
|
||||
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
rsp_bufs[1].fd = bufs[1].fd;
|
||||
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||
rsp_bufs[1].offset = bufs[1].offset;
|
||||
rsp_bufs[1].size = bufs[1].size;
|
||||
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
int write_idx = 2;
|
||||
if (4 == n_bufs) {
|
||||
rsp_bufs[write_idx].fd = bufs[write_idx].fd;
|
||||
rsp_bufs[write_idx].ptr = bufs[write_idx].ptr;
|
||||
rsp_bufs[write_idx].offset = bufs[write_idx].offset;
|
||||
rsp_bufs[write_idx].size = bufs[write_idx].size;
|
||||
rsp_bufs[write_idx].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
write_idx++;
|
||||
}
|
||||
int write_idx = (n_bufs == 4) ? 3 : 2;
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[write_idx].fd = bufs[write_idx].fd;
|
||||
rsp_bufs[write_idx].ptr = bufs[write_idx].ptr;
|
||||
rsp_bufs[write_idx].offset = bufs[write_idx].offset;
|
||||
rsp_bufs[write_idx].size = bufs[write_idx].size;
|
||||
rsp_bufs[write_idx].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
rsp_bufs[0].fd = bufs[write_idx].fd;
|
||||
rsp_bufs[0].ptr = bufs[write_idx].ptr;
|
||||
rsp_bufs[0].offset = bufs[write_idx].offset;
|
||||
rsp_bufs[0].size = bufs[write_idx].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
@@ -819,7 +703,7 @@ static void proc_rope_req(struct htp_context * ctx,
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, n_bufs, &prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
|
||||
@@ -29,10 +29,11 @@ if (CXX_IS_HIPCC)
|
||||
endif()
|
||||
else()
|
||||
# Forward (AMD)GPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
|
||||
if(AMDGPU_TARGETS AND NOT GPU_TARGETS)
|
||||
set(GPU_TARGETS ${AMDGPU_TARGETS})
|
||||
endif()
|
||||
if(GPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
|
||||
set(CMAKE_HIP_ARCHITECTURES ${GPU_TARGETS})
|
||||
elseif(AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
|
||||
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
|
||||
endif()
|
||||
cmake_minimum_required(VERSION 3.21)
|
||||
enable_language(HIP)
|
||||
|
||||
@@ -682,6 +682,7 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
#include <array>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
|
||||
@@ -697,6 +698,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
|
||||
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
|
||||
}
|
||||
|
||||
// Return true if the edges in the graph match expectations.
|
||||
inline bool ggml_check_edges(const struct ggml_cgraph * cgraph,
|
||||
int start_idx,
|
||||
std::initializer_list<std::array<int, 3>> edges) {
|
||||
for (const auto & edge : edges) {
|
||||
int dst_node = edge[0];
|
||||
int src_idx = edge[1];
|
||||
int src_node = edge[2];
|
||||
if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// expose GGUF internals for test code
|
||||
GGML_API size_t gguf_type_size(enum gguf_type type);
|
||||
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
|
||||
|
||||
@@ -6156,8 +6156,8 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf3));
|
||||
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
||||
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
||||
sf0 = (float)(ne0 - 1) / (ne00 - 1);
|
||||
sf1 = (float)(ne1 - 1) / (ne01 - 1);
|
||||
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
|
||||
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
|
||||
pixel_offset = 0.0f;
|
||||
}
|
||||
|
||||
|
||||
@@ -32,8 +32,10 @@
|
||||
#include "pad.hpp"
|
||||
#include "quantize.hpp"
|
||||
#include "quants.hpp"
|
||||
#include "roll.hpp"
|
||||
#include "rope.hpp"
|
||||
#include "set_rows.hpp"
|
||||
#include "ssm_conv.hpp"
|
||||
#include "softmax.hpp"
|
||||
#include "tsembd.hpp"
|
||||
#include "wkv.hpp"
|
||||
|
||||
@@ -42,13 +42,16 @@
|
||||
#include "ggml-sycl/backend.hpp"
|
||||
#include "ggml-sycl/common.hpp"
|
||||
#include "ggml-sycl/element_wise.hpp"
|
||||
#include "ggml-sycl/norm.hpp"
|
||||
#include "ggml-sycl/presets.hpp"
|
||||
#include "ggml-sycl/gemm.hpp"
|
||||
#include "ggml-sycl/set_rows.hpp"
|
||||
#include "ggml-sycl/set.hpp"
|
||||
#include "ggml-sycl/sycl_hw.hpp"
|
||||
#include "ggml-sycl/getrows.hpp"
|
||||
#include "ggml-sycl/repeat_back.hpp"
|
||||
#include "ggml-sycl/quantize.hpp"
|
||||
#include "ggml-sycl/ssm_conv.hpp"
|
||||
#include "ggml.h"
|
||||
|
||||
static bool g_sycl_loaded = false;
|
||||
@@ -2615,6 +2618,10 @@ catch (sycl::exception const &exc) {
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_repeat_back(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
@@ -2631,6 +2638,11 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
||||
ggml_sycl_op_rms_norm(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
ggml_sycl_op_rms_norm_back(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_l2_norm(ctx, dst);
|
||||
@@ -3679,6 +3691,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||
case GGML_OP_REPEAT:
|
||||
ggml_sycl_repeat(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
ggml_sycl_repeat_back(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
ggml_sycl_get_rows(ctx, dst);
|
||||
break;
|
||||
@@ -3818,6 +3833,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
ggml_sycl_leaky_relu(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
ggml_sycl_rms_norm_back(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
ggml_sycl_rms_norm(ctx, dst);
|
||||
break;
|
||||
@@ -3913,6 +3931,11 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
ggml_sycl_op_gated_linear_attn(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
ggml_sycl_ssm_conv(ctx, dst);
|
||||
case GGML_OP_ROLL:
|
||||
ggml_sycl_roll(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ARANGE:
|
||||
ggml_sycl_arange(ctx, dst);
|
||||
break;
|
||||
@@ -4516,6 +4539,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
||||
}
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
return src0_type == GGML_TYPE_F32;
|
||||
}
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_NONE:
|
||||
@@ -4552,6 +4580,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_RMS_NORM:
|
||||
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
|
||||
case GGML_OP_SCALE:
|
||||
return true;
|
||||
case GGML_OP_CONT:
|
||||
@@ -4586,6 +4616,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
return true;
|
||||
case GGML_OP_SSM_CONV:
|
||||
return op->type == GGML_TYPE_F32 &&
|
||||
op->src[0]->type == GGML_TYPE_F32 &&
|
||||
op->src[1]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ROLL:
|
||||
return op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ARANGE:
|
||||
return op->type == GGML_TYPE_F32;
|
||||
default:
|
||||
|
||||
@@ -480,6 +480,162 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
|
||||
}
|
||||
|
||||
void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); // dz
|
||||
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32); // x
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
float eps = 1e-5f;
|
||||
std::memcpy(&eps, dst->op_params, sizeof(float));
|
||||
if (!(eps > 0.0f) || !std::isfinite(eps)) eps = 1e-5f;
|
||||
|
||||
const float * g_base = static_cast<const float *>(dst->src[0]->data); // dz
|
||||
const float * x_base = static_cast<const float *>(dst->src[1]->data); // x
|
||||
float * dx_base = static_cast< float *>(dst->data);
|
||||
|
||||
const int64_t D = dst->ne[0];
|
||||
const int64_t n1 = dst->ne[1], n2 = dst->ne[2], n3 = dst->ne[3]; (void) n3;
|
||||
const int64_t N = ggml_nrows(dst);
|
||||
if (D == 0 || N == 0) return;
|
||||
|
||||
const ggml_tensor *G = dst->src[0];
|
||||
const ggml_tensor *X = dst->src[1];
|
||||
const int ts = (int) ggml_type_size(X->type);
|
||||
GGML_ASSERT((size_t) X->nb[0] == (size_t) ts);
|
||||
GGML_ASSERT((size_t) G->nb[0] == (size_t) ts);
|
||||
GGML_ASSERT((size_t) dst->nb[0] == (size_t) ts);
|
||||
|
||||
const int64_t xs1 = X->nb[1] / ts, xs2 = X->nb[2] / ts, xs3 = X->nb[3] / ts;
|
||||
const int64_t gs1 = G->nb[1] / ts, gs2 = G->nb[2] / ts, gs3 = G->nb[3] / ts;
|
||||
const int64_t ds1 = dst->nb[1] / ts, ds2 = dst->nb[2] / ts, ds3 = dst->nb[3] / ts;
|
||||
|
||||
dpct::queue_ptr q = ctx.stream();
|
||||
|
||||
// work-group size: multiple of WARP_SIZE, capped by device and 256, and not larger than D
|
||||
const int device_max_wg = ggml_sycl_info().max_work_group_sizes[ctx.device];
|
||||
auto roundup = [](int v, int m) { return ((v + m - 1) / m) * m; };
|
||||
int wg_cap = 256;
|
||||
if (device_max_wg > 0) wg_cap = std::min(wg_cap, device_max_wg);
|
||||
int WG = std::max(WARP_SIZE, std::min(roundup((int)std::min<int64_t>(D, wg_cap), WARP_SIZE), wg_cap));
|
||||
|
||||
// FP32 path: per-thread compensated accumulation + hierarchical reduction
|
||||
q->submit([&](sycl::handler &cgh) {
|
||||
const int nwarps_loc = std::max(1, WG / WARP_SIZE);
|
||||
// store one partial value per warp (xx and xg) for cross-warp reduction
|
||||
auto l_xx = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(nwarps_loc), cgh);
|
||||
auto l_xg = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(nwarps_loc), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, N) * sycl::range<3>(1, 1, WG),
|
||||
sycl::range<3>(1, 1, WG)),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
const int row = item_ct1.get_group(2);
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
|
||||
const int64_t i1 = row % n1;
|
||||
const int64_t i2 = (row / n1) % n2;
|
||||
const int64_t i3 = row / (n1 * n2);
|
||||
|
||||
const float *__restrict x_row = x_base + i3 * xs3 + i2 * xs2 + i1 * xs1;
|
||||
const float *__restrict g_row = g_base + i3 * gs3 + i2 * gs2 + i1 * gs1;
|
||||
float *__restrict d_row = dx_base + i3 * ds3 + i2 * ds2 + i1 * ds1;
|
||||
|
||||
// per-thread accumulation (compensated by default)
|
||||
float sum_xx = 0.f, sum_xg = 0.f;
|
||||
#ifndef GGML_SYCL_RMS_BACK_FAST
|
||||
float c_xx = 0.f, c_xg = 0.f;
|
||||
#endif
|
||||
for (int64_t col = tid; col < D; col += WG) {
|
||||
const float xv = x_row[col];
|
||||
const float gv = g_row[col];
|
||||
#ifdef GGML_SYCL_RMS_BACK_FAST
|
||||
sum_xx += xv * xv;
|
||||
sum_xg += xv * gv;
|
||||
#else
|
||||
float y1 = xv * xv - c_xx;
|
||||
float t1 = sum_xx + y1;
|
||||
c_xx = (t1 - sum_xx) - y1;
|
||||
sum_xx = t1;
|
||||
|
||||
float y2 = xv * gv - c_xg;
|
||||
float t2 = sum_xg + y2;
|
||||
c_xg = (t2 - sum_xg) - y2;
|
||||
sum_xg = t2;
|
||||
#endif
|
||||
}
|
||||
|
||||
// warp-level reduction
|
||||
sycl::float2 xx = sycl::float2(sum_xx,
|
||||
#ifndef GGML_SYCL_RMS_BACK_FAST
|
||||
c_xx
|
||||
#else
|
||||
0.f
|
||||
#endif
|
||||
);
|
||||
sycl::float2 xg = sycl::float2(sum_xg,
|
||||
#ifndef GGML_SYCL_RMS_BACK_FAST
|
||||
c_xg
|
||||
#else
|
||||
0.f
|
||||
#endif
|
||||
);
|
||||
xx = warp_reduce_sum(xx, item_ct1);
|
||||
xg = warp_reduce_sum(xg, item_ct1);
|
||||
|
||||
// cross-warp reduction using local memory (single barrier)
|
||||
const auto sub_group = item_ct1.get_sub_group();
|
||||
const auto sg_id = sub_group.get_group_linear_id();
|
||||
const auto wi_in_sg = sub_group.get_local_linear_id();
|
||||
const int nthreads = item_ct1.get_local_range(2);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
|
||||
sycl::float2 xx_total = xx;
|
||||
sycl::float2 xg_total = xg;
|
||||
if (nwarps > 1) {
|
||||
if (wi_in_sg == 0) {
|
||||
l_xx[sg_id] = xx;
|
||||
l_xg[sg_id] = xg;
|
||||
}
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
if (sg_id == 0) {
|
||||
const unsigned wi_u = wi_in_sg;
|
||||
sycl::float2 xx_first = (wi_u < static_cast<unsigned>(nwarps)) ? l_xx[wi_u] : sycl::float2(0.f, 0.f);
|
||||
sycl::float2 xg_first = (wi_u < static_cast<unsigned>(nwarps)) ? l_xg[wi_u] : sycl::float2(0.f, 0.f);
|
||||
xx_total = warp_reduce_sum(xx_first, item_ct1);
|
||||
xg_total = warp_reduce_sum(xg_first, item_ct1);
|
||||
} else {
|
||||
// other subgroups keep their local totals; they'll be ignored
|
||||
xx_total = xx;
|
||||
xg_total = xg;
|
||||
}
|
||||
// ensure all threads see the first-subgroup result via broadcast below
|
||||
}
|
||||
|
||||
// compute inv_r and coeff once per row and broadcast to the whole work-group
|
||||
float inv_r = 0.f;
|
||||
float coeff = 0.f;
|
||||
if (tid == 0) {
|
||||
const float sum_xx_f = xx_total.x() + xx_total.y();
|
||||
const float sum_xdz_f = xg_total.x() + xg_total.y();
|
||||
const float mean_eps = sum_xx_f / (float) D + eps;
|
||||
const float sum_eps = sum_xx_f + eps * (float) D;
|
||||
inv_r = sycl::rsqrt(mean_eps);
|
||||
coeff = -sum_xdz_f / sum_eps;
|
||||
}
|
||||
inv_r = sycl::group_broadcast(item_ct1.get_group(), inv_r);
|
||||
coeff = sycl::group_broadcast(item_ct1.get_group(), coeff);
|
||||
|
||||
for (int64_t col = tid; col < D; col += WG) {
|
||||
d_row[col] = (g_row[col] + coeff * x_row[col]) * inv_r;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
|
||||
@@ -19,6 +19,8 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
||||
|
||||
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
||||
|
||||
void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
||||
|
||||
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
||||
|
||||
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
#include "repeat_back.hpp"
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const float * src0_dd = (const float *) dst->src[0]->data;
|
||||
float * dst_dd = (float *) dst->data;
|
||||
|
||||
const int64_t ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3];
|
||||
const int64_t ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2],
|
||||
ne03 = dst->src[0]->ne[3];
|
||||
|
||||
const int nr0 = (int) (ne00 / ne0);
|
||||
const int nr1 = (int) (ne01 / ne1);
|
||||
const int nr2 = (int) (ne02 / ne2);
|
||||
const int nr3 = (int) (ne03 / ne3);
|
||||
|
||||
const size_t total = ne0 * ne1 * ne2 * ne3;
|
||||
const int BLOCK_SIZE = 256;
|
||||
const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
queue_ptr stream = ctx.stream();
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks * BLOCK_SIZE), sycl::range<1>(BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
const size_t i = item_ct1.get_global_linear_id();
|
||||
if (i >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i0 = i % ne0;
|
||||
const int i1 = (i / ne0) % ne1;
|
||||
const int i2 = (i / (ne0 * ne1)) % ne2;
|
||||
const int i3 = i / (ne0 * ne1 * ne2);
|
||||
|
||||
float acc = 0.0f;
|
||||
|
||||
for (int j3 = 0; j3 < nr3; ++j3) {
|
||||
for (int j2 = 0; j2 < nr2; ++j2) {
|
||||
for (int j1 = 0; j1 < nr1; ++j1) {
|
||||
for (int j0 = 0; j0 < nr0; ++j0) {
|
||||
acc += src0_dd[(i0 + j0 * ne0) + (i1 + j1 * ne1) * ne00 + (i2 + j2 * ne2) * ne00 * ne01 +
|
||||
(i3 + j3 * ne3) * ne00 * ne01 * ne02];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dst_dd[i] = acc;
|
||||
});
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
#ifndef GGML_SYCL_REPEAT_BACK_HPP
|
||||
#define GGML_SYCL_REPEAT_BACK_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_REPEAT_BACK_HPP
|
||||
@@ -0,0 +1,122 @@
|
||||
#include "roll.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
using namespace sycl;
|
||||
|
||||
static inline int wrap_add(int i, int shift, int n) {
|
||||
|
||||
int s = i + shift;
|
||||
return (s >= n) ? (s - n) : s;
|
||||
}
|
||||
|
||||
static void kernel_roll_fused_i0_i1(
|
||||
queue &q,
|
||||
const float *src_d,
|
||||
float *dst_d,
|
||||
int ne0, int ne1, int ne2, int ne3,
|
||||
int sh0, int sh1, int sh2, int sh3)
|
||||
{
|
||||
if (ne0 == 0 || ne1 == 0 || ne2 == 0 || ne3 == 0) return;
|
||||
|
||||
|
||||
const int stride1 = ne0;
|
||||
const int stride2 = ne0 * ne1;
|
||||
const int stride3 = ne0 * ne1 * ne2;
|
||||
|
||||
|
||||
const int shNe0 = (ne0 - sh0) % ne0;
|
||||
const int shNe1 = (ne1 - sh1) % ne1;
|
||||
const int shNe2 = (ne2 - sh2) % ne2;
|
||||
const int shNe3 = (ne3 - sh3) % ne3;
|
||||
|
||||
|
||||
const size_t g0 = (size_t) ne3;
|
||||
const size_t g1 = (size_t) ne2;
|
||||
const size_t g2 = (size_t) (ne1 * ne0);
|
||||
|
||||
const range<3> global{ g0, g1, g2 };
|
||||
|
||||
q.submit([&](handler &h) {
|
||||
h.parallel_for(global, [=](id<3> idx) {
|
||||
const int i3 = (int) idx[0];
|
||||
const int i2 = (int) idx[1];
|
||||
|
||||
const int fused = (int) idx[2];
|
||||
const int i1 = fused / ne0;
|
||||
const int i0 = fused - i1 * ne0; // fused % ne0
|
||||
|
||||
|
||||
const int idx_dst = i0
|
||||
+ i1 * stride1
|
||||
+ i2 * stride2
|
||||
+ i3 * stride3;
|
||||
|
||||
|
||||
const int s0 = wrap_add(i0, shNe0, ne0);
|
||||
const int s1 = wrap_add(i1, shNe1, ne1);
|
||||
const int s2 = wrap_add(i2, shNe2, ne2);
|
||||
const int s3 = wrap_add(i3, shNe3, ne3);
|
||||
|
||||
const int idx_src = s0
|
||||
+ s1 * stride1
|
||||
+ s2 * stride2
|
||||
+ s3 * stride3;
|
||||
|
||||
dst_d[idx_dst] = src_d[idx_src];
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const ggml_tensor *src = dst->src[0];
|
||||
GGML_ASSERT(src && src->type == GGML_TYPE_F32);
|
||||
|
||||
const int ne0 = (int) dst->ne[0];
|
||||
const int ne1 = (int) dst->ne[1];
|
||||
const int ne2 = (int) dst->ne[2];
|
||||
const int ne3 = (int) dst->ne[3];
|
||||
|
||||
const int32_t *params = (const int32_t *) dst->op_params;
|
||||
int shift0 = params[0];
|
||||
int shift1 = params[1];
|
||||
int shift2 = params[2];
|
||||
int shift3 = params[3];
|
||||
|
||||
|
||||
if ((shift0 | shift1 | shift2 | shift3) == 0) {
|
||||
const size_t nb = ggml_nbytes(src);
|
||||
queue *q = ctx.stream();
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb)));
|
||||
return;
|
||||
}
|
||||
|
||||
auto norm = [](int sh, int n) -> int {
|
||||
if (n <= 0) return 0;
|
||||
sh %= n;
|
||||
if (sh < 0) sh += n;
|
||||
return sh;
|
||||
};
|
||||
shift0 = norm(shift0, ne0);
|
||||
shift1 = norm(shift1, ne1);
|
||||
shift2 = norm(shift2, ne2);
|
||||
shift3 = norm(shift3, ne3);
|
||||
|
||||
try {
|
||||
queue *q = ctx.stream();
|
||||
|
||||
const float *src_d = (const float *) src->data;
|
||||
float *dst_d = (float *) dst->data;
|
||||
GGML_ASSERT(src_d && dst_d);
|
||||
|
||||
kernel_roll_fused_i0_i1(
|
||||
*q, src_d, dst_d,
|
||||
ne0, ne1, ne2, ne3,
|
||||
shift0, shift1, shift2, shift3
|
||||
);
|
||||
} catch (const std::exception &e) {
|
||||
std::fprintf(stderr, "[SYCL-ROLL] ERROR: %s\n", e.what());
|
||||
throw;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2024 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
#ifndef GGML_SYCL_ROLL_HPP
|
||||
#define GGML_SYCL_ROLL_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
|
||||
|
||||
#endif // GGML_SYCL_ROLL_HPP
|
||||
@@ -0,0 +1,127 @@
|
||||
#include "ssm_conv.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
#include <cstdio>
|
||||
|
||||
using namespace sycl;
|
||||
|
||||
static void kernel_ssm_conv(
|
||||
queue &q,
|
||||
const float *src_data,
|
||||
const float *weights,
|
||||
float *dst_data,
|
||||
int d_conv,
|
||||
int d_inner,
|
||||
int n_t,
|
||||
int n_s,
|
||||
int ncs __attribute__((unused)),
|
||||
int src_stride_inner,
|
||||
int src_stride_seq,
|
||||
int dst_stride_token,
|
||||
int dst_stride_seq
|
||||
) {
|
||||
const size_t total_work = static_cast<size_t>(d_inner) * static_cast<size_t>(n_t) * static_cast<size_t>(n_s);
|
||||
const size_t work_group_size = 256;
|
||||
const size_t num_work_groups = (total_work + work_group_size - 1) / work_group_size;
|
||||
|
||||
const range<1> global_range(num_work_groups * work_group_size);
|
||||
const range<1> local_range(work_group_size);
|
||||
|
||||
q.submit([&](handler &h) {
|
||||
h.parallel_for(
|
||||
nd_range<1>(global_range, local_range),
|
||||
[=](nd_item<1> item) {
|
||||
const size_t idx = item.get_global_id(0);
|
||||
if (idx >= total_work) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int channel = static_cast<int>(idx % d_inner);
|
||||
const int token = static_cast<int>((idx / d_inner) % n_t);
|
||||
const int seq = static_cast<int>(idx / (static_cast<size_t>(d_inner) * static_cast<size_t>(n_t)));
|
||||
|
||||
const float *s = src_data
|
||||
+ static_cast<size_t>(seq) * static_cast<size_t>(src_stride_seq)
|
||||
+ static_cast<size_t>(channel) * static_cast<size_t>(src_stride_inner)
|
||||
+ static_cast<size_t>(token);
|
||||
|
||||
const float *c = weights + static_cast<size_t>(channel) * static_cast<size_t>(d_conv);
|
||||
|
||||
float sumf = 0.0f;
|
||||
for (int i0 = 0; i0 < d_conv; ++i0) {
|
||||
sumf += s[i0] * c[i0];
|
||||
}
|
||||
|
||||
const size_t dst_idx =
|
||||
static_cast<size_t>(seq) * static_cast<size_t>(dst_stride_seq) +
|
||||
static_cast<size_t>(token) * static_cast<size_t>(dst_stride_token) +
|
||||
static_cast<size_t>(channel);
|
||||
|
||||
dst_data[dst_idx] = sumf;
|
||||
}
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int d_conv = src1->ne[0];
|
||||
const int ncs = src0->ne[0];
|
||||
const int d_inner = src0->ne[1];
|
||||
const int n_t = dst->ne[1];
|
||||
const int n_s = dst->ne[2];
|
||||
|
||||
GGML_ASSERT(src0->ne[0] == d_conv - 1 + n_t);
|
||||
GGML_ASSERT(src0->ne[1] == d_inner);
|
||||
GGML_ASSERT(src1->ne[1] == d_inner);
|
||||
|
||||
GGML_ASSERT(dst->ne[0] == d_inner);
|
||||
GGML_ASSERT(dst->ne[1] == n_t);
|
||||
GGML_ASSERT(dst->ne[2] == n_s);
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
|
||||
GGML_ASSERT(src0->nb[1] == src0->ne[0] * static_cast<int>(sizeof(float)));
|
||||
|
||||
const int src_stride_inner = ncs;
|
||||
const int src_stride_seq = ncs * d_inner;
|
||||
const int dst_stride_token = d_inner;
|
||||
const int dst_stride_seq = d_inner * n_t;
|
||||
|
||||
try {
|
||||
queue *q = ctx.stream();
|
||||
|
||||
const float *src_data = static_cast<const float *>(src0->data);
|
||||
const float *weights = static_cast<const float *>(src1->data);
|
||||
float *dst_data = static_cast<float *>(dst->data);
|
||||
|
||||
GGML_ASSERT(src_data && weights && dst_data);
|
||||
|
||||
kernel_ssm_conv(
|
||||
*q,
|
||||
src_data,
|
||||
weights,
|
||||
dst_data,
|
||||
d_conv,
|
||||
d_inner,
|
||||
n_t,
|
||||
n_s,
|
||||
ncs,
|
||||
src_stride_inner,
|
||||
src_stride_seq,
|
||||
dst_stride_token,
|
||||
dst_stride_seq
|
||||
);
|
||||
|
||||
} catch (const std::exception &e) {
|
||||
std::fprintf(stderr, "[SYCL-SSM_CONV] ERROR: %s\n", e.what());
|
||||
throw;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
@@ -96,8 +96,6 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
|
||||
|
||||
#define GGML_VK_MAX_NODES 8192
|
||||
|
||||
#define MAX_VK_BUFFERS 256
|
||||
|
||||
#define VK_CHECK(err, msg) \
|
||||
do { \
|
||||
vk::Result err_ = (err); \
|
||||
@@ -387,12 +385,76 @@ static constexpr uint32_t num_argsort_pipelines = 11;
|
||||
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
|
||||
static constexpr uint32_t num_topk_moe_pipelines = 10;
|
||||
|
||||
static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
|
||||
static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
||||
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
|
||||
GGML_OP_RESHAPE };
|
||||
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
||||
static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
|
||||
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
|
||||
|
||||
//node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ]
|
||||
//node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
|
||||
//node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
|
||||
//node #981 ( VIEW): ffn_moe_topk-15 ( 0K) [Vulka ] use=4: ffn_moe_argsort-15 ( 0K) [Vulka ]
|
||||
//node #982 ( GET_ROWS): ffn_moe_weights-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 0K) [Vulka ] ffn_moe_topk-15 ( 0K) [Vulka ]
|
||||
//node #983 ( RESHAPE): ffn_moe_weights-15 ( ( 0K) [Vulka ] use=2: ffn_moe_weights-15 ( 0K) [Vulka ]
|
||||
//node #984 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ]
|
||||
//node #985 ( CLAMP): ffn_moe_weights_sum_ ( 0K) [Vulka ] use=1: ffn_moe_weights_sum- ( 0K) [Vulka ]
|
||||
//node #986 ( DIV): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] ffn_moe_weights_sum_ ( 0K) [Vulka ]
|
||||
//node #987 ( RESHAPE): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights_norm ( 0K) [Vulka ]
|
||||
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges {
|
||||
{ 1, 0, 0 }, // reshape->src[0] == softmax
|
||||
{ 2, 0, 0 }, // argsort->src[0] == softmax
|
||||
{ 3, 0, 2 }, // view->src[0] == argsort
|
||||
{ 4, 0, 1 }, // get_rows->src[0] == reshape
|
||||
{ 4, 1, 3 }, // get_rows->src[1] == view
|
||||
{ 5, 0, 4 }, // reshape->src[0] == get_rows
|
||||
{ 6, 0, 5 }, // sum_rows->src[0] == reshape
|
||||
{ 7, 0, 6 }, // clamp->src[0] == sum_rows
|
||||
{ 8, 0, 5 }, // div->src[0] == reshape
|
||||
{ 8, 1, 7 }, // div->src[1] == clamp
|
||||
{ 9, 0, 8 }, // reshape->src[0] == div
|
||||
};
|
||||
|
||||
// same as early_softmax_norm but ending after the get_rows
|
||||
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
|
||||
{ 1, 0, 0 }, // reshape->src[0] == softmax
|
||||
{ 2, 0, 0 }, // argsort->src[0] == softmax
|
||||
{ 3, 0, 2 }, // view->src[0] == argsort
|
||||
{ 4, 0, 1 }, // get_rows->src[0] == reshape
|
||||
{ 4, 1, 3 }, // get_rows->src[1] == view
|
||||
};
|
||||
|
||||
//node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ]
|
||||
//node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ]
|
||||
//node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ]
|
||||
//node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ]
|
||||
//node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ]
|
||||
//node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ]
|
||||
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax_edges {
|
||||
{ 1, 0, 0 }, // view->src[0] == argsort
|
||||
{ 2, 1, 1 }, // get_rows->src[1] == view
|
||||
{ 3, 0, 2 }, // reshape->src[0] == get_rows
|
||||
{ 4, 0, 3 }, // soft_max->src[0] == reshape
|
||||
{ 5, 0, 4 }, // reshape->src[0] == soft_max
|
||||
};
|
||||
|
||||
enum topk_moe_mode {
|
||||
TOPK_MOE_EARLY_SOFTMAX,
|
||||
TOPK_MOE_EARLY_SOFTMAX_NORM,
|
||||
TOPK_MOE_LATE_SOFTMAX,
|
||||
TOPK_MOE_COUNT,
|
||||
};
|
||||
|
||||
static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
|
||||
topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
|
||||
num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX :
|
||||
TOPK_MOE_LATE_SOFTMAX;
|
||||
return mode;
|
||||
}
|
||||
|
||||
struct vk_device_struct {
|
||||
std::recursive_mutex mutex;
|
||||
@@ -488,6 +550,7 @@ struct vk_device_struct {
|
||||
vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
|
||||
|
||||
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
|
||||
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_COUNT];
|
||||
|
||||
vk_pipeline pipeline_matmul_split_k_reduce;
|
||||
vk_pipeline pipeline_quantize_q8_1;
|
||||
@@ -525,7 +588,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_add_id_f32;
|
||||
|
||||
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
|
||||
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
|
||||
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32;
|
||||
vk_pipeline pipeline_scale_f32;
|
||||
vk_pipeline pipeline_sqr_f32;
|
||||
vk_pipeline pipeline_sqrt_f32;
|
||||
@@ -606,8 +669,7 @@ struct vk_device_struct {
|
||||
|
||||
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
||||
|
||||
// [2] is {!norm, norm}
|
||||
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
|
||||
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT];
|
||||
|
||||
std::vector<vk_pipeline_ref> all_pipelines;
|
||||
|
||||
@@ -955,6 +1017,8 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
|
||||
struct vk_op_topk_moe_push_constants {
|
||||
uint32_t n_rows;
|
||||
uint32_t n_expert_used;
|
||||
float clamp_min;
|
||||
float clamp_max;
|
||||
};
|
||||
|
||||
struct vk_op_add_id_push_constants {
|
||||
@@ -1240,6 +1304,7 @@ struct vk_op_upscale_push_constants {
|
||||
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
||||
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
|
||||
float sf0; float sf1; float sf2; float sf3;
|
||||
float pixel_offset;
|
||||
};
|
||||
|
||||
struct vk_op_sum_rows_push_constants
|
||||
@@ -1311,7 +1376,6 @@ struct ggml_vk_garbage_collector {
|
||||
std::vector<vk_semaphore> tl_semaphores;
|
||||
std::vector<vk_semaphore> semaphores;
|
||||
std::vector<vk::Event> events;
|
||||
std::vector<vk_buffer> temp_buffers;
|
||||
std::vector<vk_context> contexts;
|
||||
};
|
||||
|
||||
@@ -1482,8 +1546,6 @@ struct ggml_backend_vk_context {
|
||||
// and set to true after the buffer contents are consumed.
|
||||
bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
|
||||
|
||||
vk_buffer buffer_pool[MAX_VK_BUFFERS];
|
||||
|
||||
vk_context_ref compute_ctx;
|
||||
vk_context_ref transfer_ctx;
|
||||
|
||||
@@ -2452,8 +2514,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
l_warptile_id, m_warptile_id, s_warptile_id,
|
||||
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
|
||||
l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
|
||||
l_warptile_mmq_int_k, m_warptile_mmq_int_k, s_warptile_mmq_int_k,
|
||||
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
|
||||
l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
|
||||
l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid,
|
||||
l_warptile_mmqid_int, m_warptile_mmqid_int, s_warptile_mmqid_int,
|
||||
l_warptile_mmqid_int_k, m_warptile_mmqid_int_k, s_warptile_mmqid_int_k;
|
||||
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
|
||||
l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms,
|
||||
l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
|
||||
@@ -2516,10 +2581,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
|
||||
|
||||
// Integer MMQ has a smaller shared memory profile, but heavier register use
|
||||
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
|
||||
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
|
||||
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
|
||||
|
||||
// K-quants use even more registers, mitigate by setting WMITER to 1
|
||||
l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
|
||||
m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
|
||||
s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 };
|
||||
|
||||
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
|
||||
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
|
||||
s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
|
||||
@@ -2528,10 +2599,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
|
||||
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
|
||||
|
||||
l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
|
||||
m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
|
||||
s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
|
||||
|
||||
l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
|
||||
m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
|
||||
s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
|
||||
|
||||
// chip specific tuning
|
||||
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
|
||||
m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
|
||||
m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
|
||||
m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
|
||||
}
|
||||
|
||||
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
|
||||
@@ -2916,18 +2995,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
|
||||
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) { \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->l, #NAMELC "_f16acc_l", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
} \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) { \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->m, #NAMELC "_f16acc_m", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
} \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) { \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->s, #NAMELC "_f16acc_s", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
} \
|
||||
|
||||
// Create 2 variants, {f16,f32} accumulator
|
||||
@@ -2966,11 +3042,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product) {
|
||||
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
|
||||
|
||||
CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_MXFP4], matmul_mxfp4_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
|
||||
|
||||
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K], matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K], matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K], matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -3000,6 +3084,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
||||
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
||||
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product) {
|
||||
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
||||
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
||||
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
||||
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
||||
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
||||
|
||||
CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
||||
|
||||
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
|
||||
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
|
||||
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
|
||||
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
|
||||
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
|
||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
|
||||
@@ -3026,6 +3128,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
||||
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
||||
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product) {
|
||||
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
|
||||
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
|
||||
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
|
||||
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
|
||||
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
|
||||
|
||||
CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
|
||||
|
||||
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
|
||||
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
|
||||
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
|
||||
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
|
||||
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#undef CREATE_MM2
|
||||
#undef CREATE_MMQ
|
||||
@@ -3090,6 +3210,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
|
||||
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -3149,7 +3275,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
}
|
||||
// reusing CREATE_MM from the fp32 path
|
||||
if ((device->coopmat2 || device->coopmat_support)
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||
&& !device->coopmat_bf16_support
|
||||
#endif
|
||||
) {
|
||||
@@ -3498,7 +3624,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
@@ -3623,8 +3748,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1);
|
||||
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
|
||||
} else {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
|
||||
|
||||
@@ -3739,8 +3869,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0}, 1, true, true);
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1}, 1, true, true);
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true);
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true);
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true);
|
||||
}
|
||||
|
||||
for (auto &c : compiles) {
|
||||
@@ -4733,7 +4864,14 @@ static void ggml_vk_instance_init() {
|
||||
vk::PhysicalDeviceIDProperties old_id;
|
||||
old_props.pNext = &old_id;
|
||||
devices[k].getProperties2(&old_props);
|
||||
return std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
|
||||
|
||||
bool equals = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
|
||||
equals = equals || (
|
||||
old_id.deviceLUIDValid && new_id.deviceLUIDValid &&
|
||||
std::equal(std::begin(old_id.deviceLUID), std::end(old_id.deviceLUID), std::begin(new_id.deviceLUID))
|
||||
);
|
||||
|
||||
return equals;
|
||||
}
|
||||
);
|
||||
if (old_device == vk_instance.device_indices.end()) {
|
||||
@@ -4771,6 +4909,7 @@ static void ggml_vk_instance_init() {
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
driver_priorities[vk::DriverId::eMesaDozen] = 100;
|
||||
|
||||
if (driver_priorities.count(old_driver.driverID)) {
|
||||
old_priority = driver_priorities[old_driver.driverID];
|
||||
@@ -4920,7 +5059,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
||||
|
||||
// MMQ
|
||||
if (src1_type == GGML_TYPE_Q8_1) {
|
||||
vk_matmul_pipeline pipelines = (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
|
||||
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
|
||||
|
||||
if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
|
||||
return nullptr;
|
||||
@@ -5067,6 +5206,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
||||
}
|
||||
}
|
||||
|
||||
// MMQ
|
||||
if (src1_type == GGML_TYPE_Q8_1) {
|
||||
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;
|
||||
|
||||
if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return pipelines;
|
||||
}
|
||||
|
||||
GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
|
||||
|
||||
switch (src0_type) {
|
||||
@@ -5144,71 +5294,6 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
||||
return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type];
|
||||
}
|
||||
|
||||
static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) {
|
||||
VK_LOG_DEBUG("ggml_vk_pool_malloc(" << size << ")");
|
||||
VK_LOG_MEMORY("ggml_vk_pool_malloc");
|
||||
|
||||
int best_i = -1;
|
||||
size_t best_size = std::numeric_limits<size_t>::max(); //smallest unused buffer that fits our needs
|
||||
int worst_i = -1;
|
||||
size_t worst_size = 0; //largest unused buffer seen so far
|
||||
for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
|
||||
vk_buffer &b = ctx->buffer_pool[i];
|
||||
if (b != nullptr && b->size >= size && b->size < best_size) {
|
||||
best_i = i;
|
||||
best_size = b->size;
|
||||
}
|
||||
if (b != nullptr && b->size > worst_size) {
|
||||
worst_i = i;
|
||||
worst_size = b->size;
|
||||
}
|
||||
}
|
||||
if(best_i != -1) {
|
||||
//found the smallest buffer that fits our needs
|
||||
vk_buffer b = ctx->buffer_pool[best_i];
|
||||
ctx->buffer_pool[best_i].reset();
|
||||
return b;
|
||||
}
|
||||
if(worst_i != -1) {
|
||||
//no buffer that fits our needs, resize largest one to save memory
|
||||
vk_buffer& b = ctx->buffer_pool[worst_i];
|
||||
ggml_vk_destroy_buffer(b);
|
||||
}
|
||||
|
||||
return ggml_vk_create_buffer_device(ctx->device, size);
|
||||
}
|
||||
|
||||
static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) {
|
||||
VK_LOG_DEBUG("ggml_vk_pool_free(" << buffer->size << ")");
|
||||
for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
|
||||
vk_buffer& b = ctx->buffer_pool[i];
|
||||
if (b == nullptr) {
|
||||
b = buffer;
|
||||
return;
|
||||
}
|
||||
}
|
||||
std::cerr << "ggml_vulkan: WARNING: vk buffer pool full, increase MAX_VK_BUFFERS" << std::endl;
|
||||
ggml_vk_destroy_buffer(buffer);
|
||||
}
|
||||
|
||||
// Returns an available temporary buffer that may only be used temporarily, it will be reused
|
||||
static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_t size) {
|
||||
// Try to find existing temp buffer with enough capacity
|
||||
for (auto& buffer : ctx->gc.temp_buffers) {
|
||||
if (buffer->size >= size) {
|
||||
return buffer;
|
||||
}
|
||||
}
|
||||
|
||||
VK_LOG_MEMORY("ggml_vk_create_buffer_temp(" << size << ")");
|
||||
|
||||
// Otherwise create new buffer
|
||||
vk_buffer buf = ggml_vk_pool_malloc(ctx, size);
|
||||
ctx->gc.temp_buffers.push_back(buf);
|
||||
|
||||
return buf;
|
||||
}
|
||||
|
||||
static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
|
||||
VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
|
||||
vk_buffer buf = ggml_vk_create_buffer(device, size,
|
||||
@@ -5709,14 +5794,11 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
|
||||
VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
|
||||
// Copy device to device
|
||||
ggml_vk_ensure_sync_staging_buffer(src->device, size);
|
||||
ggml_vk_ensure_sync_staging_buffer(dst->device, size);
|
||||
|
||||
// Copy to src staging buffer
|
||||
ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
|
||||
// memcpy to dst staging buffer
|
||||
memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size);
|
||||
// Copy to dst buffer
|
||||
ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size);
|
||||
ggml_vk_buffer_write_2d(dst, dst_offset, src->device->sync_staging->ptr, 0, size, 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6937,10 +7019,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
|
||||
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
||||
|
||||
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
|
||||
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
|
||||
|
||||
// Check for mmq first
|
||||
vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
|
||||
|
||||
if (mmp == nullptr) {
|
||||
// Fall back to f16 dequant mul mat
|
||||
mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
|
||||
quantize_y = false;
|
||||
}
|
||||
|
||||
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
||||
const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig;
|
||||
const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
|
||||
|
||||
if (qx_needs_dequant) {
|
||||
// Fall back to dequant + f16 mulmat
|
||||
@@ -6950,8 +7041,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
// Not implemented
|
||||
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
||||
|
||||
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
|
||||
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
|
||||
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
|
||||
const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8;
|
||||
|
||||
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
|
||||
|
||||
@@ -6964,12 +7055,13 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
||||
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
||||
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
|
||||
const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
|
||||
const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
|
||||
const uint64_t ids_sz = nbi2;
|
||||
const uint64_t d_sz = sizeof(float) * d_ne;
|
||||
|
||||
vk_pipeline to_fp16_vk_0 = nullptr;
|
||||
vk_pipeline to_fp16_vk_1 = nullptr;
|
||||
vk_pipeline to_q8_1 = nullptr;
|
||||
|
||||
if (x_non_contig) {
|
||||
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
|
||||
@@ -6984,9 +7076,16 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
|
||||
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
|
||||
|
||||
if (quantize_y) {
|
||||
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true);
|
||||
}
|
||||
|
||||
if (dryrun) {
|
||||
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
|
||||
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
|
||||
uint64_t y_sz_upd = y_sz * ne12 * ne13;
|
||||
if (quantize_y) {
|
||||
y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
|
||||
}
|
||||
if (
|
||||
(qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
|
||||
(qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
|
||||
@@ -6995,7 +7094,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
|
||||
ctx->prealloc_size_x = x_sz_upd;
|
||||
}
|
||||
if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
|
||||
if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
|
||||
ctx->prealloc_size_y = y_sz_upd;
|
||||
}
|
||||
|
||||
@@ -7007,6 +7106,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
if (qy_needs_dequant) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
|
||||
}
|
||||
if (quantize_y) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -7043,6 +7145,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
if (qy_needs_dequant) {
|
||||
d_Y = ctx->prealloc_y;
|
||||
GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
|
||||
} else if (quantize_y) {
|
||||
d_Y = ctx->prealloc_y;
|
||||
GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144);
|
||||
} else {
|
||||
d_Y = d_Qy;
|
||||
y_buf_offset = qy_buf_offset;
|
||||
@@ -7074,6 +7179,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
ctx->prealloc_y_last_tensor_used = src1;
|
||||
}
|
||||
}
|
||||
if (quantize_y) {
|
||||
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
|
||||
ctx->prealloc_y_last_tensor_used != src1) {
|
||||
if (ctx->prealloc_y_need_sync) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true);
|
||||
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
|
||||
ctx->prealloc_y_last_tensor_used = src1;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t stride_batch_x = ne00*ne01;
|
||||
uint32_t stride_batch_y = ne10*ne11;
|
||||
@@ -7082,14 +7198,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
|
||||
}
|
||||
|
||||
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
|
||||
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
|
||||
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
|
||||
}
|
||||
|
||||
uint32_t y_sz_total = y_sz * ne12 * ne13;
|
||||
if (quantize_y) {
|
||||
y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
|
||||
}
|
||||
|
||||
// compute
|
||||
ggml_vk_matmul_id(
|
||||
ctx, subctx, pipeline,
|
||||
{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
|
||||
{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total },
|
||||
{ d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
|
||||
ne01, ne21, ne10, ne10, ne10, ne01,
|
||||
stride_batch_x, stride_batch_y, ne20*ne21,
|
||||
@@ -7855,14 +7976,14 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return nullptr;
|
||||
case GGML_OP_UPSCALE:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
int mode = ggml_get_op_params_i32(dst, 0);
|
||||
ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(dst, 0) & 0xFF);
|
||||
switch (mode) {
|
||||
case GGML_SCALE_MODE_NEAREST:
|
||||
return ctx->device->pipeline_upscale_nearest_f32;
|
||||
case GGML_SCALE_MODE_BILINEAR:
|
||||
return ctx->device->pipeline_upscale_bilinear_f32;
|
||||
case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS:
|
||||
return ctx->device->pipeline_upscale_bilinear_ac_f32;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
@@ -8028,8 +8149,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
if (ctx->num_additional_fused_ops) {
|
||||
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||
GGML_ASSERT(idx < num_topk_moe_pipelines);
|
||||
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
|
||||
return ctx->device->pipeline_topk_moe[idx][with_norm];
|
||||
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
||||
return ctx->device->pipeline_topk_moe[idx][mode];
|
||||
}
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
||||
@@ -8084,6 +8205,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return nullptr;
|
||||
}
|
||||
case GGML_OP_ARGSORT:
|
||||
if (ctx->num_additional_fused_ops) {
|
||||
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||
GGML_ASSERT(idx < num_topk_moe_pipelines);
|
||||
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
||||
return ctx->device->pipeline_topk_moe[idx][mode];
|
||||
}
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
|
||||
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||
return ctx->device->pipeline_argsort_f32[idx];
|
||||
@@ -9351,22 +9479,26 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
float sf0 = (float)dst->ne[0] / src0->ne[0];
|
||||
float sf1 = (float)dst->ne[1] / src0->ne[1];
|
||||
float sf2 = (float)dst->ne[2] / src0->ne[2];
|
||||
float sf3 = (float)dst->ne[3] / src0->ne[3];
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
float sf0 = (float)ne0 / ne00;
|
||||
float sf1 = (float)ne1 / ne01;
|
||||
float sf2 = (float)ne2 / ne02;
|
||||
float sf3 = (float)ne3 / ne03;
|
||||
float pixel_offset = 0.5f;
|
||||
|
||||
if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
||||
sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
|
||||
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
|
||||
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
|
||||
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
|
||||
pixel_offset = 0.0f;
|
||||
}
|
||||
|
||||
ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
|
||||
(uint32_t)ggml_nelements(dst), 0, 0,
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
|
||||
(uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
|
||||
sf0, sf1, sf2, sf3,
|
||||
(uint32_t)ne00, (uint32_t)ne01,
|
||||
(uint32_t)nb00 / src0_type_size, (uint32_t)nb01 / src0_type_size, (uint32_t)nb02 / src0_type_size, (uint32_t)nb03 / src0_type_size,
|
||||
(uint32_t)ne0, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
|
||||
sf0, sf1, sf2, sf3, pixel_offset
|
||||
}, dryrun);
|
||||
}
|
||||
|
||||
@@ -9619,10 +9751,12 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
|
||||
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
|
||||
|
||||
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
|
||||
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
||||
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
|
||||
ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
|
||||
ggml_tensor * ids = cgraph->nodes[node_idx + 3];
|
||||
ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
|
||||
(mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] :
|
||||
cgraph->nodes[node_idx + 5];
|
||||
ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
|
||||
|
||||
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
||||
@@ -9681,9 +9815,14 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
GGML_ASSERT(d_ids != nullptr);
|
||||
}
|
||||
|
||||
vk_op_topk_moe_push_constants pc;
|
||||
vk_op_topk_moe_push_constants pc {};
|
||||
pc.n_rows = n_rows;
|
||||
pc.n_expert_used = n_expert_used;
|
||||
if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
|
||||
ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
|
||||
pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
|
||||
pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_expert_used <= n_experts);
|
||||
|
||||
@@ -11278,7 +11417,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define ENABLE_SYNC_LOGGING 0
|
||||
|
||||
if (need_sync) {
|
||||
#if ENABLE_SYNC_LOGGING
|
||||
std::cerr << "sync" << std::endl;
|
||||
#endif
|
||||
ctx->unsynced_nodes_written.clear();
|
||||
ctx->unsynced_nodes_read.clear();
|
||||
ggml_vk_sync_buffers(ctx, compute_ctx);
|
||||
@@ -11296,6 +11441,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
}
|
||||
}
|
||||
}
|
||||
#if ENABLE_SYNC_LOGGING
|
||||
if (!dryrun) {
|
||||
for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
|
||||
auto *n = cgraph->nodes[node_idx + i];
|
||||
std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name;
|
||||
if (n->op == GGML_OP_GLU) {
|
||||
std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
switch (node->op) {
|
||||
case GGML_OP_REPEAT:
|
||||
@@ -11474,7 +11631,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
|
||||
if (ctx->num_additional_fused_ops) {
|
||||
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun);
|
||||
} else {
|
||||
ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
|
||||
}
|
||||
|
||||
break;
|
||||
case GGML_OP_SUM:
|
||||
@@ -11789,10 +11950,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||
// Clean up after graph processing is done
|
||||
static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
|
||||
VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
|
||||
for (auto& buffer : ctx->gc.temp_buffers) {
|
||||
ggml_vk_pool_free(ctx, buffer);
|
||||
}
|
||||
ctx->gc.temp_buffers.clear();
|
||||
ctx->prealloc_y_last_pipeline_used = {};
|
||||
|
||||
ctx->unsynced_nodes_written.clear();
|
||||
@@ -11835,10 +11992,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
|
||||
ggml_vk_destroy_buffer(ctx->prealloc_split_k);
|
||||
ctx->prealloc_y_last_pipeline_used = nullptr;
|
||||
|
||||
for (auto& buffer : ctx->buffer_pool) {
|
||||
ggml_vk_destroy_buffer(buffer);
|
||||
}
|
||||
|
||||
ctx->prealloc_size_x = 0;
|
||||
ctx->prealloc_size_y = 0;
|
||||
ctx->prealloc_size_split_k = 0;
|
||||
@@ -12255,31 +12408,28 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
|
||||
}
|
||||
|
||||
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
|
||||
int node_idx, bool with_norm) {
|
||||
int node_idx, topk_moe_mode mode) {
|
||||
|
||||
if (with_norm) {
|
||||
if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < topk_moe_norm.size(); ++i) {
|
||||
if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < topk_moe.size(); ++i) {
|
||||
if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
const ggml_tensor * softmax;
|
||||
const ggml_tensor * weights;
|
||||
|
||||
switch (mode) {
|
||||
case TOPK_MOE_EARLY_SOFTMAX_NORM:
|
||||
softmax = cgraph->nodes[node_idx + 0];
|
||||
weights = cgraph->nodes[node_idx + 9];
|
||||
break;
|
||||
case TOPK_MOE_EARLY_SOFTMAX:
|
||||
softmax = cgraph->nodes[node_idx + 0];
|
||||
weights = cgraph->nodes[node_idx + 4];
|
||||
break;
|
||||
case TOPK_MOE_LATE_SOFTMAX:
|
||||
softmax = cgraph->nodes[node_idx + 4];
|
||||
weights = cgraph->nodes[node_idx + 5];
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
const ggml_tensor * softmax = cgraph->nodes[node_idx + 0];
|
||||
const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
|
||||
|
||||
const float * op_params = (const float *)softmax->op_params;
|
||||
|
||||
float scale = op_params[0];
|
||||
@@ -12304,60 +12454,6 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that the nodes don't have any unexpected uses
|
||||
const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
|
||||
const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
|
||||
const ggml_tensor * view = cgraph->nodes[node_idx + 3];
|
||||
const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
|
||||
const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
|
||||
const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
|
||||
const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
|
||||
const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
|
||||
|
||||
// softmax is used by reshape and argsort
|
||||
if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
|
||||
reshape1->src[0] != softmax ||
|
||||
argsort->src[0] != softmax) {
|
||||
return false;
|
||||
}
|
||||
// reshape is used by get_rows
|
||||
if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
|
||||
get_rows->src[0] != reshape1) {
|
||||
return false;
|
||||
}
|
||||
// argsort is used by view
|
||||
if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
|
||||
view->src[0] != argsort) {
|
||||
return false;
|
||||
}
|
||||
// view is written (via argsort), we can skip checking it
|
||||
|
||||
if (with_norm) {
|
||||
// get_rows is used by reshape
|
||||
if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
|
||||
reshape5->src[0] != get_rows) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// reshape is used by sum_rows and div
|
||||
if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
|
||||
sum_rows->src[0] != reshape5 ||
|
||||
div->src[0] != reshape5) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// sum_rows is used by div
|
||||
if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
|
||||
div->src[1] != sum_rows) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// div/reshape are written
|
||||
if (reshape8->src[0] != div) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!ctx->device->subgroup_arithmetic ||
|
||||
!ctx->device->subgroup_shuffle ||
|
||||
!ctx->device->subgroup_require_full_support ||
|
||||
@@ -12443,10 +12539,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
ctx->num_additional_fused_ops = num_adds - 1;
|
||||
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
|
||||
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
|
||||
ctx->num_additional_fused_ops = topk_moe.size() - 1;
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
|
||||
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
|
||||
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
|
||||
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
|
||||
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
|
||||
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
|
||||
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
|
||||
}
|
||||
}
|
||||
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
|
||||
@@ -12544,10 +12648,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
ctx->num_additional_fused_ops = num_adds - 1;
|
||||
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
|
||||
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
|
||||
ctx->num_additional_fused_ops = topk_moe.size() - 1;
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
|
||||
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
|
||||
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
|
||||
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
|
||||
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
|
||||
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
|
||||
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12679,25 +12791,44 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
||||
while (first_unused < graph->n_nodes) {
|
||||
std::vector<int> current_set;
|
||||
|
||||
// Avoid reordering topk_moe_norm
|
||||
if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) {
|
||||
bool is_topk_moe_norm = true;
|
||||
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
|
||||
if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) {
|
||||
is_topk_moe_norm = false;
|
||||
// Check for fusion patterns and avoid reordering them
|
||||
auto const &match_pattern = [&](const std::initializer_list<ggml_op> &pattern, int start) -> bool {
|
||||
if (start + (int)pattern.size() <= graph->n_nodes) {
|
||||
bool is_pattern = true;
|
||||
for (size_t j = 0; j < pattern.size(); ++j) {
|
||||
if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) {
|
||||
is_pattern = false;
|
||||
}
|
||||
}
|
||||
return is_pattern;
|
||||
}
|
||||
if (is_topk_moe_norm) {
|
||||
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
|
||||
return false;
|
||||
};
|
||||
|
||||
auto const &keep_pattern = [&](const std::initializer_list<ggml_op> &pattern) -> bool {
|
||||
if (match_pattern(pattern, first_unused)) {
|
||||
for (size_t j = 0; j < pattern.size(); ++j) {
|
||||
new_order.push_back(graph->nodes[first_unused + j]);
|
||||
used[first_unused + j] = true;
|
||||
}
|
||||
while (first_unused < graph->n_nodes && used[first_unused]) {
|
||||
first_unused++;
|
||||
}
|
||||
continue;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
if (keep_pattern(topk_moe_early_softmax_norm)) {
|
||||
continue;
|
||||
}
|
||||
if (keep_pattern(topk_moe_early_softmax)) {
|
||||
continue;
|
||||
}
|
||||
if (keep_pattern(topk_moe_late_softmax)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// First, grab the next unused node.
|
||||
current_set.push_back(first_unused);
|
||||
|
||||
@@ -12715,6 +12846,12 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
||||
if (is_empty(graph->nodes[j])) {
|
||||
continue;
|
||||
}
|
||||
// Don't pull forward nodes from fusion patterns
|
||||
if (match_pattern(topk_moe_early_softmax_norm, j) ||
|
||||
match_pattern(topk_moe_early_softmax, j) ||
|
||||
match_pattern(topk_moe_late_softmax, j)) {
|
||||
continue;
|
||||
}
|
||||
bool ok = true;
|
||||
for (int c = first_unused; c < j; ++c) {
|
||||
if (!used[c] &&
|
||||
|
||||
@@ -437,7 +437,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
#if defined(DATA_A_MXFP4)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
||||
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
|
||||
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5;
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
vec2 v0 = dequantize(ib, iqs, a_offset);
|
||||
@@ -488,9 +488,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
|
||||
const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);
|
||||
const uint scales = data_a[a_offset + ib].scales[scalesi];
|
||||
const vec2 d = vec2(data_a[a_offset + ib].d);
|
||||
const vec2 dm = vec2(data_a[a_offset + ib].dm);
|
||||
|
||||
return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
|
||||
return dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
|
||||
}
|
||||
vec2 get_dm(uint ib, uint a_offset) {
|
||||
return vec2(1, 0);
|
||||
@@ -529,7 +529,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
const uint is = 2 * n + b; // 0..7
|
||||
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
||||
|
||||
const vec2 loadd = vec2(data_a[a_offset + ib].d);
|
||||
const vec2 loadd = vec2(data_a[a_offset + ib].dm);
|
||||
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
@@ -567,7 +567,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
|
||||
const uint8_t hm = uint8_t(1 << (iqs / 16));
|
||||
|
||||
const vec2 loadd = vec2(data_a[a_offset + ib].d);
|
||||
const vec2 loadd = vec2(data_a[a_offset + ib].dm);
|
||||
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
|
||||
@@ -120,7 +120,7 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2
|
||||
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
|
||||
const f16vec2 d = bl.block.d;
|
||||
const f16vec2 dm = bl.block.dm;
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint scalesi = (idx & 0xF0) >> 4; // 0..15
|
||||
@@ -131,7 +131,7 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2
|
||||
qs = unpack8(qs)[idx & 1];
|
||||
|
||||
const uint scales = bl.block.scales[scalesi];
|
||||
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
|
||||
float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -680,7 +680,7 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
|
||||
uint32_t qs = bl.block.qs[iqs];
|
||||
qs >>= shift;
|
||||
qs &= 0xF;
|
||||
float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
|
||||
float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5);
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -26,7 +26,7 @@ void main() {
|
||||
const float d = e8m0_to_fp32(data_a[ib].e);
|
||||
|
||||
[[unroll]] for (uint l = 0; l < 8; ++l) {
|
||||
data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
|
||||
data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
|
||||
data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]));
|
||||
data_b[b_idx + l + 16] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,8 +24,8 @@ void main() {
|
||||
const uint ql_idx = 32 * ip + il;
|
||||
const uint8_t qs = data_a[i].qs[32 * ip + il];
|
||||
|
||||
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
|
||||
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
|
||||
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].dm.x);
|
||||
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].dm.y);
|
||||
data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
|
||||
data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
|
||||
data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
|
||||
|
||||
@@ -20,8 +20,8 @@ void main() {
|
||||
const uint is = 2 * il;
|
||||
const uint n = 4;
|
||||
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
|
||||
|
||||
const uint y_idx = ib * QUANT_K + 64 * il + n * ir;
|
||||
const uint qs_idx = 32*il + n * ir;
|
||||
|
||||
@@ -19,8 +19,8 @@ void main() {
|
||||
const uint ir = tid % 16;
|
||||
const uint is = 2 * il;
|
||||
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
|
||||
|
||||
const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir;
|
||||
const uint qs_idx = 32*il + 2 * ir;
|
||||
|
||||
@@ -41,9 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
|
||||
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
|
||||
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
|
||||
|
||||
vec2 d = vec2(data_a[ib0 + i].d);
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
||||
const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm);
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
|
||||
@@ -75,7 +73,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
|
||||
fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im],
|
||||
fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
|
||||
}
|
||||
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
|
||||
temp[j][n] = fma(dm.x, sum1, fma(-dm.y, sum2, temp[j][n]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
|
||||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||
vec2 d = vec2(data_a[ib0 + i].d);
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
||||
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
|
||||
|
||||
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
|
||||
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
|
||||
@@ -81,7 +79,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
|
||||
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
|
||||
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
|
||||
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
|
||||
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
|
||||
temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
|
||||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||
vec2 d = vec2(data_a[ib0 + i].d);
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
||||
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
|
||||
|
||||
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
|
||||
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
|
||||
@@ -113,7 +111,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
|
||||
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
|
||||
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
|
||||
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
|
||||
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
|
||||
temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,81 +120,11 @@ shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[BN];
|
||||
uint _ne1;
|
||||
|
||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||
shared uvec4 ballots_sh[NUM_WARPS];
|
||||
|
||||
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
uint nei0shift = findLSB(p.nei0);
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_LocalInvocationIndex;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
|
||||
ballots_sh[gl_SubgroupID] = ballot;
|
||||
barrier();
|
||||
|
||||
uint subgroup_base = 0;
|
||||
uint total = 0;
|
||||
for (uint k = 0; k < gl_NumSubgroups; ++k) {
|
||||
if (k == gl_SubgroupID) {
|
||||
subgroup_base = total;
|
||||
}
|
||||
total += subgroupBallotBitCount(ballots_sh[k]);
|
||||
}
|
||||
barrier();
|
||||
|
||||
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
|
||||
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1 += total;
|
||||
iter &= 15;
|
||||
if (_ne1 >= (ic + 1) * BN) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
#endif // MUL_MAT_ID_USE_SUBGROUPS
|
||||
#endif // MUL_MAT_ID
|
||||
|
||||
#ifdef COOPMAT
|
||||
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
||||
#endif
|
||||
|
||||
#include "mul_mm_id_funcs.glsl"
|
||||
#include "mul_mm_funcs.glsl"
|
||||
|
||||
void main() {
|
||||
|
||||
@@ -134,15 +134,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
|
||||
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
|
||||
const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15
|
||||
const uint scalesi = iqs / 8; // 0..15
|
||||
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
||||
|
||||
const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
|
||||
const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
|
||||
const uint scales = data_a[ib].scales[scalesi];
|
||||
const vec2 d = vec2(data_a[ib].d);
|
||||
const vec2 dm = vec2(data_a[ib].dm);
|
||||
|
||||
const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
|
||||
const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
|
||||
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
|
||||
#elif defined(DATA_A_Q3_K)
|
||||
@@ -179,7 +179,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint is = 2 * n + b; // 0..7
|
||||
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
||||
|
||||
const vec2 loadd = vec2(data_a[ib].d);
|
||||
const vec2 loadd = vec2(data_a[ib].dm);
|
||||
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
@@ -215,7 +215,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
|
||||
const uint8_t hm = uint8_t(1 << (iqs / 16));
|
||||
|
||||
const vec2 loadd = vec2(data_a[ib].d);
|
||||
const vec2 loadd = vec2(data_a[ib].dm);
|
||||
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
@@ -468,7 +468,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = (idx & 0x07) * 2;
|
||||
|
||||
const float d = e8m0_to_fp32(data_a[ib].e);
|
||||
const float d = e8m0_to_fp32(data_a[ib].e) * 0.5;
|
||||
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
|
||||
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[BN];
|
||||
uint _ne1;
|
||||
|
||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||
shared uvec4 ballots_sh[NUM_WARPS];
|
||||
|
||||
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
uint nei0shift = findLSB(p.nei0);
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_LocalInvocationIndex;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
|
||||
ballots_sh[gl_SubgroupID] = ballot;
|
||||
barrier();
|
||||
|
||||
uint subgroup_base = 0;
|
||||
uint total = 0;
|
||||
for (uint k = 0; k < gl_NumSubgroups; ++k) {
|
||||
if (k == gl_SubgroupID) {
|
||||
subgroup_base = total;
|
||||
}
|
||||
total += subgroupBallotBitCount(ballots_sh[k]);
|
||||
}
|
||||
barrier();
|
||||
|
||||
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
|
||||
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1 += total;
|
||||
iter &= 15;
|
||||
if (_ne1 >= (ic + 1) * BN) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
#endif // MUL_MAT_ID_USE_SUBGROUPS
|
||||
#endif // MUL_MAT_ID
|
||||
@@ -10,10 +10,9 @@
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#if defined(MUL_MAT_ID_USE_SUBGROUPS)
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_ballot : enable
|
||||
#endif
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
@@ -24,7 +23,10 @@
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
|
||||
#endif
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
|
||||
#endif
|
||||
@@ -76,40 +78,27 @@ layout (constant_id = 10) const uint WARP = 32;
|
||||
|
||||
#define BK 32
|
||||
|
||||
#ifdef COOPMAT
|
||||
#define SHMEM_STRIDE (BK / 4 + 4)
|
||||
#else
|
||||
#define SHMEM_STRIDE (BK / 4 + 1)
|
||||
#define MMQ_SHMEM
|
||||
|
||||
#include "mul_mmq_shmem_types.glsl"
|
||||
|
||||
#ifndef BK_STEP
|
||||
#define BK_STEP 4
|
||||
#endif
|
||||
|
||||
shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
|
||||
// Shared memory cache
|
||||
shared block_a_cache buf_a[BM * BK_STEP];
|
||||
shared block_b_cache buf_b[BN * BK_STEP];
|
||||
// Register cache
|
||||
block_a_cache cache_a[WMITER * TM];
|
||||
block_b_cache cache_b;
|
||||
|
||||
#ifndef COOPMAT
|
||||
#if QUANT_AUXF == 1
|
||||
shared FLOAT_TYPE buf_a_dm[BM];
|
||||
#else
|
||||
shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
|
||||
#endif
|
||||
#endif
|
||||
|
||||
shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
|
||||
#ifndef COOPMAT
|
||||
shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
|
||||
#endif
|
||||
|
||||
#define LOAD_VEC_A (4 * QUANT_R)
|
||||
#define LOAD_VEC_A (4 * QUANT_R_MMQ)
|
||||
#define LOAD_VEC_B 16
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[4096];
|
||||
#endif // MUL_MAT_ID
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
|
||||
#ifdef COOPMAT
|
||||
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
||||
#endif
|
||||
|
||||
#include "mul_mm_id_funcs.glsl"
|
||||
#include "mul_mmq_funcs.glsl"
|
||||
|
||||
void main() {
|
||||
@@ -139,26 +128,12 @@ void main() {
|
||||
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||
const uint WSUBM = WM / WMITER;
|
||||
const uint WSUBN = WN / WNITER;
|
||||
|
||||
#ifdef COOPMAT
|
||||
const uint warp_i = gl_SubgroupID;
|
||||
|
||||
const uint tiw = gl_SubgroupInvocationID;
|
||||
|
||||
const uint cms_per_row = WM / TM;
|
||||
const uint cms_per_col = WN / TN;
|
||||
|
||||
const uint storestride = WARP / TM;
|
||||
const uint store_r = tiw % TM;
|
||||
const uint store_c = tiw / TM;
|
||||
#else
|
||||
const uint warp_i = gl_LocalInvocationID.x / WARP;
|
||||
|
||||
const uint tiw = gl_LocalInvocationID.x % WARP;
|
||||
|
||||
const uint tiwr = tiw % (WSUBM / TM);
|
||||
const uint tiwc = tiw / (WSUBM / TM);
|
||||
#endif
|
||||
|
||||
const uint warp_r = warp_i % (BM / WM);
|
||||
const uint warp_c = warp_i / (BM / WM);
|
||||
@@ -172,17 +147,27 @@ void main() {
|
||||
const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
uint _ne1 = 0;
|
||||
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
||||
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
|
||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||
if (bitCount(p.nei0) == 1) {
|
||||
load_row_ids(expert_idx, true, ic);
|
||||
} else {
|
||||
load_row_ids(expert_idx, false, ic);
|
||||
}
|
||||
#else
|
||||
_ne1 = 0;
|
||||
for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
|
||||
for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
|
||||
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
|
||||
row_ids[_ne1] = u16vec2(ii0, ii1);
|
||||
if (_ne1 >= ic * BN) {
|
||||
row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
#endif
|
||||
|
||||
// Workgroup has no work
|
||||
if (ic * BN >= _ne1) return;
|
||||
@@ -209,159 +194,70 @@ void main() {
|
||||
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
|
||||
coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
|
||||
coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
|
||||
|
||||
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
||||
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
|
||||
}
|
||||
#else
|
||||
int32_t cache_a_qs[WMITER * TM * BK / 4];
|
||||
|
||||
int32_t cache_b_qs[TN * BK / 4];
|
||||
|
||||
ACC_TYPE sums[WMITER * TM * WNITER * TN];
|
||||
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = ACC_TYPE(0.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if QUANT_AUXF == 1
|
||||
FLOAT_TYPE cache_a_dm[WMITER * TM];
|
||||
#else
|
||||
FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM];
|
||||
#endif
|
||||
|
||||
FLOAT_TYPE_VEC2 cache_b_ds[TN];
|
||||
|
||||
for (uint block = start_k; block < end_k; block += BK) {
|
||||
for (uint block = start_k; block < end_k; block += BK * BK_STEP) {
|
||||
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
|
||||
const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
|
||||
const uint iqs = loadr_a;
|
||||
const uint buf_ib = loadc_a + l;
|
||||
const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
|
||||
const uint iqs = loadr_a;
|
||||
|
||||
if (iqs == 0) {
|
||||
#if QUANT_AUXF == 1
|
||||
buf_a_dm[buf_ib] = get_d(ib);
|
||||
#else
|
||||
buf_a_dm[buf_ib] = get_dm(ib);
|
||||
#endif
|
||||
[[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
|
||||
block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs);
|
||||
}
|
||||
#if QUANT_R == 1
|
||||
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
|
||||
#else
|
||||
const i32vec2 vals = repack(ib, iqs);
|
||||
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
|
||||
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
|
||||
#endif
|
||||
}
|
||||
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
|
||||
#ifdef MUL_MAT_ID
|
||||
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
|
||||
const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x7;
|
||||
#else
|
||||
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
|
||||
const uint ib_outer = ib / 4;
|
||||
const uint ib_inner = ib % 4;
|
||||
|
||||
const uint iqs = loadr_b;
|
||||
#endif
|
||||
|
||||
const uint buf_ib = loadc_b + l;
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
|
||||
#ifdef MUL_MAT_ID
|
||||
const u16vec2 row_idx = row_ids[buf_ib];
|
||||
const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK;
|
||||
#else
|
||||
const uint ib = pos_b_ib + buf_ib * p.stride_b / BK;
|
||||
#endif
|
||||
const uint iqs = loadr_b;
|
||||
|
||||
[[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
|
||||
block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs);
|
||||
}
|
||||
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a_ib += 1;
|
||||
pos_b_ib += 1;
|
||||
pos_a_ib += BK_STEP;
|
||||
pos_b_ib += BK_STEP;
|
||||
|
||||
#ifdef COOPMAT
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
const uint ib_a = warp_r * WM + cm_row * TM;
|
||||
for (uint k_step = 0; k_step < BK_STEP; k_step++) {
|
||||
// Load from shared into cache
|
||||
coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
|
||||
|
||||
// TODO: only cache values that are actually needed
|
||||
[[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
|
||||
cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
const uint ib_b = warp_c * WN + cm_col * TN;
|
||||
coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
// TODO: only cache values that are actually needed
|
||||
[[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
|
||||
cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
|
||||
}
|
||||
|
||||
cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
|
||||
cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
|
||||
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
|
||||
}
|
||||
|
||||
coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
|
||||
}
|
||||
}
|
||||
#else
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
|
||||
cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
|
||||
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
|
||||
cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
|
||||
cache_b_ds[cc] = buf_b_ds[ib];
|
||||
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
|
||||
cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint cache_a_idx = wsir * TM + cr;
|
||||
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
||||
int32_t q_sum = 0;
|
||||
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
|
||||
q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
|
||||
cache_b_qs[cc * (BK / 4) + idx_k]);
|
||||
}
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint reg_ib = wsir * TM + cr;
|
||||
const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
|
||||
|
||||
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
|
||||
block_a_to_registers(reg_ib, k_step * BM + buf_ib);
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
|
||||
block_b_to_registers(ib);
|
||||
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint cache_a_idx = wsir * TM + cr;
|
||||
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
||||
|
||||
sums[sums_idx] += mmq_dot_product(cache_a_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
barrier();
|
||||
}
|
||||
@@ -373,54 +269,6 @@ void main() {
|
||||
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
#ifdef MUL_MAT_ID
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
|
||||
const uint row_i = dc + cm_col * TN + col + store_c;
|
||||
if (row_i >= _ne1) break;
|
||||
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
|
||||
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
|
||||
|
||||
if (is_aligned && is_in_bounds) {
|
||||
// Full coopMat is within bounds and stride_d is aligned with 16B
|
||||
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
|
||||
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
} else if (is_in_bounds) {
|
||||
// Full coopMat is within bounds, but stride_d is not aligned
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
|
||||
// Partial coopMat is within bounds
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
|
||||
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // MUL_MAT_ID
|
||||
#else
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
|
||||
@@ -431,19 +279,21 @@ void main() {
|
||||
const uint row_i = dc_warp + cc;
|
||||
if (row_i >= _ne1) break;
|
||||
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
const u16vec2 row_idx = row_ids[row_i - ic * BN];
|
||||
#endif // MUL_MAT_ID
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr;
|
||||
#ifdef MUL_MAT_ID
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
||||
if (dr_warp + cr < p.M) {
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
|
||||
}
|
||||
#else
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
||||
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
|
||||
}
|
||||
#endif // MUL_MAT_ID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // COOPMAT
|
||||
}
|
||||
|
||||
@@ -6,41 +6,89 @@
|
||||
|
||||
// Each iqs value maps to a 32-bit integer
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
||||
// 2-byte loads for Q4_0 blocks (18 bytes)
|
||||
// 4-byte loads for Q4_1 blocks (20 bytes)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
// Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
|
||||
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
|
||||
data_a[ib].qs[iqs * 2 + 1]);
|
||||
#ifdef DATA_A_Q4_0
|
||||
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]);
|
||||
const uint32_t vui = pack32(quants);
|
||||
return i32vec2( vui & 0x0F0F0F0F,
|
||||
(vui >> 4) & 0x0F0F0F0F);
|
||||
#else // DATA_A_Q4_1
|
||||
const uint32_t vui = data_a_packed32[ib].qs[iqs];
|
||||
return i32vec2( vui & 0x0F0F0F0F,
|
||||
(vui >> 4) & 0x0F0F0F0F);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef DATA_A_Q4_0
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_1)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
// Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4
|
||||
const uint32_t vui = data_a_packed32[ib].qs[iqs];
|
||||
return i32vec2( vui & 0x0F0F0F0F,
|
||||
(vui >> 4) & 0x0F0F0F0F);
|
||||
}
|
||||
|
||||
#else // DATA_A_Q4_1
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_0)
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
#ifdef DATA_A_Q4_0
|
||||
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
|
||||
}
|
||||
#else // DATA_A_Q4_1
|
||||
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
int32_t q_sum = 0;
|
||||
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||
const uint32_t vui = cache_a[ib_a].qs[iqs];
|
||||
const i32vec2 qs_a = i32vec2( vui & 0x0F0F0F0F,
|
||||
(vui >> 4) & 0x0F0F0F0F);
|
||||
|
||||
const int32_t qs_b0 = cache_b.qs[iqs];
|
||||
const int32_t qs_b1 = cache_b.qs[iqs + 4];
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a.x, qs_b0);
|
||||
q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
|
||||
}
|
||||
|
||||
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
|
||||
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
// 2-byte loads for Q5_0 blocks (22 bytes)
|
||||
// 4-byte loads for Q5_1 blocks (24 bytes)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
// Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4
|
||||
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
|
||||
data_a[ib].qs[iqs * 2 + 1]);
|
||||
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]);
|
||||
const uint32_t vui = pack32(quants);
|
||||
const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs));
|
||||
#ifdef DATA_A_Q5_0
|
||||
const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
|
||||
#else // DATA_A_Q5_1
|
||||
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
|
||||
#endif
|
||||
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
|
||||
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
||||
|
||||
@@ -50,40 +98,457 @@ i32vec2 repack(uint ib, uint iqs) {
|
||||
return i32vec2(v0, v1);
|
||||
}
|
||||
|
||||
#ifdef DATA_A_Q5_0
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_1)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
// Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4
|
||||
const uint32_t vui = data_a_packed32[ib].qs[iqs];
|
||||
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
|
||||
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
|
||||
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
||||
|
||||
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
|
||||
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
|
||||
|
||||
return i32vec2(v0, v1);
|
||||
}
|
||||
|
||||
#else // DATA_A_Q5_1
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
#ifdef DATA_A_Q5_0
|
||||
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
|
||||
buf_a[buf_ib].qh = pack32(u16vec2(data_a_packed16[ib].qh[0], data_a_packed16[ib].qh[1]));
|
||||
}
|
||||
#else // DATA_A_Q5_1
|
||||
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||
buf_a[buf_ib].qh = data_a_packed32[ib].qh;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
||||
cache_a[reg_ib].qh = buf_a[buf_ib].qh;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
int32_t q_sum = 0;
|
||||
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||
const uint32_t vui = cache_a[ib_a].qs[iqs];
|
||||
const int32_t qh = int32_t(cache_a[ib_a].qh >> (4 * iqs));
|
||||
const int32_t qs_a0 = int32_t(vui & 0x0F0F0F0F)
|
||||
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
||||
const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F)
|
||||
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
|
||||
|
||||
const int32_t qs_b0 = cache_b.qs[iqs];
|
||||
const int32_t qs_b1 = cache_b.qs[iqs + 4];
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a0, qs_b0);
|
||||
q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
|
||||
}
|
||||
|
||||
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
// 2-byte loads for Q8_0 blocks (34 bytes)
|
||||
int32_t repack(uint ib, uint iqs) {
|
||||
// Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4
|
||||
return pack32(i16vec2(data_a[ib].qs[iqs * 2 ],
|
||||
data_a[ib].qs[iqs * 2 + 1]));
|
||||
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(float(q_sum) * da * dsb.x);
|
||||
}
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
|
||||
}
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
int32_t q_sum = 0;
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
||||
const int32_t qs_b = cache_b.qs[iqs];
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a, qs_b);
|
||||
}
|
||||
|
||||
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_MXFP4)
|
||||
// 1-byte loads for mxfp4 blocks (17 bytes)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
|
||||
data_a[ib].qs[iqs * 4 + 1],
|
||||
data_a[ib].qs[iqs * 4 + 2],
|
||||
data_a[ib].qs[iqs * 4 + 3]));
|
||||
|
||||
return i32vec2( quants & 0x0F0F0F0F,
|
||||
(quants >> 4) & 0x0F0F0F0F);
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(da * dsb.x * float(q_sum));
|
||||
}
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
|
||||
data_a[ib].qs[iqs * 4 + 1],
|
||||
data_a[ib].qs[iqs * 4 + 2],
|
||||
data_a[ib].qs[iqs * 4 + 3]));
|
||||
|
||||
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
|
||||
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
|
||||
|
||||
buf_a[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w]));
|
||||
buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]));
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5);
|
||||
}
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].d = buf_a[buf_ib].d;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
int32_t q_sum = 0;
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||
}
|
||||
|
||||
return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
|
||||
// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
|
||||
#if defined(DATA_A_Q2_K)
|
||||
// 4-byte loads for Q2_K blocks (84 bytes)
|
||||
int32_t repack(uint ib, uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||
|
||||
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
|
||||
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
|
||||
|
||||
return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
|
||||
}
|
||||
|
||||
uint8_t get_scale(uint ib, uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||
|
||||
return data_a[ib_k].scales[iqs_k / 4];
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m)));
|
||||
}
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
|
||||
|
||||
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
|
||||
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
|
||||
|
||||
// Repack 4x4 quants into one int
|
||||
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
|
||||
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
|
||||
const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
|
||||
const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
|
||||
|
||||
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
|
||||
buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]);
|
||||
}
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
||||
cache_a[reg_ib].scales = buf_a[buf_ib].scales;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 2; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
int32_t sum_d = 0;
|
||||
int32_t sum_m = 0;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||
const uint8_t scale = cache_a[ib_a].scales[iqs / 4];
|
||||
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
|
||||
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303);
|
||||
|
||||
sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
|
||||
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
|
||||
}
|
||||
|
||||
return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q3_K)
|
||||
// 2-byte loads for Q3_K blocks (110 bytes)
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint hm_idx = iqs * QUANT_R_MMQ;
|
||||
const uint iqs_k = (ib % 8) * 8 + hm_idx;
|
||||
|
||||
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
|
||||
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
|
||||
const uint hm_shift = iqs_k / 8;
|
||||
|
||||
// Repack 2x4 quants into one int
|
||||
// Add the 3rd bit instead of subtracting it to allow packing the quants
|
||||
const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) |
|
||||
(pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4);
|
||||
|
||||
if (iqs == 0) {
|
||||
const uint is = iqs_k / 4;
|
||||
const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
|
||||
(((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4)));
|
||||
|
||||
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32);
|
||||
}
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
float result = 0.0;
|
||||
int32_t q_sum = 0;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||
// Subtract 4 from the quants to correct the 3rd bit offset
|
||||
const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||
}
|
||||
result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
|
||||
q_sum = 0;
|
||||
|
||||
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
|
||||
const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||
}
|
||||
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
|
||||
|
||||
return ACC_TYPE(cache_b.ds.x * result);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
|
||||
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
|
||||
}
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
|
||||
|
||||
const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
|
||||
const uint qs_shift = ((iqs_k % 16) / 8) * 4;
|
||||
|
||||
// Repack 2x4 quants into one int
|
||||
#if defined(DATA_A_Q4_K)
|
||||
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
|
||||
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
|
||||
|
||||
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4);
|
||||
#else // defined(DATA_A_Q5_K)
|
||||
const uint qh_idx = iqs * QUANT_R_MMQ;
|
||||
const uint qh_shift = iqs_k / 8;
|
||||
|
||||
buf_a[buf_ib].qs[iqs] = int32_t(((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x0F0F0F0F) |
|
||||
(((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
|
||||
#endif
|
||||
|
||||
|
||||
if (iqs == 0) {
|
||||
// Scale index
|
||||
const uint is = iqs_k / 8;
|
||||
u8vec2 scale_dm;
|
||||
if (is < 4) {
|
||||
scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
|
||||
} else {
|
||||
scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
|
||||
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
|
||||
}
|
||||
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
|
||||
}
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8 / QUANT_R_MMQ; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
int32_t q_sum = 0;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||
#if defined(DATA_A_Q4_K)
|
||||
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F);
|
||||
#else // defined(DATA_A_Q5_K)
|
||||
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
||||
#endif
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||
}
|
||||
|
||||
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint ib_outer = ib / 4;
|
||||
const uint ib_inner = ib % 4;
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
|
||||
}
|
||||
|
||||
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
|
||||
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
|
||||
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
|
||||
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
|
||||
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
|
||||
}
|
||||
|
||||
void block_b_to_registers(const uint ib) {
|
||||
cache_b.ds = buf_b[ib].ds;
|
||||
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
|
||||
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q6_K)
|
||||
// 2-byte loads for Q6_K blocks (210 bytes)
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||
|
||||
const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
|
||||
const uint ql_shift = ((iqs_k % 32) / 16) * 4;
|
||||
|
||||
const uint qh_idx = (iqs_k / 32) * 8 + iqs;
|
||||
const uint qh_shift = ((iqs_k % 32) / 8) * 2;
|
||||
|
||||
const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||
const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||
buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y));
|
||||
|
||||
if (iqs == 0) {
|
||||
const uint is = iqs_k / 4;
|
||||
const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2]);
|
||||
|
||||
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
|
||||
}
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
float result = 0.0;
|
||||
int32_t q_sum = 0;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||
}
|
||||
result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
|
||||
q_sum = 0;
|
||||
|
||||
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
|
||||
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||
}
|
||||
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
|
||||
|
||||
return ACC_TYPE(cache_b.ds.x * result);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
@@ -103,3 +568,10 @@ FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q2_K)
|
||||
FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
||||
const uint ib_k = ib / 8;
|
||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[16/4];
|
||||
FLOAT_TYPE dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[16/4];
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[16/4];
|
||||
uint32_t qh;
|
||||
FLOAT_TYPE dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[16/4];
|
||||
uint32_t qh;
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
#define QUANT_R_MMQ 1
|
||||
// AMD likes 4, Intel likes 1 and Nvidia likes 2
|
||||
#define BK_STEP 1
|
||||
struct block_a_cache {
|
||||
int32_t qs[32/4];
|
||||
FLOAT_TYPE dm;
|
||||
};
|
||||
#elif defined(DATA_A_MXFP4)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE d;
|
||||
};
|
||||
#elif defined(DATA_A_Q2_K)
|
||||
#define QUANT_R_MMQ 4
|
||||
struct block_a_cache {
|
||||
uint32_t qs[2];
|
||||
u8vec2 scales;
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q3_K)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[4];
|
||||
FLOAT_TYPE_VEC2 d_scales;
|
||||
};
|
||||
#elif defined(DATA_A_Q4_K)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[4];
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q5_K)
|
||||
#define QUANT_R_MMQ 1
|
||||
struct block_a_cache {
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q6_K)
|
||||
#define QUANT_R_MMQ 1
|
||||
struct block_a_cache {
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE_VEC2 d_scales;
|
||||
};
|
||||
#endif
|
||||
|
||||
struct block_b_cache
|
||||
{
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE_VEC2 ds;
|
||||
};
|
||||
@@ -1,6 +1,9 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#if USE_SUBGROUP_ADD
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#endif
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
@@ -84,35 +87,47 @@ void main() {
|
||||
}
|
||||
|
||||
barrier();
|
||||
for (uint w = D_STATE; w > SUBGROUP_SIZE; w >>= 1) {
|
||||
[[unroll]] for (uint j = 0; j < ((w >> 1) * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
|
||||
const uint k = (tid % (w >> 1)) +
|
||||
(D_STATE * (tid / (w >> 1))) +
|
||||
j * D_STATE * (D_STATE / (w >> 1));
|
||||
if (k < SPLIT_H * D_STATE && (k + (w >> 1)) < SPLIT_H * D_STATE) {
|
||||
stateC[k] += stateC[k + (w >> 1)];
|
||||
[[unroll]]
|
||||
for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) {
|
||||
[[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
|
||||
const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w);
|
||||
if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) {
|
||||
stateC[k] += stateC[k + w];
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
[[unroll]] for (uint j = 0; j <= SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) {
|
||||
[[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) {
|
||||
const uint idx = (tid % SUBGROUP_SIZE) +
|
||||
D_STATE * (tid / SUBGROUP_SIZE) +
|
||||
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
|
||||
const uint max_idx = SUBGROUP_SIZE - 1 +
|
||||
D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) +
|
||||
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
|
||||
|
||||
uint lane = tid % SUBGROUP_SIZE;
|
||||
|
||||
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
|
||||
if (idx + offset < SPLIT_H * D_STATE) {
|
||||
stateC[idx] += stateC[idx + offset];
|
||||
if (idx < SPLIT_H * D_STATE ||
|
||||
max_idx < SPLIT_H * D_STATE) {
|
||||
float sc;
|
||||
#if USE_SUBGROUP_ADD
|
||||
sc = stateC[idx];
|
||||
sc = subgroupAdd(sc);
|
||||
#else
|
||||
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
|
||||
if (idx + offset < SPLIT_H * D_STATE) {
|
||||
stateC[idx] += stateC[idx + offset];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
if (tid % SUBGROUP_SIZE == 0) {
|
||||
sc = stateC[idx];
|
||||
}
|
||||
#endif
|
||||
|
||||
if (idx < SPLIT_H * D_STATE && tid % SUBGROUP_SIZE == 0) {
|
||||
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
|
||||
d[y_base_idx + i * stride_y + k] = stateC[idx];
|
||||
if (tid % SUBGROUP_SIZE == 0) {
|
||||
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
|
||||
d[y_base_idx + i * stride_y + k] = sc;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ layout (push_constant) uniform parameter
|
||||
{
|
||||
uint n_rows;
|
||||
uint n_expert_used;
|
||||
float clamp_min;
|
||||
float clamp_max;
|
||||
};
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
||||
@@ -18,6 +20,7 @@ layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
||||
layout(constant_id = 0) const uint WARP_SIZE = 32;
|
||||
layout(constant_id = 1) const uint n_experts = 512;
|
||||
layout(constant_id = 2) const bool with_norm = true;
|
||||
layout(constant_id = 3) const bool late_softmax = false;
|
||||
|
||||
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
|
||||
|
||||
@@ -25,6 +28,52 @@ layout (binding = 0, std430) readonly buffer Logits {float logits[];};
|
||||
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
|
||||
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
|
||||
|
||||
const float INFINITY = 1.0 / 0.0;
|
||||
|
||||
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
|
||||
void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {
|
||||
float max_val = -INFINITY;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const uint idx = lane + i * WARP_SIZE;
|
||||
const bool is_active = !use_limit || (idx < limit);
|
||||
if (is_active) {
|
||||
max_val = max(max_val, vals[i]);
|
||||
}
|
||||
}
|
||||
|
||||
max_val = subgroupMax(max_val);
|
||||
|
||||
float sum = 0.f;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const uint idx = lane + i * WARP_SIZE;
|
||||
const bool is_active = !use_limit || (idx < limit);
|
||||
if (is_active) {
|
||||
const float val = exp(vals[i] - max_val);
|
||||
vals[i] = val;
|
||||
sum += val;
|
||||
} else {
|
||||
vals[i] = 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
sum = subgroupAdd(sum);
|
||||
|
||||
const float inv_sum = 1.0f / sum;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const uint idx = lane + i * WARP_SIZE;
|
||||
const bool is_active = !use_limit || (idx < limit);
|
||||
if (is_active) {
|
||||
vals[i] *= inv_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
|
||||
if (row >= n_rows) {
|
||||
@@ -35,43 +84,16 @@ void main() {
|
||||
const uint weights_offset = n_expert_used * row;
|
||||
const uint ids_offset = n_experts * row;
|
||||
|
||||
float logits_r[experts_per_thread];
|
||||
|
||||
const float INFINITY = 1.0 / 0.0;
|
||||
float wt[experts_per_thread];
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
||||
const uint expert = i + gl_LocalInvocationID.x;
|
||||
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
|
||||
const uint expert = i + gl_LocalInvocationID.x;
|
||||
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
|
||||
}
|
||||
|
||||
float max_val = logits_r[0];
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 1; i < experts_per_thread; i++) {
|
||||
const float val = logits_r[i];
|
||||
max_val = max(val, max_val);
|
||||
}
|
||||
|
||||
max_val = subgroupMax(max_val);
|
||||
|
||||
float wt[experts_per_thread];
|
||||
float tmp = 0.f;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const float val = logits_r[i];
|
||||
wt[i] = exp(val - max_val);
|
||||
tmp += wt[i];
|
||||
}
|
||||
|
||||
tmp = subgroupAdd(tmp);
|
||||
|
||||
const float inv_sum = 1.0f / tmp;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
wt[i] = wt[i] * inv_sum;
|
||||
if (!late_softmax) {
|
||||
softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false);
|
||||
}
|
||||
|
||||
// at this point, each thread holds a portion of softmax,
|
||||
@@ -82,6 +104,11 @@ void main() {
|
||||
|
||||
float output_weights[experts_per_thread];
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
output_weights[i] = 0.f;
|
||||
}
|
||||
|
||||
for (int k = 0; k < n_expert_used; k++) {
|
||||
float max_val = wt[0];
|
||||
uint max_expert = gl_LocalInvocationID.x;
|
||||
@@ -121,6 +148,7 @@ void main() {
|
||||
|
||||
if (with_norm) {
|
||||
wt_sum = subgroupAdd(wt_sum);
|
||||
wt_sum = clamp(wt_sum, clamp_min, clamp_max);
|
||||
const float inv_sum = 1.0f / wt_sum;
|
||||
|
||||
[[unroll]]
|
||||
@@ -129,6 +157,10 @@ void main() {
|
||||
}
|
||||
}
|
||||
|
||||
if (late_softmax) {
|
||||
softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true);
|
||||
}
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < experts_per_thread; ++i) {
|
||||
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
|
||||
|
||||
@@ -66,6 +66,7 @@ struct block_q4_0_packed16
|
||||
#define QUANT_AUXF 1
|
||||
#define A_TYPE block_q4_0
|
||||
#define A_TYPE_PACKED16 block_q4_0_packed16
|
||||
#define DATA_A_QUANT_LEGACY
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q4_1 32
|
||||
@@ -98,6 +99,7 @@ struct block_q4_1_packed32
|
||||
#define A_TYPE block_q4_1
|
||||
#define A_TYPE_PACKED16 block_q4_1_packed16
|
||||
#define A_TYPE_PACKED32 block_q4_1_packed32
|
||||
#define DATA_A_QUANT_LEGACY
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q5_0 32
|
||||
@@ -123,6 +125,7 @@ struct block_q5_0_packed16
|
||||
#define QUANT_AUXF 1
|
||||
#define A_TYPE block_q5_0
|
||||
#define A_TYPE_PACKED16 block_q5_0_packed16
|
||||
#define DATA_A_QUANT_LEGACY
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q5_1 32
|
||||
@@ -158,6 +161,7 @@ struct block_q5_1_packed32
|
||||
#define A_TYPE block_q5_1
|
||||
#define A_TYPE_PACKED16 block_q5_1_packed16
|
||||
#define A_TYPE_PACKED32 block_q5_1_packed32
|
||||
#define DATA_A_QUANT_LEGACY
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q8_0 32
|
||||
@@ -186,6 +190,7 @@ struct block_q8_0_packed32
|
||||
#define A_TYPE block_q8_0
|
||||
#define A_TYPE_PACKED16 block_q8_0_packed16
|
||||
#define A_TYPE_PACKED32 block_q8_0_packed32
|
||||
#define DATA_A_QUANT_LEGACY
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q8_1 32
|
||||
@@ -226,21 +231,21 @@ struct block_q2_K
|
||||
{
|
||||
uint8_t scales[QUANT_K_Q2_K/16];
|
||||
uint8_t qs[QUANT_K_Q2_K/4];
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
};
|
||||
|
||||
struct block_q2_K_packed16
|
||||
{
|
||||
uint16_t scales[QUANT_K_Q2_K/16/2];
|
||||
uint16_t qs[QUANT_K_Q2_K/4/2];
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
};
|
||||
|
||||
struct block_q2_K_packed32
|
||||
{
|
||||
uint32_t scales[QUANT_K_Q2_K/16/4];
|
||||
uint32_t qs[QUANT_K_Q2_K/4/4];
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
};
|
||||
|
||||
#if defined(DATA_A_Q2_K)
|
||||
@@ -249,6 +254,8 @@ struct block_q2_K_packed32
|
||||
#define A_TYPE block_q2_K
|
||||
#define A_TYPE_PACKED16 block_q2_K_packed16
|
||||
#define A_TYPE_PACKED32 block_q2_K_packed32
|
||||
#define SCALES_PER_32 2
|
||||
#define DATA_A_QUANT_K
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q3_K 256
|
||||
@@ -274,27 +281,28 @@ struct block_q3_K_packed16
|
||||
#define QUANT_R 1
|
||||
#define A_TYPE block_q3_K
|
||||
#define A_TYPE_PACKED16 block_q3_K_packed16
|
||||
#define DATA_A_QUANT_K
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q4_K 256
|
||||
|
||||
struct block_q4_K
|
||||
{
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
uint8_t scales[3*QUANT_K_Q4_K/64];
|
||||
uint8_t qs[QUANT_K_Q4_K/2];
|
||||
};
|
||||
|
||||
struct block_q4_K_packed16
|
||||
{
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
uint16_t scales[3*QUANT_K_Q4_K/64/2];
|
||||
uint16_t qs[QUANT_K_Q4_K/2/2];
|
||||
};
|
||||
|
||||
struct block_q4_K_packed32
|
||||
{
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
uint32_t scales[3*QUANT_K_Q4_K/64/4];
|
||||
uint32_t qs[QUANT_K_Q4_K/2/4];
|
||||
};
|
||||
@@ -310,13 +318,14 @@ struct block_q4_K_packed128
|
||||
#define A_TYPE block_q4_K
|
||||
#define A_TYPE_PACKED16 block_q4_K_packed16
|
||||
#define A_TYPE_PACKED32 block_q4_K_packed32
|
||||
#define DATA_A_QUANT_K
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q5_K 256
|
||||
|
||||
struct block_q5_K
|
||||
{
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
uint8_t scales[12];
|
||||
uint8_t qh[QUANT_K_Q5_K/8];
|
||||
uint8_t qs[QUANT_K_Q5_K/2];
|
||||
@@ -324,12 +333,20 @@ struct block_q5_K
|
||||
|
||||
struct block_q5_K_packed16
|
||||
{
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
uint16_t scales[12/2];
|
||||
uint16_t qh[QUANT_K_Q5_K/8/2];
|
||||
uint16_t qs[QUANT_K_Q5_K/2/2];
|
||||
};
|
||||
|
||||
struct block_q5_K_packed32
|
||||
{
|
||||
f16vec2 dm;
|
||||
uint32_t scales[12/4];
|
||||
uint32_t qh[QUANT_K_Q5_K/8/4];
|
||||
uint32_t qs[QUANT_K_Q5_K/2/4];
|
||||
};
|
||||
|
||||
struct block_q5_K_packed128
|
||||
{
|
||||
uvec4 q5k[11];
|
||||
@@ -340,6 +357,8 @@ struct block_q5_K_packed128
|
||||
#define QUANT_R 1
|
||||
#define A_TYPE block_q5_K
|
||||
#define A_TYPE_PACKED16 block_q5_K_packed16
|
||||
#define A_TYPE_PACKED32 block_q5_K_packed32
|
||||
#define DATA_A_QUANT_K
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q6_K 256
|
||||
@@ -356,7 +375,7 @@ struct block_q6_K_packed16
|
||||
{
|
||||
uint16_t ql[QUANT_K_Q6_K/2/2];
|
||||
uint16_t qh[QUANT_K_Q6_K/4/2];
|
||||
int8_t scales[QUANT_K_Q6_K/16];
|
||||
int16_t scales[QUANT_K_Q6_K/16/2];
|
||||
float16_t d;
|
||||
};
|
||||
|
||||
@@ -365,6 +384,7 @@ struct block_q6_K_packed16
|
||||
#define QUANT_R 1
|
||||
#define A_TYPE block_q6_K
|
||||
#define A_TYPE_PACKED16 block_q6_K_packed16
|
||||
#define DATA_A_QUANT_K
|
||||
#endif
|
||||
|
||||
// IQuants
|
||||
@@ -1363,18 +1383,11 @@ struct block_mxfp4
|
||||
uint8_t qs[QUANT_K_MXFP4/2];
|
||||
};
|
||||
|
||||
//struct block_mxfp4_packed16
|
||||
//{
|
||||
// uint8_t e;
|
||||
// uint16_t qs[QUANT_K_MXFP4/2/2];
|
||||
//};
|
||||
|
||||
#if defined(DATA_A_MXFP4)
|
||||
#define QUANT_K QUANT_K_MXFP4
|
||||
#define QUANT_R QUANT_R_MXFP4
|
||||
#define QUANT_AUXF 1
|
||||
#define A_TYPE block_mxfp4
|
||||
//#define A_TYPE_PACKED16 block_mxfp4_packed16
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
|
||||
@@ -1397,12 +1410,12 @@ void init_iq_shmem(uvec3 wgsize)
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_MXFP4)
|
||||
const FLOAT_TYPE kvalues_mxfp4_const[16] = {
|
||||
FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f),
|
||||
FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f)
|
||||
const int8_t kvalues_mxfp4_const[16] = {
|
||||
int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12),
|
||||
int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12),
|
||||
};
|
||||
|
||||
shared FLOAT_TYPE kvalues_mxfp4[16];
|
||||
shared int8_t kvalues_mxfp4[16];
|
||||
|
||||
#define NEEDS_INIT_IQ_SHMEM
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
|
||||
@@ -7,6 +7,7 @@ layout (push_constant) uniform parameter
|
||||
uint nb00; uint nb01; uint nb02; uint nb03;
|
||||
uint ne10; uint ne11; uint ne12; uint ne13;
|
||||
float sf0; float sf1; float sf2; float sf3;
|
||||
float pixel_offset;
|
||||
} p;
|
||||
|
||||
#include "types.glsl"
|
||||
@@ -19,7 +20,6 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
// from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag
|
||||
#define NEAREST 0
|
||||
#define BILINEAR 1
|
||||
#define ALIGN_CORNERS (1 << 8)
|
||||
|
||||
layout (constant_id = 0) const uint scale_mode = 0;
|
||||
|
||||
@@ -52,7 +52,7 @@ float fetch_bilinear(ivec2 c0, ivec2 c1, vec2 d, uint i12, uint i13) {
|
||||
float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {
|
||||
const ivec2 ne0 = ivec2(p.ne00, p.ne01);
|
||||
|
||||
const vec2 c = (vec2(i10, i11) + 0.5) / vec2(p.sf0, p.sf1) - 0.5;
|
||||
const vec2 c = (vec2(i10, i11) + p.pixel_offset) / vec2(p.sf0, p.sf1) - p.pixel_offset;
|
||||
const vec2 c0f = floor(c);
|
||||
const vec2 d = c - c0f;
|
||||
const ivec2 c0 = max(ivec2(c0f), 0);
|
||||
@@ -61,16 +61,6 @@ float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {
|
||||
return fetch_bilinear(c0, c1, d, i12, i13);
|
||||
}
|
||||
|
||||
float interpolate_bilinear_align_corners(uint i10, uint i11, uint i12, uint i13) {
|
||||
const vec2 c = vec2(i10, i11) / vec2(p.sf0, p.sf1);
|
||||
const vec2 c0f = floor(c);
|
||||
const vec2 d = c - c0f;
|
||||
const ivec2 c0 = ivec2(c0f);
|
||||
const ivec2 c1 = c0 + 1;
|
||||
|
||||
return fetch_bilinear(c0, c1, d, i12, i13);
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
@@ -91,9 +81,6 @@ void main() {
|
||||
case BILINEAR:
|
||||
result = interpolate_bilinear(i10, i11, i12, i13);
|
||||
break;
|
||||
case BILINEAR | ALIGN_CORNERS:
|
||||
result = interpolate_bilinear_align_corners(i10, i11, i12, i13);
|
||||
break;
|
||||
}
|
||||
|
||||
data_d[p.d_offset + idx] = D_TYPE(result);
|
||||
|
||||
@@ -566,7 +566,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
}
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) {
|
||||
// Integer dot mmq performs better with f32 accumulators
|
||||
if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
|
||||
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
#endif
|
||||
@@ -574,7 +575,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
}
|
||||
|
||||
void process_shaders() {
|
||||
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
|
||||
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
|
||||
|
||||
// matmul
|
||||
for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
|
||||
@@ -916,7 +917,8 @@ void process_shaders() {
|
||||
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
|
||||
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
|
||||
|
||||
string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
|
||||
string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
|
||||
string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
|
||||
|
||||
string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}});
|
||||
|
||||
|
||||
@@ -3062,6 +3062,7 @@ class VisionProjectorType:
|
||||
VOXTRAL = "voxtral"
|
||||
LFM2 = "lfm2"
|
||||
KIMIVL = "kimivl"
|
||||
LIGHTONOCR = "lightonocr"
|
||||
|
||||
|
||||
# Items here are (block size, type size)
|
||||
|
||||
@@ -14,12 +14,12 @@ except ImportError:
|
||||
SentencePieceProcessor = None
|
||||
|
||||
try:
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
from mistral_common.tokens.tokenizers.utils import (
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
|
||||
_filter_valid_tokenizer_files,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
except ImportError:
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
{{- bos_token -}}
|
||||
{%- set system_prompt = "" -%}
|
||||
{%- set ns = namespace(system_prompt="") -%}
|
||||
{%- if messages[0]["role"] == "system" -%}
|
||||
{%- set ns.system_prompt = messages[0]["content"] -%}
|
||||
{%- set messages = messages[1:] -%}
|
||||
{%- endif -%}
|
||||
{%- if tools -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "List of tools: <|tool_list_start|>[" -%}
|
||||
{%- for tool in tools -%}
|
||||
{%- if tool is not string -%}
|
||||
{%- set tool = tool | tojson -%}
|
||||
{%- endif -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + tool -%}
|
||||
{%- if not loop.last -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + ", " -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + "]<|tool_list_end|>" -%}
|
||||
{%- endif -%}
|
||||
{%- if ns.system_prompt -%}
|
||||
{{- "<|im_start|>system\n" + ns.system_prompt + "<|im_end|>\n" -}}
|
||||
{%- endif -%}
|
||||
{%- for message in messages -%}
|
||||
{{- "<|im_start|>" + message["role"] + "\n" -}}
|
||||
{%- set content = message["content"] -%}
|
||||
{%- if content is not string -%}
|
||||
{%- set content = content | tojson -%}
|
||||
{%- endif -%}
|
||||
{%- if message["role"] == "tool" -%}
|
||||
{%- set content = "<|tool_response_start|>" + content + "<|tool_response_end|>" -%}
|
||||
{%- endif -%}
|
||||
{{- content + "<|im_end|>\n" -}}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{- "<|im_start|>assistant\n" -}}
|
||||
{%- endif -%}
|
||||
@@ -1,5 +1,3 @@
|
||||
mistral-common>=1.8.3
|
||||
|
||||
-r ./requirements-convert_legacy_llama.txt
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
|
||||
@@ -35,5 +35,6 @@ adb $adbserial shell " \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$ndev $nhvx $opmask ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \
|
||||
-t 4 --batch-size 128 -ngl 99 $@ \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--batch-size 128 -ngl 99 $@ \
|
||||
"
|
||||
|
||||
@@ -45,8 +45,9 @@ adb $adbserial shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $ndev \
|
||||
./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \
|
||||
-t 4 --ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $ndev \
|
||||
./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on \
|
||||
-ngl 99 --device $device $cli_opts $@ \
|
||||
"
|
||||
|
||||
+42
-27
@@ -215,6 +215,7 @@ bool llama_batch_allocr::init(
|
||||
/*.n_seq_tokens =*/ (uint32_t) 1,
|
||||
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
|
||||
/*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
|
||||
/*.n_pos =*/ n_pos_per_embd,
|
||||
/*.token =*/ batch.token,
|
||||
/*.embd =*/ batch.embd,
|
||||
/*.pos =*/ batch.pos,
|
||||
@@ -251,45 +252,57 @@ bool llama_batch_allocr::init(
|
||||
// consistency checks
|
||||
//
|
||||
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_pos[s].empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
|
||||
|
||||
if (p0 >= 0) {
|
||||
bool ok = true;
|
||||
|
||||
if (batch.token) {
|
||||
if (seq_pos_min(s) != p0 + 1) {
|
||||
ok = false;
|
||||
}
|
||||
} else {
|
||||
assert(batch.embd);
|
||||
|
||||
// for embeddings (typically used as vision input), we allow them to have repeating positions
|
||||
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
|
||||
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
|
||||
ok = false;
|
||||
}
|
||||
if (n_pos_per_embd > 1) {
|
||||
// M-RoPE case: allow position to "jump" forward only (non-continuous positions are allowed)
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_pos[s].empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
|
||||
|
||||
if (p0 >= 0 && p0 >= seq_pos_min(s)) {
|
||||
LLAMA_LOG_ERROR(
|
||||
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
||||
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
||||
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
||||
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
|
||||
" for M-RoPE, it is required that the position satisfies: X < Y\n",
|
||||
__func__, s, s, p0, s, seq_pos_min(s));
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_pos[s].empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
||||
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
|
||||
return false;
|
||||
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
|
||||
|
||||
if (p0 >= 0) {
|
||||
bool ok = true;
|
||||
|
||||
if (seq_pos_min(s) != p0 + 1) {
|
||||
ok = false;
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
LLAMA_LOG_ERROR(
|
||||
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
||||
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
||||
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
||||
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
|
||||
__func__, s, s, p0, s, seq_pos_min(s));
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
||||
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -389,6 +402,7 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
|
||||
/*.n_seq_tokens =*/ n_seq_tokens,
|
||||
/*.n_seqs =*/ n_seqs,
|
||||
/*.n_seqs_unq =*/ n_seqs,
|
||||
/*.n_pos =*/ n_pos_per_embd,
|
||||
|
||||
/*.token =*/ udata->token.data(),
|
||||
/*.embd =*/ nullptr,
|
||||
@@ -710,6 +724,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
||||
/*.n_seq_tokens =*/ n_tokens/n_seqs,
|
||||
/*.n_seqs =*/ n_seqs,
|
||||
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
|
||||
/*.n_pos =*/ n_pos_per_embd,
|
||||
|
||||
/*.token =*/ batch.token ? udata->token.data() : nullptr,
|
||||
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
|
||||
|
||||
+12
-1
@@ -17,6 +17,16 @@ struct llama_ubatch {
|
||||
return b_equal_seqs != 0;
|
||||
}
|
||||
|
||||
// typical for M-RoPE cases:
|
||||
// 0 - sequantial position of the tokens/embeddings in the sequence
|
||||
// 1 - y position in the image
|
||||
// 2 - x position in the image
|
||||
// 3 - other
|
||||
bool is_pos_2d() const {
|
||||
// TODO @ngxson : we may need to check for model arch when more models use >1 positions
|
||||
return n_pos >= 3;
|
||||
}
|
||||
|
||||
uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
|
||||
// otherwise address sanitizer complains
|
||||
// TODO: whole_seqs for embeddings?
|
||||
@@ -25,6 +35,7 @@ struct llama_ubatch {
|
||||
uint32_t n_seq_tokens; // tokens per sequence set
|
||||
uint32_t n_seqs; // sequence sets in the ubatch
|
||||
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
|
||||
uint32_t n_pos; // number of position inputs for each token/embedding
|
||||
|
||||
// seq_id_unq: unique sequence ids in the ubatch
|
||||
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
|
||||
@@ -33,7 +44,7 @@ struct llama_ubatch {
|
||||
// // size | idx | val
|
||||
llama_token * token; // [n_tokens] | i | id, token
|
||||
float * embd; // [n_embd, n_tokens] | i | embd
|
||||
llama_pos * pos; // [n_tokens] | i | pos
|
||||
llama_pos * pos; // [n_tokens*n_pos] | i | pos
|
||||
int32_t * n_seq_id; // [n_tokens] | i | -
|
||||
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
|
||||
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
|
||||
|
||||
@@ -268,9 +268,7 @@ llama_context::llama_context(
|
||||
if (pipeline_parallel) {
|
||||
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
|
||||
}
|
||||
}
|
||||
|
||||
if (!hparams.vocab_only) {
|
||||
llama_memory_context_ptr mctx;
|
||||
if (memory) {
|
||||
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
|
||||
@@ -343,7 +341,14 @@ llama_context::llama_context(
|
||||
{
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
if (pipeline_parallel) {
|
||||
LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
|
||||
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload));
|
||||
gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
}
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
}
|
||||
}
|
||||
|
||||
n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
|
||||
|
||||
+9
-4
@@ -810,6 +810,9 @@ ggml_tensor * llm_graph_context::build_ffn(
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
//expand here so that we can fuse ffn gate
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
if (gate && type_gate == LLM_FFN_PAR) {
|
||||
cur = ggml_mul(ctx0, cur, tmp);
|
||||
cb(cur, "ffn_gate_par", il);
|
||||
@@ -1006,10 +1009,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
|
||||
cb(weights_sum, "ffn_moe_weights_sum", il);
|
||||
|
||||
if (arch == LLM_ARCH_BAILINGMOE2) {
|
||||
weights_sum = ggml_scale_bias(ctx0, weights_sum, 1.0, 1e-20);
|
||||
cb(weights_sum, "ffn_moe_weights_sum_biased", il);
|
||||
}
|
||||
// Avoid division by zero, clamp to smallest number representable by F16
|
||||
weights_sum = ggml_clamp(ctx0, weights_sum, 6.103515625e-5, INFINITY);
|
||||
cb(weights_sum, "ffn_moe_weights_sum_clamped", il);
|
||||
|
||||
weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
|
||||
cb(weights, "ffn_moe_weights_norm", il);
|
||||
@@ -1091,6 +1093,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
//expand here so that we can fuse ffn gate
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
||||
cb(experts, "ffn_moe_down", il);
|
||||
|
||||
|
||||
+55
-20
@@ -8,6 +8,7 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <stdexcept>
|
||||
@@ -37,8 +38,15 @@ llama_kv_cache::llama_kv_cache(
|
||||
|
||||
const uint32_t n_layer_kv = hparams.n_layer_kv();
|
||||
|
||||
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
|
||||
struct ggml_backend_buft_comparator {
|
||||
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
|
||||
return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
|
||||
}
|
||||
};
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
|
||||
|
||||
// create a context for each buffer type
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
||||
auto it = ctx_map.find(buft);
|
||||
if (it == ctx_map.end()) {
|
||||
@@ -53,13 +61,12 @@ llama_kv_cache::llama_kv_cache(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ctx_map[buft] = ctx;
|
||||
ctxs.emplace_back(ctx);
|
||||
ctx_map.emplace(buft, ctx);
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
return it->second;
|
||||
return it->second.get();
|
||||
};
|
||||
|
||||
GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
|
||||
@@ -167,11 +174,8 @@ llama_kv_cache::llama_kv_cache(
|
||||
}
|
||||
|
||||
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
||||
for (auto it : ctx_map) {
|
||||
auto * buft = it.first;
|
||||
auto * ctx = it.second;
|
||||
|
||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
||||
for (auto & [buft, ctx] : ctx_map) {
|
||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
|
||||
if (!buf) {
|
||||
throw std::runtime_error("failed to allocate buffer for kv cache");
|
||||
}
|
||||
@@ -179,7 +183,7 @@ llama_kv_cache::llama_kv_cache(
|
||||
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
||||
|
||||
ggml_backend_buffer_clear(buf, 0);
|
||||
bufs.emplace_back(buf);
|
||||
ctxs_bufs.emplace_back(std::move(ctx), buf);
|
||||
}
|
||||
|
||||
{
|
||||
@@ -203,7 +207,7 @@ void llama_kv_cache::clear(bool data) {
|
||||
}
|
||||
|
||||
if (data) {
|
||||
for (auto & buf : bufs) {
|
||||
for (auto & [_, buf] : ctxs_bufs) {
|
||||
ggml_backend_buffer_clear(buf.get(), 0);
|
||||
}
|
||||
}
|
||||
@@ -334,6 +338,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
|
||||
llama_pos pos = v_cells[s0].pos_get(i);
|
||||
llama_pos shift = v_cells[s0].get_shift(i);
|
||||
|
||||
llama_kv_cell_ext ext = v_cells[s0].ext_get(i);
|
||||
|
||||
if (shift != 0) {
|
||||
pos -= shift;
|
||||
assert(pos >= 0);
|
||||
@@ -345,6 +351,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
|
||||
if (shift != 0) {
|
||||
v_cells[s1].pos_add(i, shift);
|
||||
}
|
||||
|
||||
v_cells[s1].ext_set(i, ext);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -379,6 +387,7 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
|
||||
|
||||
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
|
||||
|
||||
auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
auto & head = v_heads[seq_to_stream[seq_id]];
|
||||
@@ -423,6 +432,7 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
|
||||
|
||||
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
|
||||
|
||||
auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
|
||||
@@ -472,8 +482,8 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
|
||||
std::map<ggml_backend_buffer_type_t, size_t> ret;
|
||||
for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
|
||||
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
|
||||
for (const auto & [_, buf] : ctxs_bufs) {
|
||||
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
@@ -896,6 +906,14 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
|
||||
|
||||
cells.pos_set(idx, ubatch.pos[i]);
|
||||
|
||||
if (ubatch.is_pos_2d()) {
|
||||
llama_kv_cell_ext ext {
|
||||
/*.x =*/ ubatch.pos[i + ubatch.n_tokens*2],
|
||||
/*.y =*/ ubatch.pos[i + ubatch.n_tokens],
|
||||
};
|
||||
cells.ext_set(idx, ext);
|
||||
}
|
||||
|
||||
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
||||
cells.seq_add(idx, ubatch.seq_id[i][s]);
|
||||
}
|
||||
@@ -957,10 +975,14 @@ bool llama_kv_cache::get_has_shift() const {
|
||||
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
|
||||
uint32_t result = 0;
|
||||
|
||||
// pad the n_kv value so that the graph remains constant across batches and can be reused
|
||||
// note: this also helps some backends with performance (f.ex https://github.com/ggml-org/llama.cpp/pull/16812#issuecomment-3455112220)
|
||||
const uint32_t n_pad_cur = std::max(n_pad, 256u);
|
||||
|
||||
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
|
||||
const auto & cells = v_cells[sinfo.strm[s]];
|
||||
|
||||
result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
|
||||
result = std::max(std::min(cells.size(), std::max(n_pad_cur, GGML_PAD(cells.used_max_p1(), n_pad_cur))), result);
|
||||
}
|
||||
|
||||
return result;
|
||||
@@ -1239,6 +1261,11 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
|
||||
|
||||
const llama_pos p1 = ubatch->pos[i];
|
||||
|
||||
// for M-RoPE
|
||||
const bool is_2d = ubatch->is_pos_2d();
|
||||
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
|
||||
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
|
||||
|
||||
const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
|
||||
|
||||
for (uint32_t j = 0; j < n_kv; ++j) {
|
||||
@@ -1258,6 +1285,14 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
|
||||
continue;
|
||||
}
|
||||
|
||||
// M-RoPE causal mask
|
||||
if (causal_attn && is_2d && p0 == p1) {
|
||||
const auto & p0_ext = cells.ext_get(j);
|
||||
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// apply SWA if any
|
||||
if (is_masked_swa(p0, p1)) {
|
||||
continue;
|
||||
@@ -1298,7 +1333,7 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch
|
||||
size_t llama_kv_cache::total_size() const {
|
||||
size_t size = 0;
|
||||
|
||||
for (const auto & buf : bufs) {
|
||||
for (const auto & [_, buf] : ctxs_bufs) {
|
||||
size += ggml_backend_buffer_get_size(buf.get());
|
||||
}
|
||||
|
||||
@@ -1551,6 +1586,9 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t
|
||||
io.write(&pos, sizeof(pos));
|
||||
io.write(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
// TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it
|
||||
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
|
||||
|
||||
for (const auto & seq_id : seq_ids) {
|
||||
io.write(&seq_id, sizeof(seq_id));
|
||||
}
|
||||
@@ -1696,6 +1734,8 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: we cannot yet restore llama_kv_cell_ext as the apply_ubatch() does not support it yet
|
||||
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
|
||||
apply_ubatch(sinfo, ubatch);
|
||||
|
||||
const auto head_cur = sinfo.head();
|
||||
@@ -2010,8 +2050,3 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
||||
void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
kv->set_input_pos_bucket(dst, ubatch);
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
|
||||
// the FA kernels require padding to avoid extra runtime boundary checks
|
||||
return cparams.flash_attn ? 256u : 32u;
|
||||
}
|
||||
|
||||
@@ -19,8 +19,6 @@ struct llama_context;
|
||||
|
||||
class llama_kv_cache : public llama_memory_i {
|
||||
public:
|
||||
static uint32_t get_padding(const llama_cparams & cparams);
|
||||
|
||||
struct stream_copy_info {
|
||||
bool empty() const {
|
||||
assert(ssrc.size() == sdst.size());
|
||||
@@ -217,8 +215,8 @@ private:
|
||||
// this is the SWA type of the cache - not to be confused with the model SWA type
|
||||
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
// ggml contexts for the KV cache along with the allocated backend buffers:
|
||||
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
|
||||
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
|
||||
+44
-2
@@ -5,9 +5,27 @@
|
||||
|
||||
#include <bitset>
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
struct llama_kv_cell_ext {
|
||||
// 2D spatial positions, typically used for M-RoPE
|
||||
llama_pos x = 0;
|
||||
llama_pos y = 0;
|
||||
|
||||
// return true if the current 2D spatial position is greater than other
|
||||
bool is_2d_gt(llama_pos ox, llama_pos oy) const {
|
||||
return (y > oy) || (y == oy && x > ox);
|
||||
}
|
||||
|
||||
void reset() {
|
||||
static_assert(std::is_trivially_copyable_v<llama_kv_cell_ext>);
|
||||
|
||||
memset(this, 0, sizeof(*this));
|
||||
}
|
||||
};
|
||||
|
||||
// meta information about KV cells that can be part of multiple sequences at the same time
|
||||
// TODO: add unit tests
|
||||
@@ -16,6 +34,7 @@ public:
|
||||
void reset() {
|
||||
for (uint32_t i = 0; i < pos.size(); ++i) {
|
||||
pos[i] = -1;
|
||||
ext[i].reset();
|
||||
shift[i] = 0;
|
||||
seq[i].reset();
|
||||
}
|
||||
@@ -43,6 +62,7 @@ public:
|
||||
|
||||
void resize(uint32_t n) {
|
||||
pos.resize(n);
|
||||
ext.resize(n);
|
||||
shift.resize(n);
|
||||
seq.resize(n);
|
||||
|
||||
@@ -108,6 +128,7 @@ public:
|
||||
const auto idx = i + j;
|
||||
|
||||
res.pos[j] = pos[idx];
|
||||
res.ext[j] = ext[idx];
|
||||
res.seq[j] = seq[idx];
|
||||
|
||||
assert(shift[idx] == 0);
|
||||
@@ -126,6 +147,7 @@ public:
|
||||
const auto idx = idxs[j];
|
||||
|
||||
res.pos[j] = pos[idx];
|
||||
res.ext[j] = ext[idx];
|
||||
res.seq[j] = seq[idx];
|
||||
|
||||
assert(shift[idx] == 0);
|
||||
@@ -154,6 +176,7 @@ public:
|
||||
}
|
||||
|
||||
pos[idx] = other.pos[j];
|
||||
ext[idx] = other.ext[j];
|
||||
seq[idx] = other.seq[j];
|
||||
|
||||
if (pos[idx] != -1) {
|
||||
@@ -184,6 +207,7 @@ public:
|
||||
}
|
||||
|
||||
pos[idx] = other.pos[j];
|
||||
ext[idx] = other.ext[j];
|
||||
seq[idx] = other.seq[j];
|
||||
|
||||
if (pos[idx] != -1) {
|
||||
@@ -203,6 +227,7 @@ public:
|
||||
seq[i].reset();
|
||||
|
||||
pos[i] = -1;
|
||||
ext[i].reset();
|
||||
shift[i] = 0;
|
||||
|
||||
used.erase(i);
|
||||
@@ -221,6 +246,7 @@ public:
|
||||
|
||||
if (seq[i].none()) {
|
||||
pos[i] = -1;
|
||||
ext[i].reset();
|
||||
shift[i] = 0;
|
||||
|
||||
used.erase(i);
|
||||
@@ -250,6 +276,7 @@ public:
|
||||
seq[i].reset();
|
||||
|
||||
pos[i] = -1;
|
||||
ext[i].reset();
|
||||
shift[i] = 0;
|
||||
|
||||
used.erase(i);
|
||||
@@ -340,6 +367,13 @@ public:
|
||||
return pos[i];
|
||||
}
|
||||
|
||||
const llama_kv_cell_ext & ext_get(uint32_t i) const {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] != -1);
|
||||
|
||||
return ext[i];
|
||||
}
|
||||
|
||||
// note: call only if the cell is not empty
|
||||
llama_pos get_shift(uint32_t i) const {
|
||||
assert(i < pos.size());
|
||||
@@ -368,6 +402,11 @@ public:
|
||||
used.insert(i);
|
||||
}
|
||||
|
||||
void ext_set(uint32_t i, llama_kv_cell_ext p) {
|
||||
assert(i < ext.size());
|
||||
ext[i] = p;
|
||||
}
|
||||
|
||||
// pos[i] = pos[i] + d
|
||||
// sets "has_shift" to true
|
||||
// note: call only if the cell is not empty
|
||||
@@ -424,6 +463,9 @@ private:
|
||||
|
||||
std::vector<llama_pos> pos;
|
||||
|
||||
// stores extra info per cell
|
||||
std::vector<llama_kv_cell_ext> ext;
|
||||
|
||||
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
|
||||
// this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
|
||||
//
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <stdexcept>
|
||||
@@ -32,8 +33,15 @@ llama_memory_recurrent::llama_memory_recurrent(
|
||||
cells.clear();
|
||||
cells.resize(mem_size);
|
||||
|
||||
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
|
||||
struct ggml_backend_buft_comparator {
|
||||
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
|
||||
return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
|
||||
}
|
||||
};
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
|
||||
|
||||
// create a context for each buffer type
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
||||
auto it = ctx_map.find(buft);
|
||||
if (it == ctx_map.end()) {
|
||||
@@ -48,13 +56,12 @@ llama_memory_recurrent::llama_memory_recurrent(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ctx_map[buft] = ctx;
|
||||
ctxs.emplace_back(ctx);
|
||||
ctx_map.emplace(buft, ctx);
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
return it->second;
|
||||
return it->second.get();
|
||||
};
|
||||
|
||||
r_l.resize(n_layer);
|
||||
@@ -93,17 +100,14 @@ llama_memory_recurrent::llama_memory_recurrent(
|
||||
}
|
||||
|
||||
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
||||
for (auto it : ctx_map) {
|
||||
auto * buft = it.first;
|
||||
auto * ctx = it.second;
|
||||
|
||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
||||
for (auto & [buft, ctx] : ctx_map) {
|
||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
|
||||
if (!buf) {
|
||||
throw std::runtime_error("failed to allocate buffer for rs cache");
|
||||
}
|
||||
ggml_backend_buffer_clear(buf, 0);
|
||||
LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
||||
bufs.emplace_back(buf);
|
||||
ctxs_bufs.emplace_back(std::move(ctx), buf);
|
||||
}
|
||||
|
||||
{
|
||||
@@ -129,7 +133,7 @@ void llama_memory_recurrent::clear(bool data) {
|
||||
used = 0;
|
||||
|
||||
if (data) {
|
||||
for (auto & buf : bufs) {
|
||||
for (auto & [_, buf] : ctxs_bufs) {
|
||||
ggml_backend_buffer_clear(buf.get(), 0);
|
||||
}
|
||||
}
|
||||
@@ -364,8 +368,8 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
|
||||
std::map<ggml_backend_buffer_type_t, size_t> ret;
|
||||
for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
|
||||
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
|
||||
for (const auto & [_, buf] : ctxs_bufs) {
|
||||
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
@@ -662,7 +666,7 @@ bool llama_memory_recurrent::get_can_shift() const {
|
||||
|
||||
size_t llama_memory_recurrent::total_size() const {
|
||||
size_t size = 0;
|
||||
for (const auto & buf : bufs) {
|
||||
for (const auto & [_, buf] : ctxs_bufs) {
|
||||
size += ggml_backend_buffer_get_size(buf.get());
|
||||
}
|
||||
|
||||
|
||||
@@ -109,8 +109,8 @@ private:
|
||||
|
||||
const uint32_t n_seq_max = 1;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
// ggml contexts for the KV cache along with the allocated backend buffers:
|
||||
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
|
||||
+26
-32
@@ -15,7 +15,6 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cfloat>
|
||||
#include <cstring>
|
||||
#include <cmath>
|
||||
@@ -438,7 +437,7 @@ struct llama_model::impl {
|
||||
llama_mlocks mlock_mmaps;
|
||||
|
||||
// contexts where the model tensors metadata is stored as well ass the corresponding buffers:
|
||||
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
|
||||
std::vector<std::pair<ggml_context_ptr, std::vector<ggml_backend_buffer_ptr>>> ctxs_bufs;
|
||||
|
||||
buft_list_t cpu_buft_list;
|
||||
std::map<ggml_backend_dev_t, buft_list_t> gpu_buft_list;
|
||||
@@ -2232,7 +2231,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
|
||||
struct ggml_backend_buft_comparator {
|
||||
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
|
||||
return ggml_backend_buft_name(lhs) < ggml_backend_buft_name(rhs);
|
||||
return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
|
||||
}
|
||||
};
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
|
||||
@@ -6186,7 +6185,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr;
|
||||
bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev);
|
||||
|
||||
ggml_backend_buffer_t buf = nullptr;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) {
|
||||
for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
|
||||
// only the mmap region containing the tensors in the model is mapped to the backend buffer
|
||||
@@ -6199,15 +6198,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
continue;
|
||||
}
|
||||
const size_t max_size = ggml_get_max_tensor_size(ctx);
|
||||
buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
|
||||
ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
|
||||
if (buf == nullptr) {
|
||||
throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
|
||||
}
|
||||
bufs.emplace_back(buf);
|
||||
buf_map.emplace(idx, buf);
|
||||
}
|
||||
}
|
||||
else {
|
||||
buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
||||
if (buf == nullptr) {
|
||||
throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
|
||||
}
|
||||
@@ -6217,11 +6217,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
mlock_buf->init (ggml_backend_buffer_get_base(buf));
|
||||
mlock_buf->grow_to(ggml_backend_buffer_get_size(buf));
|
||||
}
|
||||
bufs.emplace_back(buf);
|
||||
for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
|
||||
buf_map.emplace(idx, buf);
|
||||
}
|
||||
}
|
||||
pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), buf);
|
||||
pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), std::move(bufs));
|
||||
|
||||
for (auto & buf : buf_map) {
|
||||
// indicate that this buffer contains weights
|
||||
@@ -6247,8 +6248,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
}
|
||||
|
||||
// print memory requirements per buffer type
|
||||
for (auto & [_, buf] : pimpl->ctxs_bufs) {
|
||||
LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
|
||||
for (auto & [_, bufs] : pimpl->ctxs_bufs) {
|
||||
for (auto & buf: bufs) {
|
||||
LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n",
|
||||
__func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
|
||||
}
|
||||
}
|
||||
|
||||
// populate tensors_by_name
|
||||
@@ -6300,8 +6304,10 @@ size_t llama_model::n_devices() const {
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, size_t> llama_model::memory_breakdown() const {
|
||||
std::map<ggml_backend_buffer_type_t, size_t> ret;
|
||||
for (const auto & [_, buf] : pimpl->ctxs_bufs) {
|
||||
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
|
||||
for (const auto & [_, bufs] : pimpl->ctxs_bufs) {
|
||||
for (const auto & buf : bufs) {
|
||||
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
@@ -6369,6 +6375,8 @@ void llama_model::print_info() const {
|
||||
LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
|
||||
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
|
||||
LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups);
|
||||
LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used);
|
||||
LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn);
|
||||
LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type);
|
||||
LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type);
|
||||
@@ -6469,8 +6477,6 @@ void llama_model::print_info() const {
|
||||
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
||||
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
|
||||
LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
|
||||
LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups);
|
||||
LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used);
|
||||
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
|
||||
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
|
||||
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
|
||||
@@ -17965,6 +17971,8 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
|
||||
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
@@ -19337,6 +19345,7 @@ struct llm_build_smallthinker : public llm_graph_context{
|
||||
|
||||
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
@@ -19632,7 +19641,7 @@ struct llm_build_apertus : public llm_graph_context {
|
||||
}
|
||||
};
|
||||
|
||||
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
|
||||
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, const llama_cparams & cparams) const {
|
||||
llama_memory_i * res;
|
||||
|
||||
switch (arch) {
|
||||
@@ -19683,17 +19692,13 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
};
|
||||
}
|
||||
|
||||
const auto padding = llama_kv_cache::get_padding(cparams);
|
||||
|
||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
|
||||
|
||||
res = new llama_memory_hybrid(
|
||||
/* model */ *this,
|
||||
/* attn_type_k */ params.type_k,
|
||||
/* attn_type_v */ params.type_v,
|
||||
/* attn_v_trans */ !cparams.flash_attn,
|
||||
/* attn_kv_size */ cparams.n_ctx,
|
||||
/* attn_n_pad */ padding,
|
||||
/* attn_n_pad */ 1,
|
||||
/* attn_n_swa */ hparams.n_swa,
|
||||
/* attn_swa_type */ hparams.swa_type,
|
||||
/* recurrent_type_k */ GGML_TYPE_F32,
|
||||
@@ -19705,23 +19710,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
/* filter_attn */ std::move(filter_attn),
|
||||
/* filter_recr */ std::move(filter_recr));
|
||||
} else {
|
||||
const auto padding = llama_kv_cache::get_padding(cparams);
|
||||
|
||||
uint32_t n_ctx_per_stream = cparams.n_ctx;
|
||||
|
||||
if (!cparams.kv_unified) {
|
||||
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
|
||||
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
|
||||
|
||||
cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max;
|
||||
} else {
|
||||
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
|
||||
|
||||
cparams.n_ctx = n_ctx_per_stream;
|
||||
}
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
||||
|
||||
llama_memory_i::layer_reuse_cb reuse = nullptr;
|
||||
|
||||
if (arch == LLM_ARCH_GEMMA3N) {
|
||||
@@ -19748,7 +19742,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
n_ctx_per_stream,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_ubatch,
|
||||
padding,
|
||||
1,
|
||||
nullptr,
|
||||
reuse);
|
||||
} else {
|
||||
@@ -19763,7 +19757,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
cparams.kv_unified,
|
||||
n_ctx_per_stream,
|
||||
cparams.n_seq_max,
|
||||
padding,
|
||||
1,
|
||||
hparams.n_swa,
|
||||
hparams.swa_type,
|
||||
nullptr,
|
||||
|
||||
+1
-2
@@ -500,9 +500,8 @@ struct llama_model {
|
||||
|
||||
ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const;
|
||||
|
||||
// note: can mutate `cparams`
|
||||
// TODO: move this to new llm_arch_model_i interface
|
||||
llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;
|
||||
llama_memory_i * create_memory(const llama_memory_params & params, const llama_cparams & cparams) const;
|
||||
|
||||
// TODO: move this to new llm_arch_model_i interface
|
||||
ggml_cgraph * build_graph(const llm_graph_params & params) const;
|
||||
|
||||
+211
-13
@@ -511,7 +511,7 @@ struct test_result {
|
||||
};
|
||||
|
||||
// Printer classes for different output formats
|
||||
enum class test_status_t { NOT_SUPPORTED, OK, FAIL };
|
||||
enum class test_status_t { NOT_SUPPORTED, OK, FAIL, SKIPPED };
|
||||
|
||||
struct test_operation_info {
|
||||
std::string op_name;
|
||||
@@ -687,6 +687,8 @@ struct printer {
|
||||
virtual void print_backend_status(const backend_status_info & info) { (void) info; }
|
||||
|
||||
virtual void print_overall_summary(const overall_summary_info & info) { (void) info; }
|
||||
|
||||
virtual void print_failed_tests(const std::vector<std::string> & failed_tests) { (void) failed_tests; }
|
||||
};
|
||||
|
||||
struct console_printer : public printer {
|
||||
@@ -804,6 +806,17 @@ struct console_printer : public printer {
|
||||
}
|
||||
}
|
||||
|
||||
void print_failed_tests(const std::vector<std::string> & failed_tests) override {
|
||||
if (failed_tests.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
printf("\nFailing tests:\n");
|
||||
for (const auto & test_name : failed_tests) {
|
||||
printf(" %s\n", test_name.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void print_test_console(const test_result & result) {
|
||||
printf(" %s(%s): ", result.op_name.c_str(), result.op_params.c_str());
|
||||
@@ -1056,6 +1069,8 @@ struct test_case {
|
||||
|
||||
std::vector<ggml_tensor *> sentinels;
|
||||
|
||||
std::string current_op_name;
|
||||
|
||||
void add_sentinel(ggml_context * ctx) {
|
||||
if (mode == MODE_PERF || mode == MODE_GRAD || mode == MODE_SUPPORT) {
|
||||
return;
|
||||
@@ -1127,7 +1142,10 @@ struct test_case {
|
||||
}
|
||||
}
|
||||
|
||||
bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_names_filter, printer * output_printer) {
|
||||
test_status_t eval(ggml_backend_t backend1,
|
||||
ggml_backend_t backend2,
|
||||
const char * op_names_filter,
|
||||
printer * output_printer) {
|
||||
mode = MODE_TEST;
|
||||
|
||||
ggml_init_params params = {
|
||||
@@ -1144,11 +1162,12 @@ struct test_case {
|
||||
add_sentinel(ctx);
|
||||
|
||||
ggml_tensor * out = build_graph(ctx);
|
||||
std::string current_op_name = op_desc(out);
|
||||
current_op_name = op_desc(out);
|
||||
|
||||
if (!matches_filter(out, op_names_filter)) {
|
||||
//printf(" %s: skipping\n", op_desc(out).c_str());
|
||||
ggml_free(ctx);
|
||||
return true;
|
||||
return test_status_t::SKIPPED;
|
||||
}
|
||||
|
||||
// check if the backends support the ops
|
||||
@@ -1172,7 +1191,7 @@ struct test_case {
|
||||
}
|
||||
|
||||
ggml_free(ctx);
|
||||
return true;
|
||||
return test_status_t::NOT_SUPPORTED;
|
||||
}
|
||||
|
||||
// post-graph sentinel
|
||||
@@ -1184,7 +1203,7 @@ struct test_case {
|
||||
if (buf == NULL) {
|
||||
printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1));
|
||||
ggml_free(ctx);
|
||||
return false;
|
||||
return test_status_t::FAIL;
|
||||
}
|
||||
|
||||
// build graph
|
||||
@@ -1289,7 +1308,7 @@ struct test_case {
|
||||
output_printer->print_test_result(result);
|
||||
}
|
||||
|
||||
return test_passed;
|
||||
return test_passed ? test_status_t::OK : test_status_t::FAIL;
|
||||
}
|
||||
|
||||
bool eval_perf(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {
|
||||
@@ -1306,7 +1325,7 @@ struct test_case {
|
||||
GGML_ASSERT(ctx);
|
||||
|
||||
ggml_tensor * out = build_graph(ctx.get());
|
||||
std::string current_op_name = op_desc(out);
|
||||
current_op_name = op_desc(out);
|
||||
if (!matches_filter(out, op_names_filter)) {
|
||||
//printf(" %s: skipping\n", op_desc(out).c_str());
|
||||
return true;
|
||||
@@ -1435,8 +1454,9 @@ struct test_case {
|
||||
ggml_context_ptr ctx(ggml_init(params)); // smart ptr
|
||||
GGML_ASSERT(ctx);
|
||||
|
||||
ggml_tensor * out = build_graph(ctx.get());
|
||||
std::string current_op_name = op_desc(out);
|
||||
ggml_tensor * out = build_graph(ctx.get());
|
||||
current_op_name = op_desc(out);
|
||||
|
||||
if (!matches_filter(out, op_names_filter)) {
|
||||
return true;
|
||||
}
|
||||
@@ -4712,6 +4732,7 @@ struct test_topk_moe: public test_case {
|
||||
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
|
||||
ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
|
||||
|
||||
weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY);
|
||||
out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens]
|
||||
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
|
||||
}
|
||||
@@ -4721,6 +4742,140 @@ struct test_topk_moe: public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
struct test_mul_mat_vec_fusion : public test_case {
|
||||
const ggml_type type;
|
||||
const ggml_glu_op glu_op;
|
||||
const int64_t m;
|
||||
const int64_t n;
|
||||
const int64_t k;
|
||||
const bool use_id;
|
||||
const int n_mats;
|
||||
const int n_used;
|
||||
const bool b; // broadcast b matrix (only for use_id)
|
||||
const bool with_bias;
|
||||
const bool with_gate;
|
||||
|
||||
test_mul_mat_vec_fusion(ggml_type type, ggml_glu_op op, int64_t m, int64_t n, int64_t k,
|
||||
bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true)
|
||||
: type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate) {
|
||||
if (use_id) {
|
||||
GGML_ASSERT(n_used <= n_mats);
|
||||
}
|
||||
}
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR11(type, glu_op, m, n, k, use_id, n_mats, n_used, b, with_bias, with_gate);
|
||||
}
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
return "MUL_MAT_VEC_FUSION";
|
||||
}
|
||||
|
||||
bool run_whole_graph() override { return true; }
|
||||
|
||||
ggml_tensor * build_gate(ggml_context * ctx, ggml_tensor * ffn_gate, ggml_tensor * ffn_up) {
|
||||
ggml_tensor * out = nullptr;
|
||||
if (with_gate) {
|
||||
if (glu_op == GGML_GLU_OP_SWIGLU_OAI) {
|
||||
constexpr float alpha = 1.702f;
|
||||
constexpr float limit = 7.0f;
|
||||
out = ggml_swiglu_oai(ctx, ffn_gate, ffn_up, alpha, limit);
|
||||
} else {
|
||||
out = ggml_glu_split(ctx, ffn_gate, ffn_up, glu_op);
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
if (!use_id) {
|
||||
std::array<int64_t, 4> ne = {k, m, 1, 1};
|
||||
std::array<int64_t, 4> ne0 = {k, n, 1, 1};
|
||||
|
||||
ggml_tensor * cur = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
|
||||
ggml_tensor * gate = with_gate ? ggml_new_tensor(ctx, type, 4, ne0.data()) : nullptr;
|
||||
ggml_tensor * up = ggml_new_tensor(ctx, type, 4, ne0.data());
|
||||
|
||||
ggml_tensor * ffn_up = ggml_mul_mat(ctx, up, cur);
|
||||
if (with_bias) {
|
||||
std::array<int64_t, 4> bias_ne = {ffn_up->ne[0], 1, 1, 1};
|
||||
ggml_tensor * up_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data());
|
||||
ffn_up = ggml_add(ctx, ffn_up, up_bias);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_gate = with_gate ? ggml_mul_mat(ctx, gate, cur) : nullptr;
|
||||
if (with_bias && with_gate) {
|
||||
std::array<int64_t, 4> bias_ne = {ffn_gate->ne[0], 1, 1, 1};
|
||||
ggml_tensor * gate_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data());
|
||||
ffn_gate = ggml_add(ctx, ffn_gate, gate_bias);
|
||||
}
|
||||
|
||||
ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;
|
||||
ggml_set_name(out, "out");
|
||||
return out;
|
||||
} else {
|
||||
ggml_tensor * gates = ggml_new_tensor_3d(ctx, type, k, n, n_mats);
|
||||
ggml_tensor * ups = ggml_new_tensor_3d(ctx, type, k, n, n_mats);
|
||||
ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, m);
|
||||
|
||||
if (n_used != n_mats) {
|
||||
ids = ggml_view_2d(ctx, ids, n_used, m, ids->nb[1], 0);
|
||||
}
|
||||
|
||||
ggml_tensor * cur = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k, this->b ? 1 : n_used, m);
|
||||
ggml_set_name(cur, "cur");
|
||||
|
||||
ggml_tensor * ffn_up = ggml_mul_mat_id(ctx, ups, cur, ids);
|
||||
if (with_bias) {
|
||||
ggml_tensor * up_bias_param = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ffn_up->ne[0], n_mats);
|
||||
ffn_up = ggml_add_id(ctx, ffn_up, up_bias_param, ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_gate = with_gate? ggml_mul_mat_id(ctx, gates, cur, ids) : nullptr;
|
||||
if (with_bias && with_gate) {
|
||||
ggml_tensor * gate_bias_param = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ffn_gate->ne[0], n_mats);
|
||||
ffn_gate = ggml_add_id(ctx, ffn_gate, gate_bias_param, ids);
|
||||
}
|
||||
|
||||
ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;
|
||||
ggml_set_name(out, "out");
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
if (!use_id) {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
} else {
|
||||
std::random_device rd;
|
||||
std::default_random_engine rng(rd());
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_I32) {
|
||||
if (ggml_is_view_op(t->op)) { continue; }
|
||||
// ids
|
||||
for (int64_t r = 0; r < ggml_nrows(t); r++) {
|
||||
std::vector<int32_t> data(t->ne[0]);
|
||||
for (int i = 0; i < t->ne[0]; i++) {
|
||||
data[i] = i % n_mats;
|
||||
}
|
||||
std::shuffle(data.begin(), data.end(), rng);
|
||||
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
|
||||
}
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
return 5e-3;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_SUM
|
||||
struct test_sum : public test_case {
|
||||
const ggml_type type;
|
||||
@@ -6407,6 +6562,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
add_test_bin_bcast(type, {1, 1, 640, 1}, {32, 32, 1, 1});
|
||||
add_test_bin_bcast(type, {5120, 1, 1, 1}, {1, 256, 1, 1});
|
||||
add_test_bin_bcast(type, {640, 1, 1, 1}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(type, {64, 262144, 1, 1}, {1, 1, 1, 1});
|
||||
//add_test_bin_bcast(type, {3, 3, 2560, 1280}, {1, 1, 1, 1});
|
||||
//add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1});
|
||||
}
|
||||
@@ -6562,6 +6718,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 1024, {3, 2}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 1024, {3, 2}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 1024, {3, 2}, {1, 1}));
|
||||
|
||||
// test cases with large batch size
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {1536, 1}, {1, 1}));
|
||||
}
|
||||
}
|
||||
for (ggml_type type_a : other_types) {
|
||||
@@ -6890,6 +7049,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {5, 7, 11, 13}, {2, 5, 7, 11}, mode));
|
||||
}
|
||||
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS));
|
||||
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {1, 4, 3, 2}, {2, 8, 3, 2}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS));
|
||||
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {4, 1, 3, 2}, {1, 1, 3, 2}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS));
|
||||
|
||||
test_cases.emplace_back(new test_sum());
|
||||
test_cases.emplace_back(new test_sum_rows());
|
||||
@@ -6982,6 +7143,33 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
|
||||
test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, {10, 5, 4, 3}));
|
||||
|
||||
for (ggml_type type : base_types) {
|
||||
for (bool with_gate : {false, true}) {
|
||||
for (bool use_id : {false, true}) {
|
||||
for (bool b : {false, true}) {
|
||||
if (!use_id && b) {
|
||||
continue;
|
||||
}
|
||||
for (bool with_bias : {false, true}) {
|
||||
if (!with_gate && !with_bias) {
|
||||
continue;
|
||||
}
|
||||
for (ggml_glu_op glu_op : {GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU}) {
|
||||
if (!with_bias && glu_op == GGML_GLU_OP_SWIGLU_OAI) {
|
||||
continue;
|
||||
}
|
||||
if (!with_gate && glu_op != GGML_GLU_OP_SWIGLU) {
|
||||
continue;
|
||||
}
|
||||
test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256,
|
||||
use_id, 16, 8, b, with_bias, with_gate));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (bool with_norm : {false, true}) {
|
||||
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm));
|
||||
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm));
|
||||
@@ -7194,16 +7382,26 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
}
|
||||
|
||||
size_t n_ok = 0;
|
||||
size_t tests_run = 0;
|
||||
std::vector<std::string> failed_tests;
|
||||
for (auto & test : test_cases) {
|
||||
if (test->eval(backend, backend_cpu, op_names_filter, output_printer)) {
|
||||
test_status_t status = test->eval(backend, backend_cpu, op_names_filter, output_printer);
|
||||
if (status == test_status_t::SKIPPED || status == test_status_t::NOT_SUPPORTED) {
|
||||
continue;
|
||||
}
|
||||
tests_run++;
|
||||
if (status == test_status_t::OK) {
|
||||
n_ok++;
|
||||
} else if (status == test_status_t::FAIL) {
|
||||
failed_tests.push_back(test->current_op_name + "(" + test->vars() + ")");
|
||||
}
|
||||
}
|
||||
output_printer->print_summary(test_summary_info(n_ok, test_cases.size(), false));
|
||||
output_printer->print_summary(test_summary_info(n_ok, tests_run, false));
|
||||
output_printer->print_failed_tests(failed_tests);
|
||||
|
||||
ggml_backend_free(backend_cpu);
|
||||
|
||||
return n_ok == test_cases.size();
|
||||
return n_ok == tests_run;
|
||||
}
|
||||
|
||||
if (mode == MODE_GRAD) {
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
@@ -2138,6 +2139,154 @@ static void test_template_output_parsers() {
|
||||
|
||||
assert_equals(true, common_chat_templates_support_enable_thinking(tmpls.get()));
|
||||
}
|
||||
{
|
||||
// LFM2 format tests
|
||||
auto tmpls = read_templates("models/templates/llama-cpp-lfm2.jinja");
|
||||
std::vector<std::string> end_tokens{ "<|im_end|>" };
|
||||
|
||||
auto inputs_tools_forced_json_schema = std::invoke([&]() -> common_chat_templates_inputs {
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.messages = {
|
||||
std::invoke([&]() -> common_chat_msg {
|
||||
common_chat_msg msg;
|
||||
msg.role = "system";
|
||||
msg.content = "force json schema.\n";
|
||||
return msg;
|
||||
}),
|
||||
message_user,
|
||||
};
|
||||
inputs.tools = {special_function_tool};
|
||||
return inputs;
|
||||
});
|
||||
|
||||
{
|
||||
auto params = common_chat_templates_apply(tmpls.get(), inputs_no_tools);
|
||||
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params.format);
|
||||
assert_equals(false, params.grammar_lazy);
|
||||
assert_equals(std::string(R"(<|im_start|>user
|
||||
Hey there!<|im_end|>
|
||||
<|im_start|>assistant
|
||||
)"), params.prompt);
|
||||
}
|
||||
|
||||
{
|
||||
auto params = common_chat_templates_apply(tmpls.get(), inputs_tools);
|
||||
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params.format);
|
||||
assert_equals(false, params.grammar_lazy);
|
||||
assert_equals(std::string(R"(<|im_start|>system
|
||||
List of tools: <|tool_list_start|>[{"type": "function", "function": {"name": "special_function", "description": "I'm special", "parameters": {"type": "object", "properties": {"arg1": {"type": "integer", "description": "The arg."}}, "required": ["arg1"]}}}]<|tool_list_end|><|im_end|>
|
||||
<|im_start|>user
|
||||
Hey there!<|im_end|>
|
||||
<|im_start|>assistant
|
||||
)"), params.prompt);
|
||||
assert_equals(true, params.grammar.empty());
|
||||
}
|
||||
|
||||
{
|
||||
auto params = common_chat_templates_apply(tmpls.get(), inputs_tools_forced_json_schema);
|
||||
assert_equals(COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS, params.format);
|
||||
assert_equals(true, params.grammar_lazy);
|
||||
assert_equals(std::string(R"(<|im_start|>system
|
||||
List of tools: <|tool_list_start|>[{"type": "function", "function": {"name": "special_function", "description": "I'm special", "parameters": {"type": "object", "properties": {"arg1": {"type": "integer", "description": "The arg."}}, "required": ["arg1"]}}}]<|tool_list_end|><|im_end|>
|
||||
<|im_start|>user
|
||||
Hey there!<|im_end|>
|
||||
<|im_start|>assistant
|
||||
)"), params.prompt);
|
||||
assert_equals(false, params.grammar.empty());
|
||||
}
|
||||
|
||||
// Test parsing regular content
|
||||
assert_msg_equals(message_assist,
|
||||
common_chat_parse(
|
||||
"Hello, world!\nWhat's up?",
|
||||
/* is_partial= */ false,
|
||||
{COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS}));
|
||||
|
||||
// Test single tool call with JSON format
|
||||
common_chat_msg msg_single_tool_call;
|
||||
msg_single_tool_call.role = "assistant";
|
||||
msg_single_tool_call.tool_calls.push_back({"special_function", "{\"arg1\":1}", ""});
|
||||
assert_msg_equals(
|
||||
msg_single_tool_call,
|
||||
common_chat_parse(
|
||||
"<|tool_call_start|>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]<|tool_call_end|>",
|
||||
/* is_partial= */ false,
|
||||
{COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS}));
|
||||
|
||||
// Test tool call with string argument
|
||||
common_chat_msg msg_tool_call_string;
|
||||
msg_tool_call_string.role = "assistant";
|
||||
msg_tool_call_string.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""});
|
||||
assert_msg_equals(
|
||||
msg_tool_call_string,
|
||||
common_chat_parse(
|
||||
"<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}]<|tool_call_end|>",
|
||||
/* is_partial= */ false,
|
||||
{COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS}));
|
||||
|
||||
// Test tool call with multiple arguments
|
||||
common_chat_msg msg_multi_args;
|
||||
msg_multi_args.role = "assistant";
|
||||
msg_multi_args.tool_calls.push_back({"calculate", "{\"x\":10,\"y\":20,\"operation\":\"add\"}", ""});
|
||||
assert_msg_equals(
|
||||
msg_multi_args,
|
||||
common_chat_parse(
|
||||
"<|tool_call_start|>[{\"name\": \"calculate\", \"arguments\": {\"x\": 10, \"y\": 20, \"operation\": \"add\"}}]<|tool_call_end|>",
|
||||
/* is_partial= */ false,
|
||||
{COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS}));
|
||||
|
||||
// Test multiple tool calls in single array
|
||||
common_chat_msg msg_multiple_tools;
|
||||
msg_multiple_tools.role = "assistant";
|
||||
msg_multiple_tools.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""});
|
||||
msg_multiple_tools.tool_calls.push_back({"get_time", "{\"timezone\":\"UTC\"}", ""});
|
||||
assert_msg_equals(
|
||||
msg_multiple_tools,
|
||||
common_chat_parse(
|
||||
"<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}, {\"name\": \"get_time\", \"arguments\": {\"timezone\": \"UTC\"}}]<|tool_call_end|>",
|
||||
/* is_partial= */ false,
|
||||
{COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS}));
|
||||
|
||||
// Test tool call with content before
|
||||
common_chat_msg msg_content_before_tool;
|
||||
msg_content_before_tool.role = "assistant";
|
||||
msg_content_before_tool.content = "Let me check the weather for you.";
|
||||
msg_content_before_tool.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""});
|
||||
assert_msg_equals(
|
||||
msg_content_before_tool,
|
||||
common_chat_parse(
|
||||
"Let me check the weather for you.<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}]<|tool_call_end|>",
|
||||
/* is_partial= */ false,
|
||||
{COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS}));
|
||||
|
||||
// Test tool call with content after
|
||||
common_chat_msg msg_content_after_tool;
|
||||
msg_content_after_tool.role = "assistant";
|
||||
msg_content_after_tool.content = "Here's the result.";
|
||||
msg_content_after_tool.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""});
|
||||
assert_msg_equals(
|
||||
msg_content_after_tool,
|
||||
common_chat_parse(
|
||||
"<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}]<|tool_call_end|>Here's the result.",
|
||||
/* is_partial= */ false,
|
||||
{COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS}));
|
||||
|
||||
// Test tool call with newlines (common in LLM output)
|
||||
common_chat_msg msg_tool_call_newlines;
|
||||
msg_tool_call_newlines.role = "assistant";
|
||||
msg_tool_call_newlines.tool_calls.push_back({"get_current_time", "{\"location\":\"Paris\"}", ""});
|
||||
assert_msg_equals(
|
||||
msg_tool_call_newlines,
|
||||
common_chat_parse(
|
||||
"<|tool_call_start|>[{\n \"name\": \"get_current_time\",\n \"arguments\": {\n \"location\": \"Paris\"\n }\n}]<|tool_call_end|>",
|
||||
/* is_partial= */ false,
|
||||
{COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS}));
|
||||
|
||||
// Note: LFM2 uses JSON format for tool calls: [{"name": "...", "arguments": {...}}]
|
||||
// Unlike other formats, LFM2 template does not render tool calls in conversation history,
|
||||
// so we don't use test_templates() for tool call generation. Instead, the parsing tests
|
||||
// above verify edge cases and format variations for the tool call output format.
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -1124,9 +1124,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
foo ::= "{" space foo-a-kv "}" space
|
||||
foo-a-kv ::= "\"a\"" space ":" space string
|
||||
root ::= foo
|
||||
ref-definitions-foo ::= "{" space ref-definitions-foo-a-kv "}" space
|
||||
ref-definitions-foo-a-kv ::= "\"a\"" space ":" space string
|
||||
root ::= ref-definitions-foo
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
@@ -1151,20 +1151,58 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"type": "object"
|
||||
})""",
|
||||
R"""(
|
||||
alternative-0 ::= foo
|
||||
alternative-1 ::= bar
|
||||
bar ::= "{" space (bar-b-kv )? "}" space
|
||||
bar-b-kv ::= "\"b\"" space ":" space number
|
||||
alternative-0 ::= ref-definitions-foo
|
||||
alternative-1 ::= ref-definitions-bar
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
foo ::= "{" space (foo-a-kv )? "}" space
|
||||
foo-a-kv ::= "\"a\"" space ":" space number
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
ref-definitions-bar ::= "{" space (ref-definitions-bar-b-kv )? "}" space
|
||||
ref-definitions-bar-b-kv ::= "\"b\"" space ":" space number
|
||||
ref-definitions-foo ::= "{" space (ref-definitions-foo-a-kv )? "}" space
|
||||
ref-definitions-foo-a-kv ::= "\"a\"" space ":" space number
|
||||
root ::= alternative-0 | alternative-1
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"anyOf $ref",
|
||||
R"""({
|
||||
"properties": {
|
||||
"a": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "number"}
|
||||
]
|
||||
},
|
||||
"b": {
|
||||
"anyOf": [
|
||||
{"$ref": "#/properties/a/anyOf/0"},
|
||||
{"type": "boolean"}
|
||||
]
|
||||
}
|
||||
},
|
||||
"type": "object"
|
||||
})""",
|
||||
R"""(
|
||||
a ::= string | number
|
||||
a-kv ::= "\"a\"" space ":" space a
|
||||
a-rest ::= ( "," space b-kv )?
|
||||
b ::= b-0 | boolean
|
||||
b-0 ::= string
|
||||
b-kv ::= "\"b\"" space ":" space b
|
||||
boolean ::= ("true" | "false") space
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space (a-kv a-rest | b-kv )? "}" space
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"mix of allOf, anyOf and $ref (similar to https://json.schemastore.org/tsconfig.json)",
|
||||
|
||||
@@ -6,3 +6,8 @@ target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
endif()
|
||||
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "AIX")
|
||||
# AIX's flock() function comes from libbsd.a
|
||||
target_link_libraries(${TARGET} PRIVATE -lbsd)
|
||||
endif()
|
||||
|
||||
@@ -82,6 +82,9 @@ Using the `-d <n>` option, each test can be run at a specified context depth, pr
|
||||
|
||||
For a description of the other options, see the [main example](../main/README.md).
|
||||
|
||||
> [!NOTE]
|
||||
> The measurements with `llama-bench` do not include the times for tokenization and for sampling.
|
||||
|
||||
## Examples
|
||||
|
||||
### Text generation with different models
|
||||
@@ -131,7 +134,7 @@ $ ./llama-bench -n 0 -n 16 -p 64 -t 1,2,4,8,16,32
|
||||
| llama 7B mostly Q4_0 | 3.56 GiB | 6.74 B | CPU | 16 | pp 64 | 33.52 ± 0.03 |
|
||||
| llama 7B mostly Q4_0 | 3.56 GiB | 6.74 B | CPU | 16 | tg 16 | 15.32 ± 0.05 |
|
||||
| llama 7B mostly Q4_0 | 3.56 GiB | 6.74 B | CPU | 32 | pp 64 | 59.00 ± 1.11 |
|
||||
| llama 7B mostly Q4_0 | 3.56 GiB | 6.74 B | CPU | 32 | tg 16 | 16.41 ± 0.79 ||
|
||||
| llama 7B mostly Q4_0 | 3.56 GiB | 6.74 B | CPU | 32 | tg 16 | 16.41 ± 0.79 |
|
||||
|
||||
### Different numbers of layers offloaded to the GPU
|
||||
|
||||
|
||||
@@ -139,6 +139,7 @@ enum projector_type {
|
||||
PROJECTOR_TYPE_VOXTRAL,
|
||||
PROJECTOR_TYPE_LFM2,
|
||||
PROJECTOR_TYPE_KIMIVL,
|
||||
PROJECTOR_TYPE_LIGHTONOCR,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -161,6 +162,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
|
||||
{ PROJECTOR_TYPE_LFM2, "lfm2"},
|
||||
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
|
||||
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
|
||||
};
|
||||
|
||||
static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||
|
||||
+31
-7
@@ -171,7 +171,7 @@ struct clip_hparams {
|
||||
int32_t n_head;
|
||||
int32_t n_layer;
|
||||
// idefics3
|
||||
int32_t preproc_image_size = 0;
|
||||
int32_t preproc_image_size = 0; // aka max_dimension
|
||||
int32_t proj_scale_factor = 0;
|
||||
|
||||
float image_mean[3];
|
||||
@@ -621,7 +621,7 @@ struct clip_graph {
|
||||
}
|
||||
|
||||
// arrangement of the [IMG_BREAK] token
|
||||
{
|
||||
if (model.token_embd_img_break) {
|
||||
// not efficient, but works
|
||||
// the trick is to view the embeddings as a 3D tensor with shape [n_embd, n_patches_per_row, n_rows]
|
||||
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
|
||||
@@ -2095,6 +2095,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
res = graph.build_siglip();
|
||||
} break;
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||
{
|
||||
res = graph.build_pixtral();
|
||||
} break;
|
||||
@@ -2380,6 +2381,7 @@ struct clip_model_loader {
|
||||
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||
{
|
||||
hparams.rope_theta = 10000.0f;
|
||||
hparams.warmup_image_size = hparams.patch_size * 8;
|
||||
@@ -2722,6 +2724,15 @@ struct clip_model_loader {
|
||||
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
|
||||
model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||
{
|
||||
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
|
||||
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
|
||||
model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
|
||||
model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
|
||||
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
|
||||
model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
{
|
||||
model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
|
||||
@@ -3210,8 +3221,8 @@ struct image_manipulation {
|
||||
return {0, 0};
|
||||
}
|
||||
|
||||
float scale = std::min(1.0f, std::min(static_cast<float>(max_dimension) / inp_size.width,
|
||||
static_cast<float>(max_dimension) / inp_size.height));
|
||||
float scale = std::min(static_cast<float>(max_dimension) / inp_size.width,
|
||||
static_cast<float>(max_dimension) / inp_size.height);
|
||||
|
||||
float target_width_f = static_cast<float>(inp_size.width) * scale;
|
||||
float target_height_f = static_cast<float>(inp_size.height) * scale;
|
||||
@@ -3374,7 +3385,7 @@ struct llava_uhd {
|
||||
|
||||
// resize to overview size
|
||||
clip_image_u8_ptr resized_img(clip_image_u8_init());
|
||||
image_manipulation::bicubic_resize(*img, *resized_img, inst.overview_size.width, inst.overview_size.height);
|
||||
image_manipulation::resize_and_pad_image(*img, *resized_img, inst.overview_size);
|
||||
output.push_back(std::move(resized_img));
|
||||
if (inst.slices.empty()) {
|
||||
// no slices, just return the resized image
|
||||
@@ -3576,6 +3587,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||
// CITE: https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics3/image_processing_idefics3.py#L737
|
||||
const clip_image_size refined_size = image_manipulation::calc_size_preserved_ratio(
|
||||
original_size, params.image_size, params.preproc_image_size);
|
||||
// LOG_INF("%s: original size: %d x %d, refined size: %d x %d\n",
|
||||
// __func__, original_size.width, original_size.height,
|
||||
// refined_size.width, refined_size.height);
|
||||
|
||||
llava_uhd::slice_instructions instructions;
|
||||
instructions.overview_size = clip_image_size{params.image_size, params.image_size};
|
||||
@@ -3586,6 +3600,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||
};
|
||||
for (int y = 0; y < refined_size.height; y += params.image_size) {
|
||||
for (int x = 0; x < refined_size.width; x += params.image_size) {
|
||||
// LOG_INF("%s: adding slice at x=%d, y=%d\n", __func__, x, y);
|
||||
instructions.slices.push_back(llava_uhd::slice_coordinates{
|
||||
/* x */x,
|
||||
/* y */y,
|
||||
@@ -3622,7 +3637,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||
res_imgs->entries.push_back(std::move(img_f32));
|
||||
return true;
|
||||
|
||||
} else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL) {
|
||||
} else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL
|
||||
|| ctx->proj_type() == PROJECTOR_TYPE_LIGHTONOCR
|
||||
) {
|
||||
clip_image_u8 resized_image;
|
||||
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
|
||||
image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
|
||||
@@ -3865,12 +3882,17 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||
n_patches = x_patch * y_patch;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||
{
|
||||
// dynamic size
|
||||
int n_merge = params.spatial_merge_size;
|
||||
int n_patches_x = img->nx / patch_size / (n_merge > 0 ? n_merge : 1);
|
||||
int n_patches_y = img->ny / patch_size / (n_merge > 0 ? n_merge : 1);
|
||||
n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
|
||||
if (ctx->model.token_embd_img_break) {
|
||||
n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
|
||||
} else {
|
||||
n_patches = n_patches_y * n_patches_x;
|
||||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
@@ -4247,6 +4269,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
} break;
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||
{
|
||||
// set the 2D positions
|
||||
int n_patches_per_col = image_size_width / patch_size;
|
||||
@@ -4377,6 +4400,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
return ctx->model.mm_model_peg_0_b->ne[0];
|
||||
case PROJECTOR_TYPE_MLP:
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||
return ctx->model.mm_2_w->ne[1];
|
||||
case PROJECTOR_TYPE_MLP_NORM:
|
||||
return ctx->model.mm_3_b->ne[0];
|
||||
|
||||
+32
-15
@@ -76,9 +76,11 @@ struct mtmd_cli_context {
|
||||
|
||||
mtmd::bitmaps bitmaps;
|
||||
|
||||
// note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
|
||||
// so here we don't need to keep track of chat history
|
||||
// chat template
|
||||
common_chat_templates_ptr tmpls;
|
||||
std::vector<common_chat_msg> chat_history;
|
||||
bool use_jinja = false;
|
||||
// TODO: support for --system-prompt with /clear command
|
||||
|
||||
// support for legacy templates (models not having EOT token)
|
||||
llama_tokens antiprompt_tokens;
|
||||
@@ -108,6 +110,8 @@ struct mtmd_cli_context {
|
||||
}
|
||||
|
||||
tmpls = common_chat_templates_init(model, params.chat_template);
|
||||
use_jinja = params.use_jinja;
|
||||
chat_history.clear();
|
||||
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(tmpls.get(), params.use_jinja, params.default_template_kwargs).c_str());
|
||||
|
||||
init_vision_context(params);
|
||||
@@ -193,19 +197,33 @@ static int generate_response(mtmd_cli_context & ctx, int n_predict) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
std::string generated_text = common_detokenize(ctx.lctx, generated_tokens);
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
msg.content = generated_text;
|
||||
ctx.chat_history.push_back(std::move(msg));
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_bos = false) {
|
||||
common_chat_templates_inputs tmpl_inputs;
|
||||
tmpl_inputs.messages = {msg};
|
||||
tmpl_inputs.add_generation_prompt = true;
|
||||
tmpl_inputs.use_jinja = false; // jinja is buggy here
|
||||
auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs);
|
||||
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());
|
||||
static std::string chat_add_and_format(mtmd_cli_context & ctx, common_chat_msg & new_msg) {
|
||||
LOG_DBG("chat_add_and_format: new_msg.role='%s', new_msg.content='%s'\n",
|
||||
new_msg.role.c_str(), new_msg.content.c_str());
|
||||
auto formatted = common_chat_format_single(ctx.tmpls.get(), ctx.chat_history,
|
||||
new_msg, new_msg.role == "user",
|
||||
ctx.use_jinja);
|
||||
ctx.chat_history.push_back(new_msg);
|
||||
return formatted;
|
||||
}
|
||||
|
||||
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg) {
|
||||
bool add_bos = ctx.chat_history.empty();
|
||||
auto formatted_chat = chat_add_and_format(ctx, msg);
|
||||
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.c_str());
|
||||
|
||||
mtmd_input_text text;
|
||||
text.text = formatted_chat.prompt.c_str();
|
||||
text.text = formatted_chat.c_str();
|
||||
text.add_special = add_bos;
|
||||
text.parse_special = true;
|
||||
|
||||
@@ -303,7 +321,7 @@ int main(int argc, char ** argv) {
|
||||
return 1; // error is already printed by libmtmd
|
||||
}
|
||||
}
|
||||
if (eval_message(ctx, msg, true)) {
|
||||
if (eval_message(ctx, msg)) {
|
||||
return 1;
|
||||
}
|
||||
if (!g_is_interrupted && generate_response(ctx, n_predict)) {
|
||||
@@ -322,7 +340,6 @@ int main(int argc, char ** argv) {
|
||||
LOG("\n /quit or /exit exit the program");
|
||||
LOG("\n");
|
||||
|
||||
bool is_first_msg = true;
|
||||
std::string content;
|
||||
|
||||
while (!g_is_interrupted) {
|
||||
@@ -342,7 +359,8 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
if (line == "/clear") {
|
||||
ctx.n_past = 0;
|
||||
llama_memory_seq_rm(llama_get_memory(ctx.lctx), 0, 1, -1); // keep BOS
|
||||
ctx.chat_history.clear();
|
||||
llama_memory_clear(llama_get_memory(ctx.lctx), true);
|
||||
LOG("Chat history cleared\n\n");
|
||||
continue;
|
||||
}
|
||||
@@ -367,7 +385,7 @@ int main(int argc, char ** argv) {
|
||||
common_chat_msg msg;
|
||||
msg.role = "user";
|
||||
msg.content = content;
|
||||
int ret = eval_message(ctx, msg, is_first_msg);
|
||||
int ret = eval_message(ctx, msg);
|
||||
if (ret) {
|
||||
return 1;
|
||||
}
|
||||
@@ -376,7 +394,6 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
content.clear();
|
||||
is_first_msg = false;
|
||||
}
|
||||
}
|
||||
if (g_is_interrupted) LOG("\nInterrupted by user\n");
|
||||
|
||||
+17
-1
@@ -5,6 +5,15 @@
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
// fix problem with std::min and std::max
|
||||
#if defined(_WIN32)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#ifndef NOMINMAX
|
||||
# define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
#include <cerrno>
|
||||
#include <cstdio>
|
||||
@@ -275,6 +284,11 @@ struct mtmd_context {
|
||||
img_beg = "<img>";
|
||||
img_end = "</img>";
|
||||
|
||||
} else if (proj == PROJECTOR_TYPE_LIGHTONOCR) {
|
||||
// <|im_start|> ... (image embeddings) ... <|im_end|>
|
||||
img_beg = "<|im_start|>";
|
||||
img_end = "<|im_end|>";
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1026,7 +1040,9 @@ const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
|
||||
|
||||
llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
|
||||
if (image_tokens->use_mrope_pos) {
|
||||
return 1; // for M-RoPE, the whole image is 1 in temporal dimension
|
||||
// for M-RoPE, temporal dimension = max(t,h,w)
|
||||
// t is omitted as we don't support video input
|
||||
return std::max(image_tokens->nx, image_tokens->ny);
|
||||
}
|
||||
return image_tokens->n_tokens();
|
||||
}
|
||||
|
||||
+2
-2
@@ -153,7 +153,7 @@ MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd
|
||||
MTMD_API size_t mtmd_input_chunk_get_n_tokens (const mtmd_input_chunk * chunk);
|
||||
// returns nullptr for ID on text chunk
|
||||
MTMD_API const char * mtmd_input_chunk_get_id (const mtmd_input_chunk * chunk);
|
||||
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
|
||||
// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise)
|
||||
MTMD_API llama_pos mtmd_input_chunk_get_n_pos (const mtmd_input_chunk * chunk);
|
||||
|
||||
// in case you want to use custom logic to handle the chunk (i.e. KV cache management)
|
||||
@@ -171,7 +171,7 @@ MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * i
|
||||
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
|
||||
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
|
||||
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate
|
||||
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
|
||||
// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise)
|
||||
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate
|
||||
|
||||
// tokenize an input text prompt and a list of bitmaps (images/audio)
|
||||
|
||||
+5
-1
@@ -70,6 +70,7 @@ add_test_vision "ggml-org/InternVL3-1B-Instruct-GGUF:Q8_0"
|
||||
add_test_vision "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
|
||||
add_test_vision "ggml-org/LFM2-VL-450M-GGUF:Q8_0"
|
||||
add_test_vision "ggml-org/granite-docling-258M-GGUF:Q8_0"
|
||||
add_test_vision "ggml-org/LightOnOCR-1B-1025-GGUF:Q8_0"
|
||||
|
||||
add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0"
|
||||
add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
|
||||
@@ -138,7 +139,10 @@ for i in "${!arr_hf[@]}"; do
|
||||
|
||||
echo "$output" > $SCRIPT_DIR/output/$bin-$(echo "$hf" | tr '/' '-').log
|
||||
|
||||
if echo "$output" | grep -iq "new york"; then
|
||||
# either contains "new york" or both "men" and "walk"
|
||||
if echo "$output" | grep -iq "new york" \
|
||||
|| (echo "$output" | grep -iq "men" && echo "$output" | grep -iq "walk")
|
||||
then
|
||||
result="$prefix \033[32mOK\033[0m: $bin $hf"
|
||||
else
|
||||
result="$prefix \033[31mFAIL\033[0m: $bin $hf"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user