forked from wylab/llama.cpp
Compare commits
48 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 85cc1ae998 | |||
| 1d8d83deaa | |||
| c4e9239064 | |||
| 39842a7f73 | |||
| 0fd90db585 | |||
| 4c37636b3e | |||
| 34bdbbd7c2 | |||
| 74f52f77f2 | |||
| f7207b0415 | |||
| 4d917cd4f6 | |||
| 886b97a5d6 | |||
| 111f8d06f0 | |||
| 5eff6ec9b1 | |||
| dfd9b5f6c7 | |||
| 5a6bc6b1a6 | |||
| 6b64f74b55 | |||
| 0d5a470223 | |||
| b0ba31f525 | |||
| 7da9fed0d6 | |||
| c247d06f38 | |||
| 043fb27d38 | |||
| b730706a49 | |||
| c9a24fb932 | |||
| a9c6ffcbfa | |||
| e78cf0d4b1 | |||
| 710dfc465a | |||
| 611f419cff | |||
| b1afcab804 | |||
| 9ef536907d | |||
| 21dc4ddaf2 | |||
| 289bf4113e | |||
| b55f06e1aa | |||
| 0a9b43e507 | |||
| 330c3d2d21 | |||
| e92734d51b | |||
| 45363632cb | |||
| 32732f2459 | |||
| 92f7f0a53c | |||
| b1ab91821f | |||
| 9ebebef62f | |||
| ad5c975c2d | |||
| 4afb0a746f | |||
| e288693669 | |||
| a0f98dd604 | |||
| 54a241f505 | |||
| cd36b5e5c7 | |||
| 3f196be84b | |||
| 97ae5961a4 |
@@ -2,14 +2,30 @@ ARG UBUNTU_VERSION=24.04
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
# Install build tools
|
||||
RUN apt update && apt install -y git build-essential cmake wget
|
||||
# Ref: https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started.html
|
||||
|
||||
# Install Vulkan SDK and cURL
|
||||
RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - && \
|
||||
wget -qO /etc/apt/sources.list.d/lunarg-vulkan-noble.list https://packages.lunarg.com/vulkan/lunarg-vulkan-noble.list && \
|
||||
apt update -y && \
|
||||
apt-get install -y vulkan-sdk libcurl4-openssl-dev curl
|
||||
# Install build tools
|
||||
RUN apt update && apt install -y git build-essential cmake wget xz-utils
|
||||
|
||||
# Install Vulkan SDK
|
||||
ARG VULKAN_VERSION=1.4.321.1
|
||||
RUN ARCH=$(uname -m) && \
|
||||
wget -qO /tmp/vulkan-sdk.tar.xz https://sdk.lunarg.com/sdk/download/${VULKAN_VERSION}/linux/vulkan-sdk-linux-${ARCH}-${VULKAN_VERSION}.tar.xz && \
|
||||
mkdir -p /opt/vulkan && \
|
||||
tar -xf /tmp/vulkan-sdk.tar.xz -C /tmp --strip-components=1 && \
|
||||
mv /tmp/${ARCH}/* /opt/vulkan/ && \
|
||||
rm -rf /tmp/*
|
||||
|
||||
# Install cURL and Vulkan SDK dependencies
|
||||
RUN apt install -y libcurl4-openssl-dev curl \
|
||||
libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev
|
||||
|
||||
# Set environment variables
|
||||
ENV VULKAN_SDK=/opt/vulkan
|
||||
ENV PATH=$VULKAN_SDK/bin:$PATH
|
||||
ENV LD_LIBRARY_PATH=$VULKAN_SDK/lib:$LD_LIBRARY_PATH
|
||||
ENV CMAKE_PREFIX_PATH=$VULKAN_SDK:$CMAKE_PREFIX_PATH
|
||||
ENV PKG_CONFIG_PATH=$VULKAN_SDK/lib/pkgconfig:$PKG_CONFIG_PATH
|
||||
|
||||
# Build it
|
||||
WORKDIR /app
|
||||
|
||||
@@ -151,6 +151,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [x] [Bunny](https://github.com/BAAI-DCAI/Bunny)
|
||||
- [x] [GLM-EDGE](https://huggingface.co/models?search=glm-edge)
|
||||
- [x] [Qwen2-VL](https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d)
|
||||
- [x] [LFM2-VL](https://huggingface.co/collections/LiquidAI/lfm2-vl-68963bbc84a610f7638d5ffa)
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
+5
-3
@@ -1755,7 +1755,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params) {
|
||||
params.warmup = false;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL}));
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY}));
|
||||
add_opt(common_arg(
|
||||
{"--spm-infill"},
|
||||
string_format(
|
||||
@@ -2254,9 +2254,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
|
||||
add_opt(common_arg(
|
||||
{"-dt", "--defrag-thold"}, "N",
|
||||
string_format("KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold),
|
||||
string_format("KV cache defragmentation threshold (DEPRECATED)"),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.defrag_thold = std::stof(value);
|
||||
GGML_UNUSED(params);
|
||||
GGML_UNUSED(value);
|
||||
LOG_WRN("DEPRECATED: --defrag-thold is deprecated and no longer necessary to specify\n");
|
||||
}
|
||||
).set_env("LLAMA_ARG_DEFRAG_THOLD"));
|
||||
add_opt(common_arg(
|
||||
|
||||
+21
-1
@@ -1361,6 +1361,26 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
"<|end|>",
|
||||
};
|
||||
|
||||
if (!inputs.json_schema.is_null()) {
|
||||
data.grammar_lazy = false;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schema = inputs.json_schema;
|
||||
builder.resolve_refs(schema);
|
||||
|
||||
auto not_end = builder.add_rule("not-end",
|
||||
"[^<] | \"<\" [^|] | \"<|\" [^e] | \"<|e\" [^n] | \"<|en\" [^d] | \"<|end\" [^|] | \"<|end|\" [^>]");
|
||||
auto analysis = builder.add_rule("analysis",
|
||||
"\"<|channel|>analysis<|message|>\" ( " + not_end + " )* \"<|end|>\"");
|
||||
auto constraint = builder.add_rule("constraint", "\"<|constrain|>\"? [a-zA-Z0-9_-]+");
|
||||
auto final = builder.add_rule("final",
|
||||
"\"<|channel|>final\" ( \" \" " + constraint + " )? \"<|message|>\" " +
|
||||
builder.add_schema("response", schema)
|
||||
);
|
||||
|
||||
builder.add_rule("root", "( " + analysis + " \"<|start|>assistant\" )? " + final);
|
||||
});
|
||||
}
|
||||
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
@@ -2121,7 +2141,7 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
}
|
||||
|
||||
// GPT-OSS
|
||||
if (src.find("<|channel|>") != std::string::npos && params.json_schema.is_null()) {
|
||||
if (src.find("<|channel|>") != std::string::npos) {
|
||||
return common_chat_params_init_gpt_oss(tmpl, params);
|
||||
}
|
||||
|
||||
|
||||
@@ -1152,7 +1152,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
||||
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
cparams.attention_type = params.attention_type;
|
||||
cparams.defrag_thold = params.defrag_thold;
|
||||
cparams.cb_eval = params.cb_eval;
|
||||
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
||||
cparams.offload_kqv = !params.no_kv_offload;
|
||||
|
||||
@@ -288,7 +288,6 @@ struct common_params {
|
||||
float yarn_beta_fast = 32.0f; // YaRN low correction dim
|
||||
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
||||
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
||||
float defrag_thold = 0.1f; // KV cache defragmentation threshold
|
||||
|
||||
// offload params
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
|
||||
+71
-69
@@ -1216,6 +1216,55 @@ class TextModel(ModelBase):
|
||||
raise NotImplementedError("Only MEAN, CLS, and LAST pooling types supported")
|
||||
self.gguf_writer.add_pooling_type(pooling_type)
|
||||
|
||||
def _set_vocab_interns1(self):
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
|
||||
vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
|
||||
vocab_size = self.hparams.get("vocab_size", len(vocab))
|
||||
assert max(vocab.values()) < vocab_size
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
|
||||
added_tokens_decoder = tokenizer.added_tokens_decoder
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
tokens.append(f"[PAD{i}]")
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
else:
|
||||
token: str = reverse_vocab[i]
|
||||
if token in added_vocab:
|
||||
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
|
||||
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
|
||||
if not added_tokens_decoder[i].normalized:
|
||||
previous_token = token
|
||||
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
|
||||
if previous_token != token:
|
||||
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
|
||||
|
||||
if added_tokens_decoder[i].special or self.does_token_look_special(token):
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
tokens.append(token)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
special_vocab._set_special_token("bos", 151643)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
|
||||
class MmprojModel(ModelBase):
|
||||
model_type = ModelType.MMPROJ
|
||||
@@ -2932,7 +2981,8 @@ class Qwen2Model(TextModel):
|
||||
if "language_model." in name:
|
||||
name = name.replace("language_model.", "") # for InternVL
|
||||
if name.startswith("mlp") or name.startswith("multi_modal_projector") \
|
||||
or name.startswith("vision_model") or name.startswith("audio_tower"):
|
||||
or name.startswith("vision_model") or name.startswith("audio_tower") \
|
||||
or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):
|
||||
# skip vision and audio tensors
|
||||
return []
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
@@ -3109,7 +3159,7 @@ class LLaDAModel(TextModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Ernie4_5_ForCausalLM")
|
||||
@ModelBase.register("Ernie4_5_ForCausalLM", "Ernie4_5ForCausalLM")
|
||||
class Ernie4_5Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.ERNIE4_5
|
||||
|
||||
@@ -3604,6 +3654,19 @@ class Qwen2MoeModel(TextModel):
|
||||
class Qwen3Model(Qwen2Model):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
|
||||
self.origin_hf_arch = hparams.get('architectures', [None])[0]
|
||||
|
||||
def set_vocab(self):
|
||||
# deal with intern-s1-mini
|
||||
if self.origin_hf_arch == 'InternS1ForConditionalGeneration':
|
||||
self._set_vocab_interns1()
|
||||
return
|
||||
|
||||
super().set_vocab()
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3MoeForCausalLM")
|
||||
class Qwen3MoeModel(Qwen2MoeModel):
|
||||
@@ -3620,73 +3683,7 @@ class Qwen3MoeModel(Qwen2MoeModel):
|
||||
self._set_vocab_interns1()
|
||||
return
|
||||
|
||||
try:
|
||||
self._set_vocab_sentencepiece()
|
||||
except FileNotFoundError:
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def _set_vocab_interns1(self):
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
|
||||
vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
|
||||
vocab_size = self.hparams.get("vocab_size", len(vocab))
|
||||
assert max(vocab.values()) < vocab_size
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
|
||||
added_tokens_decoder = tokenizer.added_tokens_decoder
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
tokens.append(f"[PAD{i}]")
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
else:
|
||||
token: str = reverse_vocab[i]
|
||||
if token in added_vocab:
|
||||
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
|
||||
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
|
||||
if not added_tokens_decoder[i].normalized:
|
||||
previous_token = token
|
||||
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
|
||||
if previous_token != token:
|
||||
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
|
||||
|
||||
if added_tokens_decoder[i].special or self.does_token_look_special(token):
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
tokens.append(token)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
special_tokens_map_file = self.dir_model / 'special_tokens_map.json'
|
||||
additional_special_tokens = []
|
||||
if special_tokens_map_file.is_file():
|
||||
with open(special_tokens_map_file, encoding = 'utf-8') as f:
|
||||
additional_special_tokens = json.load(f).get('additional_special_tokens', [])
|
||||
tokenizer_cfg_file = self.dir_model / 'special_tokens_map.json'
|
||||
if tokenizer_cfg_file.is_file():
|
||||
with open(tokenizer_cfg_file, encoding = 'utf-8') as f:
|
||||
added_tokens_decoder = json.load(f).get('added_tokens_decoder', {})
|
||||
token2ids_map = {data['content'] : int(token) for token, data in added_tokens_decoder.items() if data['special']}
|
||||
for token in additional_special_tokens:
|
||||
if token in token2ids_map:
|
||||
special_vocab._set_special_token(token, token2ids_map[token])
|
||||
special_vocab._set_special_token('eos', 151645)
|
||||
special_vocab._set_special_token("bos", 151643)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
super().set_vocab()
|
||||
|
||||
|
||||
@ModelBase.register("GPT2LMHeadModel")
|
||||
@@ -5854,6 +5851,11 @@ class OlmoModel(TextModel):
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register("SeedOssForCausalLM")
|
||||
class SeedOssModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.SEED_OSS
|
||||
|
||||
|
||||
@ModelBase.register("Olmo2ForCausalLM")
|
||||
class Olmo2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.OLMO2
|
||||
|
||||
+4
-3
@@ -265,8 +265,9 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
|
||||
| BF16 | 🚫 | 🚫 | ❓ | ❓ |
|
||||
| Q4_0 | ✅ | ✅ | ❓ | ❓ |
|
||||
| Q4_1 | ✅ | ✅ | ❓ | ❓ |
|
||||
| Q5_0 | 🚫 | 🚫 | ❓ | ❓ |
|
||||
| Q5_1 | 🚫 | 🚫 | ❓ | ❓ |
|
||||
| MXFP4 | 🚫 | 🚫 | ❓ | ❓ |
|
||||
| Q5_0 | ✅ | ✅ | ❓ | ❓ |
|
||||
| Q5_1 | ✅ | ✅ | ❓ | ❓ |
|
||||
| Q8_0 | ✅ | ✅ | ❓ | ❓ |
|
||||
| Q2_K | 🚫 | 🚫 | ❓ | ❓ |
|
||||
| Q3_K | ✅ | ✅ | ❓ | ❓ |
|
||||
@@ -291,4 +292,4 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
|
||||
- 🚫 - acceleration unavailable, will still run using scalar implementation
|
||||
- ❓ - acceleration unknown, please contribute if you can test it yourself
|
||||
|
||||
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on July 31, 2025.
|
||||
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Aug 22, 2025.
|
||||
|
||||
@@ -6,7 +6,7 @@ Download [MiniCPM-V-4](https://huggingface.co/openbmb/MiniCPM-V-4) PyTorch model
|
||||
|
||||
|
||||
### Build llama.cpp
|
||||
Readme modification time: 20250206
|
||||
Readme modification time: 20250731
|
||||
|
||||
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md)
|
||||
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
## MiniCPM-V 4.5
|
||||
|
||||
### Prepare models and code
|
||||
|
||||
Download [MiniCPM-V-4_5](https://huggingface.co/openbmb/MiniCPM-V-4_5) PyTorch model from huggingface to "MiniCPM-V-4_5" folder.
|
||||
|
||||
|
||||
### Build llama.cpp
|
||||
Readme modification time: 20250826
|
||||
|
||||
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md)
|
||||
|
||||
Clone llama.cpp:
|
||||
```bash
|
||||
git clone https://github.com/ggerganov/llama.cpp
|
||||
cd llama.cpp
|
||||
```
|
||||
|
||||
Build llama.cpp using `CMake`:
|
||||
```bash
|
||||
cmake -B build
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
|
||||
### Usage of MiniCPM-V 4
|
||||
|
||||
Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-V-4_5-gguf) by us)
|
||||
|
||||
```bash
|
||||
python ./tools/mtmd/legacy-models/minicpmv-surgery.py -m ../MiniCPM-V-4_5
|
||||
python ./tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-V-4_5 --minicpmv-projector ../MiniCPM-V-4_5/minicpmv.projector --output-dir ../MiniCPM-V-4_5/ --minicpmv_version 6
|
||||
python ./convert_hf_to_gguf.py ../MiniCPM-V-4_5/model
|
||||
|
||||
# quantize int4 version
|
||||
./build/bin/llama-quantize ../MiniCPM-V-4_5/model/ggml-model-f16.gguf ../MiniCPM-V-4_5/model/ggml-model-Q4_K_M.gguf Q4_K_M
|
||||
```
|
||||
|
||||
|
||||
Inference on Linux or Mac
|
||||
```bash
|
||||
# run in single-turn mode
|
||||
./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4_5/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-4_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
|
||||
# run in conversation mode
|
||||
./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-4_5/mmproj-model-f16.gguf
|
||||
```
|
||||
+1
-1
@@ -17,7 +17,7 @@
|
||||
"
|
||||
" start the llama.cpp server with a FIM-compatible model. for example:
|
||||
"
|
||||
" $ llama-server -m {model.gguf} --port 8012 -ngl 99 -fa -dt 0.1 --ubatch-size 512 --batch-size 1024 --cache-reuse 256
|
||||
" $ llama-server -m {model.gguf} --port 8012 -ngl 99 -fa --ubatch-size 512 --batch-size 1024 --cache-reuse 256
|
||||
"
|
||||
" --batch-size [512, model max context]
|
||||
"
|
||||
|
||||
@@ -144,6 +144,15 @@ perplexity-run:
|
||||
hf-create-model:
|
||||
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}"
|
||||
|
||||
hf-create-model-dry-run:
|
||||
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -d
|
||||
|
||||
hf-create-model-embedding:
|
||||
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -e
|
||||
|
||||
hf-create-model-embedding-dry-run:
|
||||
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -e -d
|
||||
|
||||
hf-create-model-private:
|
||||
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -p
|
||||
|
||||
|
||||
@@ -285,13 +285,21 @@ For the following targets a `HF_TOKEN` environment variable is required.
|
||||
This will create a new model repsository on Hugging Face with the specified
|
||||
model name.
|
||||
```console
|
||||
(venv) $ make hf-create-model MODEL_NAME='TestModel' NAMESPACE="danbev"
|
||||
(venv) $ make hf-create-model MODEL_NAME='TestModel' NAMESPACE="danbev" ORIGINAL_BASE_MODEL="some-base-model"
|
||||
Repository ID: danbev/TestModel-GGUF
|
||||
Repository created: https://huggingface.co/danbev/TestModel-GGUF
|
||||
```
|
||||
Note that we append a `-GGUF` suffix to the model name to ensure a consistent
|
||||
naming convention for GGUF models.
|
||||
|
||||
An embedding model can be created using the following command:
|
||||
```console
|
||||
(venv) $ make hf-create-model-embedding MODEL_NAME='TestEmbeddingModel' NAMESPACE="danbev" ORIGINAL_BASE_MODEL="some-base-model"
|
||||
```
|
||||
The only difference is that the model card for an embedding model will be different
|
||||
with regards to the llama-server command and also how to access/call the embedding
|
||||
endpoint.
|
||||
|
||||
### Upload a GGUF model to model repository
|
||||
The following target uploads a model to an existing Hugging Face model repository.
|
||||
```console
|
||||
|
||||
@@ -112,6 +112,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_params.no_perf = false;
|
||||
if (embedding_mode) {
|
||||
ctx_params.embeddings = true;
|
||||
ctx_params.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
ctx_params.n_ubatch = ctx_params.n_batch;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
---
|
||||
base_model:
|
||||
- {base_model}
|
||||
---
|
||||
# {model_name} GGUF
|
||||
|
||||
Recommended way to run this model:
|
||||
|
||||
```sh
|
||||
llama-server -hf {namespace}/{model_name}-GGUF
|
||||
```
|
||||
|
||||
Then the endpoint can be accessed at http://localhost:8080/embedding, for
|
||||
example using `curl`:
|
||||
```console
|
||||
curl --request POST \
|
||||
--url http://localhost:8080/embedding \
|
||||
--header "Content-Type: application/json" \
|
||||
--data '{{"input": "Hello embeddings"}}' \
|
||||
--silent
|
||||
```
|
||||
|
||||
Alternatively, the `llama-embedding` command line tool can be used:
|
||||
```sh
|
||||
llama-embedding -hf {namespace}/{model_name}-GGUF --verbose-prompt -p "Hello embeddings"
|
||||
```
|
||||
|
||||
#### embd_normalize
|
||||
When a model uses pooling, or the pooling method is specified using `--pooling`,
|
||||
the normalization can be controlled by the `embd_normalize` parameter.
|
||||
|
||||
The default value is `2` which means that the embeddings are normalized using
|
||||
the Euclidean norm (L2). Other options are:
|
||||
* -1 No normalization
|
||||
* 0 Max absolute
|
||||
* 1 Taxicab
|
||||
* 2 Euclidean/L2
|
||||
* \>2 P-Norm
|
||||
|
||||
This can be passed in the request body to `llama-server`, for example:
|
||||
```sh
|
||||
--data '{{"input": "Hello embeddings", "embd_normalize": -1}}' \
|
||||
```
|
||||
|
||||
And for `llama-embedding`, by passing `--embd-normalize <value>`, for example:
|
||||
```sh
|
||||
llama-embedding -hf {namespace}/{model_name}-GGUF --embd-normalize -1 -p "Hello embeddings"
|
||||
```
|
||||
@@ -26,21 +26,31 @@ parser.add_argument('--namespace', '-ns', help='Namespace to add the model to',
|
||||
parser.add_argument('--org-base-model', '-b', help='Original Base model name', default="")
|
||||
parser.add_argument('--no-card', action='store_true', help='Skip creating model card')
|
||||
parser.add_argument('--private', '-p', action='store_true', help='Create private model')
|
||||
parser.add_argument('--embedding', '-e', action='store_true', help='Use embedding model card template')
|
||||
parser.add_argument('--dry-run', '-d', action='store_true', help='Print repository info and template without creating repository')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
repo_id = f"{args.namespace}/{args.model_name}-GGUF"
|
||||
print("Repository ID: ", repo_id)
|
||||
|
||||
repo_url = api.create_repo(
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
private=args.private,
|
||||
exist_ok=False
|
||||
)
|
||||
repo_url = None
|
||||
if not args.dry_run:
|
||||
repo_url = api.create_repo(
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
private=args.private,
|
||||
exist_ok=False
|
||||
)
|
||||
|
||||
if not args.no_card:
|
||||
template_path = "scripts/readme.md.template"
|
||||
if args.embedding:
|
||||
template_path = "scripts/embedding/modelcard.template"
|
||||
else:
|
||||
template_path = "scripts/causal/modelcard.template"
|
||||
|
||||
print("Template path: ", template_path)
|
||||
|
||||
model_card_content = load_template_and_substitute(
|
||||
template_path,
|
||||
model_name=args.model_name,
|
||||
@@ -48,16 +58,21 @@ if not args.no_card:
|
||||
base_model=args.org_base_model,
|
||||
)
|
||||
|
||||
if model_card_content:
|
||||
api.upload_file(
|
||||
path_or_fileobj=model_card_content.encode('utf-8'),
|
||||
path_in_repo="README.md",
|
||||
repo_id=repo_id
|
||||
)
|
||||
print("Model card created successfully.")
|
||||
if args.dry_run:
|
||||
print("\nTemplate Content:\n")
|
||||
print(model_card_content)
|
||||
else:
|
||||
print("Failed to create model card.")
|
||||
if model_card_content:
|
||||
api.upload_file(
|
||||
path_or_fileobj=model_card_content.encode('utf-8'),
|
||||
path_in_repo="README.md",
|
||||
repo_id=repo_id
|
||||
)
|
||||
print("Model card created successfully.")
|
||||
else:
|
||||
print("Failed to create model card.")
|
||||
|
||||
print(f"Repository created: {repo_url}")
|
||||
if not args.dry_run and repo_url:
|
||||
print(f"Repository created: {repo_url}")
|
||||
|
||||
|
||||
|
||||
@@ -512,6 +512,7 @@ extern "C" {
|
||||
GGML_OP_IM2COL,
|
||||
GGML_OP_IM2COL_BACK,
|
||||
GGML_OP_CONV_2D,
|
||||
GGML_OP_CONV_3D,
|
||||
GGML_OP_CONV_2D_DW,
|
||||
GGML_OP_CONV_TRANSPOSE_2D,
|
||||
GGML_OP_POOL_1D,
|
||||
@@ -1940,6 +1941,23 @@ extern "C" {
|
||||
int d0, // dilation dimension 0
|
||||
int d1); // dilation dimension 1
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_3d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC]
|
||||
struct ggml_tensor * b, // input [W, H, D, C * N]
|
||||
int s0, // stride
|
||||
int s1,
|
||||
int s2,
|
||||
int p0, // padding
|
||||
int p1,
|
||||
int p2,
|
||||
int d0, // dilation
|
||||
int d1,
|
||||
int d2,
|
||||
int n_channels,
|
||||
int n_batch,
|
||||
int n_channels_out);
|
||||
|
||||
enum ggml_op_pool {
|
||||
GGML_OP_POOL_MAX,
|
||||
GGML_OP_POOL_AVG,
|
||||
|
||||
@@ -1355,15 +1355,15 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
|
||||
std::vector<int32_t> ids;
|
||||
std::vector<ggml_bitset_t> used_ids;
|
||||
|
||||
for (int i = 0; i < sched->n_splits; i++) {
|
||||
struct ggml_backend_sched_split * split = &splits[i];
|
||||
for (int split_id = 0; split_id < sched->n_splits; split_id++) {
|
||||
struct ggml_backend_sched_split * split = &splits[split_id];
|
||||
int split_backend_id = split->backend_id;
|
||||
ggml_backend_t split_backend = sched->backends[split_backend_id];
|
||||
|
||||
// copy the input tensors to the split backend
|
||||
for (int j = 0; j < split->n_inputs; j++) {
|
||||
ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[j]);
|
||||
struct ggml_tensor * input = split->inputs[j];
|
||||
for (int input_id = 0; input_id < split->n_inputs; input_id++) {
|
||||
ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]);
|
||||
struct ggml_tensor * input = split->inputs[input_id];
|
||||
struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy);
|
||||
|
||||
if (input->flags & GGML_TENSOR_FLAG_INPUT) {
|
||||
@@ -1398,10 +1398,22 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
|
||||
|
||||
// get the ids
|
||||
ggml_tensor * ids_tensor = node->src[2];
|
||||
ggml_backend_t ids_backend = split_backend;
|
||||
|
||||
// if the ids tensor is also an input of the split, it may not have been copied yet to the split backend
|
||||
// in that case, we use the original ids tensor
|
||||
for (int i = input_id + 1; i < split->n_inputs; i++) {
|
||||
if (ids_tensor == tensor_copy(split->inputs[i], split_backend_id, sched->cur_copy)) {
|
||||
ids_tensor = split->inputs[i];
|
||||
ids_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[i]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (ids_tensor != prev_ids_tensor) {
|
||||
ids.resize(ggml_nbytes(ids_tensor) / sizeof(int32_t));
|
||||
ggml_backend_tensor_get_async(split_backend, ids_tensor, ids.data(), 0, ggml_nbytes(ids_tensor));
|
||||
ggml_backend_synchronize(split_backend);
|
||||
ggml_backend_tensor_get_async(ids_backend, ids_tensor, ids.data(), 0, ggml_nbytes(ids_tensor));
|
||||
ggml_backend_synchronize(ids_backend);
|
||||
|
||||
// find the used experts
|
||||
used_ids.clear();
|
||||
@@ -1409,6 +1421,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
|
||||
for (int64_t i1 = 0; i1 < ids_tensor->ne[1]; i1++) {
|
||||
for (int64_t i0 = 0; i0 < ids_tensor->ne[0]; i0++) {
|
||||
int32_t id = ids[i1 * ids_tensor->nb[1]/sizeof(int32_t) + i0 * ids_tensor->nb[0]/sizeof(int32_t)];
|
||||
GGML_ASSERT(id >= 0 && id < n_expert);
|
||||
ggml_bitset_set(used_ids.data(), id);
|
||||
}
|
||||
}
|
||||
|
||||
+234
-114
@@ -867,6 +867,86 @@ static aclTensor* aclnn_values(ggml_backend_cann_context& ctx, void* buffer,
|
||||
return acl_tensor;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Fills a tensor with a scalar value.
|
||||
*
|
||||
* This function fills the destination tensor `acl_dst` with the scalar value
|
||||
* `scalar`.
|
||||
*
|
||||
* @param ctx The context for the CANN backend operations.
|
||||
* @param scalar The scalar value used to fill the tensor.
|
||||
* @param acl_dst The destination tensor to be filled with the scalar value.
|
||||
*/
|
||||
static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
|
||||
aclTensor* acl_dst) {
|
||||
auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst, acl_scalar);
|
||||
ggml_cann_release_resources(ctx, acl_scalar);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get or expand a cached float32 tensor filled with a scalar value.
|
||||
*
|
||||
* This function manages cached device memory for float32 tensors. If the current
|
||||
* cache size is insufficient for the requested tensor shape, the old memory will
|
||||
* be released and new memory will be allocated. The allocated buffer is then
|
||||
* initialized either with zeros (when @p value == 0.0f) or with the given scalar
|
||||
* value using CANN operations. Finally, an aclTensor object is created from the
|
||||
* cached memory and returned.
|
||||
*
|
||||
* @param ctx The CANN backend context that manages device memory.
|
||||
* @param buffer A pointer to the cached device buffer (will be allocated
|
||||
* or reallocated if necessary).
|
||||
* @param cache_element The current number of cached elements. This will be
|
||||
* updated when the cache is expanded.
|
||||
* @param ne The tensor shape array (number of elements in each dimension).
|
||||
* @param nb The stride size for each dimension.
|
||||
* @param dims The number of tensor dimensions.
|
||||
* @param value The scalar value used to fill the tensor (supports zero
|
||||
* initialization via memset or arbitrary values via fill_scalar).
|
||||
* @return An aclTensor pointer created from the cached buffer.
|
||||
*/
|
||||
static aclTensor* get_f32_cache_acl_tensor(
|
||||
ggml_backend_cann_context& ctx,
|
||||
void** buffer,
|
||||
int64_t &cache_element,
|
||||
int64_t* ne,
|
||||
size_t* nb,
|
||||
int64_t dims,
|
||||
float value) {
|
||||
// Calculate total number of elements
|
||||
int64_t n_element = 1;
|
||||
for (int i = 0; i < dims; i++) {
|
||||
n_element *= ne[i];
|
||||
}
|
||||
size_t size = n_element * sizeof(float);
|
||||
|
||||
// Allocate or expand cache if needed
|
||||
if (cache_element < n_element) {
|
||||
if (*buffer != nullptr) {
|
||||
aclrtFree(*buffer);
|
||||
*buffer = nullptr;
|
||||
}
|
||||
|
||||
ACL_CHECK(aclrtMalloc(buffer, size, ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
cache_element = n_element;
|
||||
|
||||
// Initialize cache
|
||||
if (value == 0.0f) {
|
||||
ACL_CHECK(aclrtMemsetAsync(*buffer, size, 0, size, ctx.stream()));
|
||||
} else {
|
||||
int64_t pool_ne[1] = { n_element };
|
||||
size_t pool_nb[1] = { sizeof(float) };
|
||||
aclTensor* acl_value = ggml_cann_create_tensor(
|
||||
*buffer, ACL_FLOAT, sizeof(float), pool_ne, pool_nb, 1);
|
||||
aclnn_fill_scalar(ctx, 1, acl_value);
|
||||
ggml_cann_release_resources(ctx, acl_value);
|
||||
}
|
||||
}
|
||||
|
||||
return ggml_cann_create_tensor(*buffer, ACL_FLOAT, sizeof(float), ne, nb, dims);
|
||||
}
|
||||
|
||||
void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
ggml_tensor* src = dst->src[0];
|
||||
|
||||
@@ -875,20 +955,39 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
size_t one_tensor_n_bytes = src->ne[0] * ggml_element_size(src);
|
||||
ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes);
|
||||
|
||||
aclTensor* acl_gamma = aclnn_values(
|
||||
ctx, one_tensor_allocator.get(), one_tensor_n_bytes, src->ne, 1,
|
||||
ggml_cann_type_mapping(src->type), ggml_element_size(src));
|
||||
// build gamma, one...
|
||||
size_t acl_gamma_nb[GGML_MAX_DIMS];
|
||||
acl_gamma_nb[0] = sizeof(float);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1];
|
||||
}
|
||||
aclTensor* acl_gamma = get_f32_cache_acl_tensor(
|
||||
ctx,
|
||||
&ctx.f32_one_cache,
|
||||
ctx.f32_one_cache_element,
|
||||
src->ne,
|
||||
acl_gamma_nb,
|
||||
1, // dims
|
||||
1.0f // value
|
||||
);
|
||||
|
||||
// build rstd, zero...
|
||||
size_t acl_rstd_nb[GGML_MAX_DIMS];
|
||||
acl_rstd_nb[0] = sizeof(float);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
acl_rstd_nb[i] = acl_rstd_nb[i - 1] * src->ne[i - 1];
|
||||
}
|
||||
aclTensor* acl_rstd = get_f32_cache_acl_tensor(
|
||||
ctx,
|
||||
&ctx.f32_zero_cache,
|
||||
ctx.f32_zero_cache_element,
|
||||
src->ne,
|
||||
acl_rstd_nb,
|
||||
GGML_MAX_DIMS,
|
||||
0.0f // value
|
||||
);
|
||||
|
||||
size_t zero_tensor_n_bytes =
|
||||
src->ne[1] * src->ne[2] * src->ne[3] * ggml_element_size(src);
|
||||
ggml_cann_pool_alloc zero_tensor_allocator(ctx.pool(), zero_tensor_n_bytes);
|
||||
aclTensor* acl_rstd =
|
||||
aclnn_zero(ctx, zero_tensor_allocator.get(), zero_tensor_n_bytes,
|
||||
src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
|
||||
ggml_element_size(src));
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd);
|
||||
ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_gamma, acl_rstd);
|
||||
}
|
||||
@@ -903,14 +1002,13 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
|
||||
const int n_past = ((int32_t*)dst->op_params)[0];
|
||||
|
||||
size_t one_tensor_n_bytes = src->ne[0] * src->ne[1] * src->ne[2] *
|
||||
src->ne[3] * ggml_element_size(src);
|
||||
ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes);
|
||||
ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), ggml_nbytes(src));
|
||||
void* buffer = one_tensor_allocator.get();
|
||||
|
||||
aclTensor* mask_tensor =
|
||||
aclnn_values(ctx, one_tensor_allocator.get(), one_tensor_n_bytes,
|
||||
src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
|
||||
ggml_element_size(src), value);
|
||||
aclTensor* mask_tensor = ggml_cann_create_tensor(buffer, ggml_cann_type_mapping(src->type),
|
||||
ggml_type_size(src->type), src->ne, src->nb, GGML_MAX_DIMS);
|
||||
|
||||
aclnn_fill_scalar(ctx, value, mask_tensor);
|
||||
|
||||
aclScalar* alpha = nullptr;
|
||||
float alphaValue = 1.0f;
|
||||
@@ -1159,12 +1257,20 @@ static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) {
|
||||
|
||||
void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
||||
aclTensor* acl_dst) {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
|
||||
if(acl_dst == nullptr) {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCos, acl_src);
|
||||
} else {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
|
||||
}
|
||||
}
|
||||
|
||||
void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
||||
aclTensor* acl_dst) {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
|
||||
if(acl_dst == nullptr) {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSin, acl_src);
|
||||
} else {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
|
||||
@@ -1277,23 +1383,6 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
|
||||
tmp_permute_tensor, tmp_mul_tensor, acl_dst);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Fills a tensor with a scalar value.
|
||||
*
|
||||
* This function fills the destination tensor `acl_dst` with the scalar value
|
||||
* `scalar`.
|
||||
*
|
||||
* @param ctx The context for the CANN backend operations.
|
||||
* @param scalar The scalar value used to fill the tensor.
|
||||
* @param acl_dst The destination tensor to be filled with the scalar value.
|
||||
*/
|
||||
static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
|
||||
aclTensor* acl_dst) {
|
||||
auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst, acl_scalar);
|
||||
ggml_cann_release_resources(ctx, acl_scalar);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Raises each element of a tensor to the power of the corresponding
|
||||
* element in another tensor.
|
||||
@@ -2140,13 +2229,54 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
|
||||
ggml_cann_release_resources(ctx, acl_index, acl_value);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Initializes and caches sine/cosine positional encoding values
|
||||
* (used in RoPE, Rotary Position Embedding) for attention layers.
|
||||
*
|
||||
* This function computes and caches the sin/cos values of
|
||||
* θ = position * theta_scale for RoPE encoding. The cache is shared
|
||||
* across attention layers, and only the first attention layer will
|
||||
* trigger initialization. The cache includes repeated sin/cos values
|
||||
* with different repeat methods depending on the @param is_neox flag.
|
||||
*
|
||||
* Steps performed by this function:
|
||||
* 1. Identify whether the target tensor belongs to Q/K in attention
|
||||
* and restrict computation to the first layer only.
|
||||
* 2. Initialize the theta scale array (arange → power → freq scaling).
|
||||
* 3. Allocate sin/cos caches if the max prompt length increases.
|
||||
* 4. Compute θ = position * theta_scale.
|
||||
* 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
|
||||
* 6. Expand sin/cos values by repeat or repeat_interleave depending
|
||||
* on whether @param is_neox is enabled.
|
||||
* 7. Store the computed values into persistent buffers
|
||||
* (ctx.rope_sin_ptr / ctx.rope_cos_ptr).
|
||||
*
|
||||
* @param ctx The CANN backend context, holding memory pool,
|
||||
* stream, and persistent buffers for rope init/cache.
|
||||
* @param dst The destination ggml_tensor whose computation
|
||||
* depends on the cached RoPE values (usually Qcur/Kcur).
|
||||
* @param theta_scale Scalar exponent base for computing theta scale values.
|
||||
* @param freq_scale Frequency scaling factor, applied to theta scale.
|
||||
* @param attn_factor Attention scaling factor, applied to sin/cos.
|
||||
* @param is_neox Whether to use Neox-style repeat strategy
|
||||
* (dim expansion vs repeat_interleave).
|
||||
*/
|
||||
static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
aclTensor* acl_cos_repeat_tensor,
|
||||
aclTensor* acl_sin_repeat_tensor,
|
||||
float theta_scale, float freq_scale,
|
||||
float attn_factor, bool is_neox) {
|
||||
// int sin/cos cache, cache has different repeat method depond on
|
||||
// @param.is_neox
|
||||
bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
|
||||
bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
|
||||
|
||||
// used for accuracy testing
|
||||
bool is_attention = is_q || is_k;
|
||||
|
||||
// just compute in first layer in attention
|
||||
bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
|
||||
if(is_attention && !is_fisrt_layer) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_tensor* src0 = dst->src[0]; // input
|
||||
ggml_tensor* src1 = dst->src[1]; // position
|
||||
@@ -2172,21 +2302,16 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
|
||||
}
|
||||
|
||||
bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
|
||||
bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
|
||||
|
||||
// used for accuracy testing
|
||||
bool is_attention = is_q || is_k;
|
||||
|
||||
if(ctx.init_ptr == nullptr || !is_attention) {
|
||||
// init theta scale, just one time
|
||||
if(ctx.rope_init_ptr == nullptr || !is_attention) {
|
||||
// theta_scale arange, [0,1,...,ne00/2 - 1]
|
||||
if(ctx.init_ptr != nullptr){
|
||||
ACL_CHECK(aclrtFree(ctx.init_ptr));
|
||||
if(ctx.rope_init_ptr != nullptr){
|
||||
ACL_CHECK(aclrtFree(ctx.rope_init_ptr));
|
||||
}
|
||||
ACL_CHECK(aclrtMalloc(&ctx.init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
ACL_CHECK(aclrtMalloc(&ctx.rope_init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
|
||||
aclTensor* acl_theta_scale_tensor =
|
||||
ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
|
||||
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
|
||||
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
|
||||
float start = 0;
|
||||
float step = 1;
|
||||
@@ -2216,67 +2341,55 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
ggml_cann_release_resources(ctx, acl_theta_scale_tensor,acl_theta_scale);
|
||||
}
|
||||
|
||||
if(ctx.sin_ptr == nullptr) {
|
||||
int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
|
||||
ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
}
|
||||
// init sin_repeat && cos_repeat, one token just init in 0 layer
|
||||
if(position_length > ctx.max_prompt_length) {
|
||||
ctx.max_prompt_length = position_length;
|
||||
int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
|
||||
ACL_CHECK(aclrtFree(ctx.sin_ptr));
|
||||
ACL_CHECK(aclrtFree(ctx.cos_ptr));
|
||||
ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
int64_t repeat_theta_length = theta_scale_length * ctx.max_prompt_length * 2;
|
||||
if(ctx.rope_sin_ptr != nullptr) {
|
||||
ACL_CHECK(aclrtFree(ctx.rope_sin_ptr));
|
||||
ACL_CHECK(aclrtFree(ctx.rope_cos_ptr));
|
||||
}
|
||||
ACL_CHECK(aclrtMalloc(&ctx.rope_sin_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
ACL_CHECK(aclrtMalloc(&ctx.rope_cos_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
}
|
||||
|
||||
bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
|
||||
|
||||
if(is_fisrt_layer || !is_attention) {
|
||||
|
||||
aclTensor* acl_theta_scale_tensor =
|
||||
ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
|
||||
aclTensor* acl_theta_scale_tensor =
|
||||
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
|
||||
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
|
||||
|
||||
// position
|
||||
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
|
||||
src1->data, ggml_cann_type_mapping(src1->type),
|
||||
ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
|
||||
// position
|
||||
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
|
||||
src1->data, ggml_cann_type_mapping(src1->type),
|
||||
ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
|
||||
|
||||
// power * position
|
||||
int64_t theta_length = theta_scale_length * position_length;
|
||||
ggml_cann_pool_alloc theta_allocator(ctx.pool(),
|
||||
theta_length * sizeof(float_t));
|
||||
void* theta_buffer = theta_allocator.get();
|
||||
// power * position
|
||||
int64_t theta_length = theta_scale_length * position_length;
|
||||
ggml_cann_pool_alloc theta_allocator(ctx.pool(),
|
||||
theta_length * sizeof(float_t));
|
||||
void* theta_buffer = theta_allocator.get();
|
||||
|
||||
aclTensor* acl_theta_tensor =
|
||||
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
|
||||
theta_ne, theta_nb, GGML_MAX_DIMS);
|
||||
aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
|
||||
acl_theta_tensor);
|
||||
|
||||
// sin/cos
|
||||
aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
|
||||
ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
|
||||
GGML_MAX_DIMS, ACL_FORMAT_ND);
|
||||
aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
|
||||
|
||||
aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
|
||||
ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
|
||||
GGML_MAX_DIMS, ACL_FORMAT_ND);
|
||||
aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
|
||||
|
||||
// release
|
||||
ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
|
||||
acl_theta_tensor, acl_sin_tensor, acl_cos_tensor);
|
||||
}
|
||||
aclTensor* acl_theta_tensor =
|
||||
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
|
||||
theta_ne, theta_nb, GGML_MAX_DIMS);
|
||||
aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
|
||||
acl_theta_tensor);
|
||||
|
||||
// sin/cos
|
||||
ggml_cann_pool_alloc sin_allocator(ctx.pool(),
|
||||
theta_length * sizeof(float_t));
|
||||
void* sin_buffer = sin_allocator.get();
|
||||
aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
|
||||
ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
|
||||
GGML_MAX_DIMS, ACL_FORMAT_ND);
|
||||
sin_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
|
||||
GGML_MAX_DIMS, ACL_FORMAT_ND);
|
||||
aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
|
||||
|
||||
ggml_cann_pool_alloc cos_allocator(ctx.pool(),
|
||||
theta_length * sizeof(float_t));
|
||||
void* cos_buffer = cos_allocator.get();
|
||||
aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
|
||||
ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
|
||||
GGML_MAX_DIMS, ACL_FORMAT_ND);
|
||||
cos_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
|
||||
GGML_MAX_DIMS, ACL_FORMAT_ND);
|
||||
aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
|
||||
|
||||
// attn_factor
|
||||
if (attn_factor != 1) {
|
||||
@@ -2284,6 +2397,19 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true);
|
||||
}
|
||||
|
||||
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
|
||||
size_t sin_reshape_nb[GGML_MAX_DIMS];
|
||||
sin_reshape_nb[0] = sizeof(float_t);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
|
||||
}
|
||||
aclTensor* acl_sin_repeat_tensor =
|
||||
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
|
||||
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
||||
aclTensor* acl_cos_repeat_tensor =
|
||||
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
|
||||
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
||||
|
||||
// repeat
|
||||
if (is_neox) {
|
||||
int64_t repeatsArray[] = {1, 1, 1, 2};
|
||||
@@ -2299,8 +2425,9 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
num_repeats, output_size);
|
||||
}
|
||||
|
||||
// release
|
||||
ggml_cann_release_resources(ctx, acl_sin_tensor, acl_cos_tensor);
|
||||
ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
|
||||
acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor,
|
||||
acl_cos_repeat_tensor);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
@@ -2354,13 +2481,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
|
||||
// init cos/sin cache
|
||||
ggml_cann_pool_alloc sin_allocator(
|
||||
ctx.pool(), ne00 * ne02 * sizeof(float_t));
|
||||
ggml_cann_pool_alloc cos_allocator(
|
||||
ctx.pool(), ne00 * ne02 * sizeof(float_t));
|
||||
void* sin_buffer = sin_allocator.get();
|
||||
void* cos_buffer = cos_allocator.get();
|
||||
// init ctx.rope_cos/rope_sin cache
|
||||
aclnn_cache_init(ctx, dst, theta_scale, freq_scale, attn_factor, is_neox);
|
||||
|
||||
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
|
||||
size_t sin_reshape_nb[GGML_MAX_DIMS];
|
||||
@@ -2369,13 +2491,11 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
|
||||
}
|
||||
aclTensor* acl_sin_reshape_tensor =
|
||||
ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float_t),
|
||||
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
|
||||
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
||||
aclTensor* acl_cos_reshape_tensor =
|
||||
ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t),
|
||||
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
|
||||
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
||||
aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor,
|
||||
theta_scale, freq_scale, attn_factor, is_neox);
|
||||
|
||||
aclTensor* acl_src = ggml_cann_create_tensor(src0);
|
||||
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
|
||||
|
||||
+22
-10
@@ -368,10 +368,6 @@ struct ggml_backend_cann_context {
|
||||
std::string name; /**< Name of the device. */
|
||||
std::string description; /**< Description of the device. */
|
||||
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
|
||||
void* init_ptr = nullptr;
|
||||
void* sin_ptr = nullptr;
|
||||
void* cos_ptr = nullptr;
|
||||
int64_t max_prompt_length = 65536;
|
||||
#ifdef USE_ACL_GRAPH
|
||||
/// Cached CANN ACL graph used for executing the current ggml computation graph.
|
||||
std::unique_ptr<ggml_cann_graph> cann_graph;
|
||||
@@ -379,6 +375,16 @@ struct ggml_backend_cann_context {
|
||||
cann_task_queue task_queue;
|
||||
bool async_mode;
|
||||
bool support_set_rows;
|
||||
// Rope Cache
|
||||
void* rope_init_ptr = nullptr;
|
||||
void* rope_sin_ptr = nullptr;
|
||||
void* rope_cos_ptr = nullptr;
|
||||
int64_t max_prompt_length = 0;
|
||||
// Constant Pool
|
||||
void* f32_zero_cache = nullptr;
|
||||
void* f32_one_cache = nullptr;
|
||||
int64_t f32_zero_cache_element = 0;
|
||||
int64_t f32_one_cache_element = 0;
|
||||
|
||||
aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
|
||||
|
||||
@@ -418,14 +424,20 @@ struct ggml_backend_cann_context {
|
||||
ACL_CHECK(aclrtDestroyStream(streams[i]));
|
||||
}
|
||||
}
|
||||
if(init_ptr != nullptr) {
|
||||
ACL_CHECK(aclrtFree(init_ptr));
|
||||
if(rope_init_ptr != nullptr) {
|
||||
ACL_CHECK(aclrtFree(rope_init_ptr));
|
||||
}
|
||||
if(sin_ptr != nullptr) {
|
||||
ACL_CHECK(aclrtFree(sin_ptr));
|
||||
if(rope_sin_ptr != nullptr) {
|
||||
ACL_CHECK(aclrtFree(rope_sin_ptr));
|
||||
}
|
||||
if(cos_ptr != nullptr) {
|
||||
ACL_CHECK(aclrtFree(cos_ptr));
|
||||
if(rope_cos_ptr != nullptr) {
|
||||
ACL_CHECK(aclrtFree(rope_cos_ptr));
|
||||
}
|
||||
if(f32_zero_cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(f32_zero_cache));
|
||||
}
|
||||
if(f32_one_cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(f32_one_cache));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -150,8 +150,6 @@
|
||||
#elif defined(__s390x__)
|
||||
// quants.c
|
||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||
#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
|
||||
#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
|
||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
|
||||
|
||||
@@ -23,6 +23,27 @@
|
||||
|
||||
#define UNUSED GGML_UNUSED
|
||||
|
||||
#if defined(__VXE__) || defined(__VXE2__)
|
||||
#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
|
||||
#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
|
||||
#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
|
||||
#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
|
||||
#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
|
||||
#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
|
||||
#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
|
||||
#define B8(c,s ) B7(c,s, c), B7(c,s, s)
|
||||
|
||||
// precomputed tables for expanding 8bits to 8 bytes:
|
||||
static const __attribute__((aligned(16))) uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b ) << 4
|
||||
static const __attribute__((aligned(16))) uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
|
||||
|
||||
// permute mask for byteswapping
|
||||
static const uint8x16_t v_kperm = (const uint8x16_t){
|
||||
7, 6, 5, 4, 3, 2, 1, 0,
|
||||
15, 14, 13, 12, 11, 10, 9, 8
|
||||
};
|
||||
#endif
|
||||
|
||||
void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
||||
assert(QK8_0 == 32);
|
||||
assert(k % QK8_0 == 0);
|
||||
@@ -241,6 +262,301 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(qk == QK5_0);
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_q5_0 * GGML_RESTRICT x = vx;
|
||||
const block_q8_0 * GGML_RESTRICT y = vy;
|
||||
|
||||
int ib = 0;
|
||||
float sumf = 0.0f;
|
||||
|
||||
#if defined(__VXE__) || defined(__VXE2__)
|
||||
float32x4_t v_sum0 = vec_splats(0.0f);
|
||||
float32x4_t v_sum1 = vec_splats(0.0f);
|
||||
|
||||
uint32_t qh0, qh1;
|
||||
uint64_t tmp0[4], tmp1[4];
|
||||
|
||||
const uint8x16_t v_m = vec_splats((uint8_t)0x0F);
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (; ib + 1 < nb; ib += 2) {
|
||||
const block_q5_0 * GGML_RESTRICT x0 = &x[ib + 0];
|
||||
const block_q5_0 * GGML_RESTRICT x1 = &x[ib + 1];
|
||||
const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
|
||||
const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
|
||||
|
||||
memcpy(&qh0, x0->qh, sizeof(qh0));
|
||||
memcpy(&qh1, x1->qh, sizeof(qh1));
|
||||
|
||||
tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF];
|
||||
tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF];
|
||||
tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
|
||||
tmp0[3] = table_b2b_1[(qh0 >> 24) ];
|
||||
|
||||
tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF];
|
||||
tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF];
|
||||
tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
|
||||
tmp1[3] = table_b2b_1[(qh1 >> 24) ];
|
||||
|
||||
int8x16_t v_qh0l = vec_xl(0, (const int8_t *)(tmp0 + 0));
|
||||
int8x16_t v_qh0h = vec_xl(0, (const int8_t *)(tmp0 + 2));
|
||||
int8x16_t v_qh1l = vec_xl(0, (const int8_t *)(tmp1 + 0));
|
||||
int8x16_t v_qh1h = vec_xl(0, (const int8_t *)(tmp1 + 2));
|
||||
|
||||
// required for fixing the byteorder
|
||||
v_qh0l = vec_perm(v_qh0l, v_qh0l, v_kperm);
|
||||
v_qh0h = vec_perm(v_qh0h, v_qh0h, v_kperm);
|
||||
v_qh1l = vec_perm(v_qh1l, v_qh1l, v_kperm);
|
||||
v_qh1h = vec_perm(v_qh1h, v_qh1h, v_kperm);
|
||||
|
||||
const uint8x16_t v_x0 = vec_xl(0, (const uint8_t *)x0->qs);
|
||||
const uint8x16_t v_x1 = vec_xl(0, (const uint8_t *)x1->qs);
|
||||
|
||||
int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);
|
||||
int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);
|
||||
int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);
|
||||
int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);
|
||||
|
||||
const int8x16_t v_x0lf = vec_sub(v_x0l, v_qh0l);
|
||||
const int8x16_t v_x0hf = vec_sub(v_x0h, v_qh0h);
|
||||
const int8x16_t v_x1lf = vec_sub(v_x1l, v_qh1l);
|
||||
const int8x16_t v_x1hf = vec_sub(v_x1h, v_qh1h);
|
||||
|
||||
const int8x16_t v_y0l = vec_xl(0, (const int8_t *)y0->qs);
|
||||
const int8x16_t v_y0h = vec_xl(QK8_0/2, (const int8_t *)y0->qs);
|
||||
const int8x16_t v_y1l = vec_xl(0, (const int8_t *)y1->qs);
|
||||
const int8x16_t v_y1h = vec_xl(QK8_0/2, (const int8_t *)y1->qs);
|
||||
|
||||
const int32x4_t v_xy0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0lf, v_y0l), v_x0hf, v_y0h);
|
||||
const int32x4_t v_xy1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1lf, v_y1l), v_x1hf, v_y1h);
|
||||
|
||||
const float32x4_t v_xy0f = vec_float(v_xy0);
|
||||
const float32x4_t v_xy1f = vec_float(v_xy1);
|
||||
|
||||
const float32x4_t v_d0 = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));
|
||||
const float32x4_t v_d1 = vec_splats(GGML_CPU_FP16_TO_FP32(x1->d) * GGML_CPU_FP16_TO_FP32(y1->d));
|
||||
|
||||
v_sum0 = vec_madd(v_xy0f, v_d0, v_sum0);
|
||||
v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1);
|
||||
}
|
||||
|
||||
sumf += vec_hsum(v_sum0) + vec_hsum(v_sum1);
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (; ib < nb; ++ib) {
|
||||
const block_q5_0 * GGML_RESTRICT x0 = &x[ib];
|
||||
const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
|
||||
|
||||
uint32_t qh;
|
||||
memcpy(&qh, x0->qh, sizeof(qh));
|
||||
|
||||
uint64_t tmp[4];
|
||||
tmp[0] = table_b2b_1[(qh >> 0) & 0xFF];
|
||||
tmp[1] = table_b2b_1[(qh >> 8) & 0xFF];
|
||||
tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
|
||||
tmp[3] = table_b2b_1[(qh >> 24) ];
|
||||
|
||||
int8x16_t v_qhl = vec_xl(0, (const int8_t *)(tmp + 0));
|
||||
int8x16_t v_qhh = vec_xl(0, (const int8_t *)(tmp + 2));
|
||||
|
||||
// required for fixing the byteorder
|
||||
v_qhl = vec_perm(v_qhl, v_qhl, v_kperm);
|
||||
v_qhh = vec_perm(v_qhh, v_qhh, v_kperm);
|
||||
|
||||
const uint8x16_t v_x = vec_xl(0, (const uint8_t *)x0->qs);
|
||||
int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);
|
||||
int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);
|
||||
|
||||
const int8x16_t v_xlf = vec_sub(v_xl, v_qhl);
|
||||
const int8x16_t v_xhf = vec_sub(v_xh, v_qhh);
|
||||
|
||||
const int8x16_t v_yl = vec_xl(0, (const int8_t *)y0->qs);
|
||||
const int8x16_t v_yh = vec_xl(QK8_0/2, (const int8_t *)y0->qs);
|
||||
|
||||
const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xlf, v_yl), v_xhf, v_yh);
|
||||
const float32x4_t v_xyf = vec_float(v_xy);
|
||||
|
||||
const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));
|
||||
const float32x4_t v_acc = vec_madd(v_xyf, v_d, vec_splats(0.0f));
|
||||
|
||||
sumf += vec_hsum(v_acc);
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
#else
|
||||
UNUSED(nb);
|
||||
UNUSED(x);
|
||||
UNUSED(y);
|
||||
UNUSED(ib);
|
||||
UNUSED(sumf);
|
||||
ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
const int qk = QK8_1;
|
||||
const int nb = n / qk;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(qk == QK5_1);
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_q5_1 * GGML_RESTRICT x = vx;
|
||||
const block_q8_1 * GGML_RESTRICT y = vy;
|
||||
|
||||
int ib = 0;
|
||||
float sumf = 0.0f;
|
||||
|
||||
#if defined(__VXE__) || defined(__VXE2__)
|
||||
float32x4_t v_sum0 = vec_splats(0.0f);
|
||||
float32x4_t v_sum1 = vec_splats(0.0f);
|
||||
|
||||
float summs0 = 0.0f;
|
||||
float summs1 = 0.0f;
|
||||
|
||||
uint32_t qh0;
|
||||
uint32_t qh1;
|
||||
|
||||
uint64_t tmp0[4];
|
||||
uint64_t tmp1[4];
|
||||
|
||||
const uint8x16_t v_m = vec_splats((uint8_t)0x0F);
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (; ib + 1 < nb; ib += 2) {
|
||||
const block_q5_1 * GGML_RESTRICT x0 = &x[ib + 0];
|
||||
const block_q5_1 * GGML_RESTRICT x1 = &x[ib + 1];
|
||||
const block_q8_1 * GGML_RESTRICT y0 = &y[ib + 0];
|
||||
const block_q8_1 * GGML_RESTRICT y1 = &y[ib + 1];
|
||||
|
||||
summs0 += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s);
|
||||
summs1 += GGML_CPU_FP16_TO_FP32(x1->m) * GGML_CPU_FP16_TO_FP32(y1->s);
|
||||
|
||||
memcpy(&qh0, x0->qh, sizeof(qh0));
|
||||
memcpy(&qh1, x1->qh, sizeof(qh1));
|
||||
|
||||
tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF];
|
||||
tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF];
|
||||
tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
|
||||
tmp0[3] = table_b2b_0[(qh0 >> 24) ];
|
||||
|
||||
tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF];
|
||||
tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF];
|
||||
tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
|
||||
tmp1[3] = table_b2b_0[(qh1 >> 24) ];
|
||||
|
||||
int8x16_t v_qh0l = vec_xl(0, (const int8_t *)(tmp0 + 0));
|
||||
int8x16_t v_qh0h = vec_xl(0, (const int8_t *)(tmp0 + 2));
|
||||
int8x16_t v_qh1l = vec_xl(0, (const int8_t *)(tmp1 + 0));
|
||||
int8x16_t v_qh1h = vec_xl(0, (const int8_t *)(tmp1 + 2));
|
||||
|
||||
// required for fixing the byteorder
|
||||
v_qh0l = vec_perm(v_qh0l, v_qh0l, v_kperm);
|
||||
v_qh0h = vec_perm(v_qh0h, v_qh0h, v_kperm);
|
||||
v_qh1l = vec_perm(v_qh1l, v_qh1l, v_kperm);
|
||||
v_qh1h = vec_perm(v_qh1h, v_qh1h, v_kperm);
|
||||
|
||||
const uint8x16_t v_x0 = vec_xl(0, x0->qs);
|
||||
const uint8x16_t v_x1 = vec_xl(0, x1->qs);
|
||||
|
||||
const int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);
|
||||
const int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);
|
||||
const int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);
|
||||
const int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);
|
||||
|
||||
const int8x16_t v_x0lf = vec_or(v_x0l, v_qh0l);
|
||||
const int8x16_t v_x0hf = vec_or(v_x0h, v_qh0h);
|
||||
const int8x16_t v_x1lf = vec_or(v_x1l, v_qh1l);
|
||||
const int8x16_t v_x1hf = vec_or(v_x1h, v_qh1h);
|
||||
|
||||
const int8x16_t v_y0l = vec_xl(0 , y0->qs);
|
||||
const int8x16_t v_y0h = vec_xl(QK8_1/2, y0->qs);
|
||||
const int8x16_t v_y1l = vec_xl(0 , y1->qs);
|
||||
const int8x16_t v_y1h = vec_xl(QK8_1/2, y1->qs);
|
||||
|
||||
const int32x4_t v_xy0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0lf, v_y0l), v_x0hf, v_y0h);
|
||||
const int32x4_t v_xy1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1lf, v_y1l), v_x1hf, v_y1h);
|
||||
|
||||
const float32x4_t v_xy0f = vec_float(v_xy0);
|
||||
const float32x4_t v_xy1f = vec_float(v_xy1);
|
||||
|
||||
const float32x4_t v_d0 = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));
|
||||
const float32x4_t v_d1 = vec_splats(GGML_CPU_FP16_TO_FP32(x1->d) * GGML_CPU_FP16_TO_FP32(y1->d));
|
||||
|
||||
v_sum0 = vec_madd(v_xy0f, v_d0, v_sum0);
|
||||
v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1);
|
||||
}
|
||||
|
||||
sumf += vec_hsum(v_sum0) + vec_hsum(v_sum1) + summs0 + summs1;
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (; ib < nb; ++ib) {
|
||||
const block_q5_1 * GGML_RESTRICT x0 = &x[ib];
|
||||
const block_q8_1 * GGML_RESTRICT y0 = &y[ib];
|
||||
|
||||
float summs = GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s);
|
||||
|
||||
uint32_t qh;
|
||||
memcpy(&qh, x0->qh, sizeof(qh));
|
||||
|
||||
uint64_t tmp[4];
|
||||
tmp[0] = table_b2b_0[(qh >> 0) & 0xFF];
|
||||
tmp[1] = table_b2b_0[(qh >> 8) & 0xFF];
|
||||
tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
|
||||
tmp[3] = table_b2b_0[(qh >> 24) ];
|
||||
|
||||
int8x16_t v_qhl = vec_xl(0, (const int8_t *)(tmp + 0));
|
||||
int8x16_t v_qhh = vec_xl(0, (const int8_t *)(tmp + 2));
|
||||
|
||||
// required for fixing the byteorder
|
||||
v_qhl = vec_perm(v_qhl, v_qhl, v_kperm);
|
||||
v_qhh = vec_perm(v_qhh, v_qhh, v_kperm);
|
||||
|
||||
const uint8x16_t v_x = vec_xl(0, x0->qs);
|
||||
const int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);
|
||||
const int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);
|
||||
|
||||
const int8x16_t v_xlf = vec_or(v_xl, v_qhl);
|
||||
const int8x16_t v_xhf = vec_or(v_xh, v_qhh);
|
||||
|
||||
const int8x16_t v_yl = vec_xl(0 , y0->qs);
|
||||
const int8x16_t v_yh = vec_xl(QK8_1/2, y0->qs);
|
||||
|
||||
const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xlf, v_yl), v_xhf, v_yh);
|
||||
const float32x4_t v_xyf = vec_float(v_xy);
|
||||
|
||||
const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));
|
||||
const float32x4_t v_acc = vec_madd(v_xyf, v_d, v_acc);
|
||||
|
||||
sumf += vec_hsum(v_acc) + summs;
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
#else
|
||||
UNUSED(nb);
|
||||
UNUSED(x);
|
||||
UNUSED(y);
|
||||
UNUSED(ib);
|
||||
UNUSED(sumf);
|
||||
ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
||||
@@ -486,6 +486,14 @@ inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) {
|
||||
return v_abo + v_abe;
|
||||
}
|
||||
|
||||
/**
|
||||
* @see https://github.com/ggml-org/llama.cpp/pull/14037
|
||||
*/
|
||||
inline float vec_hsum(float32x4_t v) {
|
||||
float32x4_t v_temp = v + vec_reve(v);
|
||||
return v_temp[0] + v_temp[1];
|
||||
}
|
||||
|
||||
inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
||||
const int16x8_t p = vec_mule(a, b) + vec_mulo(a, b);
|
||||
return acc + (vec_unpackh(p) + vec_unpackl(p));
|
||||
|
||||
@@ -1880,6 +1880,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_conv_2d(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_CONV_3D:
|
||||
{
|
||||
ggml_compute_forward_conv_3d(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
{
|
||||
ggml_compute_forward_conv_2d_dw(params, tensor);
|
||||
@@ -2252,6 +2256,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_IM2COL_BACK:
|
||||
case GGML_OP_CONV_2D:
|
||||
case GGML_OP_CONV_3D:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
@@ -2773,6 +2778,7 @@ struct ggml_cplan ggml_graph_plan(
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_CONV_2D:
|
||||
case GGML_OP_CONV_3D:
|
||||
{
|
||||
cur = GGML_IM2COL_WORK_SIZE;
|
||||
} break;
|
||||
|
||||
@@ -7207,6 +7207,148 @@ void ggml_compute_forward_conv_2d(
|
||||
ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_conv_3d
|
||||
|
||||
static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
|
||||
const ggml_tensor * kernel,
|
||||
const ggml_tensor * src,
|
||||
ggml_tensor * dst,
|
||||
ggml_type kernel_type) {
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(kernel));
|
||||
GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(kernel->type == kernel_type);
|
||||
|
||||
const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
|
||||
|
||||
const int32_t s0 = dst->op_params[0];
|
||||
const int32_t s1 = dst->op_params[1];
|
||||
const int32_t s2 = dst->op_params[2];
|
||||
const int32_t p0 = dst->op_params[3];
|
||||
const int32_t p1 = dst->op_params[4];
|
||||
const int32_t p2 = dst->op_params[5];
|
||||
const int32_t d0 = dst->op_params[6];
|
||||
const int32_t d1 = dst->op_params[7];
|
||||
const int32_t d2 = dst->op_params[8];
|
||||
const int32_t c = dst->op_params[9];
|
||||
const int32_t n = dst->op_params[10];
|
||||
const int32_t oc = dst->op_params[11];
|
||||
|
||||
const int64_t src_w = src->ne[0];
|
||||
const int64_t src_h = src->ne[1];
|
||||
const int64_t src_d = src->ne[2];
|
||||
const int64_t knl_w = kernel->ne[0];
|
||||
const int64_t knl_h = kernel->ne[1];
|
||||
const int64_t knl_d = kernel->ne[2];
|
||||
const int64_t dst_w = dst->ne[0];
|
||||
const int64_t dst_h = dst->ne[1];
|
||||
const int64_t dst_d = dst->ne[2];
|
||||
|
||||
const float * src_data = (float *) src->data;
|
||||
void * knl_data = kernel->data;
|
||||
float * dst_data = (float *) dst->data;
|
||||
|
||||
const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
|
||||
const int64_t knl_n_total = knl_n_per_channel * c;
|
||||
const int64_t patch_total = n * dst_w * dst_h * dst_d;
|
||||
|
||||
const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float);
|
||||
const int64_t batch_size = params->wsize / space_per_patch;
|
||||
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
|
||||
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
|
||||
|
||||
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
|
||||
|
||||
void * tmp = params->wdata;
|
||||
|
||||
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
|
||||
const int64_t patch_start_batch = batch_i * patches_per_batch;
|
||||
const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total);
|
||||
const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch;
|
||||
|
||||
const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
|
||||
const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
|
||||
const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
|
||||
|
||||
for (int64_t p = patch_start; p < patch_end; ++p) {
|
||||
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
|
||||
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
|
||||
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
|
||||
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
|
||||
const int64_t dst_y = p_in_depth / dst_w;
|
||||
const int64_t dst_x = p_in_depth % dst_w;
|
||||
|
||||
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
|
||||
|
||||
for (int64_t ic = 0; ic < c; ++ic) {
|
||||
for (int64_t kz = 0; kz < knl_d; ++kz) {
|
||||
for (int64_t ky = 0; ky < knl_h; ++ky) {
|
||||
for (int64_t kx = 0; kx < knl_w; ++kx) {
|
||||
const int64_t sz = dst_z * s2 + kz * d2 - p2;
|
||||
const int64_t sy = dst_y * s1 + ky * d1 - p1;
|
||||
const int64_t sx = dst_x * s0 + kx * d0 - p0;
|
||||
|
||||
int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
|
||||
|
||||
float src_val;
|
||||
if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
|
||||
src_val = 0.0f;
|
||||
} else {
|
||||
const int64_t cn_idx = batch_idx * c + ic;
|
||||
const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
|
||||
src_val = *src_ptr;
|
||||
}
|
||||
|
||||
char * element_ptr = dst_row + dst_idx * traits->type_size;
|
||||
if (kernel_type == GGML_TYPE_F32) {
|
||||
*(float *)element_ptr = src_val;
|
||||
} else if (kernel_type == GGML_TYPE_F16) {
|
||||
*(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
|
||||
ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
|
||||
const int64_t permute_start = params->ith * permute_per_thread;
|
||||
const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
|
||||
|
||||
for (int64_t i = permute_start; i < permute_end; ++i) {
|
||||
const int64_t p = patch_start_batch + i;
|
||||
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
|
||||
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
|
||||
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
|
||||
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
|
||||
const int64_t dst_y = p_in_depth / dst_w;
|
||||
const int64_t dst_x = p_in_depth % dst_w;
|
||||
|
||||
for (int64_t ioc = 0; ioc < oc; ++ioc) {
|
||||
const float value = gemm_output[i * oc + ioc];
|
||||
const int64_t ocn_idx = batch_idx * oc + ioc;
|
||||
float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
|
||||
*dst_ptr = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_conv_3d(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_conv_transpose_2d
|
||||
|
||||
void ggml_compute_forward_conv_transpose_2d(
|
||||
|
||||
@@ -70,6 +70,7 @@ void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * p
|
||||
void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
||||
@@ -420,16 +420,28 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ int warp_reduce_all(int x) {
|
||||
#ifdef GGML_USE_HIP
|
||||
if (width == ggml_cuda_get_physical_warp_size()) {
|
||||
return __all_sync(0xffffffff, x);
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int offset = width/2; offset > 0; offset >>= 1) {
|
||||
x = x && __shfl_xor_sync(0xffffffff, x, offset, width);
|
||||
for (int offset = width/2; offset > 0; offset >>= 1) {
|
||||
x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ int warp_reduce_any(int x) {
|
||||
if (width == ggml_cuda_get_physical_warp_size()) {
|
||||
return __any_sync(0xffffffff, x);
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int offset = width/2; offset > 0; offset >>= 1) {
|
||||
x = __shfl_xor_sync(0xffffffff, x, offset, width) || x;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
return x;
|
||||
#else
|
||||
static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented");
|
||||
return __all_sync(0xffffffff, x);
|
||||
#endif // GGML_USE_HIP
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
|
||||
@@ -258,7 +258,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
const half val = hexp(sink - kqmax[j0/nwarps]);
|
||||
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
kqsum[j0/nwarps].x = __hadd(kqsum[j0/nwarps].x, val);
|
||||
kqsum[j0/nwarps].x = __hadd(__low2half(kqsum[j0/nwarps]), val);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
|
||||
@@ -49,6 +49,7 @@
|
||||
#include "ggml-cuda/wkv.cuh"
|
||||
#include "ggml-cuda/gla.cuh"
|
||||
#include "ggml-cuda/set-rows.cuh"
|
||||
#include "ggml-cuda/pad_reflect_1d.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <algorithm>
|
||||
@@ -203,6 +204,8 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
|
||||
#endif // GGML_CUDA_FORCE_CUBLAS
|
||||
GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
|
||||
|
||||
std::vector<std::pair<int, std::string>> turing_devices_without_mma;
|
||||
for (int id = 0; id < info.device_count; ++id) {
|
||||
int device_vmm = 0;
|
||||
|
||||
@@ -260,7 +263,25 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
info.devices[id].cc = 100*prop.major + 10*prop.minor;
|
||||
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
|
||||
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
|
||||
#endif // defined(GGML_USE_HIP)
|
||||
std::string device_name(prop.name);
|
||||
if (device_name == "NVIDIA GeForce MX450") {
|
||||
turing_devices_without_mma.push_back({ id, device_name });
|
||||
} else if (device_name == "NVIDIA GeForce MX550") {
|
||||
turing_devices_without_mma.push_back({ id, device_name });
|
||||
} else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
|
||||
turing_devices_without_mma.push_back({ id, device_name });
|
||||
}
|
||||
#endif // defined(GGML_USE_HIP)
|
||||
}
|
||||
|
||||
if (ggml_cuda_highest_compiled_arch(GGML_CUDA_CC_TURING) >= GGML_CUDA_CC_TURING && !turing_devices_without_mma.empty()) {
|
||||
GGML_LOG_INFO("The following devices will have suboptimal performance due to a lack of tensor cores:\n");
|
||||
for (size_t device_pos = 0; device_pos < turing_devices_without_mma.size(); device_pos++) {
|
||||
GGML_LOG_INFO(
|
||||
" Device %d: %s\n", turing_devices_without_mma[device_pos].first, turing_devices_without_mma[device_pos].second.c_str());
|
||||
}
|
||||
GGML_LOG_INFO(
|
||||
"Consider compiling with CMAKE_CUDA_ARCHITECTURES=61-virtual;80-virtual and DGGML_CUDA_FORCE_MMQ to force the use of the Pascal code for Turing.\n");
|
||||
}
|
||||
|
||||
for (int id = 0; id < info.device_count; ++id) {
|
||||
@@ -2352,6 +2373,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_PAD:
|
||||
ggml_cuda_op_pad(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
ggml_cuda_op_pad_reflect_1d(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ARANGE:
|
||||
ggml_cuda_op_arange(ctx, dst);
|
||||
break;
|
||||
@@ -3481,15 +3505,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_ACC:
|
||||
return true;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
|
||||
+177
-47
@@ -3,6 +3,140 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
|
||||
struct mmq_ids_helper_store {
|
||||
uint32_t data;
|
||||
|
||||
__device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
|
||||
data = (it & 0x003FFFFF) | (iex_used << 22);
|
||||
}
|
||||
|
||||
__device__ uint32_t it() const {
|
||||
return data & 0x003FFFFF;
|
||||
}
|
||||
|
||||
__device__ uint32_t iex_used() const {
|
||||
return data >> 22;
|
||||
}
|
||||
};
|
||||
static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
|
||||
|
||||
// Helper function for mul_mat_id, converts ids to a more convenient format.
|
||||
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
|
||||
// ids_dst describes the same mapping but for the dst tensor.
|
||||
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
|
||||
template <int n_expert_used_template>
|
||||
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
|
||||
static __global__ void mmq_ids_helper(
|
||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
|
||||
const int expert = blockIdx.x;
|
||||
|
||||
extern __shared__ char data_mmq_ids_helper[];
|
||||
mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
|
||||
|
||||
int nex_prev = 0; // Number of columns for experts with a lower index.
|
||||
int it_compact = 0; // Running index for the compact slice of this expert.
|
||||
|
||||
if constexpr (n_expert_used_template == 0) {
|
||||
// Generic implementation:
|
||||
for (int it = 0; it < n_tokens; ++it) {
|
||||
int iex_used = -1; // The index at which the expert is used, if any.
|
||||
for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
|
||||
const int expert_used = ids[it*si1 + iex];
|
||||
nex_prev += expert_used < expert;
|
||||
if (expert_used == expert) {
|
||||
iex_used = iex;
|
||||
}
|
||||
}
|
||||
|
||||
if (iex_used != -1) {
|
||||
store[it_compact] = mmq_ids_helper_store(it, iex_used);
|
||||
}
|
||||
|
||||
if (warp_reduce_any<warp_size>(iex_used != -1)) {
|
||||
it_compact++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Implementation optimized for specific numbers of experts used:
|
||||
static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
|
||||
const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
|
||||
for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
|
||||
const int it = it0 + threadIdx.x / neu_padded;
|
||||
|
||||
const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
|
||||
const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
|
||||
ids[it*si1 + iex] : INT_MAX;
|
||||
const int iex_used = expert_used == expert ? iex : -1;
|
||||
nex_prev += expert_used < expert;
|
||||
|
||||
// Whether the threads at this token position have used the expert:
|
||||
const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
|
||||
|
||||
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
|
||||
int it_compact_add_lower = 0;
|
||||
#pragma unroll
|
||||
for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
|
||||
const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
|
||||
if (threadIdx.x >= offset) {
|
||||
it_compact_add_lower += tmp;
|
||||
}
|
||||
}
|
||||
|
||||
if (iex_used != -1) {
|
||||
store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used);
|
||||
}
|
||||
|
||||
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
|
||||
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
|
||||
}
|
||||
}
|
||||
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
|
||||
|
||||
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
|
||||
const mmq_ids_helper_store store_it = store[itc];
|
||||
const int it = store_it.it();
|
||||
const int iex_used = store_it.iex_used();
|
||||
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
|
||||
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
|
||||
}
|
||||
|
||||
if (threadIdx.x != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
expert_bounds[expert] = nex_prev;
|
||||
|
||||
if (expert < gridDim.x - 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
expert_bounds[gridDim.x] = nex_prev + it_compact;
|
||||
}
|
||||
|
||||
template <int n_expert_used_template>
|
||||
static void launch_mmq_ids_helper(
|
||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
|
||||
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store");
|
||||
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store");
|
||||
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper<n_expert_used_template>, smpbo);
|
||||
|
||||
const dim3 num_blocks(n_experts, 1, 1);
|
||||
const dim3 block_size(warp_size, 1, 1);
|
||||
const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
|
||||
GGML_ASSERT(nbytes_shared <= smpbo);
|
||||
mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
|
||||
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
|
||||
switch (args.type_x) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
@@ -137,7 +271,7 @@ void ggml_cuda_mul_mat_q(
|
||||
ne00, ne01, ne1, s01, ne11, s1,
|
||||
ne02, ne12, s02, s12, s2,
|
||||
ne03, ne13, s03, s13, s3,
|
||||
use_stream_k};
|
||||
use_stream_k, ne1};
|
||||
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
|
||||
return;
|
||||
}
|
||||
@@ -148,54 +282,50 @@ void ggml_cuda_mul_mat_q(
|
||||
|
||||
const int64_t n_expert_used = ids->ne[0];
|
||||
const int64_t ne_get_rows = ne12 * n_expert_used;
|
||||
GGML_ASSERT(ne1 == n_expert_used);
|
||||
|
||||
std::vector<char> ids_host(ggml_nbytes(ids));
|
||||
std::vector<int32_t> ids_src1_host;
|
||||
ids_src1_host.reserve(ne_get_rows);
|
||||
std::vector<int32_t> ids_dst_host;
|
||||
ids_dst_host.reserve(ne_get_rows);
|
||||
std::vector<int32_t> tokens_per_expert_host(ne02);
|
||||
std::vector<int32_t> expert_bounds_host(ne02 + 1);
|
||||
ggml_cuda_pool_alloc<int32_t> ids_buf_dev(ctx.pool());
|
||||
ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows);
|
||||
ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows);
|
||||
ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1);
|
||||
|
||||
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
{
|
||||
GGML_ASSERT(ids->nb[0] == ggml_element_size(ids));
|
||||
const int si1 = ids->nb[1] / ggml_element_size(ids);
|
||||
const int sis1 = nb12 / nb11;
|
||||
|
||||
for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens
|
||||
for (int64_t iex = 0; iex < n_expert_used; ++iex) {
|
||||
const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]);
|
||||
assert(expert_to_use >= 0 && expert_to_use < ne02);
|
||||
if (expert_to_use == i02) {
|
||||
ids_src1_host.push_back(i12*(nb12/nb11) + iex % ne11);
|
||||
ids_dst_host.push_back(i12*ne1 + iex);
|
||||
tokens_per_expert_host[i02]++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
switch (n_expert_used) {
|
||||
case 2:
|
||||
launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 4:
|
||||
launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 6:
|
||||
launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 8:
|
||||
launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 16:
|
||||
launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 32:
|
||||
launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
default:
|
||||
launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
int32_t cumsum = 0;
|
||||
for (int64_t i = 0; i < ne02; ++i) {
|
||||
expert_bounds_host[i] = cumsum;
|
||||
cumsum += tokens_per_expert_host[i];
|
||||
}
|
||||
expert_bounds_host[ne02] = cumsum;
|
||||
|
||||
std::vector<int32_t> ids_buf_host;
|
||||
ids_buf_host.reserve(ids_src1_host.size() + ids_dst_host.size() + expert_bounds_host.size());
|
||||
ids_buf_host.insert(ids_buf_host.end(), ids_src1_host.begin(), ids_src1_host.end());
|
||||
ids_buf_host.insert(ids_buf_host.end(), ids_dst_host.begin(), ids_dst_host.end());
|
||||
ids_buf_host.insert(ids_buf_host.end(), expert_bounds_host.begin(), expert_bounds_host.end());
|
||||
ids_buf_dev.alloc(ids_buf_host.size() + get_mmq_x_max_host(cc)); // Expert bounds are padded on device.
|
||||
CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_buf_host.data(), ids_buf_host.size()*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
const int32_t * ids_src1_dev = ids_buf_dev.ptr;
|
||||
const int32_t * ids_dst_dev = ids_src1_dev + ids_src1_host.size();
|
||||
const int32_t * expert_bounds_dev = ids_dst_dev + ids_dst_host.size();
|
||||
|
||||
const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
|
||||
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
|
||||
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
|
||||
@@ -208,7 +338,7 @@ void ggml_cuda_mul_mat_q(
|
||||
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s13 = src1->nb[2] / ts_src1;
|
||||
quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
|
||||
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type,
|
||||
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
@@ -218,11 +348,11 @@ void ggml_cuda_mul_mat_q(
|
||||
|
||||
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
|
||||
const mmq_args args = {
|
||||
src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d,
|
||||
src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d,
|
||||
ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
|
||||
ne02, ne02, s02, s12, s2,
|
||||
ne03, ne13, s03, s13, s3,
|
||||
use_stream_k};
|
||||
use_stream_k, ne12};
|
||||
|
||||
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
|
||||
}
|
||||
@@ -262,7 +392,7 @@ void ggml_cuda_op_mul_mat_q(
|
||||
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
|
||||
1, 1, 0, 0, 0,
|
||||
1, 1, 0, 0, 0,
|
||||
use_stream_k};
|
||||
use_stream_k, src1_ncols};
|
||||
|
||||
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
|
||||
|
||||
|
||||
+21
-13
@@ -3138,7 +3138,8 @@ static __global__ void mul_mat_q(
|
||||
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
||||
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
|
||||
const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
const int ncols_max) {
|
||||
|
||||
// Skip unused template specializations for faster compilation:
|
||||
if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
|
||||
@@ -3152,7 +3153,7 @@ static __global__ void mul_mat_q(
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int mmq_y = get_mmq_y_device();
|
||||
|
||||
const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; // Number of tiles x
|
||||
const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x
|
||||
const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
|
||||
|
||||
// Initialize the ids for writing back data with just the index.
|
||||
@@ -3376,7 +3377,8 @@ template <ggml_type type, int mmq_x, bool need_check>
|
||||
static __global__ void mul_mat_q_stream_k_fixup(
|
||||
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
|
||||
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
|
||||
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
|
||||
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
|
||||
const int ncols_max) {
|
||||
constexpr int mmq_y = get_mmq_y_device();
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
||||
@@ -3387,7 +3389,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
||||
|
||||
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
|
||||
|
||||
const int ntx = (ncols_dst + mmq_x - 1) / mmq_x;
|
||||
const int ntx = (ncols_max + mmq_x - 1) / mmq_x;
|
||||
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
|
||||
const int bidx0 = blockIdx.x;
|
||||
@@ -3528,7 +3530,7 @@ struct mmq_args {
|
||||
int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
|
||||
int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
|
||||
int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
|
||||
bool use_stream_k;
|
||||
bool use_stream_k; int64_t ncols_max;
|
||||
};
|
||||
|
||||
template<ggml_type type>
|
||||
@@ -3558,7 +3560,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, true>), nbytes_shared);
|
||||
|
||||
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
|
||||
const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x;
|
||||
const int ntzw = args.nchannels_y * args.nsamples_y;
|
||||
const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
|
||||
|
||||
@@ -3574,14 +3576,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
||||
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
||||
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
||||
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -3601,7 +3605,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
||||
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
|
||||
if (!fixup_needed) {
|
||||
return;
|
||||
@@ -3609,14 +3614,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||
|
||||
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
||||
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
|
||||
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
|
||||
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
||||
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
||||
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
|
||||
if (!fixup_needed) {
|
||||
return;
|
||||
@@ -3624,7 +3631,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||
|
||||
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
||||
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
|
||||
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
|
||||
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3649,7 +3657,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
|
||||
continue;
|
||||
}
|
||||
|
||||
const int ntiles_x = (args.ncols_y + mmq_x - 1) / mmq_x;
|
||||
const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x;
|
||||
|
||||
if (ntiles_x < ntiles_x_best) {
|
||||
mmq_x_best = mmq_x;
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
#include "pad_reflect_1d.cuh"
|
||||
|
||||
static __global__ void pad_reflect_1d_kernel_f32(
|
||||
const void * __restrict__ src0,
|
||||
void * __restrict__ dst,
|
||||
const int64_t ne0,
|
||||
const int64_t ne00,
|
||||
const int64_t ne01,
|
||||
const int64_t ne02,
|
||||
const int64_t ne03,
|
||||
const int64_t nb00,
|
||||
const int64_t nb01,
|
||||
const int64_t nb02,
|
||||
const int64_t nb03,
|
||||
const int64_t nb0,
|
||||
const int64_t nb1,
|
||||
const int64_t nb2,
|
||||
const int64_t nb3,
|
||||
const int p0,
|
||||
const int p1) {
|
||||
|
||||
const int64_t i3 = blockIdx.z;
|
||||
const int64_t i2 = blockIdx.y;
|
||||
const int64_t i1 = blockIdx.x;
|
||||
|
||||
if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
|
||||
return;
|
||||
}
|
||||
|
||||
const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01;
|
||||
char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1;
|
||||
|
||||
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
|
||||
float value;
|
||||
|
||||
if (i0 < p0) {
|
||||
// Left padding - reflect
|
||||
value = *(const float *)(src0_ptr + (p0 - i0) * nb00);
|
||||
} else if (i0 < ne0 - p1) {
|
||||
// Middle - copy
|
||||
value = *(const float *)(src0_ptr + (i0 - p0) * nb00);
|
||||
} else {
|
||||
// Right padding - reflect
|
||||
int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1;
|
||||
value = *(const float *)(src0_ptr + src_idx * nb00);
|
||||
}
|
||||
|
||||
*(float *)(dst_ptr + i0 * nb0) = value;
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int32_t * opts = (const int32_t *) dst->op_params;
|
||||
const int p0 = opts[0];
|
||||
const int p1 = opts[1];
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
|
||||
GGML_ASSERT(ne0 == ne00 + p0 + p1);
|
||||
|
||||
const dim3 block_dims(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1, 1);
|
||||
const dim3 grid_dims(ne01, ne02, ne03);
|
||||
|
||||
pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>(
|
||||
src0->data, dst->data,
|
||||
ne0, ne00, ne01, ne02, ne03,
|
||||
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
|
||||
p0, p1
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_PAD_REFLECT_1D_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -28,7 +28,58 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
|
||||
return ((const int *) x)[i32]; // assume at least 4 byte alignment
|
||||
}
|
||||
|
||||
// q4 contains 8 indices with 4 bit each.
|
||||
// This function selects those bytes from table that are at those indices and returns them as int2.
|
||||
// The first int contains the bytes with even indices in q4, the second int contains the bytes with odd indices in q4.
|
||||
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
|
||||
#if defined(GGML_USE_HIP)
|
||||
// Load the 16-byte table into four 32-bit unsigned integers.
|
||||
const uint32_t *values = (const uint32_t *)table;
|
||||
|
||||
const uint32_t q_even = q4;
|
||||
const uint32_t q_odd = (q4 >> 4);
|
||||
|
||||
// Perform lookups in the lower half of the table (indices 0-7).
|
||||
uint32_t v_even_low = __builtin_amdgcn_perm(values[1], values[0], q_even & 0x07070707);
|
||||
uint32_t v_odd_low = __builtin_amdgcn_perm(values[1], values[0], q_odd & 0x07070707);
|
||||
|
||||
// Perform lookups in the upper half of the table (indices 8-15).
|
||||
uint32_t v_even_high = __builtin_amdgcn_perm(values[3], values[2], q_even & 0x07070707);
|
||||
uint32_t v_odd_high = __builtin_amdgcn_perm(values[3], values[2], q_odd & 0x07070707);
|
||||
|
||||
// Select between the low and high results based on the MSB of each index nibble.
|
||||
uint32_t mask_even = 0x03020100 | ((q_even & 0x08080808) >> 1);
|
||||
uint32_t res_x = __builtin_amdgcn_perm(v_even_high, v_even_low, mask_even);
|
||||
uint32_t mask_odd = 0x03020100 | ((q_odd & 0x08080808) >> 1);
|
||||
uint32_t res_y = __builtin_amdgcn_perm(v_odd_high, v_odd_low, mask_odd);
|
||||
|
||||
return make_int2(res_x, res_y);
|
||||
#elif !defined(GGML_USE_MUSA)
|
||||
// CUDA does not have an instruction for selecting bytes with 4 bit indices.
|
||||
// However, __byte_perm is an instruction that selects bytes with 3 bit indices that can be used instead.
|
||||
const uint32_t * table32 = (const uint32_t *) table;
|
||||
|
||||
// __byte_perm selects bytes based on the lower 16 bits in its third argument.
|
||||
// Therefore, do 2 iterations over the 32 bits in q4 with 0 and 16 shift.
|
||||
// To handle the fourth bit, first call _byte_perm both for the low and the high 64 bit of table, using the low 3 bits.
|
||||
// Then, call __byte_perm again to select from the low and high bytes based on the fourth bit.
|
||||
uint32_t tmp[2];
|
||||
const uint32_t low_high_selection_indices = (0x32103210 | ((q4 & 0x88888888) >> 1));
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 2; ++i) {
|
||||
const uint32_t shift = 16 * i;
|
||||
|
||||
const uint32_t low = __byte_perm(table32[0], table32[1], q4 >> shift);
|
||||
const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift);
|
||||
tmp[i] = __byte_perm(low, high, low_high_selection_indices >> shift);
|
||||
}
|
||||
|
||||
// tmp contains the bytes from tyble in the same order as the 4 bit indices in q4.
|
||||
// However, for the result we need ints with all even/odd 4 bit indices in q4.
|
||||
// Therefore, 2 more calls to __byte_perm to put the bytes in the correct order.
|
||||
return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531));
|
||||
#else
|
||||
// Generic implementation.
|
||||
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
|
||||
const int8_t * q0_8 = (const int8_t *) &q0_32;
|
||||
const char4 val0_8 = make_char4(
|
||||
@@ -40,6 +91,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
|
||||
table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
|
||||
|
||||
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
|
||||
#endif
|
||||
}
|
||||
|
||||
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
||||
|
||||
Vendored
+3
@@ -22,7 +22,10 @@
|
||||
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
|
||||
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
|
||||
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
|
||||
#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width)
|
||||
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
||||
#define __all_sync(mask, var) __all(var)
|
||||
#define __any_sync(mask, var) __any(var)
|
||||
#define cublasCreate hipblasCreate
|
||||
#define cublasDestroy hipblasDestroy
|
||||
#define cublasGemmEx hipblasGemmEx
|
||||
|
||||
@@ -320,40 +320,31 @@ typedef struct {
|
||||
} ggml_metal_kargs_mul_mv_ext;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne02;
|
||||
int32_t ne10;
|
||||
int32_t ne11; // n_expert_used (bcast)
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
int32_t neh11; // n_tokens
|
||||
uint64_t nbh11;
|
||||
int32_t ne21; // n_tokens
|
||||
int32_t ne20; // n_expert_used
|
||||
uint64_t nb21;
|
||||
} ggml_metal_kargs_mul_mm_id_map0;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne20; // n_expert_used
|
||||
int32_t neh0;
|
||||
int32_t neh1;
|
||||
uint64_t nbh1;
|
||||
uint64_t nbh2;
|
||||
int32_t ne0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
} ggml_metal_kargs_mul_mm_id_map1;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne02;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t neh12;
|
||||
uint64_t nbh10;
|
||||
uint64_t nbh11;
|
||||
uint64_t nbh12;
|
||||
uint64_t nbh13;
|
||||
int32_t neh0;
|
||||
int32_t neh1;
|
||||
int32_t ne11;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
int32_t ne20;
|
||||
int32_t ne21;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int16_t r2;
|
||||
int16_t r3;
|
||||
} ggml_metal_kargs_mul_mm_id;
|
||||
|
||||
+126
-109
@@ -93,35 +93,37 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
||||
if (ctx->mtl_device == nil) {
|
||||
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
||||
|
||||
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
||||
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
||||
if (ctx->mtl_device) {
|
||||
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
||||
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
||||
|
||||
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
||||
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
||||
|
||||
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
|
||||
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
|
||||
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
|
||||
#endif
|
||||
|
||||
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
||||
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
||||
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
||||
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
||||
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
ctx->use_bfloat = ctx->has_bfloat;
|
||||
ctx->use_bfloat = ctx->has_bfloat;
|
||||
#else
|
||||
ctx->use_bfloat = false;
|
||||
ctx->use_bfloat = false;
|
||||
#endif
|
||||
ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
|
||||
ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
|
||||
|
||||
{
|
||||
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
|
||||
ctx->debug_fusion = val ? atoi(val) : 0;
|
||||
{
|
||||
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
|
||||
ctx->debug_fusion = val ? atoi(val) : 0;
|
||||
}
|
||||
|
||||
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
|
||||
|
||||
ctx->max_size = ctx->mtl_device.maxBufferLength;
|
||||
|
||||
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
||||
}
|
||||
|
||||
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
|
||||
|
||||
ctx->max_size = ctx->mtl_device.maxBufferLength;
|
||||
|
||||
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
||||
}
|
||||
|
||||
ctx->mtl_device_ref_count++;
|
||||
@@ -396,8 +398,12 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
|
||||
@@ -443,6 +449,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
||||
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
||||
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
||||
@@ -452,6 +459,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
|
||||
@@ -461,6 +469,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
||||
@@ -470,6 +479,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
|
||||
@@ -479,6 +489,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
|
||||
@@ -488,6 +499,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
|
||||
@@ -497,6 +509,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
|
||||
@@ -506,6 +519,13 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
|
||||
@@ -1412,8 +1432,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1, mul_mm_id_map0_f16_ne20_1, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2, mul_mm_id_map0_f16_ne20_2, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
|
||||
@@ -1459,6 +1483,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40, flash_attn_ext_f16_h40, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
|
||||
@@ -1468,6 +1493,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40, flash_attn_ext_bf16_h40, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
|
||||
@@ -1477,6 +1503,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40, flash_attn_ext_q4_0_h40, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
|
||||
@@ -1486,6 +1513,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40, flash_attn_ext_q4_1_h40, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
|
||||
@@ -1495,6 +1523,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40, flash_attn_ext_q5_0_h40, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
|
||||
@@ -1504,6 +1533,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40, flash_attn_ext_q5_1_h40, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
|
||||
@@ -1513,6 +1543,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40, flash_attn_ext_q8_0_h40, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
|
||||
@@ -1522,6 +1553,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40, flash_attn_ext_vec_f16_h40, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40, flash_attn_ext_vec_bf16_h40, has_simdgroup_reduction && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40, flash_attn_ext_vec_q4_0_h40, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40, flash_attn_ext_vec_q4_1_h40, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40, flash_attn_ext_vec_q5_0_h40, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40, flash_attn_ext_vec_q5_1_h40, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40, flash_attn_ext_vec_q8_0_h40, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
|
||||
@@ -1846,7 +1884,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_OP_ROPE:
|
||||
return true;
|
||||
case GGML_OP_IM2COL:
|
||||
return op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
|
||||
return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
|
||||
case GGML_OP_POOL_1D:
|
||||
return false;
|
||||
case GGML_OP_UPSCALE:
|
||||
@@ -3878,38 +3916,6 @@ static int ggml_metal_encode_node(
|
||||
default: break;
|
||||
}
|
||||
|
||||
const int64_t neh10 = ne10; // n_embd
|
||||
const int64_t neh11 = ne21; // n_tokens
|
||||
const int64_t neh12 = ne02; // n_expert
|
||||
|
||||
const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16);
|
||||
const uint64_t nbh11 = nbh10*neh10;
|
||||
const uint64_t nbh12 = nbh11*neh11;
|
||||
const uint64_t nbh13 = nbh12*neh12;
|
||||
|
||||
const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12;
|
||||
id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
|
||||
if (!h_src1) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int64_t neh0 = ne0;
|
||||
const int64_t neh1 = ne21;
|
||||
const int64_t neh2 = ne02;
|
||||
|
||||
const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32);
|
||||
const uint64_t nbh1 = nbh0*neh0;
|
||||
const uint64_t nbh2 = nbh1*neh1;
|
||||
//const uint64_t nbh3 = nbh2*neh2;
|
||||
|
||||
const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2;
|
||||
id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
|
||||
if (!h_dst) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// tokens per expert
|
||||
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
|
||||
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
|
||||
@@ -3919,8 +3925,8 @@ static int ggml_metal_encode_node(
|
||||
}
|
||||
|
||||
// id map
|
||||
// [n_expert_used, n_tokens]
|
||||
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21;
|
||||
// [n_tokens, n_expert]
|
||||
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne21*ne02;
|
||||
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
|
||||
if (!h_ids) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
|
||||
@@ -3928,32 +3934,45 @@ static int ggml_metal_encode_node(
|
||||
}
|
||||
|
||||
{
|
||||
const int nth = MIN(1024, ne10/4);
|
||||
|
||||
ggml_metal_kargs_mul_mm_id_map0 args = {
|
||||
ne02,
|
||||
ne10,
|
||||
ne11, // n_expert_used (bcast)
|
||||
ne11, // n_expert_used (bcast)
|
||||
nb11,
|
||||
nb12,
|
||||
neh11, // n_tokens
|
||||
nbh11,
|
||||
ne20, // n_expert_used
|
||||
ne21, // n_tokens
|
||||
ne20, // n_expert_used
|
||||
nb21,
|
||||
};
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
|
||||
pipeline = nil;
|
||||
|
||||
switch (ne20) {
|
||||
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1 ].pipeline; break;
|
||||
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2 ].pipeline; break;
|
||||
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break;
|
||||
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break;
|
||||
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break;
|
||||
case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break;
|
||||
default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20);
|
||||
}
|
||||
|
||||
GGML_ASSERT(ne02 <= (int) pipeline.maxTotalThreadsPerThreadgroup);
|
||||
|
||||
const size_t smem = ne02*ne20*sizeof(uint16_t);
|
||||
|
||||
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
[encoder setBuffer: h_src1 offset:0 atIndex:3];
|
||||
[encoder setBuffer: h_tpe offset:0 atIndex:4];
|
||||
[encoder setBuffer: h_ids offset:0 atIndex:5];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:1];
|
||||
[encoder setBuffer: h_tpe offset:0 atIndex:2];
|
||||
[encoder setBuffer: h_ids offset:0 atIndex:3];
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
|
||||
}
|
||||
|
||||
{
|
||||
@@ -3992,13 +4011,15 @@ static int ggml_metal_encode_node(
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.neh12 =*/ neh12,
|
||||
/*.nbh10 =*/ nbh10,
|
||||
/*.nbh11 =*/ nbh11,
|
||||
/*.nbh12 =*/ nbh12,
|
||||
/*.nbh13 =*/ nbh13,
|
||||
/*.neh0 =*/ neh0,
|
||||
/*.neh1 =*/ neh1,
|
||||
/*.ne11 =*/ ne11, // n_expert_used (bcast)
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.ne20 =*/ ne20, // n_expert_used
|
||||
/*.ne21 =*/ ne21, // n_tokens
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.r2 =*/ r2,
|
||||
/*.r3 =*/ r3,
|
||||
};
|
||||
@@ -4006,42 +4027,14 @@ static int ggml_metal_encode_node(
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer: h_src1 offset:0 atIndex:2];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
[encoder setBuffer: h_tpe offset:0 atIndex:3];
|
||||
[encoder setBuffer: h_dst offset:0 atIndex:4];
|
||||
[encoder setBuffer: h_ids offset:0 atIndex:4];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:5];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||
}
|
||||
|
||||
{
|
||||
GGML_ASSERT(ne0 % 4 == 0);
|
||||
|
||||
const int nth = MIN(1024, ne0/4);
|
||||
|
||||
ggml_metal_kargs_mul_mm_id_map1 args = {
|
||||
ne20, // n_expert_used
|
||||
neh0,
|
||||
neh1,
|
||||
nbh1,
|
||||
nbh2,
|
||||
ne0,
|
||||
nb1,
|
||||
nb2,
|
||||
};
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer: h_dst offset:0 atIndex:1];
|
||||
[encoder setBuffer: h_ids offset:0 atIndex:2];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
}
|
||||
} else {
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
@@ -4701,7 +4694,6 @@ static int ggml_metal_encode_node(
|
||||
} break;
|
||||
case GGML_OP_IM2COL:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
||||
@@ -5130,6 +5122,7 @@ static int ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40 ].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
||||
@@ -5154,6 +5147,7 @@ static int ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40 ].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
|
||||
@@ -5178,6 +5172,7 @@ static int ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40 ].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
|
||||
@@ -5202,6 +5197,7 @@ static int ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40 ].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
|
||||
@@ -5226,6 +5222,7 @@ static int ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40 ].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
|
||||
@@ -5250,6 +5247,7 @@ static int ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40 ].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
|
||||
@@ -5274,6 +5272,7 @@ static int ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40 ].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
|
||||
@@ -5301,6 +5300,24 @@ static int ggml_metal_encode_node(
|
||||
use_vec_kernel = true;
|
||||
|
||||
switch (ne00) {
|
||||
case 40:
|
||||
{
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40].pipeline; break;
|
||||
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40].pipeline; break;
|
||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40].pipeline; break;
|
||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40].pipeline; break;
|
||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40].pipeline; break;
|
||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40].pipeline; break;
|
||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
||||
GGML_LOG_ERROR("add template specialization for this type\n");
|
||||
GGML_ABORT("add template specialization for this type");
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case 64:
|
||||
{
|
||||
switch (src1->type) {
|
||||
|
||||
@@ -974,9 +974,16 @@ kernel void kernel_mul(
|
||||
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
||||
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
|
||||
if (args.ne10 == 1) {
|
||||
const float x = *((device float *)(src1_ptr));
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
|
||||
}
|
||||
} else {
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1000,9 +1007,16 @@ kernel void kernel_div(
|
||||
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
||||
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
|
||||
if (args.ne10 == 1) {
|
||||
const float x = 1.0f / *((device float *)(src1_ptr));
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
|
||||
}
|
||||
} else {
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4663,6 +4677,7 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
|
||||
@@ -4674,6 +4689,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
||||
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
|
||||
@@ -4685,6 +4701,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
||||
#endif
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
|
||||
@@ -4695,6 +4712,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
|
||||
@@ -4705,6 +4723,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
|
||||
@@ -4715,6 +4734,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
|
||||
@@ -4725,6 +4745,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
|
||||
@@ -5115,6 +5136,16 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 40, 40, 8>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 40, 40, 8>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 40, 40, 8>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 40, 40, 8>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 40, 40, 8>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 40, 40, 8>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 40, 40, 8>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 8>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 8>;
|
||||
@@ -7474,97 +7505,81 @@ kernel void kernel_mul_mm(
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T4>
|
||||
template<short ne20> // n_expert_used
|
||||
kernel void kernel_mul_mm_id_map0(
|
||||
constant ggml_metal_kargs_mul_mm_id_map0 & args,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * hsrc1,
|
||||
device char * htpe,
|
||||
device char * hids,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int ide = tgpig[0]; // expert id
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
ushort tpitg[[thread_position_in_threadgroup]],
|
||||
ushort ntg[[threads_per_threadgroup]]) {
|
||||
const short ide = tpitg; // expert id
|
||||
|
||||
int n_all = 0;
|
||||
uint32_t n_all = 0;
|
||||
|
||||
device int32_t * ids_i32 = (device int32_t *) (hids);
|
||||
device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
|
||||
|
||||
for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens
|
||||
device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
|
||||
for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
|
||||
if (i21 + tpitg < args.ne21) {
|
||||
device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
|
||||
|
||||
for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
|
||||
if (src2_i32[i20] != ide) {
|
||||
continue;
|
||||
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
|
||||
|
||||
#pragma unroll(ne20)
|
||||
for (short i20 = 0; i20 < ne20; i20++) {
|
||||
sids[i20] = src2_i32[i20];
|
||||
}
|
||||
|
||||
device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11);
|
||||
device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
|
||||
|
||||
for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) {
|
||||
hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]);
|
||||
}
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
|
||||
}
|
||||
|
||||
++n_all;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (short t = 0; t < ntg; t++) {
|
||||
if (i21 + t >= args.ne21) {
|
||||
break;
|
||||
}
|
||||
|
||||
threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
|
||||
|
||||
short sel = 0;
|
||||
#pragma unroll(ne20)
|
||||
for (short i20 = 0; i20 < ne20; i20++) {
|
||||
sel += (sids[i20] == ide)*(i20 + 1);
|
||||
}
|
||||
|
||||
ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
|
||||
|
||||
n_all += sel > 0;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
device int32_t * tpe_i32 = (device int32_t *) (htpe);
|
||||
tpe_i32[ide] = n_all;
|
||||
}
|
||||
device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
|
||||
tpe_u32[ide] = n_all;
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
|
||||
typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_mul_mm_id_map1(
|
||||
constant ggml_metal_kargs_mul_mm_id_map1 & args,
|
||||
device const char * hdst,
|
||||
device const char * hids,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i20 = tgpig[0]; // used expert
|
||||
const int i21 = tgpig[1]; // token
|
||||
|
||||
device const int32_t * ids_i32 = (device const int32_t *) (hids);
|
||||
device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2);
|
||||
|
||||
const int id = ids_i32[i21*args.ne20 + i20];
|
||||
|
||||
const int ide = id / args.neh1;
|
||||
const int idt = id % args.neh1;
|
||||
|
||||
device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2);
|
||||
|
||||
for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) {
|
||||
dst_f32x4[i0] = hdst_f32x4[i0];
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mm_id_map1<float>) kernel_mul_mm_id_map1_t;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1<float>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
|
||||
|
||||
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
||||
kernel void kernel_mul_mm_id(
|
||||
constant ggml_metal_kargs_mul_mm_id & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * tpe,
|
||||
device const char * htpe,
|
||||
device const char * hids,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
threadgroup T * sa = (threadgroup T *)(shmem);
|
||||
@@ -7572,19 +7587,20 @@ kernel void kernel_mul_mm_id(
|
||||
|
||||
const int r0 = tgpig.y;
|
||||
const int r1 = tgpig.x;
|
||||
const int im = tgpig.z;
|
||||
const int im = tgpig.z; // expert
|
||||
|
||||
device const int32_t * tpe_i32 = (device const int32_t *) (tpe);
|
||||
device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
|
||||
device const int32_t * ids_i32 = (device const int32_t *) (hids);
|
||||
|
||||
const int neh1 = tpe_i32[im];
|
||||
const int32_t neh1 = tpe_u32[im];
|
||||
|
||||
if (r1*BLOCK_SIZE_N >= neh1) {
|
||||
return;
|
||||
}
|
||||
|
||||
// if this block is of 64x32 shape or smaller
|
||||
const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
||||
const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
||||
const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
||||
const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
||||
|
||||
// a thread shouldn't load data outside of the matrix
|
||||
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
||||
@@ -7600,20 +7616,23 @@ kernel void kernel_mul_mm_id(
|
||||
|
||||
short il = (tiitg % THREAD_PER_ROW);
|
||||
|
||||
const int i12 = im%args.neh12;
|
||||
const int i13 = im/args.neh12;
|
||||
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col];
|
||||
|
||||
const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const short i11 = (id % args.ne20) % args.ne11;
|
||||
const short i12 = (id / args.ne20);
|
||||
const short i13 = 0;
|
||||
|
||||
const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
|
||||
const short offset1 = il/nl;
|
||||
|
||||
device const block_q * x = (device const block_q *)(src0
|
||||
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
|
||||
|
||||
device const half * y = (device const half *)(src1
|
||||
+ args.nbh13*i13
|
||||
+ args.nbh12*i12
|
||||
+ args.nbh11*(r1*BLOCK_SIZE_N + thread_col)
|
||||
+ args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
||||
device const float * y = (device const float *)(src1
|
||||
+ args.nb13*i13
|
||||
+ args.nb12*i12
|
||||
+ args.nb11*i11
|
||||
+ args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
||||
|
||||
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
|
||||
// load data and store to threadgroup memory
|
||||
@@ -7629,7 +7648,7 @@ kernel void kernel_mul_mm_id(
|
||||
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
|
||||
}
|
||||
|
||||
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y);
|
||||
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (half2x4)(*((device float2x4 *) y));
|
||||
|
||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
||||
@@ -7665,43 +7684,38 @@ kernel void kernel_mul_mm_id(
|
||||
}
|
||||
}
|
||||
|
||||
if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) {
|
||||
device float * C = (device float *) dst +
|
||||
(BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
|
||||
(BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (short i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0);
|
||||
}
|
||||
} else {
|
||||
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup float * temp_str = ((threadgroup float *) shmem) \
|
||||
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
|
||||
for (short i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
|
||||
threadgroup float * temp_str = ((threadgroup float *) shmem) \
|
||||
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
|
||||
|
||||
#pragma unroll(8)
|
||||
for (short i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (short j = sgitg; j < n_cols; j += 4) {
|
||||
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
|
||||
|
||||
const short ide = id % args.ne20;
|
||||
const short idt = id / args.ne20;
|
||||
|
||||
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0;
|
||||
device float4 * D4 = (device float4 *) D;
|
||||
|
||||
threadgroup float * C = (threadgroup float *) shmem + (j*BLOCK_SIZE_M);
|
||||
threadgroup float4 * C4 = (threadgroup float4 *) C;
|
||||
|
||||
int i = tiisg;
|
||||
for (; i < n_rows/4; i += 32) {
|
||||
*(D4 + i) = *(C4 + i);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sgitg == 0) {
|
||||
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
||||
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0;
|
||||
device float4 * D4 = (device float4 *) D;
|
||||
|
||||
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
|
||||
threadgroup float4 * C4 = (threadgroup float4 *) C;
|
||||
|
||||
int i = 0;
|
||||
for (; i < n_rows/4; i++) {
|
||||
*(D4 + i) = *(C4 + i);
|
||||
}
|
||||
|
||||
i *= 4;
|
||||
for (; i < n_rows; i++) {
|
||||
*(D + i) = *(C + i);
|
||||
}
|
||||
}
|
||||
i = (4*(n_rows/4)) + tiisg;
|
||||
for (; i < n_rows; i += 32) {
|
||||
*(D + i) = *(C + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2647,8 +2647,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
return true;
|
||||
case GGML_OP_RMS_NORM:
|
||||
return op->ne[0] % 4 == 0 && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_REPEAT:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
|
||||
case GGML_OP_PAD:
|
||||
|
||||
@@ -4391,10 +4391,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return true;
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_ARGSORT:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,20 +1,34 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#if ADD_RMS
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#endif
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_binary_head.comp"
|
||||
|
||||
const uint num_threads = 256;
|
||||
|
||||
layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];};
|
||||
|
||||
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
#if ADD_RMS
|
||||
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
|
||||
shared FLOAT_TYPE sumsh[num_threads];
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
uint idx = get_idx();
|
||||
uint orig_idx = idx;
|
||||
|
||||
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
|
||||
const uint num_iter = 2;
|
||||
|
||||
FLOAT_TYPE sum_sq = 0;
|
||||
|
||||
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
||||
if (idx >= p.ne) {
|
||||
continue;
|
||||
@@ -22,8 +36,34 @@ void main() {
|
||||
uint i00, i01, i02, i03;
|
||||
get_indices(idx, i00, i01, i02, i03);
|
||||
|
||||
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]);
|
||||
sum_sq += sum*sum;
|
||||
|
||||
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
|
||||
|
||||
idx += num_threads;
|
||||
}
|
||||
|
||||
#if ADD_RMS
|
||||
if (p.param3 != 0) {
|
||||
// reduce the sum within each subgroup, then across subgroups
|
||||
const uint NumSubgroups = num_threads / gl_SubgroupSize;
|
||||
sum_sq = subgroupAdd(sum_sq);
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
sumsh[gl_SubgroupID] = sum_sq;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
|
||||
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
|
||||
sum_sq += sumsh[gl_SubgroupID + s];
|
||||
sumsh[gl_SubgroupID] = sum_sq;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
|
||||
partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -9,6 +9,10 @@ layout (constant_id = 4) const uint32_t HSV = 32;
|
||||
layout (constant_id = 5) const uint32_t Clamp = 0;
|
||||
layout (constant_id = 6) const uint32_t D_split = 16;
|
||||
|
||||
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
|
||||
const uint32_t HSK_pad = (HSK + 15) & ~15;
|
||||
const uint32_t HSV_pad = (HSV + 15) & ~15;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t N;
|
||||
uint32_t KV;
|
||||
|
||||
@@ -46,14 +46,14 @@ const uint32_t MatBc = 16;
|
||||
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
||||
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
|
||||
|
||||
const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4
|
||||
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 Qf[Br * qstride];
|
||||
|
||||
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
|
||||
const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
|
||||
shared ACC_TYPE sfsh[Bc * sfshstride];
|
||||
|
||||
const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4
|
||||
const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 ksh[Bc * kshstride];
|
||||
|
||||
shared float slope[Br];
|
||||
@@ -74,6 +74,21 @@ void main() {
|
||||
|
||||
#define tile_row(r) (row_tid * rows_per_thread + (r))
|
||||
|
||||
// Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK).
|
||||
if ((HSK % 16) != 0) {
|
||||
[[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) {
|
||||
if (i + tid < Br * qstride) {
|
||||
Qf[i + tid] = f16vec4(0);
|
||||
}
|
||||
}
|
||||
[[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
|
||||
if (i + tid < Bc * kshstride) {
|
||||
ksh[i + tid] = f16vec4(0);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
@@ -151,14 +166,14 @@ void main() {
|
||||
}
|
||||
barrier();
|
||||
|
||||
// K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br
|
||||
// K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br
|
||||
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
|
||||
// This is written transposed in order to allow for N being 8 if implementations need it
|
||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
|
||||
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
||||
|
||||
for (uint32_t d = 0; d < HSK / 16; ++d) {
|
||||
for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
|
||||
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
|
||||
|
||||
@@ -104,16 +104,16 @@ void main() {
|
||||
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
||||
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
|
||||
|
||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseAccumulator> Q;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA> Qf16;
|
||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
|
||||
|
||||
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
|
||||
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK));
|
||||
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
|
||||
|
||||
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA>(Q);
|
||||
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
|
||||
Qf16 *= float16_t(p.scale);
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
|
||||
|
||||
@@ -140,10 +140,10 @@ void main() {
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, HSK, Bc, gl_MatrixUseB> K_T;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
|
||||
|
||||
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC);
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
|
||||
S = coopMatMulAdd(Qf16, K_T, S);
|
||||
|
||||
if (p.logit_softcap != 0.0f) {
|
||||
@@ -208,31 +208,31 @@ void main() {
|
||||
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
|
||||
rowsum = coopMatMulAdd(P_A, One, rowsum);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV, gl_MatrixUseB> V;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
|
||||
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC);
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC);
|
||||
|
||||
L = eM*L + rowsum;
|
||||
|
||||
// This is the "diagonal" matrix in the paper, but since we do componentwise
|
||||
// multiply rather than matrix multiply it has the diagonal element smeared
|
||||
// across the row
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> eMdiag;
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> eMdiag;
|
||||
|
||||
// resize eM by using smear/reduce
|
||||
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||
|
||||
// multiply with fp16 accumulation, then add to O.
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
|
||||
PV = coopMatMulAdd(P_A, V, PV);
|
||||
|
||||
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(PV);
|
||||
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(PV);
|
||||
}
|
||||
|
||||
// If there is split_k, then the split_k resolve shader does the final
|
||||
// division by L. Store the intermediate O value and per-row m and L values.
|
||||
if (p.k_num > 1) {
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
|
||||
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
||||
@@ -243,16 +243,16 @@ void main() {
|
||||
return;
|
||||
}
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Ldiag;
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Ldiag;
|
||||
|
||||
// resize L by using smear/reduce
|
||||
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||
|
||||
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> S;
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> S;
|
||||
coopMatPerElementNV(S, S, perElemOpGetSink, iq2);
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Mr;
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Mr;
|
||||
|
||||
// resize M by using smear/reduce
|
||||
coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||
@@ -285,7 +285,7 @@ void main() {
|
||||
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
||||
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
|
||||
if (p.gqa_ratio > 1) {
|
||||
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
||||
} else {
|
||||
@@ -295,6 +295,6 @@ void main() {
|
||||
// permute dimensions
|
||||
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
|
||||
|
||||
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute);
|
||||
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV_pad), tensorViewPermute);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,9 @@
|
||||
#ifdef COOPMAT
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#endif
|
||||
|
||||
#if defined(COOPMAT) || defined(MUL_MAT_ID_USE_SUBGROUPS)
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_ballot : enable
|
||||
#endif
|
||||
@@ -103,16 +106,79 @@ layout (constant_id = 10) const uint WARP = 32;
|
||||
shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
|
||||
shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[4096];
|
||||
uint _ne1;
|
||||
#ifdef COOPMAT
|
||||
shared uint _ne1_sh;
|
||||
#endif
|
||||
#endif // MUL_MAT_ID
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[BN];
|
||||
uint _ne1;
|
||||
|
||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||
shared uvec4 ballots_sh[NUM_WARPS];
|
||||
|
||||
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
uint nei0shift = findLSB(p.nei0);
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_LocalInvocationIndex;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
|
||||
ballots_sh[gl_SubgroupID] = ballot;
|
||||
barrier();
|
||||
|
||||
uint subgroup_base = 0;
|
||||
uint total = 0;
|
||||
for (uint k = 0; k < gl_NumSubgroups; ++k) {
|
||||
if (k == gl_SubgroupID) {
|
||||
subgroup_base = total;
|
||||
}
|
||||
total += subgroupBallotBitCount(ballots_sh[k]);
|
||||
}
|
||||
barrier();
|
||||
|
||||
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
|
||||
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1 += total;
|
||||
iter &= 15;
|
||||
if (_ne1 >= (ic + 1) * BN) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
#endif // MUL_MAT_ID_USE_SUBGROUPS
|
||||
#endif // MUL_MAT_ID
|
||||
|
||||
#ifdef COOPMAT
|
||||
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
||||
#endif
|
||||
@@ -177,51 +243,20 @@ void main() {
|
||||
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
#ifdef COOPMAT
|
||||
// Spread the search across all elements in the first subgroup
|
||||
if (gl_SubgroupID == 0) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_SubgroupInvocationID;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx) {
|
||||
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1 += subgroupBallotBitCount(ballot);
|
||||
iter &= 15;
|
||||
}
|
||||
_ne1_sh = _ne1;
|
||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||
if (bitCount(p.nei0) == 1) {
|
||||
load_row_ids(expert_idx, true, ic);
|
||||
} else {
|
||||
load_row_ids(expert_idx, false, ic);
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
_ne1 = _ne1_sh;
|
||||
#else
|
||||
_ne1 = 0;
|
||||
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
||||
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
|
||||
for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
|
||||
for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
|
||||
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
|
||||
row_ids[_ne1] = u16vec2(ii0, ii1);
|
||||
if (_ne1 >= ic * BN) {
|
||||
row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1++;
|
||||
}
|
||||
}
|
||||
@@ -767,7 +802,7 @@ void main() {
|
||||
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
|
||||
#if LOAD_VEC_B == 8
|
||||
#ifdef MUL_MAT_ID
|
||||
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
|
||||
const u16vec2 row_idx = row_ids[loadc_b + l];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
#else
|
||||
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
@@ -783,7 +818,7 @@ void main() {
|
||||
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
|
||||
#elif LOAD_VEC_B == 4
|
||||
#ifdef MUL_MAT_ID
|
||||
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
|
||||
const u16vec2 row_idx = row_ids[loadc_b + l];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
#else
|
||||
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
@@ -802,7 +837,7 @@ void main() {
|
||||
#else
|
||||
const uint row_i = ic * BN + loadc_b + l;
|
||||
if (row_i < _ne1 && block + loadr_b < end_k) {
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
const u16vec2 row_idx = row_ids[loadc_b + l];
|
||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
|
||||
} else {
|
||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
||||
@@ -873,7 +908,7 @@ void main() {
|
||||
const uint row_i = dc + cm_col * TN + col + store_c;
|
||||
if (row_i >= _ne1) break;
|
||||
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
const u16vec2 row_idx = row_ids[row_i - ic * BN];
|
||||
|
||||
if (dr + cm_row * TM + store_r < p.M) {
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
@@ -923,7 +958,7 @@ void main() {
|
||||
const uint row_i = dc_warp + cc;
|
||||
if (row_i >= _ne1) break;
|
||||
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
const u16vec2 row_idx = row_ids[row_i - ic * BN];
|
||||
#endif // MUL_MAT_ID
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
#ifdef MUL_MAT_ID
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#endif
|
||||
|
||||
#include "types.comp"
|
||||
#include "utils.comp"
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
@@ -92,14 +93,15 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
#ifdef MUL_MAT_ID
|
||||
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
||||
|
||||
shared u16vec4 row_ids[4096];
|
||||
shared u16vec4 row_ids[BN];
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
|
||||
B_TYPE b[];
|
||||
};
|
||||
|
||||
uint _ne1;
|
||||
shared uint _ne1_sh;
|
||||
layout (constant_id = 5) const uint subgroup_size = 32;
|
||||
shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
|
||||
|
||||
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
@@ -109,7 +111,7 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i
|
||||
return B_TYPE(0.0);
|
||||
}
|
||||
|
||||
const u16vec4 row_idx = row_ids[row_i];
|
||||
const u16vec4 row_idx = row_ids[row_i & (BN - 1)];
|
||||
B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
|
||||
|
||||
return ret;
|
||||
@@ -121,13 +123,74 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
|
||||
uint dc = ic * BN + c;
|
||||
|
||||
if (dr < p.M && dc < _ne1) {
|
||||
uint row_i = dc;
|
||||
uint row_i = c;
|
||||
const u16vec4 row_idx = row_ids[row_i];
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
|
||||
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
uint nei0shift = findLSB(p.nei0);
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_LocalInvocationIndex;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
|
||||
ballots_sh[gl_SubgroupID] = ballot;
|
||||
barrier();
|
||||
|
||||
uint subgroup_base = 0;
|
||||
uint total = 0;
|
||||
for (uint k = 0; k < gl_NumSubgroups; ++k) {
|
||||
if (k == gl_SubgroupID) {
|
||||
subgroup_base = total;
|
||||
}
|
||||
total += subgroupBallotBitCount(ballots_sh[k]);
|
||||
}
|
||||
barrier();
|
||||
|
||||
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
|
||||
row_ids[_ne1 + idx - ic * BN] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
|
||||
}
|
||||
_ne1 += total;
|
||||
iter &= 15;
|
||||
if (_ne1 >= (ic + 1) * BN) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
@@ -157,45 +220,12 @@ void main() {
|
||||
const uint ic = gl_WorkGroupID.y;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
// Spread the search across all elements in the first subgroup
|
||||
if (gl_SubgroupID == 0) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_SubgroupInvocationID;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx) {
|
||||
row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
|
||||
}
|
||||
_ne1 += subgroupBallotBitCount(ballot);
|
||||
iter &= 15;
|
||||
}
|
||||
_ne1_sh = _ne1;
|
||||
if (bitCount(p.nei0) == 1) {
|
||||
load_row_ids(expert_idx, true, ic);
|
||||
} else {
|
||||
load_row_ids(expert_idx, false, ic);
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
_ne1 = _ne1_sh;
|
||||
|
||||
// Workgroup has no work
|
||||
if (ic * BN >= _ne1) return;
|
||||
#endif
|
||||
|
||||
@@ -3,6 +3,10 @@
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_nonuniform_qualifier : enable
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#if ADD_RMS
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#endif
|
||||
|
||||
#include "rte.comp"
|
||||
#include "types.comp"
|
||||
@@ -14,11 +18,18 @@ layout (push_constant) uniform parameter2
|
||||
uint ne20; uint ne21; uint ne22; uint ne23;
|
||||
|
||||
// strides for srcs+dst
|
||||
uint nb[8][4];
|
||||
uint nb[12][4];
|
||||
|
||||
uint rms_partials;
|
||||
} p;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
|
||||
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
|
||||
// Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
|
||||
// layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
|
||||
// layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
|
||||
layout (binding = 0) buffer A {A_TYPE data_a[];} a[];
|
||||
layout (binding = 0) buffer D {D_TYPE data_d[];} d[];
|
||||
|
||||
layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[];
|
||||
|
||||
layout(constant_id = 0) const uint num_srcs = 2;
|
||||
|
||||
@@ -42,14 +53,22 @@ const uint num_threads = 256;
|
||||
|
||||
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
#if ADD_RMS
|
||||
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
|
||||
shared FLOAT_TYPE sumsh[num_threads];
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
uint idx = get_idx();
|
||||
uint orig_idx = idx;
|
||||
|
||||
uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23;
|
||||
|
||||
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
|
||||
const uint num_iter = 2;
|
||||
|
||||
FLOAT_TYPE sum_sq = 0;
|
||||
|
||||
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
||||
if (idx >= ne) {
|
||||
continue;
|
||||
@@ -61,8 +80,32 @@ void main() {
|
||||
[[unroll]] for (uint s = 0; s < num_srcs; ++s) {
|
||||
sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]);
|
||||
}
|
||||
sum_sq += sum*sum;
|
||||
d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
|
||||
|
||||
idx += num_threads;
|
||||
}
|
||||
|
||||
#if ADD_RMS
|
||||
if (p.rms_partials != 0) {
|
||||
// reduce the sum within each subgroup, then across subgroups
|
||||
const uint NumSubgroups = num_threads / gl_SubgroupSize;
|
||||
sum_sq = subgroupAdd(sum_sq);
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
sumsh[gl_SubgroupID] = sum_sq;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
|
||||
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
|
||||
sum_sq += sumsh[gl_SubgroupID + s];
|
||||
sumsh[gl_SubgroupID] = sum_sq;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
|
||||
partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -10,9 +10,9 @@ layout (constant_id = 1) const bool do_multiply = false;
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
shared FLOAT_TYPE sum[BLOCK_SIZE];
|
||||
shared FLOAT_TYPE sumsh[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
void rms_norm(uint num_iters) {
|
||||
const uint ncols = p.ne00;
|
||||
const uint nrows = gl_NumWorkGroups.x;
|
||||
const uint nchannels = gl_NumWorkGroups.y;
|
||||
@@ -30,38 +30,76 @@ void main() {
|
||||
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
|
||||
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
|
||||
|
||||
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]);
|
||||
sum[tid] += xi * xi;
|
||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||
FLOAT_TYPE xi = FLOAT_TYPE(0);
|
||||
if (col < ncols) {
|
||||
xi = FLOAT_TYPE(data_a[a_offset + col]);
|
||||
}
|
||||
sum += xi * xi;
|
||||
}
|
||||
|
||||
sumsh[tid] = sum;
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
sum[tid] += sum[tid + s];
|
||||
sum += sumsh[tid + s];
|
||||
sumsh[tid] = sum;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
sum = sumsh[0];
|
||||
|
||||
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
|
||||
const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols);
|
||||
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
|
||||
|
||||
if (do_multiply) {
|
||||
if (ncols > p.ne10) {
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||
if (col >= ncols) {
|
||||
continue;
|
||||
}
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||
if (col >= ncols) {
|
||||
continue;
|
||||
}
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||
if (col >= ncols) {
|
||||
continue;
|
||||
}
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void main() {
|
||||
// instantiate the rms_norm function for several different
|
||||
// dimensions, to allow loop unrolling
|
||||
uint num_blocks = (p.ne00 + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
if (num_blocks > 32) {
|
||||
rms_norm(num_blocks);
|
||||
} else if (num_blocks > 16) {
|
||||
rms_norm(32);
|
||||
} else if (num_blocks > 8) {
|
||||
rms_norm(16);
|
||||
} else if (num_blocks > 4) {
|
||||
rms_norm(8);
|
||||
} else if (num_blocks == 4) {
|
||||
rms_norm(4);
|
||||
} else if (num_blocks == 3) {
|
||||
rms_norm(3);
|
||||
} else if (num_blocks == 2) {
|
||||
rms_norm(2);
|
||||
} else if (num_blocks == 1) {
|
||||
rms_norm(1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_binary_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
|
||||
#define BLOCK_SIZE 128
|
||||
|
||||
layout (constant_id = 1) const bool do_multiply = false;
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 3, std430) readonly buffer PartialsBuf {float partial_sums[];};
|
||||
|
||||
shared FLOAT_TYPE sumsh[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint ncols = p.ne00;
|
||||
const uint nrows = gl_NumWorkGroups.x;
|
||||
const uint nchannels = gl_NumWorkGroups.y;
|
||||
|
||||
const uint row = 0;
|
||||
const uint channel = gl_WorkGroupID.y;
|
||||
const uint samp = gl_WorkGroupID.z;
|
||||
// The work is split across multiple workgroups in the x dimension. Each invocation
|
||||
// processes one element
|
||||
const uint tid = gl_GlobalInvocationID.x;
|
||||
|
||||
const uint stride_row = p.nb01;
|
||||
const uint stride_channel = p.nb02;
|
||||
const uint stride_sample = p.nb03;
|
||||
|
||||
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
|
||||
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
|
||||
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
|
||||
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
|
||||
uint32_t num_partials = p.param3;
|
||||
for (uint32_t i = gl_SubgroupInvocationID; i < num_partials; i += gl_SubgroupSize) {
|
||||
sum += partial_sums[i];
|
||||
}
|
||||
sum = subgroupAdd(sum);
|
||||
|
||||
uint col = tid;
|
||||
if (col >= ncols) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols);
|
||||
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
|
||||
|
||||
if (do_multiply) {
|
||||
if (ncols > p.ne10) {
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
|
||||
} else {
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
|
||||
}
|
||||
} else {
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
@@ -11,16 +11,49 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint n_cols;
|
||||
uint ne01, ne02;
|
||||
uint nb01, nb02, nb03;
|
||||
uint nb11, nb12, nb13;
|
||||
float weight;
|
||||
uint misalign_offsets;
|
||||
uint ne0_12mp, ne0_12L;
|
||||
uint ne0_1mp, ne0_1L;
|
||||
} p;
|
||||
|
||||
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
||||
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
||||
|
||||
// see init_fastdiv_values in ggml-vulkan.cpp
|
||||
uint fastdiv(uint n, uint mp, uint L) {
|
||||
uint msbs, lsbs;
|
||||
// msbs = mulhi(n, mp)
|
||||
umulExtended(n, mp, msbs, lsbs);
|
||||
return (msbs + n) >> L;
|
||||
}
|
||||
|
||||
|
||||
shared FLOAT_TYPE tmp[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint col = gl_LocalInvocationID.x;
|
||||
const float weight = p.weight;
|
||||
|
||||
tmp[col] = FLOAT_TYPE(0.0f);
|
||||
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
|
||||
const uint i03_offset = i03 * p.ne01*p.ne02;
|
||||
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
|
||||
const uint i01 = row - i03_offset - i02*p.ne01;
|
||||
|
||||
for (uint i = col; i < p.KX; i += BLOCK_SIZE) {
|
||||
tmp[col] += FLOAT_TYPE(data_a[row*p.KX + i]);
|
||||
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
|
||||
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
|
||||
|
||||
tmp[col] = FLOAT_TYPE(0.0);
|
||||
|
||||
for (uint i = col; i < p.n_cols; i += BLOCK_SIZE) {
|
||||
tmp[col] += FLOAT_TYPE(data_a[src_idx + i]);
|
||||
}
|
||||
|
||||
barrier();
|
||||
@@ -32,6 +65,6 @@ void main() {
|
||||
}
|
||||
|
||||
if (col == 0) {
|
||||
data_d[row] = D_TYPE(tmp[0]);
|
||||
data_d[dst_idx] = D_TYPE(tmp[0] * weight);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,6 +68,12 @@ const std::vector<std::string> type_names = {
|
||||
"bf16",
|
||||
};
|
||||
|
||||
enum MatMulIdType {
|
||||
NONE,
|
||||
DEFAULT,
|
||||
SUBGROUP,
|
||||
};
|
||||
|
||||
namespace {
|
||||
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
|
||||
#ifdef _WIN32
|
||||
@@ -293,7 +299,7 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
|
||||
compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
|
||||
}
|
||||
|
||||
void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) {
|
||||
void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) {
|
||||
std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
|
||||
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
|
||||
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
||||
@@ -303,9 +309,13 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
||||
};
|
||||
std::string shader_name = "matmul";
|
||||
|
||||
if (matmul_id) {
|
||||
if (matmul_id_type == MatMulIdType::DEFAULT) {
|
||||
base_dict["MUL_MAT_ID"] = "1";
|
||||
shader_name = "matmul_id";
|
||||
} else if (matmul_id_type == MatMulIdType::SUBGROUP) {
|
||||
base_dict["MUL_MAT_ID"] = "1";
|
||||
base_dict["MUL_MAT_ID_USE_SUBGROUPS"] = "1";
|
||||
shader_name = "matmul_id_subgroup";
|
||||
}
|
||||
|
||||
if (fp16) {
|
||||
@@ -389,7 +399,7 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
||||
}
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
|
||||
if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
|
||||
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
#endif
|
||||
@@ -401,26 +411,28 @@ void process_shaders() {
|
||||
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
|
||||
|
||||
// matmul
|
||||
for (const auto& matmul_id : {false, true}) {
|
||||
for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
|
||||
// No coopmats
|
||||
// fp32
|
||||
matmul_shaders(false, matmul_id, false, false, false);
|
||||
matmul_shaders(false, matmul_id_type, false, false, false);
|
||||
|
||||
// fp16, fp32acc and fp16acc
|
||||
matmul_shaders(true, matmul_id, false, false, false);
|
||||
matmul_shaders(true, matmul_id, false, false, true);
|
||||
matmul_shaders(true, matmul_id_type, false, false, false);
|
||||
matmul_shaders(true, matmul_id_type, false, false, true);
|
||||
|
||||
if (matmul_id_type != MatMulIdType::DEFAULT) {
|
||||
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
// Coopmat, fp32acc and fp16acc
|
||||
matmul_shaders(true, matmul_id, true, false, false);
|
||||
matmul_shaders(true, matmul_id, true, false, true);
|
||||
// Coopmat, fp32acc and fp16acc
|
||||
matmul_shaders(true, matmul_id_type, true, false, false);
|
||||
matmul_shaders(true, matmul_id_type, true, false, true);
|
||||
#endif
|
||||
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
// Coopmat2, fp32acc and fp16acc
|
||||
matmul_shaders(true, matmul_id, false, true, false);
|
||||
matmul_shaders(true, matmul_id, false, true, true);
|
||||
// Coopmat2, fp32acc and fp16acc
|
||||
matmul_shaders(true, matmul_id_type, false, true, false);
|
||||
matmul_shaders(true, matmul_id_type, false, true, true);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// flash attention
|
||||
@@ -503,6 +515,7 @@ void process_shaders() {
|
||||
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
@@ -538,13 +551,15 @@ void process_shaders() {
|
||||
s += std::string(dst_f16 ? "_f16" : "_f32");
|
||||
return s;
|
||||
};
|
||||
for (std::string op : {"add", "sub", "mul", "div"}) {
|
||||
for (std::string op : {"add", "sub", "mul", "div", "add_rms", }) {
|
||||
for (auto src0_f16 : {false, true}) {
|
||||
for (auto src1_f16 : {false, true}) {
|
||||
for (auto dst_f16 : {false, true}) {
|
||||
for (auto rte : {false, true}) {
|
||||
auto source = op == "add_rms" ? std::string("add") : op;
|
||||
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
|
||||
string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
auto add_rms = op == "add_rms" ? "1" : "0";
|
||||
string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -680,12 +695,15 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
||||
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
|
||||
string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
||||
string_to_spv("conv2d_dw_cwhn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
|
||||
|
||||
string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
|
||||
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
|
||||
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
@@ -743,7 +761,7 @@ void write_output_files() {
|
||||
}
|
||||
|
||||
std::string suffixes[2] = {"_f32", "_f16"};
|
||||
for (const char *op : {"add", "sub", "mul", "div"}) {
|
||||
for (const char *op : {"add", "sub", "mul", "div", "add_rms"}) {
|
||||
fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op);
|
||||
fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op);
|
||||
std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = ";
|
||||
|
||||
@@ -20,8 +20,8 @@ add_custom_command(
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory ${SHADER_OUTPUT_DIR}
|
||||
COMMAND ${CMAKE_COMMAND} -E env PYTHONIOENCODING=utf-8
|
||||
${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
|
||||
--input "${SHADER_DIR}"
|
||||
--output "${SHADER_HEADER}"
|
||||
--input_dir "${SHADER_DIR}"
|
||||
--output_file "${SHADER_HEADER}"
|
||||
DEPENDS ${WGSL_SHADER_FILES} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
@@ -118,13 +118,11 @@ struct webgpu_context_struct {
|
||||
|
||||
std::recursive_mutex mutex;
|
||||
|
||||
bool device_init = false;
|
||||
|
||||
webgpu_buf_pool param_buf_pool;
|
||||
webgpu_buf_pool set_rows_error_buf_pool;
|
||||
|
||||
wgpu::ComputePipeline memset_pipeline;
|
||||
wgpu::ComputePipeline mul_mat_pipeline;
|
||||
wgpu::ComputePipeline mul_mat_pipeline[30][2];
|
||||
wgpu::ComputePipeline set_rows_pipeline;
|
||||
wgpu::ComputePipeline cpy_pipeline;
|
||||
|
||||
@@ -238,7 +236,7 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
||||
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data);
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
|
||||
}
|
||||
}),
|
||||
UINT64_MAX);
|
||||
@@ -278,7 +276,7 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
||||
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data);
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
|
||||
}
|
||||
// Free the staged buffers
|
||||
ctx->param_buf_pool.free_bufs(staged_param_bufs);
|
||||
@@ -294,7 +292,7 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
||||
if (status != wgpu::MapAsyncStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", message.data);
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
|
||||
} else {
|
||||
const uint32_t * error_data = (const uint32_t *) error_bufs.host_buf.GetConstMappedRange();
|
||||
if (*error_data) {
|
||||
@@ -331,6 +329,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
|
||||
// To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
|
||||
// debug statements in the shader, and then call this function after encoding the commands and submitting them.
|
||||
static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
|
||||
ggml_backend_webgpu_submit_queue(ctx);
|
||||
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
|
||||
encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
|
||||
wgpu::CommandBuffer commands = encoder.Finish();
|
||||
@@ -421,15 +420,6 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, true);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_webgpu_tensor_offset(const ggml_tensor * tensor) {
|
||||
return webgpu_tensor_offset(tensor) + tensor->view_offs;
|
||||
}
|
||||
|
||||
static wgpu::Buffer ggml_backend_webgpu_tensor_buf(const ggml_tensor * tensor) {
|
||||
ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
|
||||
return ctx->buffer;
|
||||
}
|
||||
|
||||
/** End WebGPU Actions */
|
||||
|
||||
/** GGML Backend Interface */
|
||||
@@ -447,19 +437,36 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
|
||||
return webgpu_tensor_offset(tensor) + tensor->view_offs;
|
||||
}
|
||||
|
||||
static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
|
||||
ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
|
||||
return ctx->buffer;
|
||||
}
|
||||
|
||||
static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, ggml_tensor * t) {
|
||||
size_t offset = ggml_webgpu_tensor_offset(t);
|
||||
return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
}
|
||||
|
||||
static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor * t) {
|
||||
size_t offset = ggml_webgpu_tensor_offset(t);
|
||||
return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
}
|
||||
|
||||
static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) {
|
||||
return (ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t) + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
|
||||
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
size_t src_offset = ggml_backend_webgpu_tensor_offset(src);
|
||||
// assumes power of 2 offset alignment
|
||||
size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
// align to minimum offset alignment
|
||||
src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst);
|
||||
size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
|
||||
std::vector<uint32_t> params = { ne,
|
||||
(uint32_t) (src_misalignment / ggml_type_size(src->type)),
|
||||
(uint32_t) (dst_misalignment / ggml_type_size(dst->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
// Convert byte-strides to element-strides
|
||||
(uint32_t) (src->nb[0] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
||||
@@ -477,15 +484,13 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_backend_webgpu_tensor_buf(src),
|
||||
.offset = src_offset,
|
||||
.size = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
|
||||
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) },
|
||||
.buffer = ggml_webgpu_tensor_buf(src),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_backend_webgpu_tensor_buf(dst),
|
||||
.offset = dst_offset,
|
||||
.size = (ggml_nbytes(dst) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
|
||||
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) }
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
|
||||
@@ -504,21 +509,9 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
||||
error_bufs.host_buf.Unmap();
|
||||
}
|
||||
|
||||
size_t src_offset = ggml_backend_webgpu_tensor_offset(src);
|
||||
// assumes power of 2 offset alignment
|
||||
size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
// align to minimum offset alignment
|
||||
src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
size_t idx_offset = ggml_backend_webgpu_tensor_offset(idx);
|
||||
size_t idx_misalignment = idx_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
idx_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst);
|
||||
size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
|
||||
std::vector<uint32_t> params = { (uint32_t) (src_misalignment / ggml_type_size(src->type)),
|
||||
(uint32_t) (idx_misalignment / ggml_type_size(idx->type)),
|
||||
(uint32_t) (dst_misalignment / ggml_type_size(dst->type)),
|
||||
std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
// Convert byte-strides to element-strides
|
||||
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
||||
@@ -540,18 +533,18 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_backend_webgpu_tensor_buf(src),
|
||||
.offset = ggml_backend_webgpu_tensor_offset(src),
|
||||
.size = ggml_nbytes(src) },
|
||||
.buffer = ggml_webgpu_tensor_buf(src),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_backend_webgpu_tensor_buf(idx),
|
||||
.offset = ggml_backend_webgpu_tensor_offset(idx),
|
||||
.size = ggml_nbytes(idx) },
|
||||
.buffer = ggml_webgpu_tensor_buf(idx),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, idx),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, idx) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_backend_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_backend_webgpu_tensor_offset(dst),
|
||||
.size = ggml_nbytes(dst) },
|
||||
{ .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
|
||||
{ .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
|
||||
};
|
||||
|
||||
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
|
||||
@@ -565,15 +558,18 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
||||
|
||||
static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
(uint32_t) dst->ne[1], // number of rows in result (M)
|
||||
(uint32_t) dst->ne[0], // number of columns in result (N)
|
||||
(uint32_t) src0->ne[0], // number of columns in src0/src1 (K)
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 1
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 1
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 2
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 2
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 3
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 3
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 2
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 2
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 3
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 3
|
||||
(uint32_t) src0->ne[2], // batch size in dimension 2
|
||||
(uint32_t) src0->ne[3], // batch size in dimension 3
|
||||
(uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2
|
||||
@@ -582,22 +578,22 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_backend_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_backend_webgpu_tensor_offset(src0),
|
||||
.size = ggml_nbytes(src0) },
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_backend_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_backend_webgpu_tensor_offset(src1),
|
||||
.size = ggml_nbytes(src1) },
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_backend_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_backend_webgpu_tensor_offset(dst),
|
||||
.size = ggml_nbytes(dst) }
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
|
||||
};
|
||||
|
||||
uint32_t wg_x =
|
||||
(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline, params, entries, wg_x);
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x);
|
||||
}
|
||||
|
||||
// Returns true if node has enqueued work into the queue, false otherwise
|
||||
@@ -827,7 +823,7 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b
|
||||
wgpu::Buffer buf;
|
||||
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device,
|
||||
buf,
|
||||
size,
|
||||
(size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1),
|
||||
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
|
||||
"allocated_buffer");
|
||||
|
||||
@@ -907,7 +903,94 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline, wgsl_mul_mat, "mul_mat");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_f32_f32,
|
||||
"mul_mat_f32_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
|
||||
wgsl_mul_mat_f16_f16,
|
||||
"mul_mat_f16_f16");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_f16_f32,
|
||||
"mul_mat_f16_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q4_0_f32,
|
||||
"mul_mat_q4_0_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q4_1_f32,
|
||||
"mul_mat_q4_1_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q5_0_f32,
|
||||
"mul_mat_q5_0_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q5_1_f32,
|
||||
"mul_mat_q5_1_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q8_0_f32,
|
||||
"mul_mat_q8_0_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q2_k_f32,
|
||||
"mul_mat_q2_k_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q3_k_f32,
|
||||
"mul_mat_q3_k_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q4_k_f32,
|
||||
"mul_mat_q4_k_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q5_k_f32,
|
||||
"mul_mat_q5_k_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q6_k_f32,
|
||||
"mul_mat_q6_k_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq2_xxs_f32,
|
||||
"mul_mat_iq2_xxs_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq2_xs_f32,
|
||||
"mul_mat_iq2_xs_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq2_s_f32,
|
||||
"mul_mat_iq2_s_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq3_xxs_f32,
|
||||
"mul_mat_iq3_xxs_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq3_s_f32,
|
||||
"mul_mat_iq3_s_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq1_s_f32,
|
||||
"mul_mat_iq1_s_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq1_m_f32,
|
||||
"mul_mat_iq1_m_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq4_nl_f32,
|
||||
"mul_mat_iq4_nl_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq4_xs_f32,
|
||||
"mul_mat_iq4_xs_f32");
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
|
||||
@@ -933,79 +1016,6 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
|
||||
ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
|
||||
webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;
|
||||
|
||||
// Multiple threads may try to initialize the device
|
||||
std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex);
|
||||
if (!webgpu_ctx->device_init) {
|
||||
// Initialize device
|
||||
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
|
||||
wgpu::FeatureName::ImplicitDeviceSynchronization };
|
||||
wgpu::DeviceDescriptor dev_desc;
|
||||
dev_desc.requiredLimits = &webgpu_ctx->limits;
|
||||
dev_desc.requiredFeatures = required_features.data();
|
||||
dev_desc.requiredFeatureCount = required_features.size();
|
||||
dev_desc.SetDeviceLostCallback(
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
||||
GGML_UNUSED(device);
|
||||
GGML_LOG_ERROR(
|
||||
"ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
|
||||
});
|
||||
dev_desc.SetUncapturedErrorCallback(
|
||||
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
|
||||
GGML_UNUSED(device);
|
||||
GGML_LOG_ERROR(
|
||||
"ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
|
||||
});
|
||||
webgpu_ctx->instance.WaitAny(
|
||||
webgpu_ctx->adapter.RequestDevice(
|
||||
&dev_desc,
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
|
||||
if (status != wgpu::RequestDeviceStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data);
|
||||
return;
|
||||
}
|
||||
webgpu_ctx->device = std::move(device);
|
||||
}),
|
||||
UINT64_MAX);
|
||||
GGML_ASSERT(webgpu_ctx->device != nullptr);
|
||||
|
||||
// Initialize (compute) queue
|
||||
webgpu_ctx->queue = webgpu_ctx->device.GetQueue();
|
||||
|
||||
// Create buffer pool for shader parameters
|
||||
webgpu_ctx->param_buf_pool.init(webgpu_ctx->device,
|
||||
WEBGPU_NUM_PARAM_BUFS,
|
||||
WEBGPU_PARAMS_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
|
||||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
|
||||
webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->device,
|
||||
WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
|
||||
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
|
||||
|
||||
ggml_webgpu_init_memset_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_set_rows_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
|
||||
|
||||
#ifdef GGML_WEBGPU_DEBUG
|
||||
// Initialize debug buffers
|
||||
ggml_webgpu_create_buffer(webgpu_ctx->device,
|
||||
webgpu_ctx->debug_host_buf,
|
||||
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
|
||||
"debug_host_buf");
|
||||
ggml_webgpu_create_buffer(webgpu_ctx->device,
|
||||
webgpu_ctx->debug_dev_buf,
|
||||
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
|
||||
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc,
|
||||
"debug_dev_buf");
|
||||
#endif
|
||||
webgpu_ctx->device_init = true;
|
||||
}
|
||||
|
||||
static ggml_backend_webgpu_context backend_ctx;
|
||||
backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
|
||||
backend_ctx.webgpu_ctx = webgpu_ctx;
|
||||
@@ -1053,10 +1063,45 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
return true;
|
||||
case GGML_OP_CPY | GGML_OP_SET_ROWS:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_SET_ROWS:
|
||||
return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_MUL_MAT:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
||||
{
|
||||
switch (op->src[1]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
return op->src[0]->type == GGML_TYPE_F16;
|
||||
case GGML_TYPE_F32:
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -1123,20 +1168,87 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||
wgpu::AdapterInfo info{};
|
||||
ctx->adapter.GetInfo(&info);
|
||||
|
||||
// Initialize device
|
||||
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
|
||||
wgpu::FeatureName::ImplicitDeviceSynchronization };
|
||||
wgpu::DeviceDescriptor dev_desc;
|
||||
dev_desc.requiredLimits = &ctx->limits;
|
||||
dev_desc.requiredFeatures = required_features.data();
|
||||
dev_desc.requiredFeatureCount = required_features.size();
|
||||
dev_desc.SetDeviceLostCallback(
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
||||
GGML_UNUSED(device);
|
||||
GGML_LOG_ERROR(
|
||||
"ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), std::string(message).c_str());
|
||||
});
|
||||
dev_desc.SetUncapturedErrorCallback(
|
||||
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
|
||||
GGML_UNUSED(device);
|
||||
GGML_LOG_ERROR(
|
||||
"ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), std::string(message).c_str());
|
||||
});
|
||||
ctx->instance.WaitAny(ctx->adapter.RequestDevice(
|
||||
&dev_desc,
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
|
||||
if (status != wgpu::RequestDeviceStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
|
||||
return;
|
||||
}
|
||||
ctx->device = std::move(device);
|
||||
}),
|
||||
UINT64_MAX);
|
||||
GGML_ASSERT(ctx->device != nullptr);
|
||||
|
||||
// Initialize (compute) queue
|
||||
ctx->queue = ctx->device.GetQueue();
|
||||
|
||||
// Create buffer pool for shader parameters
|
||||
ctx->param_buf_pool.init(ctx->device,
|
||||
WEBGPU_NUM_PARAM_BUFS,
|
||||
WEBGPU_PARAMS_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
|
||||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
|
||||
ctx->set_rows_error_buf_pool.init(ctx->device,
|
||||
WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
|
||||
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
|
||||
|
||||
ggml_webgpu_init_memset_pipeline(ctx);
|
||||
ggml_webgpu_init_mul_mat_pipeline(ctx);
|
||||
ggml_webgpu_init_set_rows_pipeline(ctx);
|
||||
ggml_webgpu_init_cpy_pipeline(ctx);
|
||||
|
||||
#ifdef GGML_WEBGPU_DEBUG
|
||||
// Initialize debug buffers
|
||||
ggml_webgpu_create_buffer(ctx->device,
|
||||
ctx->debug_host_buf,
|
||||
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
|
||||
"debug_host_buf");
|
||||
ggml_webgpu_create_buffer(ctx->device,
|
||||
ctx->debug_dev_buf,
|
||||
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
|
||||
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc,
|
||||
"debug_dev_buf");
|
||||
#endif
|
||||
|
||||
static ggml_backend_webgpu_device_context device_ctx;
|
||||
device_ctx.webgpu_ctx = ctx;
|
||||
device_ctx.device_name = GGML_WEBGPU_NAME;
|
||||
device_ctx.device_desc = std::string(info.description.data);
|
||||
device_ctx.device_desc = info.description;
|
||||
|
||||
GGML_LOG_INFO(
|
||||
"ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
|
||||
"device_desc: %s\n",
|
||||
info.vendorID,
|
||||
info.vendor.data,
|
||||
info.architecture.data,
|
||||
std::string(info.vendor).c_str(),
|
||||
std::string(info.architecture).c_str(),
|
||||
info.deviceID,
|
||||
info.device.data,
|
||||
info.description.data);
|
||||
std::string(info.device).c_str(),
|
||||
std::string(info.description).c_str());
|
||||
|
||||
// See GGML Backend Device Interface section
|
||||
static ggml_backend_device device = {
|
||||
|
||||
@@ -1,35 +1,85 @@
|
||||
import os
|
||||
import re
|
||||
import ast
|
||||
import argparse
|
||||
|
||||
|
||||
def escape_triple_quotes(wgsl):
|
||||
# Simple defense in case of embedded """
|
||||
return wgsl.replace('"""', '\\"""')
|
||||
def extract_block(text, name):
|
||||
pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)'
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
if not match:
|
||||
raise ValueError(f"Missing block: {name}")
|
||||
return match.group(1).strip()
|
||||
|
||||
|
||||
def to_cpp_string_literal(varname, content):
|
||||
return f'const char* wgsl_{varname} = R"({content})";\n'
|
||||
def parse_decls(decls_text):
|
||||
decls = {}
|
||||
for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL):
|
||||
decls[name.strip()] = code.strip()
|
||||
return decls
|
||||
|
||||
|
||||
def replace_placeholders(shader_text, replacements):
|
||||
for key, val in replacements.items():
|
||||
# Match {{KEY}} literally, where KEY is escaped
|
||||
pattern = r'{{\s*' + re.escape(key) + r'\s*}}'
|
||||
shader_text = re.sub(pattern, str(val), shader_text)
|
||||
return shader_text
|
||||
|
||||
|
||||
def write_shader(shader_name, shader_code, output_dir, outfile):
|
||||
if output_dir:
|
||||
wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl")
|
||||
with open(wgsl_filename, "w", encoding="utf-8") as f_out:
|
||||
f_out.write(shader_code)
|
||||
outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n')
|
||||
|
||||
|
||||
def generate_variants(shader_path, output_dir, outfile):
|
||||
shader_base_name = shader_path.split("/")[-1].split(".")[0]
|
||||
|
||||
with open(shader_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
|
||||
try:
|
||||
variants = ast.literal_eval(extract_block(text, "VARIANTS"))
|
||||
except ValueError:
|
||||
write_shader(shader_base_name, text, output_dir, outfile)
|
||||
else:
|
||||
decls_map = parse_decls(extract_block(text, "DECLS"))
|
||||
shader_template = extract_block(text, "SHADER")
|
||||
|
||||
for variant in variants:
|
||||
decls = variant["DECLS"]
|
||||
decls_code = ""
|
||||
for key in decls:
|
||||
if key not in decls_map:
|
||||
raise ValueError(f"DECLS key '{key}' not found.")
|
||||
decls_code += decls_map[key] + "\n\n"
|
||||
|
||||
shader_variant = replace_placeholders(shader_template, variant["REPLS"])
|
||||
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant)
|
||||
|
||||
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
|
||||
write_shader(output_name, final_shader, output_dir, outfile)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--input', required=True)
|
||||
parser.add_argument('--output', required=True)
|
||||
parser.add_argument("--input_dir", required=True)
|
||||
parser.add_argument("--output_file", required=True)
|
||||
parser.add_argument("--output_dir")
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.output, 'w', encoding='utf-8') as out:
|
||||
out.write("// Auto-generated shader embedding \n\n")
|
||||
for fname in sorted(os.listdir(args.input)):
|
||||
if not fname.endswith('.wgsl'):
|
||||
continue
|
||||
shader_path = os.path.join(args.input, fname)
|
||||
varname = os.path.splitext(fname)[0]
|
||||
with open(shader_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
content = escape_triple_quotes(content)
|
||||
out.write(to_cpp_string_literal(varname, content))
|
||||
out.write('\n')
|
||||
if args.output_dir:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
with open(args.output_file, "w", encoding="utf-8") as out:
|
||||
out.write("// Auto-generated shader embedding\n\n")
|
||||
for fname in sorted(os.listdir(args.input_dir)):
|
||||
if fname.endswith(".wgsl"):
|
||||
generate_variants(os.path.join(args.input_dir, fname), args.output_dir, out)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -19,20 +19,20 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let start = params.offset;
|
||||
let end = params.offset + params.size;
|
||||
|
||||
for (var j: u32 = 0u; j < bytes_per_thread; j = j + 1u) {
|
||||
for (var j: u32 = 0u; j < bytes_per_thread; j += 4) {
|
||||
let byte_index = start + i + j;
|
||||
if (byte_index + 4u <= end) {
|
||||
output_buffer[(byte_index >> 2u)] = params.value;
|
||||
if (byte_index + 4 <= end) {
|
||||
output_buffer[byte_index >> 2] = params.value;
|
||||
} else {
|
||||
// Handle tail (unaligned)
|
||||
for (var k: u32 = 0u; k < 4u; k = k + 1u) {
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let idx = byte_index + k;
|
||||
if (idx < end) {
|
||||
let word_idx = idx >> 2u;
|
||||
let byte_offset = (idx & 3u) * 8u;
|
||||
let mask = ~(0xffu << byte_offset);
|
||||
let word_idx = idx >> 2;
|
||||
let bit_offset = (idx & 3) * 8u;
|
||||
let mask = ~(0xffu << bit_offset);
|
||||
let existing = output_buffer[word_idx];
|
||||
output_buffer[word_idx] = (existing & mask) | ((params.value & 0xffu) << byte_offset);
|
||||
output_buffer[word_idx] = (existing & mask) | (params.value & (0xffu << bit_offset));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,56 +0,0 @@
|
||||
struct MulMatParams {
|
||||
m: u32,
|
||||
n: u32,
|
||||
k: u32,
|
||||
// all strides are in elements
|
||||
stride_01: u32,
|
||||
stride_11: u32,
|
||||
stride_02: u32,
|
||||
stride_12: u32,
|
||||
stride_03: u32,
|
||||
stride_13: u32,
|
||||
|
||||
bs02: u32,
|
||||
bs03: u32,
|
||||
broadcast2: u32,
|
||||
broadcast3: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<f32>; // N rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<f32>; // M rows, K columns (transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
|
||||
|
||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||
|
||||
@compute @workgroup_size(64)
|
||||
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
|
||||
if (global_id.x >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
let dst2_stride = params.m * params.n;
|
||||
let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
|
||||
|
||||
let dst3_idx = global_id.x / dst3_stride;
|
||||
let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
|
||||
let src13_idx = dst3_idx; // src1 is not broadcast
|
||||
let dst3_rem = global_id.x % dst3_stride;
|
||||
|
||||
let dst2_idx = dst3_rem / dst2_stride;
|
||||
let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
|
||||
let src12_idx = dst2_idx; // src1 is not broadcast
|
||||
|
||||
let dst2_rem = dst3_rem % dst2_stride;
|
||||
|
||||
let row = dst2_rem / params.n; // output row
|
||||
let col = dst2_rem % params.n; // output column
|
||||
|
||||
var sum = 0.0;
|
||||
for (var i: u32 = 0u; i < params.k; i = i + 1u) {
|
||||
let src0_idx = src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01 + i;
|
||||
let src1_idx = src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11 + i;
|
||||
sum = sum + src0[src0_idx] * src1[src1_idx];
|
||||
}
|
||||
dst[dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum;
|
||||
}
|
||||
+54
-2
@@ -975,6 +975,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"IM2COL",
|
||||
"IM2COL_BACK",
|
||||
"CONV_2D",
|
||||
"CONV_3D",
|
||||
"CONV_2D_DW",
|
||||
"CONV_TRANSPOSE_2D",
|
||||
"POOL_1D",
|
||||
@@ -1017,7 +1018,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"GLU",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
|
||||
static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@@ -1077,6 +1078,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"im2col(x)",
|
||||
"im2col_back(x)",
|
||||
"conv_2d(x)",
|
||||
"conv_3d(x)",
|
||||
"conv_2d_dw(x)",
|
||||
"conv_transpose_2d(x)",
|
||||
"pool_1d(x)",
|
||||
@@ -1119,7 +1121,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"glu(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
|
||||
static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -4480,6 +4482,56 @@ struct ggml_tensor * ggml_conv_2d_direct(
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_conv_3d
|
||||
|
||||
struct ggml_tensor * ggml_conv_3d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int s0,
|
||||
int s1,
|
||||
int s2,
|
||||
int p0,
|
||||
int p1,
|
||||
int p2,
|
||||
int d0,
|
||||
int d1,
|
||||
int d2,
|
||||
int c,
|
||||
int n,
|
||||
int oc) {
|
||||
|
||||
GGML_ASSERT(a->ne[3] == (int64_t) c * oc);
|
||||
GGML_ASSERT(b->ne[3] == (int64_t) c * n);
|
||||
|
||||
int64_t ne[4];
|
||||
ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
|
||||
ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
|
||||
ne[2] = ggml_calc_conv_output_size(b->ne[2], a->ne[2], s2, p2, d2);
|
||||
ne[3] = (int64_t) oc * n;
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, s0);
|
||||
ggml_set_op_params_i32(result, 1, s1);
|
||||
ggml_set_op_params_i32(result, 2, s2);
|
||||
ggml_set_op_params_i32(result, 3, p0);
|
||||
ggml_set_op_params_i32(result, 4, p1);
|
||||
ggml_set_op_params_i32(result, 5, p2);
|
||||
ggml_set_op_params_i32(result, 6, d0);
|
||||
ggml_set_op_params_i32(result, 7, d1);
|
||||
ggml_set_op_params_i32(result, 8, d2);
|
||||
ggml_set_op_params_i32(result, 9, c);
|
||||
ggml_set_op_params_i32(result, 10, n);
|
||||
ggml_set_op_params_i32(result, 11, oc);
|
||||
|
||||
result->op = GGML_OP_CONV_3D;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_conv_transpose_2d_p0
|
||||
|
||||
static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
|
||||
|
||||
@@ -385,6 +385,7 @@ class MODEL_ARCH(IntEnum):
|
||||
DREAM = auto()
|
||||
SMALLTHINKER = auto()
|
||||
LLADA = auto()
|
||||
SEED_OSS = auto()
|
||||
|
||||
|
||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
@@ -717,6 +718,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.DREAM: "dream",
|
||||
MODEL_ARCH.SMALLTHINKER: "smallthinker",
|
||||
MODEL_ARCH.LLADA: "llada",
|
||||
MODEL_ARCH.SEED_OSS: "seed_oss",
|
||||
}
|
||||
|
||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||
@@ -1973,6 +1975,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.SEED_OSS: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
],
|
||||
MODEL_ARCH.OLMOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
@@ -2590,6 +2606,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
],
|
||||
MODEL_ARCH.SMALLTHINKER: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
|
||||
@@ -427,7 +427,6 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.residual_mlp.w1", # arctic
|
||||
"transformer.h.{bid}.mlp.c_fc_0", # exaone
|
||||
"model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba granite-hybrid
|
||||
"model.layers.{bid}.block_sparse_moe.gate", # smallthinker
|
||||
"model.transformer.blocks.{bid}.ff_proj", # llada
|
||||
"layers.{bid}.mlp.gate_proj", # qwen3-embedding
|
||||
),
|
||||
|
||||
+1
-106
@@ -312,7 +312,7 @@ extern "C" {
|
||||
float yarn_beta_fast; // YaRN low correction dim
|
||||
float yarn_beta_slow; // YaRN high correction dim
|
||||
uint32_t yarn_orig_ctx; // YaRN original context size
|
||||
float defrag_thold; // defragment the KV cache if holes/size > thold, <= 0 disabled (default)
|
||||
float defrag_thold; // [DEPRECATED] defragment the KV cache if holes/size > thold, <= 0 disabled (default)
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
void * cb_eval_user_data;
|
||||
@@ -663,111 +663,6 @@ extern "C" {
|
||||
// Check if the memory supports shifting
|
||||
LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);
|
||||
|
||||
//
|
||||
// KV cache for self-attention (TODO: deprecate in favor of llama_memory)
|
||||
//
|
||||
|
||||
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
||||
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
||||
DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx),
|
||||
"Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
|
||||
|
||||
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
||||
DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx),
|
||||
"Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
|
||||
|
||||
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_clear(
|
||||
struct llama_context * ctx),
|
||||
"Use llama_memory_clear() instead");
|
||||
|
||||
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
||||
// seq_id < 0 : match any sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
DEPRECATED(LLAMA_API bool llama_kv_self_seq_rm(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1),
|
||||
"Use llama_memory_seq_rm() instead");
|
||||
|
||||
// Copy all tokens that belong to the specified sequence to another sequence
|
||||
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_seq_cp(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1),
|
||||
"Use llama_memory_seq_cp() instead");
|
||||
|
||||
// Removes all tokens that do not belong to the specified sequence
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_seq_keep(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id),
|
||||
"Use llama_memory_seq_keep() instead");
|
||||
|
||||
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||
// - lazily on next llama_decode()
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_seq_add(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta),
|
||||
"Use llama_memory_seq_add() instead");
|
||||
|
||||
// Integer division of the positions by factor of `d > 1`
|
||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||
// - lazily on next llama_decode()
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_seq_div(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d),
|
||||
"Use llama_memory_seq_div() instead");
|
||||
|
||||
// Returns the smallest position present in the KV cache for the specified sequence
|
||||
// This is typically non-zero only for SWA caches
|
||||
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
|
||||
// Return -1 if the sequence is empty
|
||||
DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_min(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id),
|
||||
"Use llama_memory_seq_pos_min() instead");
|
||||
|
||||
// Returns the largest position present in the KV cache for the specified sequence
|
||||
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
|
||||
// Return -1 if the sequence is empty
|
||||
DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_max(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id),
|
||||
"Use llama_memory_seq_pos_max() instead");
|
||||
|
||||
// Defragment the KV cache
|
||||
// This will be applied:
|
||||
// - lazily on next llama_decode()
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx),
|
||||
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
|
||||
|
||||
// Check if the context supports KV cache shifting
|
||||
DEPRECATED(LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx),
|
||||
"use llama_memory_can_shift() instead");
|
||||
|
||||
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx),
|
||||
"simply remove this call, updates are applied lazily on the next llama_decode()");
|
||||
|
||||
//
|
||||
// State / sessions
|
||||
//
|
||||
|
||||
@@ -28,7 +28,6 @@ LLAMA_BENCH_DB_FIELDS = [
|
||||
"model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads",
|
||||
"cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers",
|
||||
"split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides",
|
||||
"defrag_thold",
|
||||
"use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen", "n_depth",
|
||||
"test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts",
|
||||
]
|
||||
@@ -38,7 +37,6 @@ LLAMA_BENCH_DB_TYPES = [
|
||||
"TEXT", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
|
||||
"TEXT", "INTEGER", "INTEGER", "TEXT", "TEXT", "INTEGER",
|
||||
"TEXT", "INTEGER", "INTEGER", "INTEGER", "TEXT", "TEXT",
|
||||
"REAL",
|
||||
"INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
|
||||
"TEXT", "INTEGER", "INTEGER", "REAL", "REAL",
|
||||
]
|
||||
|
||||
@@ -93,6 +93,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_DREAM, "dream" },
|
||||
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
|
||||
{ LLM_ARCH_LLADA, "llada" },
|
||||
{ LLM_ARCH_SEED_OSS, "seed_oss" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@@ -2010,6 +2011,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -2067,6 +2069,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_SEED_OSS,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
|
||||
@@ -97,6 +97,7 @@ enum llm_arch {
|
||||
LLM_ARCH_DREAM,
|
||||
LLM_ARCH_SMALLTHINKER,
|
||||
LLM_ARCH_LLADA,
|
||||
LLM_ARCH_SEED_OSS,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
||||
+13
-2
@@ -16,10 +16,10 @@
|
||||
static std::string trim(const std::string & str) {
|
||||
size_t start = 0;
|
||||
size_t end = str.size();
|
||||
while (start < end && isspace(str[start])) {
|
||||
while (start < end && isspace(static_cast<unsigned char>(str[start]))) {
|
||||
start += 1;
|
||||
}
|
||||
while (end > start && isspace(str[end - 1])) {
|
||||
while (end > start && isspace(static_cast<unsigned char>(str[end - 1]))) {
|
||||
end -= 1;
|
||||
}
|
||||
return str.substr(start, end - start);
|
||||
@@ -69,6 +69,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE },
|
||||
{ "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE },
|
||||
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
|
||||
{ "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
|
||||
};
|
||||
|
||||
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
||||
@@ -201,6 +202,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||
return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE;
|
||||
} else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
|
||||
return LLM_CHAT_TEMPLATE_KIMI_K2;
|
||||
} else if (tmpl_contains("<seed:bos>")) {
|
||||
return LLM_CHAT_TEMPLATE_SEED_OSS;
|
||||
}
|
||||
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
||||
}
|
||||
@@ -752,6 +755,14 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "<|im_assistant|>assistant<|im_middle|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_SEED_OSS) {
|
||||
for (auto message: chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<seed:bos>" << role << "\n" << (role == "assistant" ? trim(message->content) : message->content) << "<seed:eos>";
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<seed:bos>assistant\n";
|
||||
}
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
|
||||
@@ -49,6 +49,7 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_OPENAI_MOE,
|
||||
LLM_CHAT_TEMPLATE_HUNYUAN_DENSE,
|
||||
LLM_CHAT_TEMPLATE_KIMI_K2,
|
||||
LLM_CHAT_TEMPLATE_SEED_OSS,
|
||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||
};
|
||||
|
||||
|
||||
+17
-197
@@ -39,7 +39,6 @@ llama_context::llama_context(
|
||||
cparams.yarn_attn_factor = params.yarn_attn_factor;
|
||||
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
||||
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
||||
cparams.defrag_thold = params.defrag_thold;
|
||||
cparams.embeddings = params.embeddings;
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
@@ -93,7 +92,7 @@ llama_context::llama_context(
|
||||
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
|
||||
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
|
||||
// TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
|
||||
// TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory
|
||||
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
|
||||
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
|
||||
cparams.n_batch = GGML_KQ_MASK_PAD;
|
||||
@@ -281,7 +280,7 @@ llama_context::llama_context(
|
||||
}
|
||||
|
||||
// reserve worst-case graph
|
||||
if (!hparams.vocab_only && memory) {
|
||||
if (!hparams.vocab_only) {
|
||||
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
@@ -293,11 +292,13 @@ llama_context::llama_context(
|
||||
int n_splits_tg = -1;
|
||||
int n_nodes_tg = -1;
|
||||
|
||||
// simulate full KV cache
|
||||
|
||||
const auto mctx = memory->init_full();
|
||||
if (!mctx) {
|
||||
throw std::runtime_error("failed to initialize KV cache");
|
||||
llama_memory_context_ptr mctx;
|
||||
if (memory) {
|
||||
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
|
||||
mctx = memory->init_full();
|
||||
if (!mctx) {
|
||||
throw std::runtime_error("failed to initialize memory module");
|
||||
}
|
||||
}
|
||||
|
||||
cross.v_embd.clear();
|
||||
@@ -439,26 +440,12 @@ llama_memory_t llama_context::get_memory() const {
|
||||
return memory.get();
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_context::kv_self_defrag_sched() {
|
||||
if (!memory) {
|
||||
return;
|
||||
}
|
||||
|
||||
memory_force_optimize = true;
|
||||
}
|
||||
|
||||
// deprecated
|
||||
bool llama_context::kv_self_update(bool optimize) {
|
||||
bool llama_context::memory_update(bool optimize) {
|
||||
if (!memory) {
|
||||
return false;
|
||||
}
|
||||
|
||||
{
|
||||
// TODO: remove in the future
|
||||
optimize |= memory_force_optimize;
|
||||
memory_force_optimize = false;
|
||||
|
||||
const auto mctx = memory->init_update(this, optimize);
|
||||
switch (mctx->get_status()) {
|
||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||
@@ -992,8 +979,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
|
||||
bool did_optimize = false;
|
||||
|
||||
// handle any pending defrags/shifts
|
||||
kv_self_update(false);
|
||||
// handle any pending shifts/copies
|
||||
memory_update(false);
|
||||
|
||||
llama_memory_context_ptr mctx;
|
||||
|
||||
@@ -1018,7 +1005,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
if (!did_optimize) {
|
||||
did_optimize = true;
|
||||
|
||||
if (kv_self_update(true)) {
|
||||
if (memory_update(true)) {
|
||||
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
|
||||
|
||||
continue;
|
||||
@@ -1071,7 +1058,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
|
||||
|
||||
if (!res) {
|
||||
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
||||
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module
|
||||
llama_pos pos_min[LLAMA_MAX_SEQ];
|
||||
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
pos_min[s] = std::numeric_limits<llama_pos>::max();
|
||||
@@ -1088,7 +1075,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
continue;
|
||||
}
|
||||
|
||||
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
|
||||
LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
|
||||
|
||||
memory->seq_rm(s, pos_min[s], -1);
|
||||
}
|
||||
@@ -1872,7 +1859,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
||||
}
|
||||
|
||||
if (memory != nullptr) {
|
||||
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
||||
LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
|
||||
memory->state_write(io);
|
||||
}
|
||||
|
||||
@@ -1958,7 +1945,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
||||
}
|
||||
|
||||
if (memory) {
|
||||
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
||||
LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
|
||||
|
||||
memory->state_read(io);
|
||||
}
|
||||
@@ -2338,11 +2325,6 @@ const llama_model * llama_get_model(const llama_context * ctx) {
|
||||
return &ctx->get_model();
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_update(llama_context * ctx) {
|
||||
ctx->kv_self_update(false);
|
||||
}
|
||||
|
||||
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
|
||||
return ctx->pooling_type();
|
||||
}
|
||||
@@ -2560,168 +2542,6 @@ bool llama_memory_can_shift(llama_memory_t mem) {
|
||||
return mem->get_can_shift();
|
||||
}
|
||||
|
||||
//
|
||||
// kv cache
|
||||
//
|
||||
|
||||
// deprecated
|
||||
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
||||
const auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t res = 0;
|
||||
|
||||
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
|
||||
const llama_pos p0 = kv->seq_pos_min(s);
|
||||
const llama_pos p1 = kv->seq_pos_max(s);
|
||||
|
||||
if (p0 >= 0) {
|
||||
res += (p1 - p0) + 1;
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// deprecated
|
||||
// note: this is the same as above - will be removed anyway, so it's ok
|
||||
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
||||
const auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t res = 0;
|
||||
|
||||
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
|
||||
const llama_pos p0 = kv->seq_pos_min(s);
|
||||
const llama_pos p1 = kv->seq_pos_max(s);
|
||||
|
||||
if (p0 >= 0) {
|
||||
res += (p1 - p0) + 1;
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_clear(llama_context * ctx) {
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_memory_clear(kv, true);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
bool llama_kv_self_seq_rm(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return llama_memory_seq_rm(kv, seq_id, p0, p1);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_seq_cp(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_memory_seq_keep(kv, seq_id);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_seq_add(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta) {
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_memory_seq_add(kv, seq_id, p0, p1, delta);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_seq_div(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d) {
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_memory_seq_div(kv, seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return llama_memory_seq_pos_min(kv, seq_id);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return llama_memory_seq_pos_max(kv, seq_id);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_defrag(llama_context * ctx) {
|
||||
// force defrag
|
||||
ctx->kv_self_defrag_sched();
|
||||
}
|
||||
|
||||
// deprecated
|
||||
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return llama_memory_can_shift(kv);
|
||||
}
|
||||
|
||||
// llama state API
|
||||
|
||||
// deprecated
|
||||
|
||||
+2
-7
@@ -46,10 +46,8 @@ struct llama_context {
|
||||
|
||||
llama_memory_t get_memory() const;
|
||||
|
||||
// return true of the KV cache was updated
|
||||
// TODO: remove
|
||||
bool kv_self_update(bool optimize);
|
||||
void kv_self_defrag_sched();
|
||||
// return true if the memory was updated
|
||||
bool memory_update(bool optimize);
|
||||
|
||||
enum llama_pooling_type pooling_type() const;
|
||||
|
||||
@@ -230,9 +228,6 @@ private:
|
||||
|
||||
std::unique_ptr<llama_memory_i> memory;
|
||||
|
||||
// TODO: temporary, until the llama_kv_self_defrag() API is removed
|
||||
bool memory_force_optimize = false;
|
||||
|
||||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||
size_t logits_size = 0; // capacity (of floats) for logits
|
||||
float * logits = nullptr;
|
||||
|
||||
@@ -24,7 +24,6 @@ struct llama_cparams {
|
||||
float yarn_attn_factor;
|
||||
float yarn_beta_fast;
|
||||
float yarn_beta_slow;
|
||||
float defrag_thold;
|
||||
|
||||
bool embeddings;
|
||||
bool causal_attn;
|
||||
|
||||
+9
-31
@@ -1223,8 +1223,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * kq_mask,
|
||||
ggml_tensor * v_mla,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale) const {
|
||||
const bool v_trans = v->nb[1] > v->nb[2];
|
||||
|
||||
@@ -1360,6 +1360,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
@@ -1381,7 +1382,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_tensor * k = k_cur;
|
||||
ggml_tensor * v = v_cur;
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
@@ -1443,6 +1444,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
@@ -1469,7 +1471,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
@@ -1495,33 +1497,8 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
return build_attn_with_sinks(
|
||||
inp,
|
||||
wo,
|
||||
wo_b,
|
||||
q_cur,
|
||||
k_cur,
|
||||
v_cur,
|
||||
kq_b,
|
||||
v_mla,
|
||||
nullptr,
|
||||
kq_scale,
|
||||
il);
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn_with_sinks(
|
||||
llm_graph_input_attn_kv_iswa * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
// these nodes are added to the graph together so that they are not reordered
|
||||
@@ -1561,7 +1538,7 @@ ggml_tensor * llm_graph_context::build_attn_with_sinks(
|
||||
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, sinks, kq_scale);
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
@@ -1600,6 +1577,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
@@ -1615,7 +1593,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_tensor * k = k_cur;
|
||||
ggml_tensor * v = v_cur;
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
|
||||
+12
-22
@@ -680,14 +680,14 @@ struct llm_graph_context {
|
||||
//
|
||||
|
||||
ggml_tensor * build_attn_mha(
|
||||
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * kq_mask,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale) const;
|
||||
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * kq_mask,
|
||||
ggml_tensor * sinks, // [n_head_q]
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale) const;
|
||||
|
||||
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
|
||||
|
||||
@@ -699,6 +699,7 @@ struct llm_graph_context {
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks, // [n_head_q]
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
@@ -713,6 +714,7 @@ struct llm_graph_context {
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks, // [n_head_q]
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
@@ -728,21 +730,8 @@ struct llm_graph_context {
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
// TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
|
||||
ggml_tensor * build_attn_with_sinks(
|
||||
llm_graph_input_attn_kv_iswa * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
ggml_tensor * sinks, // [n_head_q]
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
@@ -756,6 +745,7 @@ struct llm_graph_context {
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks, // [n_head_q]
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
@@ -153,3 +153,28 @@ bool llama_hparams::is_swa(uint32_t il) const {
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
bool llama_hparams::has_kv(uint32_t il) const {
|
||||
if (n_layer_kv_from_start >= 0) {
|
||||
if (il < (uint32_t) n_layer_kv_from_start) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// by default, all layers have kv
|
||||
return true;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_layer_kv() const {
|
||||
uint32_t res = 0;
|
||||
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
if (has_kv(il)) {
|
||||
res++;
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -41,6 +41,7 @@ struct llama_hparams {
|
||||
uint32_t n_embd;
|
||||
uint32_t n_embd_features = 0;
|
||||
uint32_t n_layer;
|
||||
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
|
||||
uint32_t n_rot;
|
||||
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
||||
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
||||
@@ -221,6 +222,11 @@ struct llama_hparams {
|
||||
uint32_t n_pos_per_embd() const;
|
||||
|
||||
bool is_swa(uint32_t il) const;
|
||||
|
||||
bool has_kv(uint32_t il) const;
|
||||
|
||||
// number of layers for which has_kv() returns true
|
||||
uint32_t n_layer_kv() const;
|
||||
};
|
||||
|
||||
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
||||
|
||||
@@ -22,9 +22,26 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad) : hparams(model.hparams), unified(unified) {
|
||||
llama_kv_cache::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
||||
llama_kv_cache::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
||||
uint32_t n_pad,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
|
||||
|
||||
// chain filters
|
||||
const layer_filter_cb filter_base = [&](int32_t il) {
|
||||
if (filter && !filter(il)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return !model.hparams.is_swa(il);
|
||||
};
|
||||
|
||||
const layer_filter_cb filter_swa = [&](int32_t il) {
|
||||
if (filter && !filter(il)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return model.hparams.is_swa(il);
|
||||
};
|
||||
|
||||
const uint32_t size_base = kv_size;
|
||||
|
||||
@@ -41,16 +58,16 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
|
||||
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
|
||||
|
||||
kv_base = std::make_unique<llama_kv_cache>(
|
||||
model, std::move(filter_base), type_k, type_v,
|
||||
model, type_k, type_v,
|
||||
v_trans, offload, unified, size_base, n_seq_max, n_pad,
|
||||
0, LLAMA_SWA_TYPE_NONE);
|
||||
0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
||||
|
||||
kv_swa = std::make_unique<llama_kv_cache>(
|
||||
model, std::move(filter_swa), type_k, type_v,
|
||||
model, type_k, type_v,
|
||||
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
|
||||
hparams.n_swa, hparams.swa_type);
|
||||
hparams.n_swa, hparams.swa_type, filter_swa, reuse);
|
||||
}
|
||||
|
||||
void llama_kv_cache_iswa::clear(bool data) {
|
||||
|
||||
@@ -20,11 +20,13 @@ public:
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool swa_full,
|
||||
bool ,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad);
|
||||
uint32_t n_pad,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse);
|
||||
|
||||
~llama_kv_cache_iswa() = default;
|
||||
|
||||
|
||||
+42
-393
@@ -17,32 +17,25 @@
|
||||
//
|
||||
|
||||
llama_kv_cache::llama_kv_cache(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type) :
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse) :
|
||||
model(model), hparams(model.hparams), v_trans(v_trans),
|
||||
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
|
||||
|
||||
GGML_ASSERT(kv_size % n_pad == 0);
|
||||
|
||||
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
|
||||
auto n_layer_cache = hparams.n_layer;
|
||||
if (model.arch == LLM_ARCH_GEMMA3N) {
|
||||
n_layer_cache = 20;
|
||||
}
|
||||
if (model.arch == LLM_ARCH_GLM4_MOE) {
|
||||
// GLM-4.5: Only process up to last layer, skip final NextN layer
|
||||
n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
}
|
||||
const uint32_t n_layer_kv = hparams.n_layer_kv();
|
||||
|
||||
// create a context for each buffer type
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||
@@ -50,7 +43,7 @@ llama_kv_cache::llama_kv_cache(
|
||||
auto it = ctx_map.find(buft);
|
||||
if (it == ctx_map.end()) {
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()),
|
||||
/*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
@@ -97,9 +90,14 @@ llama_kv_cache::llama_kv_cache(
|
||||
__func__, hparams.n_embd_v_gqa_max());
|
||||
}
|
||||
|
||||
for (uint32_t il = 0; il < n_layer_cache; il++) {
|
||||
for (uint32_t il = 0; il < hparams.n_layer; il++) {
|
||||
if (!hparams.has_kv(il)) {
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (filter && !filter(il)) {
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: filtered\n", __func__, il);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -147,23 +145,27 @@ llama_kv_cache::llama_kv_cache(
|
||||
layers.push_back({ il, k, v, k_stream, v_stream, });
|
||||
}
|
||||
|
||||
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
|
||||
if (model.arch == LLM_ARCH_GEMMA3N) {
|
||||
LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
|
||||
if (reuse) {
|
||||
LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__);
|
||||
|
||||
for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
|
||||
if (filter && !filter(il)) {
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
|
||||
for (uint32_t il = 0; il < hparams.n_layer; il++) {
|
||||
const int32_t il_reuse = reuse(il);
|
||||
|
||||
if (il_reuse < 0) {
|
||||
LLAMA_LOG_DEBUG("%s: - layer %3d: no reuse\n", __func__, il);
|
||||
continue;
|
||||
}
|
||||
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
|
||||
if (filter && !filter(il)) {
|
||||
LLAMA_LOG_DEBUG("%s: - layer %3d: filtered\n", __func__, il);
|
||||
continue;
|
||||
}
|
||||
|
||||
GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
|
||||
|
||||
map_layer_ids[il] = map_layer_ids[il_reuse];
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
|
||||
LLAMA_LOG_DEBUG("%s: - layer %3d: reuse layer %d, is_swa = %d\n", __func__, il, il_reuse, hparams.is_swa(il));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -525,39 +527,11 @@ llama_memory_context_ptr llama_kv_cache::init_full() {
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_kv_cache::init_update(llama_context * lctx, bool optimize) {
|
||||
GGML_UNUSED(optimize);
|
||||
|
||||
bool do_shift = get_has_shift();
|
||||
|
||||
defrag_info dinfo;
|
||||
|
||||
// see if we need to defrag
|
||||
if (n_stream == 1) {
|
||||
// note : for now do not consider defrag for n_stream > 1
|
||||
const auto & cells = v_cells[seq_to_stream[0]];
|
||||
|
||||
bool do_defrag = optimize;
|
||||
|
||||
const auto thold = lctx->get_cparams().defrag_thold;
|
||||
|
||||
if (!do_defrag && thold > 0.0f) {
|
||||
const auto n_kv = cells.used_max_p1();
|
||||
|
||||
// - do not defrag small contexts (i.e. < 2048 tokens)
|
||||
// - count the padding towards the number of used tokens
|
||||
const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
|
||||
|
||||
if (fragmentation > thold) {
|
||||
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
||||
|
||||
do_defrag = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (do_defrag) {
|
||||
dinfo = defrag_prepare(lctx->graph_max_nodes());
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_unique<llama_kv_cache_context>(this, lctx, do_shift, std::move(dinfo), std::move(sc_info));
|
||||
return std::make_unique<llama_kv_cache_context>(this, lctx, do_shift, std::move(sc_info));
|
||||
}
|
||||
|
||||
llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||
@@ -629,7 +603,7 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_
|
||||
return res;
|
||||
}
|
||||
|
||||
bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) {
|
||||
bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) {
|
||||
bool updated = false;
|
||||
|
||||
auto * sched = lctx->get_sched();
|
||||
@@ -699,53 +673,6 @@ bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const defrag_in
|
||||
}
|
||||
}
|
||||
|
||||
if (!dinfo.empty()) {
|
||||
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
||||
|
||||
// note: for now do not consider defrag for n_stream > 1
|
||||
auto & cells = v_cells[seq_to_stream[0]];
|
||||
auto & head = v_heads[seq_to_stream[0]];
|
||||
|
||||
// apply moves:
|
||||
{
|
||||
const auto n_kv = dinfo.ids.size();
|
||||
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
assert(dinfo.ids[i] <= n_kv);
|
||||
|
||||
if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) {
|
||||
continue;
|
||||
}
|
||||
|
||||
cells.mv(i, dinfo.ids[i]);
|
||||
}
|
||||
|
||||
// reset the head so we can find the first free slot during the next ubatch
|
||||
head = 0;
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(sched);
|
||||
|
||||
auto * res = lctx->get_gf_res_reserve();
|
||||
|
||||
res->reset();
|
||||
|
||||
auto * gf = build_graph_defrag(res, lctx, dinfo);
|
||||
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
|
||||
res->set_inputs(nullptr);
|
||||
|
||||
if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
|
||||
updated = true;
|
||||
}
|
||||
|
||||
return updated;
|
||||
}
|
||||
|
||||
@@ -1525,283 +1452,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
||||
return gf;
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_kv_cache::build_graph_defrag(
|
||||
llm_graph_result * res,
|
||||
llama_context * lctx,
|
||||
const defrag_info & dinfo) const {
|
||||
auto * ctx = res->get_ctx();
|
||||
auto * gf = res->get_gf();
|
||||
|
||||
GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
|
||||
|
||||
const auto & cells = v_cells[0];
|
||||
|
||||
const auto & ids = dinfo.ids;
|
||||
|
||||
const auto & cparams = lctx->get_cparams();
|
||||
|
||||
#if 0
|
||||
// CPU defrag
|
||||
//
|
||||
// TODO: optimizations are possible:
|
||||
// - multiple threads
|
||||
// - avoid copying to the host memory when already there
|
||||
//
|
||||
// likely not worth the effort, as we have ggml_graph based defrag
|
||||
//
|
||||
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
const uint32_t kv_size = size;
|
||||
|
||||
std::vector<uint8_t> buf_k;
|
||||
std::vector<uint8_t> buf_v;
|
||||
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
|
||||
const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
|
||||
|
||||
const size_t v_size_el = ggml_type_size(v_l[il]->type);
|
||||
const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
|
||||
|
||||
buf_k.resize(k_size);
|
||||
buf_v.resize(v_size);
|
||||
|
||||
ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
|
||||
ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
|
||||
|
||||
// batch move [i, i+nm) to [id, id+nm)
|
||||
// note: cells can move only to a lower index
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
const uint32_t id = ids[i];
|
||||
|
||||
if (i == id || id == n_kv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t nm = 1;
|
||||
|
||||
while (i + nm < n_kv && ids[i + nm] == id + nm) {
|
||||
nm++;
|
||||
}
|
||||
|
||||
// move keys
|
||||
{
|
||||
const int64_t os = i*k_size_row;
|
||||
const int64_t od = id*k_size_row;
|
||||
|
||||
memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
|
||||
}
|
||||
|
||||
// move values (note: they are transposed)
|
||||
{
|
||||
const int64_t os = i;
|
||||
const int64_t od = id;
|
||||
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
|
||||
}
|
||||
}
|
||||
|
||||
i += nm - 1;
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
|
||||
ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
|
||||
}
|
||||
#else
|
||||
for (uint32_t i = 0; i < ids.size(); ++i) {
|
||||
const uint32_t id = ids[i];
|
||||
|
||||
if (i == id || id == ids.size()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t nm = 1;
|
||||
|
||||
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
|
||||
nm++;
|
||||
}
|
||||
|
||||
for (const auto & layer : layers) {
|
||||
const uint32_t il = layer.il;
|
||||
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k,
|
||||
n_embd_k_gqa, nm,
|
||||
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
||||
ggml_row_size(layer.k->type, n_embd_k_gqa*i));
|
||||
|
||||
ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k,
|
||||
n_embd_k_gqa, nm,
|
||||
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
||||
ggml_row_size(layer.k->type, n_embd_k_gqa*id));
|
||||
|
||||
ggml_tensor * view_v_src;
|
||||
ggml_tensor * view_v_dst;
|
||||
|
||||
if (cparams.flash_attn) {
|
||||
// NOTE: the V cache is not transposed when using flash attention
|
||||
view_v_src = ggml_view_2d(ctx, layer.v,
|
||||
n_embd_v_gqa, nm,
|
||||
ggml_row_size(layer.v->type, n_embd_v_gqa),
|
||||
ggml_row_size(layer.v->type, n_embd_v_gqa*i));
|
||||
|
||||
view_v_dst = ggml_view_2d(ctx, layer.v,
|
||||
n_embd_v_gqa, nm,
|
||||
ggml_row_size(layer.v->type, n_embd_v_gqa),
|
||||
ggml_row_size(layer.v->type, n_embd_v_gqa*id));
|
||||
} else {
|
||||
view_v_src = ggml_view_2d(ctx, layer.v,
|
||||
nm, n_embd_v_gqa,
|
||||
ggml_row_size(layer.v->type, cells.size()),
|
||||
ggml_row_size(layer.v->type, i));
|
||||
|
||||
view_v_dst = ggml_view_2d(ctx, layer.v,
|
||||
nm, n_embd_v_gqa,
|
||||
ggml_row_size(layer.v->type, cells.size()),
|
||||
ggml_row_size(layer.v->type, id));
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
|
||||
}
|
||||
|
||||
i += nm - 1;
|
||||
}
|
||||
|
||||
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
|
||||
#endif
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
llama_kv_cache::defrag_info llama_kv_cache::defrag_prepare(int32_t n_max_nodes) const {
|
||||
GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
|
||||
|
||||
const auto & cells = v_cells[0];
|
||||
|
||||
const uint32_t n_layer = layers.size();
|
||||
|
||||
const uint32_t n_kv = cells.used_max_p1();
|
||||
const uint32_t n_used = cells.get_used();
|
||||
|
||||
assert(n_used <= n_kv);
|
||||
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
// number of cells moved
|
||||
uint32_t n_moves = 0;
|
||||
|
||||
// each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
|
||||
// - source view, destination view, copy operation
|
||||
// - x2 for keys and values
|
||||
//const uint32_t max_moves = max_nodes()/(6*n_layer);
|
||||
// TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
|
||||
const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
|
||||
|
||||
// determine which KV cells to move where
|
||||
defrag_info res;
|
||||
auto & ids = res.ids;
|
||||
|
||||
ids.resize(n_kv, n_kv);
|
||||
|
||||
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
||||
if (!cells.is_empty(i0)) {
|
||||
ids[i0] = i0;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// found a hole - fill it with data from the end of the cache
|
||||
|
||||
uint32_t nh = 1;
|
||||
|
||||
// determine the size of the hole
|
||||
while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
|
||||
nh++;
|
||||
}
|
||||
|
||||
uint32_t nf = 0;
|
||||
uint32_t is = n_kv - 1;
|
||||
|
||||
// starting from the end, find nh non-empty cells
|
||||
for (; is > i0; --is) {
|
||||
if (cells.is_empty(is) || ids[is] != n_kv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// non-empty cell which is not yet moved
|
||||
nf++;
|
||||
|
||||
if (nf == nh) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// this can only happen if `n_used` is not accurate, which would be a bug
|
||||
GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
|
||||
|
||||
nf = 0;
|
||||
|
||||
uint32_t i1 = is;
|
||||
|
||||
// are we moving a continuous block of memory?
|
||||
bool cont = false;
|
||||
|
||||
// should we stop searching for the next move?
|
||||
bool stop = false;
|
||||
|
||||
// go back and move the nf cells to the hole
|
||||
for (; i1 < n_kv; ++i1) {
|
||||
if (cells.is_empty(i1) || ids[i1] != n_kv) {
|
||||
if (n_moves == max_moves) {
|
||||
stop = true;
|
||||
break;
|
||||
}
|
||||
|
||||
cont = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
// this cell goes to (i0 + nf)
|
||||
ids[i1] = i0 + nf;
|
||||
|
||||
if (!cont) {
|
||||
n_moves++;
|
||||
cont = true;
|
||||
}
|
||||
|
||||
nf++;
|
||||
|
||||
if (nf == nh) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (stop || n_moves == max_moves) {
|
||||
break;
|
||||
}
|
||||
|
||||
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
|
||||
|
||||
i0 += nh - 1;
|
||||
}
|
||||
|
||||
if (n_moves == 0) {
|
||||
return {};
|
||||
}
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
||||
assert(p0 >= 0 && p1 >= 0);
|
||||
|
||||
@@ -2300,9 +1950,8 @@ llama_kv_cache_context::llama_kv_cache_context(
|
||||
llama_kv_cache * kv,
|
||||
llama_context * lctx,
|
||||
bool do_shift,
|
||||
defrag_info dinfo,
|
||||
stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)), sc_info(std::move(sc_info)) {
|
||||
if (!do_shift && this->dinfo.empty() && this->sc_info.empty()) {
|
||||
stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), sc_info(std::move(sc_info)) {
|
||||
if (!do_shift && this->sc_info.empty()) {
|
||||
status = LLAMA_MEMORY_STATUS_NO_UPDATE;
|
||||
}
|
||||
}
|
||||
@@ -2330,7 +1979,7 @@ bool llama_kv_cache_context::apply() {
|
||||
|
||||
// no ubatches -> this is a KV cache update
|
||||
if (ubatches.empty()) {
|
||||
kv->update(lctx, do_shift, dinfo, sc_info);
|
||||
kv->update(lctx, do_shift, sc_info);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
+14
-39
@@ -21,20 +21,6 @@ class llama_kv_cache : public llama_memory_i {
|
||||
public:
|
||||
static uint32_t get_padding(const llama_cparams & cparams);
|
||||
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
struct defrag_info {
|
||||
bool empty() const {
|
||||
return ids.empty();
|
||||
}
|
||||
|
||||
// contains information about which cell moves where:
|
||||
// - cell i moves to ids[i]
|
||||
// - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
|
||||
std::vector<uint32_t> ids;
|
||||
};
|
||||
|
||||
struct stream_copy_info {
|
||||
bool empty() const {
|
||||
assert(ssrc.size() == sdst.size());
|
||||
@@ -93,18 +79,19 @@ public:
|
||||
using slot_info_vec_t = std::vector<slot_info>;
|
||||
|
||||
llama_kv_cache(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type);
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse);
|
||||
|
||||
~llama_kv_cache() = default;
|
||||
|
||||
@@ -173,7 +160,7 @@ public:
|
||||
// return empty vector on failure
|
||||
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
|
||||
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
|
||||
bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info);
|
||||
|
||||
// find a slot of kv cells that can hold the ubatch
|
||||
// if cont == true, then the slot must be continuous
|
||||
@@ -254,9 +241,6 @@ private:
|
||||
// model layer id -> KV cache layer id
|
||||
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
||||
|
||||
// return non-empty vector if cells have been moved
|
||||
defrag_info defrag_prepare(int32_t n_max_nodes) const;
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
size_t size_k_bytes() const;
|
||||
@@ -277,11 +261,6 @@ private:
|
||||
llm_graph_result * res,
|
||||
llama_context * lctx) const;
|
||||
|
||||
ggml_cgraph * build_graph_defrag(
|
||||
llm_graph_result * res,
|
||||
llama_context * lctx,
|
||||
const defrag_info & dinfo) const;
|
||||
|
||||
struct cell_ranges_t {
|
||||
uint32_t strm;
|
||||
|
||||
@@ -299,7 +278,6 @@ class llama_kv_cache_context : public llama_memory_context_i {
|
||||
public:
|
||||
// some shorthands
|
||||
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
|
||||
using defrag_info = llama_kv_cache::defrag_info;
|
||||
using stream_copy_info = llama_kv_cache::stream_copy_info;
|
||||
|
||||
// used for errors
|
||||
@@ -314,7 +292,6 @@ public:
|
||||
llama_kv_cache * kv,
|
||||
llama_context * lctx,
|
||||
bool do_shift,
|
||||
defrag_info dinfo,
|
||||
stream_copy_info sc_info);
|
||||
|
||||
// used to create a batch procesing context from a batch
|
||||
@@ -374,8 +351,6 @@ private:
|
||||
|
||||
bool do_shift = false;
|
||||
|
||||
defrag_info dinfo;
|
||||
|
||||
stream_copy_info sc_info;
|
||||
|
||||
//
|
||||
|
||||
+14
-14
@@ -77,24 +77,24 @@ public:
|
||||
}
|
||||
|
||||
// move cell isrc to idst (used during defrag)
|
||||
void mv(uint32_t isrc, uint32_t idst) {
|
||||
assert(isrc < pos.size());
|
||||
assert(idst < pos.size());
|
||||
//void mv(uint32_t isrc, uint32_t idst) {
|
||||
// assert(isrc < pos.size());
|
||||
// assert(idst < pos.size());
|
||||
|
||||
assert(pos[idst] == -1);
|
||||
assert(pos[isrc] != -1);
|
||||
// assert(pos[idst] == -1);
|
||||
// assert(pos[isrc] != -1);
|
||||
|
||||
pos [idst] = pos [isrc];
|
||||
shift[idst] = shift[isrc];
|
||||
seq [idst] = seq [isrc];
|
||||
// pos [idst] = pos [isrc];
|
||||
// shift[idst] = shift[isrc];
|
||||
// seq [idst] = seq [isrc];
|
||||
|
||||
pos [isrc] = -1;
|
||||
shift[isrc] = 0;
|
||||
seq [isrc].reset();
|
||||
// pos [isrc] = -1;
|
||||
// shift[isrc] = 0;
|
||||
// seq [isrc].reset();
|
||||
|
||||
used.erase (isrc);
|
||||
used.insert(idst);
|
||||
}
|
||||
// used.erase (isrc);
|
||||
// used.insert(idst);
|
||||
//}
|
||||
|
||||
// copy the state of cells [i, i + n) (used for save/restore the state of the cells)
|
||||
llama_kv_cells cp(uint32_t i, uint32_t n) const {
|
||||
|
||||
+29
-28
@@ -9,32 +9,29 @@
|
||||
//
|
||||
|
||||
llama_memory_hybrid::llama_memory_hybrid(
|
||||
const llama_model & model,
|
||||
/* attn */
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
/* recurrent */
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload,
|
||||
bool unified,
|
||||
/* layer filters */
|
||||
layer_filter_cb && filter_attn,
|
||||
layer_filter_cb && filter_recr) :
|
||||
const llama_model & model,
|
||||
/* attn */
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
/* recurrent */
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload,
|
||||
bool unified,
|
||||
/* layer filters */
|
||||
const layer_filter_cb & filter_attn,
|
||||
const layer_filter_cb & filter_recr) :
|
||||
hparams(model.hparams),
|
||||
mem_attn(new llama_kv_cache(
|
||||
model,
|
||||
filter_attn == nullptr ?
|
||||
[&](int32_t il) { return !hparams.is_recurrent(il); }
|
||||
: filter_attn,
|
||||
type_k,
|
||||
type_v,
|
||||
v_trans,
|
||||
@@ -44,18 +41,22 @@ llama_memory_hybrid::llama_memory_hybrid(
|
||||
n_seq_max,
|
||||
n_pad,
|
||||
n_swa,
|
||||
swa_type
|
||||
swa_type,
|
||||
filter_attn == nullptr ?
|
||||
[&](int32_t il) { return !hparams.is_recurrent(il); }
|
||||
: filter_attn,
|
||||
nullptr
|
||||
)),
|
||||
mem_recr(new llama_memory_recurrent(
|
||||
model,
|
||||
filter_recr == nullptr ?
|
||||
[&](int32_t il) { return hparams.is_recurrent(il); }
|
||||
: filter_recr,
|
||||
type_r,
|
||||
type_s,
|
||||
offload,
|
||||
rs_size,
|
||||
n_seq_max
|
||||
n_seq_max,
|
||||
filter_recr == nullptr ?
|
||||
[&](int32_t il) { return hparams.is_recurrent(il); }
|
||||
: filter_recr
|
||||
)) {}
|
||||
|
||||
llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||
|
||||
+18
-22
@@ -18,31 +18,27 @@
|
||||
|
||||
class llama_memory_hybrid : public llama_memory_i {
|
||||
public:
|
||||
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
llama_memory_hybrid(
|
||||
const llama_model & model,
|
||||
/* attn */
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
/* recurrent */
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload,
|
||||
bool unified,
|
||||
/* layer filters */
|
||||
layer_filter_cb && filter_attn = nullptr,
|
||||
layer_filter_cb && filter_recr = nullptr);
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
/* recurrent */
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload,
|
||||
bool unified,
|
||||
/* layer filters */
|
||||
const layer_filter_cb & filter_attn = nullptr,
|
||||
const layer_filter_cb & filter_recr = nullptr);
|
||||
|
||||
~llama_memory_hybrid() = default;
|
||||
|
||||
|
||||
@@ -16,13 +16,13 @@
|
||||
//
|
||||
|
||||
llama_memory_recurrent::llama_memory_recurrent(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
bool offload,
|
||||
uint32_t mem_size,
|
||||
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
||||
const llama_model & model,
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
bool offload,
|
||||
uint32_t mem_size,
|
||||
uint32_t n_seq_max,
|
||||
const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
||||
const int32_t n_layer = hparams.n_layer;
|
||||
|
||||
head = 0;
|
||||
|
||||
@@ -15,18 +15,14 @@
|
||||
// see the implementation of llama_kv_cache_context_i for an example how to do it
|
||||
class llama_memory_recurrent : public llama_memory_i {
|
||||
public:
|
||||
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
llama_memory_recurrent(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
bool offload,
|
||||
uint32_t mem_size,
|
||||
uint32_t n_seq_max);
|
||||
const llama_model & model,
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
bool offload,
|
||||
uint32_t mem_size,
|
||||
uint32_t n_seq_max,
|
||||
const layer_filter_cb & filter);
|
||||
|
||||
~llama_memory_recurrent() = default;
|
||||
|
||||
|
||||
+9
-1
@@ -3,6 +3,7 @@
|
||||
#include "llama.h"
|
||||
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
|
||||
struct llama_ubatch;
|
||||
|
||||
@@ -64,6 +65,13 @@ using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
|
||||
// general concept of LLM memory
|
||||
// the KV cache is a type of LLM memory, but there can be other types
|
||||
struct llama_memory_i {
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
// this callback is used to specify which layers should reuse memory from other layers
|
||||
// return negative value to indicate that the layer il should not reuse memory
|
||||
using layer_reuse_cb = std::function<int32_t(int32_t il)>;
|
||||
|
||||
virtual ~llama_memory_i() = default;
|
||||
|
||||
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
||||
@@ -77,7 +85,7 @@ struct llama_memory_i {
|
||||
// simulate full cache, used for allocating worst-case compute buffers
|
||||
virtual llama_memory_context_ptr init_full() = 0;
|
||||
|
||||
// prepare for any pending memory updates, such as shifts, defrags, etc.
|
||||
// prepare for any pending memory updates, such as shifts, copies, etc.
|
||||
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
|
||||
virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
|
||||
|
||||
|
||||
+303
-94
@@ -83,6 +83,7 @@ const char * llm_type_name(llm_type type) {
|
||||
case LLM_TYPE_32B: return "32B";
|
||||
case LLM_TYPE_34B: return "34B";
|
||||
case LLM_TYPE_35B: return "35B";
|
||||
case LLM_TYPE_36B: return "36B";
|
||||
case LLM_TYPE_40B: return "40B";
|
||||
case LLM_TYPE_65B: return "65B";
|
||||
case LLM_TYPE_70B: return "70B";
|
||||
@@ -1114,6 +1115,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
hparams.set_swa_pattern(5);
|
||||
|
||||
hparams.n_layer_kv_from_start = 20;
|
||||
hparams.rope_freq_base_train_swa = 10000.0f;
|
||||
hparams.rope_freq_scale_train_swa = 1.0f;
|
||||
hparams.f_attention_scale = 1.0f;
|
||||
@@ -1288,6 +1290,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_SEED_OSS:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
switch (hparams.n_layer) {
|
||||
case 64: type = LLM_TYPE_36B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_OLMOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
@@ -1465,12 +1475,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
// Expert gating function (GLM-4.5 uses sigmoid)
|
||||
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
|
||||
if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
|
||||
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
|
||||
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
|
||||
}
|
||||
|
||||
// NextN/MTP parameters
|
||||
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
|
||||
|
||||
// TODO: when MTP is implemented, this should probably be updated if needed
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer)
|
||||
case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer)
|
||||
@@ -3967,6 +3980,43 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_SEED_OSS:
|
||||
{
|
||||
const uint32_t head_dim = hparams.n_embd_head_k;
|
||||
const int64_t n_qo_dim = n_head * head_dim;
|
||||
const int64_t n_kv_dim = n_head_kv * head_dim;
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, 0);
|
||||
|
||||
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_qo_dim}, TENSOR_NOT_REQUIRED);
|
||||
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED);
|
||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
|
||||
case LLM_ARCH_OLMOE:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
@@ -5474,8 +5524,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
} break;
|
||||
case LLM_ARCH_LFM2:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
@@ -6050,7 +6105,7 @@ struct llm_build_llama : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
|
||||
@@ -6224,7 +6279,7 @@ struct llm_build_llama_iswa : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
|
||||
@@ -6401,7 +6456,7 @@ struct llm_build_deci : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -6533,7 +6588,7 @@ struct llm_build_baichuan : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -6648,7 +6703,7 @@ struct llm_build_xverse : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -6771,7 +6826,7 @@ struct llm_build_falcon : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -6901,7 +6956,7 @@ struct llm_build_grok : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -7050,7 +7105,7 @@ struct llm_build_dbrx : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -7164,7 +7219,7 @@ struct llm_build_starcoder : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -7263,7 +7318,7 @@ struct llm_build_refact : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -7426,7 +7481,7 @@ struct llm_build_bert : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@@ -7571,7 +7626,7 @@ struct llm_build_neo_bert : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, nullptr,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@@ -7671,7 +7726,7 @@ struct llm_build_bloom : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -7819,7 +7874,7 @@ struct llm_build_mpt : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -7965,7 +8020,7 @@ struct llm_build_stablelm : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -8086,7 +8141,7 @@ struct llm_build_qwen : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -8206,7 +8261,7 @@ struct llm_build_qwen2 : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -8320,8 +8375,9 @@ struct llm_build_dream : public llm_graph_context {
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr,
|
||||
nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -8420,8 +8476,9 @@ struct llm_build_llada : public llm_graph_context {
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr,
|
||||
1.0f / sqrtf(float(n_embd_head)), il);
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -8534,7 +8591,7 @@ struct llm_build_qwen2vl : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -8661,7 +8718,7 @@ struct llm_build_qwen2moe : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -8814,7 +8871,7 @@ struct llm_build_qwen3 : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -8935,7 +8992,7 @@ struct llm_build_qwen3moe : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -9075,7 +9132,7 @@ struct llm_build_phi2 : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -9212,7 +9269,7 @@ struct llm_build_phi3 : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -9346,7 +9403,7 @@ struct llm_build_plamo : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -9454,7 +9511,7 @@ struct llm_build_gpt2 : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -9568,7 +9625,7 @@ struct llm_build_codeshell : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -9697,7 +9754,7 @@ struct llm_build_orion : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -9824,7 +9881,7 @@ struct llm_build_internlm2 : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -10012,7 +10069,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
|
||||
q_states, k_states, v_states, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -10142,7 +10199,7 @@ struct llm_build_gemma : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -10257,7 +10314,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -10399,7 +10456,7 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -10471,7 +10528,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
|
||||
const int64_t n_embd_altup;
|
||||
const int64_t n_altup;
|
||||
const int i_altup_act;
|
||||
const int n_layer_kv = 20; // number of layers having KV [KV_REUSE]
|
||||
const int n_layer_sparsity = 10; // number of layers using activation sparsity
|
||||
const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
|
||||
|
||||
@@ -10521,8 +10577,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// this block is made to be closely resemble Gemma3p5DecoderLayer on python code
|
||||
const bool has_kv = (il < n_layer_kv);
|
||||
|
||||
const float freq_base_l = model.get_rope_freq_base (cparams, il);
|
||||
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
||||
|
||||
@@ -10542,7 +10596,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
|
||||
ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens]
|
||||
|
||||
// self-attention
|
||||
if (has_kv) {
|
||||
if (hparams.has_kv(il)) {
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
@@ -10580,9 +10634,9 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
|
||||
} else {
|
||||
// no KV layers
|
||||
// reuse KV cache of earlier layers
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
@@ -10598,7 +10652,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
|
||||
Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
|
||||
}
|
||||
|
||||
cur = build_norm(cur,
|
||||
@@ -10963,7 +11017,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -11390,7 +11444,9 @@ struct llm_build_jamba : public llm_graph_context_mamba {
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// No RoPE :)
|
||||
cur = build_attn(inp_hybrid->get_attn(), model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
cur = build_attn(inp_hybrid->get_attn(),
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, NULL, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -11548,7 +11604,7 @@ struct llm_build_command_r : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -11683,7 +11739,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -11814,7 +11870,7 @@ struct llm_build_olmo : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, nullptr,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -11934,7 +11990,7 @@ struct llm_build_olmo2 : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -12067,7 +12123,7 @@ struct llm_build_olmoe : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -12200,7 +12256,7 @@ struct llm_build_openelm : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -12312,7 +12368,7 @@ struct llm_build_gptneox : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -12462,7 +12518,7 @@ struct llm_build_arctic : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -12617,7 +12673,7 @@ struct llm_build_deepseek : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -12845,7 +12901,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
|
||||
// note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group)
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, model.layers[il].wv_b, kq_scale, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il);
|
||||
} else {
|
||||
ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr);
|
||||
cb(kv, "kv", il);
|
||||
@@ -12879,7 +12935,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
|
||||
// note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups)
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13046,7 +13102,7 @@ struct llm_build_bitnet : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
NULL, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.layers[il].attn_sub_norm, NULL,
|
||||
@@ -13169,7 +13225,7 @@ struct llm_build_t5_enc : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo_enc, nullptr,
|
||||
Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il);
|
||||
Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@@ -13275,7 +13331,7 @@ struct llm_build_t5_dec : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn_self,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il);
|
||||
Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@@ -13307,7 +13363,7 @@ struct llm_build_t5_dec : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn_cross,
|
||||
model.layers[il].wo_cross, nullptr,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
//ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
@@ -13439,7 +13495,7 @@ struct llm_build_jais : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/float(n_embd_head), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/float(n_embd_head), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -13571,7 +13627,7 @@ struct llm_build_chatglm : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -13704,7 +13760,7 @@ struct llm_build_glm4 : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -13853,7 +13909,7 @@ struct llm_build_glm4_moe : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids) {
|
||||
@@ -14007,7 +14063,7 @@ struct llm_build_nemotron : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -14138,7 +14194,7 @@ struct llm_build_exaone : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -14269,7 +14325,7 @@ struct llm_build_exaone4 : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
|
||||
@@ -15204,7 +15260,7 @@ struct llm_build_granite : public llm_graph_context {
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
return cur;
|
||||
}
|
||||
@@ -15423,7 +15479,7 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
return cur;
|
||||
}
|
||||
@@ -15608,7 +15664,7 @@ struct llm_build_chameleon : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, nullptr,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -15964,7 +16020,7 @@ struct llm_build_plm : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
|
||||
q_states, k_states, v_states, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -16087,7 +16143,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -16227,7 +16283,7 @@ struct llm_build_dots1 : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -16382,7 +16438,7 @@ struct llm_build_ernie4_5 : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
@@ -16515,7 +16571,7 @@ struct llm_build_ernie4_5_moe : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
|
||||
@@ -16668,7 +16724,7 @@ struct llm_build_falcon_h1 : public llm_graph_context_mamba {
|
||||
|
||||
ggml_tensor * attn_out = build_attn(inp->get_attn(),
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(attn_out, "attn_out", il);
|
||||
|
||||
cur = build_norm(inpL,
|
||||
@@ -16878,7 +16934,9 @@ private:
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cur = build_attn(inp, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head_v)), il);
|
||||
cur = build_attn(inp,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, NULL, NULL, NULL, 1.0f/sqrtf(float(n_embd_head_v)), il);
|
||||
}
|
||||
|
||||
cb(cur, "attn_out", il);
|
||||
@@ -17125,7 +17183,7 @@ struct llm_build_arcee : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
|
||||
@@ -17270,7 +17328,7 @@ struct llm_build_hunyuan_moe : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
|
||||
@@ -17430,7 +17488,7 @@ struct llm_build_hunyuan_dense : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
|
||||
@@ -17560,7 +17618,7 @@ struct llm_build_smollm3 : public llm_graph_context {
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
|
||||
@@ -17682,9 +17740,9 @@ struct llm_build_openai_moe_iswa : public llm_graph_context {
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn_with_sinks(inp_attn,
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].attn_sinks, 1.0f/sqrtf(float(n_rot)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, model.layers[il].attn_sinks, nullptr, 1.0f/sqrtf(float(n_rot)), il);
|
||||
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
@@ -17781,8 +17839,7 @@ struct llm_build_lfm2 : public llm_graph_context {
|
||||
cb(cur, "model.embedding_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head is tied with embeddings
|
||||
cur = build_lora_mm(model.tok_embd, cur);
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
cb(cur, "lm_head", -1);
|
||||
|
||||
res->t_logits = cur;
|
||||
@@ -17847,7 +17904,7 @@ struct llm_build_lfm2 : public llm_graph_context {
|
||||
);
|
||||
|
||||
cur = build_attn(inp_attn, model.layers[il].wo, NULL,
|
||||
q, k, v, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
q, k, v, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
|
||||
cb(cur, "model.layers.{}.self_attn.out_proj", il);
|
||||
|
||||
@@ -17924,6 +17981,137 @@ struct llm_build_lfm2 : public llm_graph_context {
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_seed_oss : public llm_graph_context {
|
||||
llm_build_seed_oss(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
}
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
}
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
}
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].attn_post_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_post_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool iswa>
|
||||
struct llm_build_smallthinker : public llm_graph_context{
|
||||
llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){
|
||||
@@ -17991,7 +18179,7 @@ struct llm_build_smallthinker : public llm_graph_context{
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -18069,12 +18257,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
if (llm_arch_is_recurrent(arch)) {
|
||||
res = new llama_memory_recurrent(
|
||||
*this,
|
||||
nullptr,
|
||||
GGML_TYPE_F32,
|
||||
GGML_TYPE_F32,
|
||||
cparams.offload_kqv,
|
||||
std::max((uint32_t) 1, cparams.n_seq_max),
|
||||
cparams.n_seq_max);
|
||||
cparams.n_seq_max,
|
||||
nullptr);
|
||||
} else if (llm_arch_is_hybrid(arch)) {
|
||||
const auto padding = llama_kv_cache::get_padding(cparams);
|
||||
|
||||
@@ -18115,6 +18303,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
||||
|
||||
llama_memory_i::layer_reuse_cb reuse = nullptr;
|
||||
|
||||
if (arch == LLM_ARCH_GEMMA3N) {
|
||||
reuse = [&](int32_t il) {
|
||||
if (il >= (int32_t) hparams.n_layer_kv_from_start) {
|
||||
return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1);
|
||||
}
|
||||
|
||||
return -1;
|
||||
};
|
||||
}
|
||||
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
GGML_ASSERT(hparams.is_swa_any());
|
||||
|
||||
@@ -18129,13 +18329,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
n_ctx_per_stream,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_ubatch,
|
||||
padding);
|
||||
padding,
|
||||
nullptr,
|
||||
reuse);
|
||||
} else {
|
||||
GGML_ASSERT(!hparams.is_swa_any());
|
||||
|
||||
res = new llama_kv_cache(
|
||||
*this,
|
||||
nullptr,
|
||||
params.type_k,
|
||||
params.type_v,
|
||||
!cparams.flash_attn,
|
||||
@@ -18145,7 +18346,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
cparams.n_seq_max,
|
||||
padding,
|
||||
hparams.n_swa,
|
||||
hparams.swa_type);
|
||||
hparams.swa_type,
|
||||
nullptr,
|
||||
nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -18462,6 +18665,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||
{
|
||||
llm = std::make_unique<llm_build_bailingmoe>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_SEED_OSS:
|
||||
{
|
||||
llm = std::make_unique<llm_build_seed_oss>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_DOTS1:
|
||||
{
|
||||
llm = std::make_unique<llm_build_dots1>(*this, params);
|
||||
@@ -18520,6 +18727,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||
return llm->res->get_gf();
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// interface implementation
|
||||
//
|
||||
@@ -18714,6 +18922,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_LFM2:
|
||||
case LLM_ARCH_SMALLTHINKER:
|
||||
case LLM_ARCH_GLM4_MOE:
|
||||
case LLM_ARCH_SEED_OSS:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
|
||||
@@ -76,6 +76,7 @@ enum llm_type {
|
||||
LLM_TYPE_32B,
|
||||
LLM_TYPE_34B,
|
||||
LLM_TYPE_35B,
|
||||
LLM_TYPE_36B,
|
||||
LLM_TYPE_40B,
|
||||
LLM_TYPE_65B,
|
||||
LLM_TYPE_70B,
|
||||
|
||||
+178
-8
@@ -2209,6 +2209,26 @@ struct test_count_equal : public test_case {
|
||||
double max_nmse_err() override {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
std::random_device rd;
|
||||
std::default_random_engine rng(rd());
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_F32) {
|
||||
// initialize with unique values to avoid ties
|
||||
for (int64_t r = 0; r < ggml_nrows(t); r++) {
|
||||
std::vector<float> data(t->ne[0]);
|
||||
for (int i = 0; i < t->ne[0]; i++) {
|
||||
data[i] = i;
|
||||
}
|
||||
std::shuffle(data.begin(), data.end(), rng);
|
||||
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
|
||||
}
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_REPEAT
|
||||
@@ -2858,6 +2878,7 @@ struct test_rms_norm_mul_add : public test_case {
|
||||
const std::array<int64_t, 4> ne;
|
||||
const float eps;
|
||||
const bool broadcast;
|
||||
const bool multi_add; // test a sequence of adds feeding into rms_norm
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
@@ -2867,13 +2888,13 @@ struct test_rms_norm_mul_add : public test_case {
|
||||
bool run_whole_graph() override { return true; }
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, ne, eps, broadcast);
|
||||
return VARS_TO_STR5(type, ne, eps, broadcast, multi_add);
|
||||
}
|
||||
|
||||
test_rms_norm_mul_add(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {64, 5, 4, 3},
|
||||
float eps = 1e-6f, bool broadcast = false)
|
||||
: type(type), ne(ne), eps(eps), broadcast(broadcast) {}
|
||||
float eps = 1e-6f, bool broadcast = false, bool multi_add = false)
|
||||
: type(type), ne(ne), eps(eps), broadcast(broadcast), multi_add(multi_add) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
std::array<int64_t, 4> broadcast_dims = {ne[0]*2, ne[1]*3, ne[2]*3, ne[3]*4};
|
||||
@@ -2891,6 +2912,9 @@ struct test_rms_norm_mul_add : public test_case {
|
||||
|
||||
// Use a, b and c early, so we don't end up with an OP_NONE between rms_norm and mul
|
||||
a = ggml_add(ctx, ggml_add(ctx, a, b), c);
|
||||
if (multi_add) {
|
||||
a = ggml_add(ctx, ggml_add(ctx, a, b), c);
|
||||
}
|
||||
ggml_tensor * out = ggml_add(ctx, ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b), c);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
@@ -4091,6 +4115,75 @@ struct test_conv_2d_dw : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_CONV_3D
|
||||
struct test_conv_3d : public test_case {
|
||||
// Logical 5D dimensions
|
||||
const int64_t N, IC, ID, IH, IW;
|
||||
const int64_t OC, KD, KH, KW;
|
||||
// Conv params
|
||||
const int s0, s1, s2;
|
||||
const int p0, p1, p2;
|
||||
const int d0, d1, d2;
|
||||
// Types
|
||||
const ggml_type type_kernel;
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
return "CONV_3D";
|
||||
}
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR11(N, IC, ID, IH, IW, OC, KD, KH, KW, s0, s1) + "," +
|
||||
VARS_TO_STR8(s2, p0, p1, p2, d0, d1, d2, type_kernel);
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
return 5e-4;
|
||||
}
|
||||
|
||||
uint64_t op_flops(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
|
||||
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
|
||||
};
|
||||
const int64_t OD = calc_conv_output_size(ID, KD, s2, p2, d2);
|
||||
const int64_t OH = calc_conv_output_size(IH, KH, s1, p1, d1);
|
||||
const int64_t OW = calc_conv_output_size(IW, KW, s0, p0, d0);
|
||||
|
||||
return (uint64_t)N * OC * OD * OH * OW * (2 * IC * KD * KH * KW - 1);
|
||||
}
|
||||
|
||||
test_conv_3d(
|
||||
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW,
|
||||
int64_t OC, int64_t KD, int64_t KH, int64_t KW,
|
||||
int s0, int s1, int s2,
|
||||
int p0, int p1, int p2,
|
||||
int d0, int d1, int d2,
|
||||
ggml_type type_kernel
|
||||
) : N(N), IC(IC), ID(ID), IH(IH), IW(IW),
|
||||
OC(OC), KD(KD), KH(KH), KW(KW),
|
||||
s0(s0), s1(s1), s2(s2),
|
||||
p0(p0), p1(p1), p2(p2),
|
||||
d0(d0), d1(d1), d2(d2),
|
||||
type_kernel(type_kernel) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
// GGML input tensor is packed as [W, H, D, C*N]
|
||||
const int64_t ne_input[] = {IW, IH, ID, IC * N};
|
||||
ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input);
|
||||
ggml_set_name(input, "input");
|
||||
|
||||
// GGML kernel tensor is packed as [KW, KH, KD, IC*OC]
|
||||
const int64_t ne_kernel[] = {KW, KH, KD, IC * OC};
|
||||
ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel);
|
||||
ggml_set_name(kernel, "kernel");
|
||||
|
||||
ggml_tensor * out = ggml_conv_3d(ctx, kernel, input, s0, s1, s2, p0, p1, p2, d0, d1, d2, (int)IC, (int)N, (int)OC);
|
||||
ggml_set_name(out, "out");
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_CONCAT
|
||||
struct test_concat : public test_case {
|
||||
const ggml_type type;
|
||||
@@ -4231,20 +4324,32 @@ struct test_sum : public test_case {
|
||||
struct test_sum_rows : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const bool permute;
|
||||
const bool slice;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR2(type, ne);
|
||||
return VARS_TO_STR4(type, ne, permute, slice);
|
||||
}
|
||||
|
||||
test_sum_rows(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {10, 5, 4, 3})
|
||||
: type(type), ne(ne) {}
|
||||
std::array<int64_t, 4> ne = {10, 5, 4, 3},
|
||||
bool permute = false, bool slice = false)
|
||||
: type(type), ne(ne), permute(permute), slice(slice) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_param(a);
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
if (slice) {
|
||||
a = ggml_view_4d(ctx, a,
|
||||
ne[0], ne[1], ne[2] / 2, ne[3] - 1,
|
||||
a->nb[1], a->nb[2] * 2, a->nb[3], /*offset=*/a->nb[3]);
|
||||
}
|
||||
if (permute) {
|
||||
a = ggml_permute(ctx, a, 0, 2, 3, 1);
|
||||
}
|
||||
|
||||
ggml_tensor * out = ggml_sum_rows(ctx, a);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
@@ -5528,6 +5633,61 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, false));
|
||||
test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, true));
|
||||
|
||||
// CONV_3D
|
||||
auto calc_conv_output_size_3d = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
|
||||
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
|
||||
};
|
||||
|
||||
for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
for (int N : {1, 2}) {
|
||||
for (int IC : {1, 3}) {
|
||||
for (int OC : {1, 4}) {
|
||||
for (int s0 : {1, 2}) {
|
||||
for (int p1 : {0, 1}) {
|
||||
for (int d2 : {1, 2}) {
|
||||
int64_t IW = 20, IH = 22, ID = 18;
|
||||
int64_t KW = 3, KH = 3, KD = 3;
|
||||
int s1 = s0, s2 = s0;
|
||||
int p0 = p1, p2 = p1;
|
||||
int d0 = d2, d1 = d2;
|
||||
|
||||
if (calc_conv_output_size_3d(IW, KW, s0, p0, d0) <= 0 ||
|
||||
calc_conv_output_size_3d(IH, KH, s1, p1, d1) <= 0 ||
|
||||
calc_conv_output_size_3d(ID, KD, s2, p2, d2) <= 0) {
|
||||
continue;
|
||||
}
|
||||
test_cases.emplace_back(new test_conv_3d(
|
||||
N, IC, ID, IH, IW,
|
||||
OC, KD, KH, KW,
|
||||
s0, s1, s2, p0, p1, p2, d0, d1, d2,
|
||||
kernel_type));
|
||||
|
||||
// Asymmetric kernel and params
|
||||
int64_t asym_KW = 5, asym_KH = 1, asym_KD = 3;
|
||||
int asym_s0 = 2, asym_s1 = 1, asym_s2 = 1;
|
||||
int asym_p0 = 2, asym_p1 = 0, asym_p2 = 1;
|
||||
int asym_d0 = 1, asym_d1 = 1, asym_d2 = 2;
|
||||
|
||||
if (calc_conv_output_size_3d(IW, asym_KW, asym_s0, asym_p0, asym_d0) <= 0 ||
|
||||
calc_conv_output_size_3d(IH, asym_KH, asym_s1, asym_p1, asym_d1) <= 0 ||
|
||||
calc_conv_output_size_3d(ID, asym_KD, asym_s2, asym_p2, asym_d2) <= 0) {
|
||||
continue;
|
||||
}
|
||||
test_cases.emplace_back(new test_conv_3d(
|
||||
N, IC, ID, IH, IW,
|
||||
OC, asym_KD, asym_KH, asym_KW,
|
||||
asym_s0, asym_s1, asym_s2, asym_p0, asym_p1, asym_p2, asym_d0, asym_d1, asym_d2,
|
||||
kernel_type));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Case with kernel size 1
|
||||
test_cases.emplace_back(new test_conv_3d(1, 4, 8, 8, 8, 8, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, kernel_type));
|
||||
}
|
||||
|
||||
for(uint32_t Cout : {1, 9}){
|
||||
for(uint32_t Cin : {1, 7}){
|
||||
for(uint32_t K : {1, 3, 1337}){
|
||||
@@ -5706,6 +5866,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
||||
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
|
||||
}
|
||||
for (uint32_t n : {1, 511, 1025, 8192, 33*512}) {
|
||||
for (bool multi_add : {false, true}) {
|
||||
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f, false, multi_add));
|
||||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
|
||||
|
||||
@@ -5852,6 +6017,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
// test large experts*tokens
|
||||
for (bool b : {false, true}) {
|
||||
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 32, 1024, 16));
|
||||
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 2, 2, b, 32, 8192, 64));
|
||||
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 50, 200, 64));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
|
||||
@@ -6071,6 +6238,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
|
||||
test_cases.emplace_back(new test_sum());
|
||||
test_cases.emplace_back(new test_sum_rows());
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));
|
||||
test_cases.emplace_back(new test_mean());
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 }));
|
||||
@@ -6091,8 +6261,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_timestep_embedding());
|
||||
test_cases.emplace_back(new test_leaky_relu());
|
||||
|
||||
for (int hsk : { 64, 80, 128, 192, 256, 576 }) {
|
||||
for (int hsv : { 64, 80, 128, 192, 256, 512 }) {
|
||||
for (int hsk : { 40, 64, 80, 128, 192, 256, 576 }) {
|
||||
for (int hsv : { 40, 64, 80, 128, 192, 256, 512 }) {
|
||||
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
|
||||
if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
|
||||
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
|
||||
|
||||
@@ -290,6 +290,14 @@ int main(void) {
|
||||
/* .bos_token= */ "",
|
||||
/* .eos_token= */ "",
|
||||
},
|
||||
{
|
||||
/* .name= */ "ByteDance-Seed/Seed-OSS-36B-Instruct",
|
||||
/* .template_str */ "{# <seed:bos> #}{%- for message in messages %}{%- if message.role in [\"user\", \"system\"] %}{{ bos_token + message.role + \"\\n\" + message.content + eos_token }}{%- elif message.role == \"assistant\" %}{{ bos_token + message.role }}{%- if message.content is defined and message.content is string and message.content|trim|length > 0 %}{{ \"\\n\" + message.content|trim + eos_token }}{%- endif %}{%- else %}{{ bos_token + message.role + \"\\n\" + message.content + eos_token }}{%- endif %}{%- endfor %}{%- if add_generation_prompt %}{{ bos_token + \"assistant\\n\" }}{%- endif %}",
|
||||
/* .expected_output= */ "<seed:bos>system\nYou are a helpful assistant<seed:eos><seed:bos>user\nHello<seed:eos><seed:bos>assistant\nHi there<seed:eos><seed:bos>user\nWho are you<seed:eos><seed:bos>assistant\nI am an assistant<seed:eos><seed:bos>user\nAnother question<seed:eos><seed:bos>assistant\n",
|
||||
/* .expected_output_jinja= */ "<seed:bos>system\nYou are a helpful assistant<seed:eos><seed:bos>user\nHello<seed:eos><seed:bos>assistant\nHi there<seed:eos><seed:bos>user\nWho are you<seed:eos><seed:bos>assistant\nI am an assistant<seed:eos><seed:bos>user\nAnother question<seed:eos><seed:bos>assistant\n",
|
||||
/* .bos_token= */ "<seed:bos>",
|
||||
/* .eos_token= */ "<seed:eos>",
|
||||
}
|
||||
};
|
||||
std::vector<char> formatted_chat(1024);
|
||||
int32_t res;
|
||||
|
||||
+20
-12
@@ -358,7 +358,7 @@ static std::pair<int, int> test_forward_backward(
|
||||
double accuracy;
|
||||
double accuracy_unc;
|
||||
ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
|
||||
const bool subtest_ok = ndata == 0 && loss == 0.0 && std::isnan(loss_unc) && std::isnan(accuracy) && std::isnan(accuracy_unc);
|
||||
const bool subtest_ok = ndata == 0 && almost_equal(loss, 0.0, 1e-6) && std::isnan(loss_unc) && std::isnan(accuracy) && std::isnan(accuracy_unc);
|
||||
helper_after_test_forward_backward(optim, __func__, high_level, shuffle, "results_initial", subtest_ok, ntest, npass);
|
||||
}
|
||||
|
||||
@@ -381,10 +381,12 @@ static std::pair<int, int> test_forward_backward(
|
||||
{
|
||||
float weights;
|
||||
ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
|
||||
const bool subtest_ok = weights == ndata/2;
|
||||
const bool subtest_ok = almost_equal(weights, ndata/2, 1e-10);
|
||||
helper_after_test_forward_backward(optim, __func__, high_level, shuffle, "weights_after_forward", subtest_ok, ntest, npass);
|
||||
}
|
||||
{
|
||||
constexpr double atol = 1e-10;
|
||||
|
||||
int64_t ndata;
|
||||
ggml_opt_result_ndata(cd.result, &ndata);
|
||||
bool subtest_ok = ndata == 6;
|
||||
@@ -392,7 +394,7 @@ static std::pair<int, int> test_forward_backward(
|
||||
double loss;
|
||||
double loss_unc;
|
||||
ggml_opt_result_loss(cd.result, &loss, &loss_unc);
|
||||
subtest_ok = subtest_ok && loss == 33.0 && almost_equal(loss_unc, sqrt(3.5), 1e-10);
|
||||
subtest_ok = subtest_ok && almost_equal(loss, 33.0, atol) && almost_equal(loss_unc, sqrt(3.5), atol);
|
||||
|
||||
double accuracy;
|
||||
double accuracy_unc;
|
||||
@@ -437,7 +439,7 @@ static std::pair<int, int> test_forward_backward(
|
||||
{
|
||||
float weights;
|
||||
ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
|
||||
const bool subtest_ok = weights == -ndata * .5;
|
||||
const bool subtest_ok = almost_equal(weights, -ndata * 0.5, 1e-10);
|
||||
helper_after_test_forward_backward(optim, __func__, high_level, shuffle, "weights_after_forward_backward", subtest_ok, ntest, npass);
|
||||
}
|
||||
{
|
||||
@@ -448,7 +450,7 @@ static std::pair<int, int> test_forward_backward(
|
||||
double loss;
|
||||
double loss_unc;
|
||||
ggml_opt_result_loss(cd.result, &loss, &loss_unc);
|
||||
subtest_ok = subtest_ok && loss == 18.0 && (shuffle || loss_unc == 0.0);
|
||||
subtest_ok = subtest_ok && almost_equal(loss, 18.0, 1e-10) && (shuffle || loss_unc == 0.0);
|
||||
|
||||
double accuracy;
|
||||
double accuracy_unc;
|
||||
@@ -550,10 +552,12 @@ static std::pair<int, int> test_idata_split(
|
||||
if (adamw) {
|
||||
float weights;
|
||||
ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
|
||||
const bool subtest_ok = weights == ndata/2 - epoch*idata_split;
|
||||
const bool subtest_ok = almost_equal(weights, ndata/2 - epoch*idata_split, 1e-10);
|
||||
helper_after_test_idata_split(optim, __func__, high_level, epoch, "weights", subtest_ok, ntest, npass);
|
||||
}
|
||||
if (adamw) {
|
||||
constexpr double atol = 1e-10;
|
||||
|
||||
int64_t ndata_result;
|
||||
ggml_opt_result_ndata(cd.result, &ndata_result);
|
||||
bool subtest_ok = ndata_result == idata_split;
|
||||
@@ -561,7 +565,7 @@ static std::pair<int, int> test_idata_split(
|
||||
double loss;
|
||||
double loss_unc;
|
||||
ggml_opt_result_loss(cd.result, &loss, &loss_unc);
|
||||
subtest_ok = subtest_ok && loss == 28.0 - epoch*16.0 && loss_unc == 0.0;
|
||||
subtest_ok = subtest_ok && almost_equal(loss, 28.0 - epoch*16.0, atol) && almost_equal(loss_unc, 0.0, atol);
|
||||
|
||||
double accuracy;
|
||||
double accuracy_unc;
|
||||
@@ -571,6 +575,8 @@ static std::pair<int, int> test_idata_split(
|
||||
helper_after_test_idata_split(optim, __func__, high_level, epoch, "results_backward", subtest_ok, ntest, npass);
|
||||
}
|
||||
if (adamw) {
|
||||
constexpr double atol = 1e-10;
|
||||
|
||||
int64_t ndata_result;
|
||||
ggml_opt_result_ndata(cd.result2, &ndata_result);
|
||||
bool subtest_ok = ndata_result == ndata - idata_split;
|
||||
@@ -578,7 +584,7 @@ static std::pair<int, int> test_idata_split(
|
||||
double loss;
|
||||
double loss_unc;
|
||||
ggml_opt_result_loss(cd.result2, &loss, &loss_unc);
|
||||
subtest_ok = subtest_ok && loss == 15.0 - epoch*8 && almost_equal(loss_unc, sqrt(0.5), 1e-10);
|
||||
subtest_ok = subtest_ok && almost_equal(loss, 15.0 - epoch*8, atol) && almost_equal(loss_unc, sqrt(0.5), atol);
|
||||
|
||||
double accuracy;
|
||||
double accuracy_unc;
|
||||
@@ -687,22 +693,24 @@ static std::pair<int, int> test_gradient_accumulation(
|
||||
}
|
||||
bool const adamw = optim == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
||||
if (adamw) {
|
||||
constexpr double atol = 1e-6;
|
||||
float weights;
|
||||
ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
|
||||
const bool subtest_ok = weights == (ndata/2) - epoch;
|
||||
const bool subtest_ok = almost_equal(weights, (ndata/2) - epoch, atol);
|
||||
helper_after_test_gradient_accumulation(optim, __func__, nbatch_physical, loss_type, epoch, "weights", subtest_ok, ntest, npass);
|
||||
}
|
||||
{
|
||||
constexpr double atol = 1e-6;
|
||||
int64_t ndata_result;
|
||||
ggml_opt_result_ndata(cd.result, &ndata_result);
|
||||
bool subtest_ok = ndata_result == ndata/nbatch_physical;
|
||||
bool subtest_ok = almost_equal(ndata_result, ndata/nbatch_physical, atol);
|
||||
|
||||
double loss;
|
||||
ggml_opt_result_loss(cd.result, &loss, /*loss_unc =*/ nullptr);
|
||||
if (loss_type == GGML_OPT_LOSS_TYPE_SUM) {
|
||||
subtest_ok = subtest_ok && loss == (39.0 - epoch*6.0);
|
||||
subtest_ok = subtest_ok && almost_equal(loss, (39.0 - epoch*6.0), atol);
|
||||
} else if (loss_type == GGML_OPT_LOSS_TYPE_MEAN) {
|
||||
subtest_ok = subtest_ok && almost_equal(loss, (39.0 - epoch*6.0) / ndata, 1e-6);
|
||||
subtest_ok = subtest_ok && almost_equal(loss, (39.0 - epoch*6.0) / ndata, atol);
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
@@ -124,7 +124,7 @@ int main(int argc, char ** argv) {
|
||||
const int tg = n_tg[i_tg];
|
||||
const int pl = n_pl[i_pl];
|
||||
|
||||
const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
|
||||
const int n_ctx_req = is_pp_shared ? (params.kv_unified ? pp : pl*pp) + pl*tg : pl*(pp + tg);
|
||||
|
||||
if (n_ctx_req > n_kv_max) {
|
||||
continue;
|
||||
@@ -147,13 +147,24 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
const auto t_pp_end = ggml_time_us();
|
||||
|
||||
if (is_pp_shared) {
|
||||
for (int32_t i = 1; i < pl; ++i) {
|
||||
llama_memory_seq_cp(mem, 0, i, -1, -1);
|
||||
}
|
||||
}
|
||||
|
||||
const auto t_pp_end = ggml_time_us();
|
||||
if (!params.kv_unified) {
|
||||
// run one dummy token to apply the memory copy
|
||||
common_batch_clear(batch);
|
||||
common_batch_add(batch, get_token_rand(), pp + 0, { 0 }, true);
|
||||
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
|
||||
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
llama_memory_seq_rm(mem, 0, pp, -1);
|
||||
}
|
||||
}
|
||||
|
||||
const auto t_tg_start = ggml_time_us();
|
||||
|
||||
|
||||
@@ -43,7 +43,6 @@ test parameters:
|
||||
-ub, --ubatch-size <n> (default: 512)
|
||||
-ctk, --cache-type-k <t> (default: f16)
|
||||
-ctv, --cache-type-v <t> (default: f16)
|
||||
-dt, --defrag-thold <f> (default: -1)
|
||||
-t, --threads <n> (default: system dependent)
|
||||
-C, --cpu-mask <hex,hex> (default: 0x0)
|
||||
--cpu-strict <0|1> (default: 0)
|
||||
|
||||
@@ -245,7 +245,6 @@ struct cmd_params {
|
||||
std::vector<int> n_ubatch;
|
||||
std::vector<ggml_type> type_k;
|
||||
std::vector<ggml_type> type_v;
|
||||
std::vector<float> defrag_thold;
|
||||
std::vector<int> n_threads;
|
||||
std::vector<std::string> cpu_mask;
|
||||
std::vector<bool> cpu_strict;
|
||||
@@ -282,7 +281,6 @@ static const cmd_params cmd_params_defaults = {
|
||||
/* n_ubatch */ { 512 },
|
||||
/* type_k */ { GGML_TYPE_F16 },
|
||||
/* type_v */ { GGML_TYPE_F16 },
|
||||
/* defrag_thold */ { -1.0f },
|
||||
/* n_threads */ { cpu_get_num_math() },
|
||||
/* cpu_mask */ { "0x0" },
|
||||
/* cpu_strict */ { false },
|
||||
@@ -346,8 +344,6 @@ static void print_usage(int /* argc */, char ** argv) {
|
||||
join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str());
|
||||
printf(" -ctv, --cache-type-v <t> (default: %s)\n",
|
||||
join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str());
|
||||
printf(" -dt, --defrag-thold <f> (default: %s)\n",
|
||||
join(cmd_params_defaults.defrag_thold, ",").c_str());
|
||||
printf(" -t, --threads <n> (default: %s)\n",
|
||||
join(cmd_params_defaults.n_threads, ",").c_str());
|
||||
printf(" -C, --cpu-mask <hex,hex> (default: %s)\n",
|
||||
@@ -533,13 +529,6 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
break;
|
||||
}
|
||||
params.type_v.insert(params.type_v.end(), types.begin(), types.end());
|
||||
} else if (arg == "-dt" || arg == "--defrag-thold") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
auto p = string_split<float>(argv[i], split_delim);
|
||||
params.defrag_thold.insert(params.defrag_thold.end(), p.begin(), p.end());
|
||||
} else if (arg == "-t" || arg == "--threads") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
@@ -849,9 +838,6 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
if (params.type_v.empty()) {
|
||||
params.type_v = cmd_params_defaults.type_v;
|
||||
}
|
||||
if (params.defrag_thold.empty()) {
|
||||
params.defrag_thold = cmd_params_defaults.defrag_thold;
|
||||
}
|
||||
if (params.n_gpu_layers.empty()) {
|
||||
params.n_gpu_layers = cmd_params_defaults.n_gpu_layers;
|
||||
}
|
||||
@@ -910,7 +896,6 @@ struct cmd_params_instance {
|
||||
int n_ubatch;
|
||||
ggml_type type_k;
|
||||
ggml_type type_v;
|
||||
float defrag_thold;
|
||||
int n_threads;
|
||||
std::string cpu_mask;
|
||||
bool cpu_strict;
|
||||
@@ -1007,7 +992,6 @@ struct cmd_params_instance {
|
||||
cparams.n_ubatch = n_ubatch;
|
||||
cparams.type_k = type_k;
|
||||
cparams.type_v = type_v;
|
||||
cparams.defrag_thold = defrag_thold;
|
||||
cparams.offload_kqv = !no_kv_offload;
|
||||
cparams.flash_attn = flash_attn;
|
||||
cparams.embeddings = embeddings;
|
||||
@@ -1037,7 +1021,6 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
for (const auto & nub : params.n_ubatch)
|
||||
for (const auto & tk : params.type_k)
|
||||
for (const auto & tv : params.type_v)
|
||||
for (const auto & defrag_thold : params.defrag_thold)
|
||||
for (const auto & nkvo : params.no_kv_offload)
|
||||
for (const auto & fa : params.flash_attn)
|
||||
for (const auto & nt : params.n_threads)
|
||||
@@ -1058,7 +1041,6 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
/* .n_ubatch = */ nub,
|
||||
/* .type_k = */ tk,
|
||||
/* .type_v = */ tv,
|
||||
/* .defrag_thold = */ defrag_thold,
|
||||
/* .n_threads = */ nt,
|
||||
/* .cpu_mask = */ cm,
|
||||
/* .cpu_strict = */ cs,
|
||||
@@ -1091,7 +1073,6 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
/* .n_ubatch = */ nub,
|
||||
/* .type_k = */ tk,
|
||||
/* .type_v = */ tv,
|
||||
/* .defrag_thold = */ defrag_thold,
|
||||
/* .n_threads = */ nt,
|
||||
/* .cpu_mask = */ cm,
|
||||
/* .cpu_strict = */ cs,
|
||||
@@ -1124,7 +1105,6 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
/* .n_ubatch = */ nub,
|
||||
/* .type_k = */ tk,
|
||||
/* .type_v = */ tv,
|
||||
/* .defrag_thold = */ defrag_thold,
|
||||
/* .n_threads = */ nt,
|
||||
/* .cpu_mask = */ cm,
|
||||
/* .cpu_strict = */ cs,
|
||||
@@ -1166,7 +1146,6 @@ struct test {
|
||||
int poll;
|
||||
ggml_type type_k;
|
||||
ggml_type type_v;
|
||||
float defrag_thold;
|
||||
int n_gpu_layers;
|
||||
llama_split_mode split_mode;
|
||||
int main_gpu;
|
||||
@@ -1201,7 +1180,6 @@ struct test {
|
||||
poll = inst.poll;
|
||||
type_k = inst.type_k;
|
||||
type_v = inst.type_v;
|
||||
defrag_thold = inst.defrag_thold;
|
||||
n_gpu_layers = inst.n_gpu_layers;
|
||||
split_mode = inst.split_mode;
|
||||
main_gpu = inst.main_gpu;
|
||||
@@ -1257,7 +1235,6 @@ struct test {
|
||||
"model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads",
|
||||
"cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers",
|
||||
"split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides",
|
||||
"defrag_thold",
|
||||
"use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen", "n_depth", "test_time",
|
||||
"avg_ns", "stddev_ns", "avg_ts", "stddev_ts",
|
||||
};
|
||||
@@ -1277,7 +1254,7 @@ struct test {
|
||||
field == "use_mmap" || field == "embeddings") {
|
||||
return BOOL;
|
||||
}
|
||||
if (field == "avg_ts" || field == "stddev_ts" || field == "defrag_thold") {
|
||||
if (field == "avg_ts" || field == "stddev_ts") {
|
||||
return FLOAT;
|
||||
}
|
||||
return STRING;
|
||||
@@ -1344,7 +1321,6 @@ struct test {
|
||||
std::to_string(flash_attn),
|
||||
tensor_split_str,
|
||||
tensor_buft_overrides_str,
|
||||
std::to_string(defrag_thold),
|
||||
std::to_string(use_mmap),
|
||||
std::to_string(embeddings),
|
||||
std::to_string(no_op_offload),
|
||||
@@ -1611,9 +1587,6 @@ struct markdown_printer : public printer {
|
||||
if (params.type_v.size() > 1 || params.type_v != cmd_params_defaults.type_v) {
|
||||
fields.emplace_back("type_v");
|
||||
}
|
||||
if (params.defrag_thold.size() > 1 || params.defrag_thold != cmd_params_defaults.defrag_thold) {
|
||||
fields.emplace_back("defrag_thold");
|
||||
}
|
||||
if (params.main_gpu.size() > 1 || params.main_gpu != cmd_params_defaults.main_gpu) {
|
||||
fields.emplace_back("main_gpu");
|
||||
}
|
||||
|
||||
+6
-1
@@ -2202,6 +2202,8 @@ struct clip_model_loader {
|
||||
hparams.minicpmv_query_num = 64;
|
||||
} else if (hparams.minicpmv_version == 5) {
|
||||
hparams.minicpmv_query_num = 64;
|
||||
} else if (hparams.minicpmv_version == 6) {
|
||||
hparams.minicpmv_query_num = 64;
|
||||
} else {
|
||||
hparams.minicpmv_query_num = 96;
|
||||
}
|
||||
@@ -3513,7 +3515,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||
const int height = img->ny;
|
||||
const int total_factor = params.patch_size * params.proj_scale_factor;
|
||||
constexpr int min_image_tokens = 64;
|
||||
constexpr int max_image_tokens = 256;
|
||||
constexpr int max_image_tokens = 1024;
|
||||
const float min_pixels = min_image_tokens * total_factor * total_factor;
|
||||
const float max_pixels = max_image_tokens * total_factor * total_factor;
|
||||
|
||||
@@ -3685,6 +3687,9 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||
} else if (params.minicpmv_version == 5) {
|
||||
// MiniCPM-V 4.0
|
||||
n_patches = 64;
|
||||
} else if (params.minicpmv_version == 6) {
|
||||
// MiniCPM-V 4.5
|
||||
n_patches = 64;
|
||||
} else {
|
||||
GGML_ABORT("Unknown minicpmv version");
|
||||
}
|
||||
|
||||
@@ -607,6 +607,9 @@ else:
|
||||
elif minicpmv_version == 5:
|
||||
emb_dim = 2560
|
||||
block_count = 27
|
||||
elif minicpmv_version == 6:
|
||||
emb_dim = 4096
|
||||
block_count = 27
|
||||
|
||||
default_vision_config = {
|
||||
"hidden_size": 1152,
|
||||
@@ -630,6 +633,10 @@ elif minicpmv_version == 5:
|
||||
default_vision_config["model_type"] = "siglip_vision_model"
|
||||
vision_config = SiglipVisionConfig(**default_vision_config)
|
||||
model = SiglipVisionTransformer(vision_config)
|
||||
elif minicpmv_version == 6:
|
||||
default_vision_config["model_type"] = "siglip_vision_model"
|
||||
vision_config = SiglipVisionConfig(**default_vision_config)
|
||||
model = SiglipVisionTransformer(vision_config)
|
||||
|
||||
processor = None
|
||||
# if model.attn_pool is not None:
|
||||
|
||||
+1
-1
@@ -207,7 +207,7 @@ struct mtmd_context {
|
||||
tok_row_end_trail = false; // no trailing end-of-row token
|
||||
ov_img_first = true;
|
||||
|
||||
} else if (minicpmv_version == 3 || minicpmv_version == 4 || minicpmv_version == 5) {
|
||||
} else if (minicpmv_version == 3 || minicpmv_version == 4 || minicpmv_version == 5 || minicpmv_version == 6) {
|
||||
// minicpmv 2.6 format:
|
||||
// <image> (overview) </image><slice> (slice) </slice><slice> (slice) </slice>\n ...
|
||||
slice_tmpl = MTMD_SLICE_TMPL_MINICPMV_2_6;
|
||||
|
||||
+12
-7
@@ -66,7 +66,7 @@ The project is under active development, and we are [looking for feedback and co
|
||||
| `-nkvo, --no-kv-offload` | disable KV offload<br/>(env: LLAMA_ARG_NO_KV_OFFLOAD) |
|
||||
| `-ctk, --cache-type-k TYPE` | KV cache data type for K<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_K) |
|
||||
| `-ctv, --cache-type-v TYPE` | KV cache data type for V<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_V) |
|
||||
| `-dt, --defrag-thold N` | KV cache defragmentation threshold (default: 0.1, < 0 - disabled)<br/>(env: LLAMA_ARG_DEFRAG_THOLD) |
|
||||
| `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)<br/>(env: LLAMA_ARG_DEFRAG_THOLD) |
|
||||
| `-np, --parallel N` | number of parallel sequences to decode (default: 1)<br/>(env: LLAMA_ARG_N_PARALLEL) |
|
||||
| `--mlock` | force system to keep model in RAM rather than swapping or compressing<br/>(env: LLAMA_ARG_MLOCK) |
|
||||
| `--no-mmap` | do not memory-map model (slower load but may reduce pageouts if not using mlock)<br/>(env: LLAMA_ARG_NO_MMAP) |
|
||||
@@ -226,6 +226,10 @@ services:
|
||||
### Multimodal support
|
||||
|
||||
Multimodal support was added in [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) and is currently an experimental feature.
|
||||
It is currently available in the following endpoints:
|
||||
- The OAI-compatible chat endpoint.
|
||||
- The non-OAI-compatible completions endpoint.
|
||||
- The non-OAI-compatible embeddings endpoint.
|
||||
|
||||
For more details, please refer to [multimodal documentation](../../docs/multimodal.md)
|
||||
|
||||
@@ -400,12 +404,15 @@ These input shapes and data type are allowed for `prompt`:
|
||||
- Single string: `"string"`
|
||||
- Single sequence of tokens: `[12, 34, 56]`
|
||||
- Mixed tokens and strings: `[12, 34, "string", 56, 78]`
|
||||
- A JSON object which optionally contains multimodal data: `{ "prompt_string": "string", "multimodal_data": ["base64"] }`
|
||||
|
||||
Multiple prompts are also supported. In this case, the completion result will be an array.
|
||||
|
||||
- Only strings: `["string1", "string2"]`
|
||||
- Strings and sequences of tokens: `["string1", [12, 34, 56]]`
|
||||
- Mixed types: `[[12, 34, "string", 56, 78], [12, 34, 56], "string"]`
|
||||
- Strings, JSON objects, and sequences of tokens: `["string1", [12, 34, 56], { "prompt_string": "string", "multimodal_data": ["base64"]}]`
|
||||
- Mixed types: `[[12, 34, "string", 56, 78], [12, 34, 56], "string", { "prompt_string": "string" }]`
|
||||
|
||||
Note for `multimodal_data` in JSON object prompts. This should be an array of strings, containing base64 encoded multimodal data such as images and audio. There must be an identical number of MTMD media markers in the string prompt element which act as placeholders for the data provided to this parameter. The multimodal data files will be substituted in order. The marker string (e.g. `<__media__>`) can be found by calling `mtmd_default_marker()` defined in [the MTMD C API](https://github.com/ggml-org/llama.cpp/blob/5fd160bbd9d70b94b5b11b0001fd7f477005e4a0/tools/mtmd/mtmd.h#L87). A client *must not* specify this field unless the server has the multimodal capability. Clients should check `/models` or `/v1/models` for the `multimodal` capability before a multimodal request.
|
||||
|
||||
`temperature`: Adjust the randomness of the generated text. Default: `0.8`
|
||||
|
||||
@@ -477,8 +484,6 @@ These words will not be included in the completion, so make sure to add them to
|
||||
|
||||
`t_max_predict_ms`: Set a time limit in milliseconds for the prediction (a.k.a. text-generation) phase. The timeout will trigger if the generation takes more than the specified time (measured since the first token was generated) and if a new-line character has already been generated. Useful for FIM applications. Default: `0`, which is disabled.
|
||||
|
||||
`image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `prompt`. You can determine the place of the image in the prompt as in the following: `USER:[img-12]Describe the image in detail.\nASSISTANT:`. In this case, `[img-12]` will be replaced by the embeddings of the image with id `12` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 12}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
|
||||
|
||||
`id_slot`: Assign the completion task to an specific slot. If is -1 the task will be assigned to a Idle slot. Default: `-1`
|
||||
|
||||
`cache_prompt`: Re-use KV cache from a previous request if possible. This way the common prefix does not have to be re-processed, only the suffix that differs between the requests. Because (depending on the backend) the logits are **not** guaranteed to be bit-for-bit identical for different batch sizes (prompt processing vs. token generation) enabling this option can cause nondeterministic results. Default: `true`
|
||||
@@ -638,12 +643,12 @@ Returns a JSON object with a field `prompt` containing a string of the input mes
|
||||
|
||||
The same as [the embedding example](../embedding) does.
|
||||
|
||||
This endpoint also supports multimodal embeddings. See the documentation for the `/completions` endpoint for details on how to send a multimodal prompt.
|
||||
|
||||
*Options:*
|
||||
|
||||
`content`: Set the text to process.
|
||||
|
||||
`image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `content`. You can determine the place of the image in the content as in the following: `Image: [img-21].\nCaption: This is a picture of a house`. In this case, `[img-21]` will be replaced by the embeddings of the image with id `21` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 21}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
|
||||
|
||||
`embd_normalize`: Normalization for pooled embeddings. Can be one of the following values:
|
||||
```
|
||||
-1: No normalization
|
||||
|
||||
@@ -274,7 +274,6 @@ def start_server_background(args):
|
||||
server_args.extend(['--batch-size', args.batch_size])
|
||||
server_args.extend(['--ubatch-size', args.ubatch_size])
|
||||
server_args.extend(['--n-predict', args.max_tokens * 2])
|
||||
server_args.extend(['--defrag-thold', "0.1"])
|
||||
server_args.append('--cont-batching')
|
||||
server_args.append('--metrics')
|
||||
server_args.append('--flash-attn')
|
||||
|
||||
+20
-57
@@ -4309,6 +4309,7 @@ int main(int argc, char ** argv) {
|
||||
};
|
||||
|
||||
const auto handle_api_show = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
||||
bool has_mtmd = ctx_server.mctx != nullptr;
|
||||
json data = {
|
||||
{
|
||||
"template", common_chat_templates_source(ctx_server.chat_templates.get()),
|
||||
@@ -4330,7 +4331,7 @@ int main(int argc, char ** argv) {
|
||||
{"quantization_level", ""}
|
||||
}},
|
||||
{"model_info", ""},
|
||||
{"capabilities", {"completion"}}
|
||||
{"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}
|
||||
};
|
||||
|
||||
res_ok(res, data);
|
||||
@@ -4356,56 +4357,15 @@ int main(int argc, char ** argv) {
|
||||
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
||||
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
||||
|
||||
// process files
|
||||
mtmd::bitmaps bitmaps;
|
||||
const bool has_mtmd = ctx_server.mctx != nullptr;
|
||||
{
|
||||
if (!has_mtmd && !files.empty()) {
|
||||
throw std::runtime_error("This server does not support multimodal");
|
||||
}
|
||||
for (auto & file : files) {
|
||||
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(ctx_server.mctx, file.data(), file.size()));
|
||||
if (!bmp.ptr) {
|
||||
throw std::runtime_error("Failed to load image or audio file");
|
||||
}
|
||||
// calculate bitmap hash (for KV caching)
|
||||
std::string hash = fnv_hash(bmp.data(), bmp.n_bytes());
|
||||
bmp.set_id(hash.c_str());
|
||||
bitmaps.entries.push_back(std::move(bmp));
|
||||
}
|
||||
}
|
||||
|
||||
// process prompt
|
||||
std::vector<server_tokens> inputs;
|
||||
|
||||
if (oaicompat && has_mtmd) {
|
||||
// multimodal
|
||||
std::string prompt_str = prompt.get<std::string>();
|
||||
mtmd_input_text inp_txt = {
|
||||
prompt_str.c_str(),
|
||||
/* add_special */ true,
|
||||
/* parse_special */ true,
|
||||
};
|
||||
mtmd::input_chunks chunks(mtmd_input_chunks_init());
|
||||
auto bitmaps_c_ptr = bitmaps.c_ptr();
|
||||
int32_t tokenized = mtmd_tokenize(ctx_server.mctx,
|
||||
chunks.ptr.get(),
|
||||
&inp_txt,
|
||||
bitmaps_c_ptr.data(),
|
||||
bitmaps_c_ptr.size());
|
||||
if (tokenized != 0) {
|
||||
throw std::runtime_error("Failed to tokenize prompt");
|
||||
}
|
||||
|
||||
server_tokens tmp(chunks, true);
|
||||
inputs.push_back(std::move(tmp));
|
||||
if (oaicompat && ctx_server.mctx != nullptr) {
|
||||
// This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
|
||||
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
|
||||
} else {
|
||||
// non-multimodal version
|
||||
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
|
||||
for (auto & p : tokenized_prompts) {
|
||||
auto tmp = server_tokens(p, ctx_server.mctx != nullptr);
|
||||
inputs.push_back(std::move(tmp));
|
||||
}
|
||||
// Everything else, including multimodal completions.
|
||||
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
||||
}
|
||||
|
||||
tasks.reserve(inputs.size());
|
||||
@@ -4574,7 +4534,7 @@ int main(int argc, char ** argv) {
|
||||
data["input_extra"] = input_extra; // default to empty array if it's not exist
|
||||
|
||||
std::string prompt = json_value(data, "prompt", std::string());
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, false, true);
|
||||
std::vector<server_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true);
|
||||
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
||||
data["prompt"] = format_infill(
|
||||
ctx_server.vocab,
|
||||
@@ -4585,7 +4545,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.params_base.n_predict,
|
||||
ctx_server.slots[0].n_ctx, // TODO: there should be a better way
|
||||
ctx_server.params_base.spm_infill,
|
||||
tokenized_prompts[0]
|
||||
tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal.
|
||||
);
|
||||
|
||||
std::vector<raw_buffer> files; // dummy
|
||||
@@ -4634,7 +4594,7 @@ int main(int argc, char ** argv) {
|
||||
if (current_state == SERVER_STATE_READY) {
|
||||
model_meta = ctx_server.model_meta();
|
||||
}
|
||||
|
||||
bool has_mtmd = ctx_server.mctx != nullptr;
|
||||
json models = {
|
||||
{"models", {
|
||||
{
|
||||
@@ -4646,7 +4606,7 @@ int main(int argc, char ** argv) {
|
||||
{"type", "model"},
|
||||
{"description", ""},
|
||||
{"tags", {""}},
|
||||
{"capabilities", {"completion"}},
|
||||
{"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})},
|
||||
{"parameters", ""},
|
||||
{"details", {
|
||||
{"parent_model", ""},
|
||||
@@ -4763,7 +4723,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
|
||||
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
||||
for (const auto & tokens : tokenized_prompts) {
|
||||
// this check is necessary for models that do not add BOS token to the input
|
||||
if (tokens.empty()) {
|
||||
@@ -4791,7 +4751,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
task.prompt_tokens = server_tokens(tokenized_prompts[i], ctx_server.mctx != nullptr);
|
||||
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
||||
|
||||
// OAI-compat
|
||||
task.params.oaicompat = oaicompat;
|
||||
@@ -4878,7 +4838,10 @@ int main(int argc, char ** argv) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, query, /* add_special */ false, true)[0];
|
||||
std::vector<server_tokens> tokenized_queries = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, query, /* add_special */ false, true);
|
||||
if (tokenized_queries.size() != 1) {
|
||||
res_error(res, format_error_response("\"query\" must contain only a single prompt", ERROR_TYPE_INVALID_REQUEST));
|
||||
}
|
||||
|
||||
// create and queue the task
|
||||
json responses = json::array();
|
||||
@@ -4886,14 +4849,14 @@ int main(int argc, char ** argv) {
|
||||
std::unordered_set<int> task_ids;
|
||||
{
|
||||
std::vector<server_task> tasks;
|
||||
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
|
||||
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, documents, /* add_special */ false, true);
|
||||
tasks.reserve(tokenized_docs.size());
|
||||
for (size_t i = 0; i < tokenized_docs.size(); i++) {
|
||||
auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
|
||||
auto tmp = format_rerank(ctx_server.vocab, tokenized_queries[0], tokenized_docs[i]);
|
||||
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
task.prompt_tokens = server_tokens(tmp, ctx_server.mctx != nullptr);
|
||||
task.prompt_tokens = std::move(tmp);
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
JSON_MULTIMODAL_KEY = "multimodal_data"
|
||||
JSON_PROMPT_STRING_KEY = "prompt_string"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
@@ -231,6 +233,28 @@ def test_nocache_long_input_prompt():
|
||||
})
|
||||
assert res.status_code == 400
|
||||
|
||||
def test_json_prompt_no_mtmd():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": { JSON_PROMPT_STRING_KEY: "I believe the meaning of life is" },
|
||||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
"cache_prompt": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
|
||||
def test_json_prompt_mtm_error_when_not_supported():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": { JSON_PROMPT_STRING_KEY: "I believe the meaning of life is <__media__>", JSON_MULTIMODAL_KEY: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=" },
|
||||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
"cache_prompt": False,
|
||||
})
|
||||
# MTMD is disabled on this model, so this should fail.
|
||||
assert res.status_code != 200
|
||||
|
||||
def test_completion_with_tokens_input():
|
||||
global server
|
||||
@@ -269,6 +293,20 @@ def test_completion_with_tokens_input():
|
||||
assert len(res.body) == 2
|
||||
assert res.body[0]["content"] == res.body[1]["content"]
|
||||
|
||||
# mixed JSON and tokens
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": [
|
||||
tokens,
|
||||
{
|
||||
JSON_PROMPT_STRING_KEY: "I believe the meaning of life is",
|
||||
},
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert type(res.body) == list
|
||||
assert len(res.body) == 2
|
||||
assert res.body[0]["content"] == res.body[1]["content"]
|
||||
|
||||
# mixed string and tokens in one sequence
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user