forked from wylab/llama.cpp
Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 34bdbbd7c2 | |||
| 74f52f77f2 | |||
| f7207b0415 | |||
| 4d917cd4f6 | |||
| 886b97a5d6 | |||
| 111f8d06f0 | |||
| 5eff6ec9b1 | |||
| dfd9b5f6c7 | |||
| 5a6bc6b1a6 | |||
| 6b64f74b55 | |||
| 0d5a470223 | |||
| b0ba31f525 | |||
| 7da9fed0d6 | |||
| c247d06f38 | |||
| 043fb27d38 | |||
| b730706a49 | |||
| c9a24fb932 | |||
| a9c6ffcbfa | |||
| e78cf0d4b1 | |||
| 710dfc465a | |||
| 611f419cff | |||
| b1afcab804 | |||
| 9ef536907d | |||
| 21dc4ddaf2 | |||
| 289bf4113e | |||
| b55f06e1aa | |||
| 0a9b43e507 | |||
| 330c3d2d21 | |||
| e92734d51b | |||
| 45363632cb | |||
| 32732f2459 | |||
| 92f7f0a53c |
@@ -2,14 +2,30 @@ ARG UBUNTU_VERSION=24.04
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
# Install build tools
|
||||
RUN apt update && apt install -y git build-essential cmake wget
|
||||
# Ref: https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started.html
|
||||
|
||||
# Install Vulkan SDK and cURL
|
||||
RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - && \
|
||||
wget -qO /etc/apt/sources.list.d/lunarg-vulkan-noble.list https://packages.lunarg.com/vulkan/lunarg-vulkan-noble.list && \
|
||||
apt update -y && \
|
||||
apt-get install -y vulkan-sdk libcurl4-openssl-dev curl
|
||||
# Install build tools
|
||||
RUN apt update && apt install -y git build-essential cmake wget xz-utils
|
||||
|
||||
# Install Vulkan SDK
|
||||
ARG VULKAN_VERSION=1.4.321.1
|
||||
RUN ARCH=$(uname -m) && \
|
||||
wget -qO /tmp/vulkan-sdk.tar.xz https://sdk.lunarg.com/sdk/download/${VULKAN_VERSION}/linux/vulkan-sdk-linux-${ARCH}-${VULKAN_VERSION}.tar.xz && \
|
||||
mkdir -p /opt/vulkan && \
|
||||
tar -xf /tmp/vulkan-sdk.tar.xz -C /tmp --strip-components=1 && \
|
||||
mv /tmp/${ARCH}/* /opt/vulkan/ && \
|
||||
rm -rf /tmp/*
|
||||
|
||||
# Install cURL and Vulkan SDK dependencies
|
||||
RUN apt install -y libcurl4-openssl-dev curl \
|
||||
libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev
|
||||
|
||||
# Set environment variables
|
||||
ENV VULKAN_SDK=/opt/vulkan
|
||||
ENV PATH=$VULKAN_SDK/bin:$PATH
|
||||
ENV LD_LIBRARY_PATH=$VULKAN_SDK/lib:$LD_LIBRARY_PATH
|
||||
ENV CMAKE_PREFIX_PATH=$VULKAN_SDK:$CMAKE_PREFIX_PATH
|
||||
ENV PKG_CONFIG_PATH=$VULKAN_SDK/lib/pkgconfig:$PKG_CONFIG_PATH
|
||||
|
||||
# Build it
|
||||
WORKDIR /app
|
||||
|
||||
+21
-1
@@ -1361,6 +1361,26 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
"<|end|>",
|
||||
};
|
||||
|
||||
if (!inputs.json_schema.is_null()) {
|
||||
data.grammar_lazy = false;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schema = inputs.json_schema;
|
||||
builder.resolve_refs(schema);
|
||||
|
||||
auto not_end = builder.add_rule("not-end",
|
||||
"[^<] | \"<\" [^|] | \"<|\" [^e] | \"<|e\" [^n] | \"<|en\" [^d] | \"<|end\" [^|] | \"<|end|\" [^>]");
|
||||
auto analysis = builder.add_rule("analysis",
|
||||
"\"<|channel|>analysis<|message|>\" ( " + not_end + " )* \"<|end|>\"");
|
||||
auto constraint = builder.add_rule("constraint", "\"<|constrain|>\"? [a-zA-Z0-9_-]+");
|
||||
auto final = builder.add_rule("final",
|
||||
"\"<|channel|>final\" ( \" \" " + constraint + " )? \"<|message|>\" " +
|
||||
builder.add_schema("response", schema)
|
||||
);
|
||||
|
||||
builder.add_rule("root", "( " + analysis + " \"<|start|>assistant\" )? " + final);
|
||||
});
|
||||
}
|
||||
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
@@ -2121,7 +2141,7 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
}
|
||||
|
||||
// GPT-OSS
|
||||
if (src.find("<|channel|>") != std::string::npos && params.json_schema.is_null()) {
|
||||
if (src.find("<|channel|>") != std::string::npos) {
|
||||
return common_chat_params_init_gpt_oss(tmpl, params);
|
||||
}
|
||||
|
||||
|
||||
+71
-69
@@ -1216,6 +1216,55 @@ class TextModel(ModelBase):
|
||||
raise NotImplementedError("Only MEAN, CLS, and LAST pooling types supported")
|
||||
self.gguf_writer.add_pooling_type(pooling_type)
|
||||
|
||||
def _set_vocab_interns1(self):
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
|
||||
vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
|
||||
vocab_size = self.hparams.get("vocab_size", len(vocab))
|
||||
assert max(vocab.values()) < vocab_size
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
|
||||
added_tokens_decoder = tokenizer.added_tokens_decoder
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
tokens.append(f"[PAD{i}]")
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
else:
|
||||
token: str = reverse_vocab[i]
|
||||
if token in added_vocab:
|
||||
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
|
||||
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
|
||||
if not added_tokens_decoder[i].normalized:
|
||||
previous_token = token
|
||||
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
|
||||
if previous_token != token:
|
||||
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
|
||||
|
||||
if added_tokens_decoder[i].special or self.does_token_look_special(token):
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
tokens.append(token)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
special_vocab._set_special_token("bos", 151643)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
|
||||
class MmprojModel(ModelBase):
|
||||
model_type = ModelType.MMPROJ
|
||||
@@ -2932,7 +2981,8 @@ class Qwen2Model(TextModel):
|
||||
if "language_model." in name:
|
||||
name = name.replace("language_model.", "") # for InternVL
|
||||
if name.startswith("mlp") or name.startswith("multi_modal_projector") \
|
||||
or name.startswith("vision_model") or name.startswith("audio_tower"):
|
||||
or name.startswith("vision_model") or name.startswith("audio_tower") \
|
||||
or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):
|
||||
# skip vision and audio tensors
|
||||
return []
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
@@ -3109,7 +3159,7 @@ class LLaDAModel(TextModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Ernie4_5_ForCausalLM")
|
||||
@ModelBase.register("Ernie4_5_ForCausalLM", "Ernie4_5ForCausalLM")
|
||||
class Ernie4_5Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.ERNIE4_5
|
||||
|
||||
@@ -3604,6 +3654,19 @@ class Qwen2MoeModel(TextModel):
|
||||
class Qwen3Model(Qwen2Model):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
|
||||
self.origin_hf_arch = hparams.get('architectures', [None])[0]
|
||||
|
||||
def set_vocab(self):
|
||||
# deal with intern-s1-mini
|
||||
if self.origin_hf_arch == 'InternS1ForConditionalGeneration':
|
||||
self._set_vocab_interns1()
|
||||
return
|
||||
|
||||
super().set_vocab()
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3MoeForCausalLM")
|
||||
class Qwen3MoeModel(Qwen2MoeModel):
|
||||
@@ -3620,73 +3683,7 @@ class Qwen3MoeModel(Qwen2MoeModel):
|
||||
self._set_vocab_interns1()
|
||||
return
|
||||
|
||||
try:
|
||||
self._set_vocab_sentencepiece()
|
||||
except FileNotFoundError:
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def _set_vocab_interns1(self):
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
|
||||
vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
|
||||
vocab_size = self.hparams.get("vocab_size", len(vocab))
|
||||
assert max(vocab.values()) < vocab_size
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
|
||||
added_tokens_decoder = tokenizer.added_tokens_decoder
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
tokens.append(f"[PAD{i}]")
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
else:
|
||||
token: str = reverse_vocab[i]
|
||||
if token in added_vocab:
|
||||
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
|
||||
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
|
||||
if not added_tokens_decoder[i].normalized:
|
||||
previous_token = token
|
||||
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
|
||||
if previous_token != token:
|
||||
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
|
||||
|
||||
if added_tokens_decoder[i].special or self.does_token_look_special(token):
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
tokens.append(token)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
special_tokens_map_file = self.dir_model / 'special_tokens_map.json'
|
||||
additional_special_tokens = []
|
||||
if special_tokens_map_file.is_file():
|
||||
with open(special_tokens_map_file, encoding = 'utf-8') as f:
|
||||
additional_special_tokens = json.load(f).get('additional_special_tokens', [])
|
||||
tokenizer_cfg_file = self.dir_model / 'special_tokens_map.json'
|
||||
if tokenizer_cfg_file.is_file():
|
||||
with open(tokenizer_cfg_file, encoding = 'utf-8') as f:
|
||||
added_tokens_decoder = json.load(f).get('added_tokens_decoder', {})
|
||||
token2ids_map = {data['content'] : int(token) for token, data in added_tokens_decoder.items() if data['special']}
|
||||
for token in additional_special_tokens:
|
||||
if token in token2ids_map:
|
||||
special_vocab._set_special_token(token, token2ids_map[token])
|
||||
special_vocab._set_special_token('eos', 151645)
|
||||
special_vocab._set_special_token("bos", 151643)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
super().set_vocab()
|
||||
|
||||
|
||||
@ModelBase.register("GPT2LMHeadModel")
|
||||
@@ -5854,6 +5851,11 @@ class OlmoModel(TextModel):
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register("SeedOssForCausalLM")
|
||||
class SeedOssModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.SEED_OSS
|
||||
|
||||
|
||||
@ModelBase.register("Olmo2ForCausalLM")
|
||||
class Olmo2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.OLMO2
|
||||
|
||||
@@ -144,6 +144,15 @@ perplexity-run:
|
||||
hf-create-model:
|
||||
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}"
|
||||
|
||||
hf-create-model-dry-run:
|
||||
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -d
|
||||
|
||||
hf-create-model-embedding:
|
||||
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -e
|
||||
|
||||
hf-create-model-embedding-dry-run:
|
||||
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -e -d
|
||||
|
||||
hf-create-model-private:
|
||||
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -p
|
||||
|
||||
|
||||
@@ -285,13 +285,21 @@ For the following targets a `HF_TOKEN` environment variable is required.
|
||||
This will create a new model repsository on Hugging Face with the specified
|
||||
model name.
|
||||
```console
|
||||
(venv) $ make hf-create-model MODEL_NAME='TestModel' NAMESPACE="danbev"
|
||||
(venv) $ make hf-create-model MODEL_NAME='TestModel' NAMESPACE="danbev" ORIGINAL_BASE_MODEL="some-base-model"
|
||||
Repository ID: danbev/TestModel-GGUF
|
||||
Repository created: https://huggingface.co/danbev/TestModel-GGUF
|
||||
```
|
||||
Note that we append a `-GGUF` suffix to the model name to ensure a consistent
|
||||
naming convention for GGUF models.
|
||||
|
||||
An embedding model can be created using the following command:
|
||||
```console
|
||||
(venv) $ make hf-create-model-embedding MODEL_NAME='TestEmbeddingModel' NAMESPACE="danbev" ORIGINAL_BASE_MODEL="some-base-model"
|
||||
```
|
||||
The only difference is that the model card for an embedding model will be different
|
||||
with regards to the llama-server command and also how to access/call the embedding
|
||||
endpoint.
|
||||
|
||||
### Upload a GGUF model to model repository
|
||||
The following target uploads a model to an existing Hugging Face model repository.
|
||||
```console
|
||||
|
||||
@@ -112,6 +112,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_params.no_perf = false;
|
||||
if (embedding_mode) {
|
||||
ctx_params.embeddings = true;
|
||||
ctx_params.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
ctx_params.n_ubatch = ctx_params.n_batch;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
---
|
||||
base_model:
|
||||
- {base_model}
|
||||
---
|
||||
# {model_name} GGUF
|
||||
|
||||
Recommended way to run this model:
|
||||
|
||||
```sh
|
||||
llama-server -hf {namespace}/{model_name}-GGUF
|
||||
```
|
||||
|
||||
Then the endpoint can be accessed at http://localhost:8080/embedding, for
|
||||
example using `curl`:
|
||||
```console
|
||||
curl --request POST \
|
||||
--url http://localhost:8080/embedding \
|
||||
--header "Content-Type: application/json" \
|
||||
--data '{{"input": "Hello embeddings"}}' \
|
||||
--silent
|
||||
```
|
||||
|
||||
Alternatively, the `llama-embedding` command line tool can be used:
|
||||
```sh
|
||||
llama-embedding -hf {namespace}/{model_name}-GGUF --verbose-prompt -p "Hello embeddings"
|
||||
```
|
||||
|
||||
#### embd_normalize
|
||||
When a model uses pooling, or the pooling method is specified using `--pooling`,
|
||||
the normalization can be controlled by the `embd_normalize` parameter.
|
||||
|
||||
The default value is `2` which means that the embeddings are normalized using
|
||||
the Euclidean norm (L2). Other options are:
|
||||
* -1 No normalization
|
||||
* 0 Max absolute
|
||||
* 1 Taxicab
|
||||
* 2 Euclidean/L2
|
||||
* \>2 P-Norm
|
||||
|
||||
This can be passed in the request body to `llama-server`, for example:
|
||||
```sh
|
||||
--data '{{"input": "Hello embeddings", "embd_normalize": -1}}' \
|
||||
```
|
||||
|
||||
And for `llama-embedding`, by passing `--embd-normalize <value>`, for example:
|
||||
```sh
|
||||
llama-embedding -hf {namespace}/{model_name}-GGUF --embd-normalize -1 -p "Hello embeddings"
|
||||
```
|
||||
@@ -26,21 +26,31 @@ parser.add_argument('--namespace', '-ns', help='Namespace to add the model to',
|
||||
parser.add_argument('--org-base-model', '-b', help='Original Base model name', default="")
|
||||
parser.add_argument('--no-card', action='store_true', help='Skip creating model card')
|
||||
parser.add_argument('--private', '-p', action='store_true', help='Create private model')
|
||||
parser.add_argument('--embedding', '-e', action='store_true', help='Use embedding model card template')
|
||||
parser.add_argument('--dry-run', '-d', action='store_true', help='Print repository info and template without creating repository')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
repo_id = f"{args.namespace}/{args.model_name}-GGUF"
|
||||
print("Repository ID: ", repo_id)
|
||||
|
||||
repo_url = api.create_repo(
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
private=args.private,
|
||||
exist_ok=False
|
||||
)
|
||||
repo_url = None
|
||||
if not args.dry_run:
|
||||
repo_url = api.create_repo(
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
private=args.private,
|
||||
exist_ok=False
|
||||
)
|
||||
|
||||
if not args.no_card:
|
||||
template_path = "scripts/readme.md.template"
|
||||
if args.embedding:
|
||||
template_path = "scripts/embedding/modelcard.template"
|
||||
else:
|
||||
template_path = "scripts/causal/modelcard.template"
|
||||
|
||||
print("Template path: ", template_path)
|
||||
|
||||
model_card_content = load_template_and_substitute(
|
||||
template_path,
|
||||
model_name=args.model_name,
|
||||
@@ -48,16 +58,21 @@ if not args.no_card:
|
||||
base_model=args.org_base_model,
|
||||
)
|
||||
|
||||
if model_card_content:
|
||||
api.upload_file(
|
||||
path_or_fileobj=model_card_content.encode('utf-8'),
|
||||
path_in_repo="README.md",
|
||||
repo_id=repo_id
|
||||
)
|
||||
print("Model card created successfully.")
|
||||
if args.dry_run:
|
||||
print("\nTemplate Content:\n")
|
||||
print(model_card_content)
|
||||
else:
|
||||
print("Failed to create model card.")
|
||||
if model_card_content:
|
||||
api.upload_file(
|
||||
path_or_fileobj=model_card_content.encode('utf-8'),
|
||||
path_in_repo="README.md",
|
||||
repo_id=repo_id
|
||||
)
|
||||
print("Model card created successfully.")
|
||||
else:
|
||||
print("Failed to create model card.")
|
||||
|
||||
print(f"Repository created: {repo_url}")
|
||||
if not args.dry_run and repo_url:
|
||||
print(f"Repository created: {repo_url}")
|
||||
|
||||
|
||||
|
||||
@@ -512,6 +512,7 @@ extern "C" {
|
||||
GGML_OP_IM2COL,
|
||||
GGML_OP_IM2COL_BACK,
|
||||
GGML_OP_CONV_2D,
|
||||
GGML_OP_CONV_3D,
|
||||
GGML_OP_CONV_2D_DW,
|
||||
GGML_OP_CONV_TRANSPOSE_2D,
|
||||
GGML_OP_POOL_1D,
|
||||
@@ -1940,6 +1941,23 @@ extern "C" {
|
||||
int d0, // dilation dimension 0
|
||||
int d1); // dilation dimension 1
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_3d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC]
|
||||
struct ggml_tensor * b, // input [W, H, D, C * N]
|
||||
int s0, // stride
|
||||
int s1,
|
||||
int s2,
|
||||
int p0, // padding
|
||||
int p1,
|
||||
int p2,
|
||||
int d0, // dilation
|
||||
int d1,
|
||||
int d2,
|
||||
int n_channels,
|
||||
int n_batch,
|
||||
int n_channels_out);
|
||||
|
||||
enum ggml_op_pool {
|
||||
GGML_OP_POOL_MAX,
|
||||
GGML_OP_POOL_AVG,
|
||||
|
||||
@@ -1257,12 +1257,20 @@ static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) {
|
||||
|
||||
void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
||||
aclTensor* acl_dst) {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
|
||||
if(acl_dst == nullptr) {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCos, acl_src);
|
||||
} else {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
|
||||
}
|
||||
}
|
||||
|
||||
void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
||||
aclTensor* acl_dst) {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
|
||||
if(acl_dst == nullptr) {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSin, acl_src);
|
||||
} else {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
|
||||
@@ -2221,13 +2229,54 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
|
||||
ggml_cann_release_resources(ctx, acl_index, acl_value);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Initializes and caches sine/cosine positional encoding values
|
||||
* (used in RoPE, Rotary Position Embedding) for attention layers.
|
||||
*
|
||||
* This function computes and caches the sin/cos values of
|
||||
* θ = position * theta_scale for RoPE encoding. The cache is shared
|
||||
* across attention layers, and only the first attention layer will
|
||||
* trigger initialization. The cache includes repeated sin/cos values
|
||||
* with different repeat methods depending on the @param is_neox flag.
|
||||
*
|
||||
* Steps performed by this function:
|
||||
* 1. Identify whether the target tensor belongs to Q/K in attention
|
||||
* and restrict computation to the first layer only.
|
||||
* 2. Initialize the theta scale array (arange → power → freq scaling).
|
||||
* 3. Allocate sin/cos caches if the max prompt length increases.
|
||||
* 4. Compute θ = position * theta_scale.
|
||||
* 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
|
||||
* 6. Expand sin/cos values by repeat or repeat_interleave depending
|
||||
* on whether @param is_neox is enabled.
|
||||
* 7. Store the computed values into persistent buffers
|
||||
* (ctx.rope_sin_ptr / ctx.rope_cos_ptr).
|
||||
*
|
||||
* @param ctx The CANN backend context, holding memory pool,
|
||||
* stream, and persistent buffers for rope init/cache.
|
||||
* @param dst The destination ggml_tensor whose computation
|
||||
* depends on the cached RoPE values (usually Qcur/Kcur).
|
||||
* @param theta_scale Scalar exponent base for computing theta scale values.
|
||||
* @param freq_scale Frequency scaling factor, applied to theta scale.
|
||||
* @param attn_factor Attention scaling factor, applied to sin/cos.
|
||||
* @param is_neox Whether to use Neox-style repeat strategy
|
||||
* (dim expansion vs repeat_interleave).
|
||||
*/
|
||||
static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
aclTensor* acl_cos_repeat_tensor,
|
||||
aclTensor* acl_sin_repeat_tensor,
|
||||
float theta_scale, float freq_scale,
|
||||
float attn_factor, bool is_neox) {
|
||||
// int sin/cos cache, cache has different repeat method depond on
|
||||
// @param.is_neox
|
||||
bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
|
||||
bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
|
||||
|
||||
// used for accuracy testing
|
||||
bool is_attention = is_q || is_k;
|
||||
|
||||
// just compute in first layer in attention
|
||||
bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
|
||||
if(is_attention && !is_fisrt_layer) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_tensor* src0 = dst->src[0]; // input
|
||||
ggml_tensor* src1 = dst->src[1]; // position
|
||||
@@ -2253,21 +2302,16 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
|
||||
}
|
||||
|
||||
bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
|
||||
bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
|
||||
|
||||
// used for accuracy testing
|
||||
bool is_attention = is_q || is_k;
|
||||
|
||||
if(ctx.init_ptr == nullptr || !is_attention) {
|
||||
// init theta scale, just one time
|
||||
if(ctx.rope_init_ptr == nullptr || !is_attention) {
|
||||
// theta_scale arange, [0,1,...,ne00/2 - 1]
|
||||
if(ctx.init_ptr != nullptr){
|
||||
ACL_CHECK(aclrtFree(ctx.init_ptr));
|
||||
if(ctx.rope_init_ptr != nullptr){
|
||||
ACL_CHECK(aclrtFree(ctx.rope_init_ptr));
|
||||
}
|
||||
ACL_CHECK(aclrtMalloc(&ctx.init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
ACL_CHECK(aclrtMalloc(&ctx.rope_init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
|
||||
aclTensor* acl_theta_scale_tensor =
|
||||
ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
|
||||
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
|
||||
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
|
||||
float start = 0;
|
||||
float step = 1;
|
||||
@@ -2297,67 +2341,55 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
ggml_cann_release_resources(ctx, acl_theta_scale_tensor,acl_theta_scale);
|
||||
}
|
||||
|
||||
if(ctx.sin_ptr == nullptr) {
|
||||
int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
|
||||
ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
}
|
||||
// init sin_repeat && cos_repeat, one token just init in 0 layer
|
||||
if(position_length > ctx.max_prompt_length) {
|
||||
ctx.max_prompt_length = position_length;
|
||||
int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
|
||||
ACL_CHECK(aclrtFree(ctx.sin_ptr));
|
||||
ACL_CHECK(aclrtFree(ctx.cos_ptr));
|
||||
ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
int64_t repeat_theta_length = theta_scale_length * ctx.max_prompt_length * 2;
|
||||
if(ctx.rope_sin_ptr != nullptr) {
|
||||
ACL_CHECK(aclrtFree(ctx.rope_sin_ptr));
|
||||
ACL_CHECK(aclrtFree(ctx.rope_cos_ptr));
|
||||
}
|
||||
ACL_CHECK(aclrtMalloc(&ctx.rope_sin_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
ACL_CHECK(aclrtMalloc(&ctx.rope_cos_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
}
|
||||
|
||||
bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
|
||||
|
||||
if(is_fisrt_layer || !is_attention) {
|
||||
|
||||
aclTensor* acl_theta_scale_tensor =
|
||||
ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
|
||||
aclTensor* acl_theta_scale_tensor =
|
||||
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
|
||||
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
|
||||
|
||||
// position
|
||||
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
|
||||
src1->data, ggml_cann_type_mapping(src1->type),
|
||||
ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
|
||||
// position
|
||||
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
|
||||
src1->data, ggml_cann_type_mapping(src1->type),
|
||||
ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
|
||||
|
||||
// power * position
|
||||
int64_t theta_length = theta_scale_length * position_length;
|
||||
ggml_cann_pool_alloc theta_allocator(ctx.pool(),
|
||||
theta_length * sizeof(float_t));
|
||||
void* theta_buffer = theta_allocator.get();
|
||||
// power * position
|
||||
int64_t theta_length = theta_scale_length * position_length;
|
||||
ggml_cann_pool_alloc theta_allocator(ctx.pool(),
|
||||
theta_length * sizeof(float_t));
|
||||
void* theta_buffer = theta_allocator.get();
|
||||
|
||||
aclTensor* acl_theta_tensor =
|
||||
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
|
||||
theta_ne, theta_nb, GGML_MAX_DIMS);
|
||||
aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
|
||||
acl_theta_tensor);
|
||||
|
||||
// sin/cos
|
||||
aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
|
||||
ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
|
||||
GGML_MAX_DIMS, ACL_FORMAT_ND);
|
||||
aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
|
||||
|
||||
aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
|
||||
ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
|
||||
GGML_MAX_DIMS, ACL_FORMAT_ND);
|
||||
aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
|
||||
|
||||
// release
|
||||
ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
|
||||
acl_theta_tensor, acl_sin_tensor, acl_cos_tensor);
|
||||
}
|
||||
aclTensor* acl_theta_tensor =
|
||||
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
|
||||
theta_ne, theta_nb, GGML_MAX_DIMS);
|
||||
aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
|
||||
acl_theta_tensor);
|
||||
|
||||
// sin/cos
|
||||
ggml_cann_pool_alloc sin_allocator(ctx.pool(),
|
||||
theta_length * sizeof(float_t));
|
||||
void* sin_buffer = sin_allocator.get();
|
||||
aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
|
||||
ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
|
||||
GGML_MAX_DIMS, ACL_FORMAT_ND);
|
||||
sin_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
|
||||
GGML_MAX_DIMS, ACL_FORMAT_ND);
|
||||
aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
|
||||
|
||||
ggml_cann_pool_alloc cos_allocator(ctx.pool(),
|
||||
theta_length * sizeof(float_t));
|
||||
void* cos_buffer = cos_allocator.get();
|
||||
aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
|
||||
ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
|
||||
GGML_MAX_DIMS, ACL_FORMAT_ND);
|
||||
cos_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
|
||||
GGML_MAX_DIMS, ACL_FORMAT_ND);
|
||||
aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
|
||||
|
||||
// attn_factor
|
||||
if (attn_factor != 1) {
|
||||
@@ -2365,6 +2397,19 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true);
|
||||
}
|
||||
|
||||
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
|
||||
size_t sin_reshape_nb[GGML_MAX_DIMS];
|
||||
sin_reshape_nb[0] = sizeof(float_t);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
|
||||
}
|
||||
aclTensor* acl_sin_repeat_tensor =
|
||||
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
|
||||
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
||||
aclTensor* acl_cos_repeat_tensor =
|
||||
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
|
||||
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
||||
|
||||
// repeat
|
||||
if (is_neox) {
|
||||
int64_t repeatsArray[] = {1, 1, 1, 2};
|
||||
@@ -2380,8 +2425,9 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
num_repeats, output_size);
|
||||
}
|
||||
|
||||
// release
|
||||
ggml_cann_release_resources(ctx, acl_sin_tensor, acl_cos_tensor);
|
||||
ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
|
||||
acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor,
|
||||
acl_cos_repeat_tensor);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
@@ -2435,13 +2481,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
|
||||
// init cos/sin cache
|
||||
ggml_cann_pool_alloc sin_allocator(
|
||||
ctx.pool(), ne00 * ne02 * sizeof(float_t));
|
||||
ggml_cann_pool_alloc cos_allocator(
|
||||
ctx.pool(), ne00 * ne02 * sizeof(float_t));
|
||||
void* sin_buffer = sin_allocator.get();
|
||||
void* cos_buffer = cos_allocator.get();
|
||||
// init ctx.rope_cos/rope_sin cache
|
||||
aclnn_cache_init(ctx, dst, theta_scale, freq_scale, attn_factor, is_neox);
|
||||
|
||||
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
|
||||
size_t sin_reshape_nb[GGML_MAX_DIMS];
|
||||
@@ -2450,13 +2491,11 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
|
||||
}
|
||||
aclTensor* acl_sin_reshape_tensor =
|
||||
ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float_t),
|
||||
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
|
||||
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
||||
aclTensor* acl_cos_reshape_tensor =
|
||||
ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t),
|
||||
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
|
||||
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
||||
aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor,
|
||||
theta_scale, freq_scale, attn_factor, is_neox);
|
||||
|
||||
aclTensor* acl_src = ggml_cann_create_tensor(src0);
|
||||
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
|
||||
|
||||
+18
-10
@@ -368,10 +368,6 @@ struct ggml_backend_cann_context {
|
||||
std::string name; /**< Name of the device. */
|
||||
std::string description; /**< Description of the device. */
|
||||
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
|
||||
void* init_ptr = nullptr;
|
||||
void* sin_ptr = nullptr;
|
||||
void* cos_ptr = nullptr;
|
||||
int64_t max_prompt_length = 65536;
|
||||
#ifdef USE_ACL_GRAPH
|
||||
/// Cached CANN ACL graph used for executing the current ggml computation graph.
|
||||
std::unique_ptr<ggml_cann_graph> cann_graph;
|
||||
@@ -379,6 +375,12 @@ struct ggml_backend_cann_context {
|
||||
cann_task_queue task_queue;
|
||||
bool async_mode;
|
||||
bool support_set_rows;
|
||||
// Rope Cache
|
||||
void* rope_init_ptr = nullptr;
|
||||
void* rope_sin_ptr = nullptr;
|
||||
void* rope_cos_ptr = nullptr;
|
||||
int64_t max_prompt_length = 0;
|
||||
// Constant Pool
|
||||
void* f32_zero_cache = nullptr;
|
||||
void* f32_one_cache = nullptr;
|
||||
int64_t f32_zero_cache_element = 0;
|
||||
@@ -422,14 +424,20 @@ struct ggml_backend_cann_context {
|
||||
ACL_CHECK(aclrtDestroyStream(streams[i]));
|
||||
}
|
||||
}
|
||||
if(init_ptr != nullptr) {
|
||||
ACL_CHECK(aclrtFree(init_ptr));
|
||||
if(rope_init_ptr != nullptr) {
|
||||
ACL_CHECK(aclrtFree(rope_init_ptr));
|
||||
}
|
||||
if(sin_ptr != nullptr) {
|
||||
ACL_CHECK(aclrtFree(sin_ptr));
|
||||
if(rope_sin_ptr != nullptr) {
|
||||
ACL_CHECK(aclrtFree(rope_sin_ptr));
|
||||
}
|
||||
if(cos_ptr != nullptr) {
|
||||
ACL_CHECK(aclrtFree(cos_ptr));
|
||||
if(rope_cos_ptr != nullptr) {
|
||||
ACL_CHECK(aclrtFree(rope_cos_ptr));
|
||||
}
|
||||
if(f32_zero_cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(f32_zero_cache));
|
||||
}
|
||||
if(f32_one_cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(f32_one_cache));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1880,6 +1880,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_conv_2d(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_CONV_3D:
|
||||
{
|
||||
ggml_compute_forward_conv_3d(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
{
|
||||
ggml_compute_forward_conv_2d_dw(params, tensor);
|
||||
@@ -2252,6 +2256,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_IM2COL_BACK:
|
||||
case GGML_OP_CONV_2D:
|
||||
case GGML_OP_CONV_3D:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
@@ -2773,6 +2778,7 @@ struct ggml_cplan ggml_graph_plan(
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_CONV_2D:
|
||||
case GGML_OP_CONV_3D:
|
||||
{
|
||||
cur = GGML_IM2COL_WORK_SIZE;
|
||||
} break;
|
||||
|
||||
@@ -7207,6 +7207,148 @@ void ggml_compute_forward_conv_2d(
|
||||
ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_conv_3d
|
||||
|
||||
static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
|
||||
const ggml_tensor * kernel,
|
||||
const ggml_tensor * src,
|
||||
ggml_tensor * dst,
|
||||
ggml_type kernel_type) {
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(kernel));
|
||||
GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(kernel->type == kernel_type);
|
||||
|
||||
const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
|
||||
|
||||
const int32_t s0 = dst->op_params[0];
|
||||
const int32_t s1 = dst->op_params[1];
|
||||
const int32_t s2 = dst->op_params[2];
|
||||
const int32_t p0 = dst->op_params[3];
|
||||
const int32_t p1 = dst->op_params[4];
|
||||
const int32_t p2 = dst->op_params[5];
|
||||
const int32_t d0 = dst->op_params[6];
|
||||
const int32_t d1 = dst->op_params[7];
|
||||
const int32_t d2 = dst->op_params[8];
|
||||
const int32_t c = dst->op_params[9];
|
||||
const int32_t n = dst->op_params[10];
|
||||
const int32_t oc = dst->op_params[11];
|
||||
|
||||
const int64_t src_w = src->ne[0];
|
||||
const int64_t src_h = src->ne[1];
|
||||
const int64_t src_d = src->ne[2];
|
||||
const int64_t knl_w = kernel->ne[0];
|
||||
const int64_t knl_h = kernel->ne[1];
|
||||
const int64_t knl_d = kernel->ne[2];
|
||||
const int64_t dst_w = dst->ne[0];
|
||||
const int64_t dst_h = dst->ne[1];
|
||||
const int64_t dst_d = dst->ne[2];
|
||||
|
||||
const float * src_data = (float *) src->data;
|
||||
void * knl_data = kernel->data;
|
||||
float * dst_data = (float *) dst->data;
|
||||
|
||||
const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
|
||||
const int64_t knl_n_total = knl_n_per_channel * c;
|
||||
const int64_t patch_total = n * dst_w * dst_h * dst_d;
|
||||
|
||||
const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float);
|
||||
const int64_t batch_size = params->wsize / space_per_patch;
|
||||
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
|
||||
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
|
||||
|
||||
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
|
||||
|
||||
void * tmp = params->wdata;
|
||||
|
||||
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
|
||||
const int64_t patch_start_batch = batch_i * patches_per_batch;
|
||||
const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total);
|
||||
const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch;
|
||||
|
||||
const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
|
||||
const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
|
||||
const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
|
||||
|
||||
for (int64_t p = patch_start; p < patch_end; ++p) {
|
||||
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
|
||||
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
|
||||
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
|
||||
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
|
||||
const int64_t dst_y = p_in_depth / dst_w;
|
||||
const int64_t dst_x = p_in_depth % dst_w;
|
||||
|
||||
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
|
||||
|
||||
for (int64_t ic = 0; ic < c; ++ic) {
|
||||
for (int64_t kz = 0; kz < knl_d; ++kz) {
|
||||
for (int64_t ky = 0; ky < knl_h; ++ky) {
|
||||
for (int64_t kx = 0; kx < knl_w; ++kx) {
|
||||
const int64_t sz = dst_z * s2 + kz * d2 - p2;
|
||||
const int64_t sy = dst_y * s1 + ky * d1 - p1;
|
||||
const int64_t sx = dst_x * s0 + kx * d0 - p0;
|
||||
|
||||
int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
|
||||
|
||||
float src_val;
|
||||
if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
|
||||
src_val = 0.0f;
|
||||
} else {
|
||||
const int64_t cn_idx = batch_idx * c + ic;
|
||||
const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
|
||||
src_val = *src_ptr;
|
||||
}
|
||||
|
||||
char * element_ptr = dst_row + dst_idx * traits->type_size;
|
||||
if (kernel_type == GGML_TYPE_F32) {
|
||||
*(float *)element_ptr = src_val;
|
||||
} else if (kernel_type == GGML_TYPE_F16) {
|
||||
*(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
|
||||
ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
|
||||
const int64_t permute_start = params->ith * permute_per_thread;
|
||||
const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
|
||||
|
||||
for (int64_t i = permute_start; i < permute_end; ++i) {
|
||||
const int64_t p = patch_start_batch + i;
|
||||
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
|
||||
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
|
||||
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
|
||||
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
|
||||
const int64_t dst_y = p_in_depth / dst_w;
|
||||
const int64_t dst_x = p_in_depth % dst_w;
|
||||
|
||||
for (int64_t ioc = 0; ioc < oc; ++ioc) {
|
||||
const float value = gemm_output[i * oc + ioc];
|
||||
const int64_t ocn_idx = batch_idx * oc + ioc;
|
||||
float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
|
||||
*dst_ptr = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_conv_3d(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_conv_transpose_2d
|
||||
|
||||
void ggml_compute_forward_conv_transpose_2d(
|
||||
|
||||
@@ -70,6 +70,7 @@ void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * p
|
||||
void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
||||
@@ -420,16 +420,28 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ int warp_reduce_all(int x) {
|
||||
#ifdef GGML_USE_HIP
|
||||
if (width == ggml_cuda_get_physical_warp_size()) {
|
||||
return __all_sync(0xffffffff, x);
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int offset = width/2; offset > 0; offset >>= 1) {
|
||||
x = x && __shfl_xor_sync(0xffffffff, x, offset, width);
|
||||
for (int offset = width/2; offset > 0; offset >>= 1) {
|
||||
x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ int warp_reduce_any(int x) {
|
||||
if (width == ggml_cuda_get_physical_warp_size()) {
|
||||
return __any_sync(0xffffffff, x);
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int offset = width/2; offset > 0; offset >>= 1) {
|
||||
x = __shfl_xor_sync(0xffffffff, x, offset, width) || x;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
return x;
|
||||
#else
|
||||
static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented");
|
||||
return __all_sync(0xffffffff, x);
|
||||
#endif // GGML_USE_HIP
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
|
||||
@@ -258,7 +258,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
const half val = hexp(sink - kqmax[j0/nwarps]);
|
||||
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
kqsum[j0/nwarps].x = __hadd(kqsum[j0/nwarps].x, val);
|
||||
kqsum[j0/nwarps].x = __hadd(__low2half(kqsum[j0/nwarps]), val);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
|
||||
@@ -3485,11 +3485,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_ACC:
|
||||
return true;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_UPSCALE:
|
||||
|
||||
+177
-47
@@ -3,6 +3,140 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
|
||||
struct mmq_ids_helper_store {
|
||||
uint32_t data;
|
||||
|
||||
__device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
|
||||
data = (it & 0x003FFFFF) | (iex_used << 22);
|
||||
}
|
||||
|
||||
__device__ uint32_t it() const {
|
||||
return data & 0x003FFFFF;
|
||||
}
|
||||
|
||||
__device__ uint32_t iex_used() const {
|
||||
return data >> 22;
|
||||
}
|
||||
};
|
||||
static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
|
||||
|
||||
// Helper function for mul_mat_id, converts ids to a more convenient format.
|
||||
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
|
||||
// ids_dst describes the same mapping but for the dst tensor.
|
||||
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
|
||||
template <int n_expert_used_template>
|
||||
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
|
||||
static __global__ void mmq_ids_helper(
|
||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
|
||||
const int expert = blockIdx.x;
|
||||
|
||||
extern __shared__ char data_mmq_ids_helper[];
|
||||
mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
|
||||
|
||||
int nex_prev = 0; // Number of columns for experts with a lower index.
|
||||
int it_compact = 0; // Running index for the compact slice of this expert.
|
||||
|
||||
if constexpr (n_expert_used_template == 0) {
|
||||
// Generic implementation:
|
||||
for (int it = 0; it < n_tokens; ++it) {
|
||||
int iex_used = -1; // The index at which the expert is used, if any.
|
||||
for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
|
||||
const int expert_used = ids[it*si1 + iex];
|
||||
nex_prev += expert_used < expert;
|
||||
if (expert_used == expert) {
|
||||
iex_used = iex;
|
||||
}
|
||||
}
|
||||
|
||||
if (iex_used != -1) {
|
||||
store[it_compact] = mmq_ids_helper_store(it, iex_used);
|
||||
}
|
||||
|
||||
if (warp_reduce_any<warp_size>(iex_used != -1)) {
|
||||
it_compact++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Implementation optimized for specific numbers of experts used:
|
||||
static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
|
||||
const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
|
||||
for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
|
||||
const int it = it0 + threadIdx.x / neu_padded;
|
||||
|
||||
const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
|
||||
const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
|
||||
ids[it*si1 + iex] : INT_MAX;
|
||||
const int iex_used = expert_used == expert ? iex : -1;
|
||||
nex_prev += expert_used < expert;
|
||||
|
||||
// Whether the threads at this token position have used the expert:
|
||||
const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
|
||||
|
||||
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
|
||||
int it_compact_add_lower = 0;
|
||||
#pragma unroll
|
||||
for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
|
||||
const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
|
||||
if (threadIdx.x >= offset) {
|
||||
it_compact_add_lower += tmp;
|
||||
}
|
||||
}
|
||||
|
||||
if (iex_used != -1) {
|
||||
store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used);
|
||||
}
|
||||
|
||||
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
|
||||
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
|
||||
}
|
||||
}
|
||||
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
|
||||
|
||||
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
|
||||
const mmq_ids_helper_store store_it = store[itc];
|
||||
const int it = store_it.it();
|
||||
const int iex_used = store_it.iex_used();
|
||||
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
|
||||
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
|
||||
}
|
||||
|
||||
if (threadIdx.x != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
expert_bounds[expert] = nex_prev;
|
||||
|
||||
if (expert < gridDim.x - 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
expert_bounds[gridDim.x] = nex_prev + it_compact;
|
||||
}
|
||||
|
||||
template <int n_expert_used_template>
|
||||
static void launch_mmq_ids_helper(
|
||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
|
||||
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store");
|
||||
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store");
|
||||
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper<n_expert_used_template>, smpbo);
|
||||
|
||||
const dim3 num_blocks(n_experts, 1, 1);
|
||||
const dim3 block_size(warp_size, 1, 1);
|
||||
const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
|
||||
GGML_ASSERT(nbytes_shared <= smpbo);
|
||||
mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
|
||||
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
|
||||
switch (args.type_x) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
@@ -137,7 +271,7 @@ void ggml_cuda_mul_mat_q(
|
||||
ne00, ne01, ne1, s01, ne11, s1,
|
||||
ne02, ne12, s02, s12, s2,
|
||||
ne03, ne13, s03, s13, s3,
|
||||
use_stream_k};
|
||||
use_stream_k, ne1};
|
||||
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
|
||||
return;
|
||||
}
|
||||
@@ -148,54 +282,50 @@ void ggml_cuda_mul_mat_q(
|
||||
|
||||
const int64_t n_expert_used = ids->ne[0];
|
||||
const int64_t ne_get_rows = ne12 * n_expert_used;
|
||||
GGML_ASSERT(ne1 == n_expert_used);
|
||||
|
||||
std::vector<char> ids_host(ggml_nbytes(ids));
|
||||
std::vector<int32_t> ids_src1_host;
|
||||
ids_src1_host.reserve(ne_get_rows);
|
||||
std::vector<int32_t> ids_dst_host;
|
||||
ids_dst_host.reserve(ne_get_rows);
|
||||
std::vector<int32_t> tokens_per_expert_host(ne02);
|
||||
std::vector<int32_t> expert_bounds_host(ne02 + 1);
|
||||
ggml_cuda_pool_alloc<int32_t> ids_buf_dev(ctx.pool());
|
||||
ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows);
|
||||
ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows);
|
||||
ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1);
|
||||
|
||||
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
{
|
||||
GGML_ASSERT(ids->nb[0] == ggml_element_size(ids));
|
||||
const int si1 = ids->nb[1] / ggml_element_size(ids);
|
||||
const int sis1 = nb12 / nb11;
|
||||
|
||||
for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens
|
||||
for (int64_t iex = 0; iex < n_expert_used; ++iex) {
|
||||
const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]);
|
||||
assert(expert_to_use >= 0 && expert_to_use < ne02);
|
||||
if (expert_to_use == i02) {
|
||||
ids_src1_host.push_back(i12*(nb12/nb11) + iex % ne11);
|
||||
ids_dst_host.push_back(i12*ne1 + iex);
|
||||
tokens_per_expert_host[i02]++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
switch (n_expert_used) {
|
||||
case 2:
|
||||
launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 4:
|
||||
launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 6:
|
||||
launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 8:
|
||||
launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 16:
|
||||
launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 32:
|
||||
launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
default:
|
||||
launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
int32_t cumsum = 0;
|
||||
for (int64_t i = 0; i < ne02; ++i) {
|
||||
expert_bounds_host[i] = cumsum;
|
||||
cumsum += tokens_per_expert_host[i];
|
||||
}
|
||||
expert_bounds_host[ne02] = cumsum;
|
||||
|
||||
std::vector<int32_t> ids_buf_host;
|
||||
ids_buf_host.reserve(ids_src1_host.size() + ids_dst_host.size() + expert_bounds_host.size());
|
||||
ids_buf_host.insert(ids_buf_host.end(), ids_src1_host.begin(), ids_src1_host.end());
|
||||
ids_buf_host.insert(ids_buf_host.end(), ids_dst_host.begin(), ids_dst_host.end());
|
||||
ids_buf_host.insert(ids_buf_host.end(), expert_bounds_host.begin(), expert_bounds_host.end());
|
||||
ids_buf_dev.alloc(ids_buf_host.size() + get_mmq_x_max_host(cc)); // Expert bounds are padded on device.
|
||||
CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_buf_host.data(), ids_buf_host.size()*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
const int32_t * ids_src1_dev = ids_buf_dev.ptr;
|
||||
const int32_t * ids_dst_dev = ids_src1_dev + ids_src1_host.size();
|
||||
const int32_t * expert_bounds_dev = ids_dst_dev + ids_dst_host.size();
|
||||
|
||||
const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
|
||||
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
|
||||
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
|
||||
@@ -208,7 +338,7 @@ void ggml_cuda_mul_mat_q(
|
||||
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s13 = src1->nb[2] / ts_src1;
|
||||
quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
|
||||
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type,
|
||||
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
@@ -218,11 +348,11 @@ void ggml_cuda_mul_mat_q(
|
||||
|
||||
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
|
||||
const mmq_args args = {
|
||||
src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d,
|
||||
src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d,
|
||||
ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
|
||||
ne02, ne02, s02, s12, s2,
|
||||
ne03, ne13, s03, s13, s3,
|
||||
use_stream_k};
|
||||
use_stream_k, ne12};
|
||||
|
||||
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
|
||||
}
|
||||
@@ -262,7 +392,7 @@ void ggml_cuda_op_mul_mat_q(
|
||||
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
|
||||
1, 1, 0, 0, 0,
|
||||
1, 1, 0, 0, 0,
|
||||
use_stream_k};
|
||||
use_stream_k, src1_ncols};
|
||||
|
||||
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
|
||||
|
||||
|
||||
+21
-13
@@ -3138,7 +3138,8 @@ static __global__ void mul_mat_q(
|
||||
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
||||
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
|
||||
const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
const int ncols_max) {
|
||||
|
||||
// Skip unused template specializations for faster compilation:
|
||||
if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
|
||||
@@ -3152,7 +3153,7 @@ static __global__ void mul_mat_q(
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int mmq_y = get_mmq_y_device();
|
||||
|
||||
const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; // Number of tiles x
|
||||
const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x
|
||||
const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
|
||||
|
||||
// Initialize the ids for writing back data with just the index.
|
||||
@@ -3376,7 +3377,8 @@ template <ggml_type type, int mmq_x, bool need_check>
|
||||
static __global__ void mul_mat_q_stream_k_fixup(
|
||||
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
|
||||
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
|
||||
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
|
||||
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
|
||||
const int ncols_max) {
|
||||
constexpr int mmq_y = get_mmq_y_device();
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
||||
@@ -3387,7 +3389,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
||||
|
||||
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
|
||||
|
||||
const int ntx = (ncols_dst + mmq_x - 1) / mmq_x;
|
||||
const int ntx = (ncols_max + mmq_x - 1) / mmq_x;
|
||||
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
|
||||
const int bidx0 = blockIdx.x;
|
||||
@@ -3528,7 +3530,7 @@ struct mmq_args {
|
||||
int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
|
||||
int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
|
||||
int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
|
||||
bool use_stream_k;
|
||||
bool use_stream_k; int64_t ncols_max;
|
||||
};
|
||||
|
||||
template<ggml_type type>
|
||||
@@ -3558,7 +3560,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, true>), nbytes_shared);
|
||||
|
||||
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
|
||||
const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x;
|
||||
const int ntzw = args.nchannels_y * args.nsamples_y;
|
||||
const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
|
||||
|
||||
@@ -3574,14 +3576,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
||||
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
||||
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
||||
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -3601,7 +3605,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
||||
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
|
||||
if (!fixup_needed) {
|
||||
return;
|
||||
@@ -3609,14 +3614,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||
|
||||
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
||||
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
|
||||
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
|
||||
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
||||
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
||||
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
|
||||
if (!fixup_needed) {
|
||||
return;
|
||||
@@ -3624,7 +3631,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||
|
||||
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
||||
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
|
||||
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
|
||||
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3649,7 +3657,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
|
||||
continue;
|
||||
}
|
||||
|
||||
const int ntiles_x = (args.ncols_y + mmq_x - 1) / mmq_x;
|
||||
const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x;
|
||||
|
||||
if (ntiles_x < ntiles_x_best) {
|
||||
mmq_x_best = mmq_x;
|
||||
|
||||
@@ -28,7 +28,58 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
|
||||
return ((const int *) x)[i32]; // assume at least 4 byte alignment
|
||||
}
|
||||
|
||||
// q4 contains 8 indices with 4 bit each.
|
||||
// This function selects those bytes from table that are at those indices and returns them as int2.
|
||||
// The first int contains the bytes with even indices in q4, the second int contains the bytes with odd indices in q4.
|
||||
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
|
||||
#if defined(GGML_USE_HIP)
|
||||
// Load the 16-byte table into four 32-bit unsigned integers.
|
||||
const uint32_t *values = (const uint32_t *)table;
|
||||
|
||||
const uint32_t q_even = q4;
|
||||
const uint32_t q_odd = (q4 >> 4);
|
||||
|
||||
// Perform lookups in the lower half of the table (indices 0-7).
|
||||
uint32_t v_even_low = __builtin_amdgcn_perm(values[1], values[0], q_even & 0x07070707);
|
||||
uint32_t v_odd_low = __builtin_amdgcn_perm(values[1], values[0], q_odd & 0x07070707);
|
||||
|
||||
// Perform lookups in the upper half of the table (indices 8-15).
|
||||
uint32_t v_even_high = __builtin_amdgcn_perm(values[3], values[2], q_even & 0x07070707);
|
||||
uint32_t v_odd_high = __builtin_amdgcn_perm(values[3], values[2], q_odd & 0x07070707);
|
||||
|
||||
// Select between the low and high results based on the MSB of each index nibble.
|
||||
uint32_t mask_even = 0x03020100 | ((q_even & 0x08080808) >> 1);
|
||||
uint32_t res_x = __builtin_amdgcn_perm(v_even_high, v_even_low, mask_even);
|
||||
uint32_t mask_odd = 0x03020100 | ((q_odd & 0x08080808) >> 1);
|
||||
uint32_t res_y = __builtin_amdgcn_perm(v_odd_high, v_odd_low, mask_odd);
|
||||
|
||||
return make_int2(res_x, res_y);
|
||||
#elif !defined(GGML_USE_MUSA)
|
||||
// CUDA does not have an instruction for selecting bytes with 4 bit indices.
|
||||
// However, __byte_perm is an instruction that selects bytes with 3 bit indices that can be used instead.
|
||||
const uint32_t * table32 = (const uint32_t *) table;
|
||||
|
||||
// __byte_perm selects bytes based on the lower 16 bits in its third argument.
|
||||
// Therefore, do 2 iterations over the 32 bits in q4 with 0 and 16 shift.
|
||||
// To handle the fourth bit, first call _byte_perm both for the low and the high 64 bit of table, using the low 3 bits.
|
||||
// Then, call __byte_perm again to select from the low and high bytes based on the fourth bit.
|
||||
uint32_t tmp[2];
|
||||
const uint32_t low_high_selection_indices = (0x32103210 | ((q4 & 0x88888888) >> 1));
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 2; ++i) {
|
||||
const uint32_t shift = 16 * i;
|
||||
|
||||
const uint32_t low = __byte_perm(table32[0], table32[1], q4 >> shift);
|
||||
const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift);
|
||||
tmp[i] = __byte_perm(low, high, low_high_selection_indices >> shift);
|
||||
}
|
||||
|
||||
// tmp contains the bytes from tyble in the same order as the 4 bit indices in q4.
|
||||
// However, for the result we need ints with all even/odd 4 bit indices in q4.
|
||||
// Therefore, 2 more calls to __byte_perm to put the bytes in the correct order.
|
||||
return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531));
|
||||
#else
|
||||
// Generic implementation.
|
||||
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
|
||||
const int8_t * q0_8 = (const int8_t *) &q0_32;
|
||||
const char4 val0_8 = make_char4(
|
||||
@@ -40,6 +91,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
|
||||
table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
|
||||
|
||||
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
|
||||
#endif
|
||||
}
|
||||
|
||||
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
||||
|
||||
Vendored
+3
@@ -22,7 +22,10 @@
|
||||
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
|
||||
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
|
||||
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
|
||||
#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width)
|
||||
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
||||
#define __all_sync(mask, var) __all(var)
|
||||
#define __any_sync(mask, var) __any(var)
|
||||
#define cublasCreate hipblasCreate
|
||||
#define cublasDestroy hipblasDestroy
|
||||
#define cublasGemmEx hipblasGemmEx
|
||||
|
||||
@@ -93,35 +93,37 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
||||
if (ctx->mtl_device == nil) {
|
||||
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
||||
|
||||
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
||||
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
||||
if (ctx->mtl_device) {
|
||||
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
||||
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
||||
|
||||
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
||||
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
||||
|
||||
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
|
||||
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
|
||||
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
|
||||
#endif
|
||||
|
||||
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
||||
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
||||
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
||||
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
||||
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
ctx->use_bfloat = ctx->has_bfloat;
|
||||
ctx->use_bfloat = ctx->has_bfloat;
|
||||
#else
|
||||
ctx->use_bfloat = false;
|
||||
ctx->use_bfloat = false;
|
||||
#endif
|
||||
ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
|
||||
ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
|
||||
|
||||
{
|
||||
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
|
||||
ctx->debug_fusion = val ? atoi(val) : 0;
|
||||
{
|
||||
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
|
||||
ctx->debug_fusion = val ? atoi(val) : 0;
|
||||
}
|
||||
|
||||
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
|
||||
|
||||
ctx->max_size = ctx->mtl_device.maxBufferLength;
|
||||
|
||||
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
||||
}
|
||||
|
||||
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
|
||||
|
||||
ctx->max_size = ctx->mtl_device.maxBufferLength;
|
||||
|
||||
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
||||
}
|
||||
|
||||
ctx->mtl_device_ref_count++;
|
||||
@@ -443,6 +445,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
||||
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
||||
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
||||
@@ -452,6 +455,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
|
||||
@@ -461,6 +465,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
||||
@@ -470,6 +475,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
|
||||
@@ -479,6 +485,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
|
||||
@@ -488,6 +495,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
|
||||
@@ -497,6 +505,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
|
||||
@@ -506,6 +515,13 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
|
||||
@@ -1459,6 +1475,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40, flash_attn_ext_f16_h40, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
|
||||
@@ -1468,6 +1485,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40, flash_attn_ext_bf16_h40, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
|
||||
@@ -1477,6 +1495,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40, flash_attn_ext_q4_0_h40, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
|
||||
@@ -1486,6 +1505,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40, flash_attn_ext_q4_1_h40, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
|
||||
@@ -1495,6 +1515,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40, flash_attn_ext_q5_0_h40, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
|
||||
@@ -1504,6 +1525,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40, flash_attn_ext_q5_1_h40, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
|
||||
@@ -1513,6 +1535,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40, flash_attn_ext_q8_0_h40, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
|
||||
@@ -1522,6 +1545,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40, flash_attn_ext_vec_f16_h40, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40, flash_attn_ext_vec_bf16_h40, has_simdgroup_reduction && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40, flash_attn_ext_vec_q4_0_h40, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40, flash_attn_ext_vec_q4_1_h40, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40, flash_attn_ext_vec_q5_0_h40, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40, flash_attn_ext_vec_q5_1_h40, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40, flash_attn_ext_vec_q8_0_h40, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
|
||||
@@ -5130,6 +5160,7 @@ static int ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40 ].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
||||
@@ -5154,6 +5185,7 @@ static int ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40 ].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
|
||||
@@ -5178,6 +5210,7 @@ static int ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40 ].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
|
||||
@@ -5202,6 +5235,7 @@ static int ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40 ].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
|
||||
@@ -5226,6 +5260,7 @@ static int ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40 ].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
|
||||
@@ -5250,6 +5285,7 @@ static int ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40 ].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
|
||||
@@ -5274,6 +5310,7 @@ static int ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40 ].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
|
||||
@@ -5301,6 +5338,24 @@ static int ggml_metal_encode_node(
|
||||
use_vec_kernel = true;
|
||||
|
||||
switch (ne00) {
|
||||
case 40:
|
||||
{
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40].pipeline; break;
|
||||
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40].pipeline; break;
|
||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40].pipeline; break;
|
||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40].pipeline; break;
|
||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40].pipeline; break;
|
||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40].pipeline; break;
|
||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
||||
GGML_LOG_ERROR("add template specialization for this type\n");
|
||||
GGML_ABORT("add template specialization for this type");
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case 64:
|
||||
{
|
||||
switch (src1->type) {
|
||||
|
||||
@@ -4663,6 +4663,7 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
|
||||
@@ -4674,6 +4675,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
||||
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
|
||||
@@ -4685,6 +4687,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
||||
#endif
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
|
||||
@@ -4695,6 +4698,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
|
||||
@@ -4705,6 +4709,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
|
||||
@@ -4715,6 +4720,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
|
||||
@@ -4725,6 +4731,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
|
||||
@@ -5115,6 +5122,16 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 40, 40, 8>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 40, 40, 8>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 40, 40, 8>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 40, 40, 8>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 40, 40, 8>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 40, 40, 8>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 40, 40, 8>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 8>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 8>;
|
||||
|
||||
@@ -2647,8 +2647,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
return true;
|
||||
case GGML_OP_RMS_NORM:
|
||||
return op->ne[0] % 4 == 0 && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_REPEAT:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
|
||||
case GGML_OP_PAD:
|
||||
|
||||
@@ -4391,10 +4391,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return true;
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_ARGSORT:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,20 +1,34 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#if ADD_RMS
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#endif
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_binary_head.comp"
|
||||
|
||||
const uint num_threads = 256;
|
||||
|
||||
layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];};
|
||||
|
||||
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
#if ADD_RMS
|
||||
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
|
||||
shared FLOAT_TYPE sumsh[num_threads];
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
uint idx = get_idx();
|
||||
uint orig_idx = idx;
|
||||
|
||||
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
|
||||
const uint num_iter = 2;
|
||||
|
||||
FLOAT_TYPE sum_sq = 0;
|
||||
|
||||
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
||||
if (idx >= p.ne) {
|
||||
continue;
|
||||
@@ -22,8 +36,34 @@ void main() {
|
||||
uint i00, i01, i02, i03;
|
||||
get_indices(idx, i00, i01, i02, i03);
|
||||
|
||||
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]);
|
||||
sum_sq += sum*sum;
|
||||
|
||||
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
|
||||
|
||||
idx += num_threads;
|
||||
}
|
||||
|
||||
#if ADD_RMS
|
||||
if (p.param3 != 0) {
|
||||
// reduce the sum within each subgroup, then across subgroups
|
||||
const uint NumSubgroups = num_threads / gl_SubgroupSize;
|
||||
sum_sq = subgroupAdd(sum_sq);
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
sumsh[gl_SubgroupID] = sum_sq;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
|
||||
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
|
||||
sum_sq += sumsh[gl_SubgroupID + s];
|
||||
sumsh[gl_SubgroupID] = sum_sq;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
|
||||
partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -9,6 +9,10 @@ layout (constant_id = 4) const uint32_t HSV = 32;
|
||||
layout (constant_id = 5) const uint32_t Clamp = 0;
|
||||
layout (constant_id = 6) const uint32_t D_split = 16;
|
||||
|
||||
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
|
||||
const uint32_t HSK_pad = (HSK + 15) & ~15;
|
||||
const uint32_t HSV_pad = (HSV + 15) & ~15;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t N;
|
||||
uint32_t KV;
|
||||
|
||||
@@ -46,14 +46,14 @@ const uint32_t MatBc = 16;
|
||||
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
||||
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
|
||||
|
||||
const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4
|
||||
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 Qf[Br * qstride];
|
||||
|
||||
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
|
||||
const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
|
||||
shared ACC_TYPE sfsh[Bc * sfshstride];
|
||||
|
||||
const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4
|
||||
const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 ksh[Bc * kshstride];
|
||||
|
||||
shared float slope[Br];
|
||||
@@ -74,6 +74,21 @@ void main() {
|
||||
|
||||
#define tile_row(r) (row_tid * rows_per_thread + (r))
|
||||
|
||||
// Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK).
|
||||
if ((HSK % 16) != 0) {
|
||||
[[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) {
|
||||
if (i + tid < Br * qstride) {
|
||||
Qf[i + tid] = f16vec4(0);
|
||||
}
|
||||
}
|
||||
[[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
|
||||
if (i + tid < Bc * kshstride) {
|
||||
ksh[i + tid] = f16vec4(0);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
@@ -151,14 +166,14 @@ void main() {
|
||||
}
|
||||
barrier();
|
||||
|
||||
// K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br
|
||||
// K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br
|
||||
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
|
||||
// This is written transposed in order to allow for N being 8 if implementations need it
|
||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
|
||||
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
||||
|
||||
for (uint32_t d = 0; d < HSK / 16; ++d) {
|
||||
for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
|
||||
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
|
||||
|
||||
@@ -104,16 +104,16 @@ void main() {
|
||||
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
||||
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
|
||||
|
||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseAccumulator> Q;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA> Qf16;
|
||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
|
||||
|
||||
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
|
||||
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK));
|
||||
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
|
||||
|
||||
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA>(Q);
|
||||
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
|
||||
Qf16 *= float16_t(p.scale);
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
|
||||
|
||||
@@ -140,10 +140,10 @@ void main() {
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, HSK, Bc, gl_MatrixUseB> K_T;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
|
||||
|
||||
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC);
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
|
||||
S = coopMatMulAdd(Qf16, K_T, S);
|
||||
|
||||
if (p.logit_softcap != 0.0f) {
|
||||
@@ -208,31 +208,31 @@ void main() {
|
||||
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
|
||||
rowsum = coopMatMulAdd(P_A, One, rowsum);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV, gl_MatrixUseB> V;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
|
||||
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC);
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC);
|
||||
|
||||
L = eM*L + rowsum;
|
||||
|
||||
// This is the "diagonal" matrix in the paper, but since we do componentwise
|
||||
// multiply rather than matrix multiply it has the diagonal element smeared
|
||||
// across the row
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> eMdiag;
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> eMdiag;
|
||||
|
||||
// resize eM by using smear/reduce
|
||||
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||
|
||||
// multiply with fp16 accumulation, then add to O.
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
|
||||
PV = coopMatMulAdd(P_A, V, PV);
|
||||
|
||||
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(PV);
|
||||
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(PV);
|
||||
}
|
||||
|
||||
// If there is split_k, then the split_k resolve shader does the final
|
||||
// division by L. Store the intermediate O value and per-row m and L values.
|
||||
if (p.k_num > 1) {
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
|
||||
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
||||
@@ -243,16 +243,16 @@ void main() {
|
||||
return;
|
||||
}
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Ldiag;
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Ldiag;
|
||||
|
||||
// resize L by using smear/reduce
|
||||
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||
|
||||
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> S;
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> S;
|
||||
coopMatPerElementNV(S, S, perElemOpGetSink, iq2);
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Mr;
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Mr;
|
||||
|
||||
// resize M by using smear/reduce
|
||||
coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||
@@ -285,7 +285,7 @@ void main() {
|
||||
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
||||
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
|
||||
if (p.gqa_ratio > 1) {
|
||||
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
||||
} else {
|
||||
@@ -295,6 +295,6 @@ void main() {
|
||||
// permute dimensions
|
||||
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
|
||||
|
||||
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute);
|
||||
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV_pad), tensorViewPermute);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,9 @@
|
||||
#ifdef COOPMAT
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#endif
|
||||
|
||||
#if defined(COOPMAT) || defined(MUL_MAT_ID_USE_SUBGROUPS)
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_ballot : enable
|
||||
#endif
|
||||
@@ -103,16 +106,79 @@ layout (constant_id = 10) const uint WARP = 32;
|
||||
shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
|
||||
shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[4096];
|
||||
uint _ne1;
|
||||
#ifdef COOPMAT
|
||||
shared uint _ne1_sh;
|
||||
#endif
|
||||
#endif // MUL_MAT_ID
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[BN];
|
||||
uint _ne1;
|
||||
|
||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||
shared uvec4 ballots_sh[NUM_WARPS];
|
||||
|
||||
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
uint nei0shift = findLSB(p.nei0);
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_LocalInvocationIndex;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
|
||||
ballots_sh[gl_SubgroupID] = ballot;
|
||||
barrier();
|
||||
|
||||
uint subgroup_base = 0;
|
||||
uint total = 0;
|
||||
for (uint k = 0; k < gl_NumSubgroups; ++k) {
|
||||
if (k == gl_SubgroupID) {
|
||||
subgroup_base = total;
|
||||
}
|
||||
total += subgroupBallotBitCount(ballots_sh[k]);
|
||||
}
|
||||
barrier();
|
||||
|
||||
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
|
||||
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1 += total;
|
||||
iter &= 15;
|
||||
if (_ne1 >= (ic + 1) * BN) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
#endif // MUL_MAT_ID_USE_SUBGROUPS
|
||||
#endif // MUL_MAT_ID
|
||||
|
||||
#ifdef COOPMAT
|
||||
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
||||
#endif
|
||||
@@ -177,51 +243,20 @@ void main() {
|
||||
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
#ifdef COOPMAT
|
||||
// Spread the search across all elements in the first subgroup
|
||||
if (gl_SubgroupID == 0) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_SubgroupInvocationID;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx) {
|
||||
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1 += subgroupBallotBitCount(ballot);
|
||||
iter &= 15;
|
||||
}
|
||||
_ne1_sh = _ne1;
|
||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||
if (bitCount(p.nei0) == 1) {
|
||||
load_row_ids(expert_idx, true, ic);
|
||||
} else {
|
||||
load_row_ids(expert_idx, false, ic);
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
_ne1 = _ne1_sh;
|
||||
#else
|
||||
_ne1 = 0;
|
||||
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
||||
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
|
||||
for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
|
||||
for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
|
||||
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
|
||||
row_ids[_ne1] = u16vec2(ii0, ii1);
|
||||
if (_ne1 >= ic * BN) {
|
||||
row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1++;
|
||||
}
|
||||
}
|
||||
@@ -767,7 +802,7 @@ void main() {
|
||||
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
|
||||
#if LOAD_VEC_B == 8
|
||||
#ifdef MUL_MAT_ID
|
||||
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
|
||||
const u16vec2 row_idx = row_ids[loadc_b + l];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
#else
|
||||
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
@@ -783,7 +818,7 @@ void main() {
|
||||
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
|
||||
#elif LOAD_VEC_B == 4
|
||||
#ifdef MUL_MAT_ID
|
||||
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
|
||||
const u16vec2 row_idx = row_ids[loadc_b + l];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
#else
|
||||
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
@@ -802,7 +837,7 @@ void main() {
|
||||
#else
|
||||
const uint row_i = ic * BN + loadc_b + l;
|
||||
if (row_i < _ne1 && block + loadr_b < end_k) {
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
const u16vec2 row_idx = row_ids[loadc_b + l];
|
||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
|
||||
} else {
|
||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
||||
@@ -873,7 +908,7 @@ void main() {
|
||||
const uint row_i = dc + cm_col * TN + col + store_c;
|
||||
if (row_i >= _ne1) break;
|
||||
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
const u16vec2 row_idx = row_ids[row_i - ic * BN];
|
||||
|
||||
if (dr + cm_row * TM + store_r < p.M) {
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
@@ -923,7 +958,7 @@ void main() {
|
||||
const uint row_i = dc_warp + cc;
|
||||
if (row_i >= _ne1) break;
|
||||
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
const u16vec2 row_idx = row_ids[row_i - ic * BN];
|
||||
#endif // MUL_MAT_ID
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
#ifdef MUL_MAT_ID
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#endif
|
||||
|
||||
#include "types.comp"
|
||||
#include "utils.comp"
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
@@ -92,14 +93,15 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
#ifdef MUL_MAT_ID
|
||||
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
||||
|
||||
shared u16vec4 row_ids[4096];
|
||||
shared u16vec4 row_ids[BN];
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
|
||||
B_TYPE b[];
|
||||
};
|
||||
|
||||
uint _ne1;
|
||||
shared uint _ne1_sh;
|
||||
layout (constant_id = 5) const uint subgroup_size = 32;
|
||||
shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
|
||||
|
||||
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
@@ -109,7 +111,7 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i
|
||||
return B_TYPE(0.0);
|
||||
}
|
||||
|
||||
const u16vec4 row_idx = row_ids[row_i];
|
||||
const u16vec4 row_idx = row_ids[row_i & (BN - 1)];
|
||||
B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
|
||||
|
||||
return ret;
|
||||
@@ -121,13 +123,74 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
|
||||
uint dc = ic * BN + c;
|
||||
|
||||
if (dr < p.M && dc < _ne1) {
|
||||
uint row_i = dc;
|
||||
uint row_i = c;
|
||||
const u16vec4 row_idx = row_ids[row_i];
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
|
||||
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
uint nei0shift = findLSB(p.nei0);
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_LocalInvocationIndex;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
|
||||
ballots_sh[gl_SubgroupID] = ballot;
|
||||
barrier();
|
||||
|
||||
uint subgroup_base = 0;
|
||||
uint total = 0;
|
||||
for (uint k = 0; k < gl_NumSubgroups; ++k) {
|
||||
if (k == gl_SubgroupID) {
|
||||
subgroup_base = total;
|
||||
}
|
||||
total += subgroupBallotBitCount(ballots_sh[k]);
|
||||
}
|
||||
barrier();
|
||||
|
||||
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
|
||||
row_ids[_ne1 + idx - ic * BN] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
|
||||
}
|
||||
_ne1 += total;
|
||||
iter &= 15;
|
||||
if (_ne1 >= (ic + 1) * BN) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
@@ -157,45 +220,12 @@ void main() {
|
||||
const uint ic = gl_WorkGroupID.y;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
// Spread the search across all elements in the first subgroup
|
||||
if (gl_SubgroupID == 0) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_SubgroupInvocationID;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx) {
|
||||
row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
|
||||
}
|
||||
_ne1 += subgroupBallotBitCount(ballot);
|
||||
iter &= 15;
|
||||
}
|
||||
_ne1_sh = _ne1;
|
||||
if (bitCount(p.nei0) == 1) {
|
||||
load_row_ids(expert_idx, true, ic);
|
||||
} else {
|
||||
load_row_ids(expert_idx, false, ic);
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
_ne1 = _ne1_sh;
|
||||
|
||||
// Workgroup has no work
|
||||
if (ic * BN >= _ne1) return;
|
||||
#endif
|
||||
|
||||
@@ -3,6 +3,10 @@
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_nonuniform_qualifier : enable
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#if ADD_RMS
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#endif
|
||||
|
||||
#include "rte.comp"
|
||||
#include "types.comp"
|
||||
@@ -14,11 +18,18 @@ layout (push_constant) uniform parameter2
|
||||
uint ne20; uint ne21; uint ne22; uint ne23;
|
||||
|
||||
// strides for srcs+dst
|
||||
uint nb[8][4];
|
||||
uint nb[12][4];
|
||||
|
||||
uint rms_partials;
|
||||
} p;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
|
||||
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
|
||||
// Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
|
||||
// layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
|
||||
// layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
|
||||
layout (binding = 0) buffer A {A_TYPE data_a[];} a[];
|
||||
layout (binding = 0) buffer D {D_TYPE data_d[];} d[];
|
||||
|
||||
layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[];
|
||||
|
||||
layout(constant_id = 0) const uint num_srcs = 2;
|
||||
|
||||
@@ -42,14 +53,22 @@ const uint num_threads = 256;
|
||||
|
||||
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
#if ADD_RMS
|
||||
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
|
||||
shared FLOAT_TYPE sumsh[num_threads];
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
uint idx = get_idx();
|
||||
uint orig_idx = idx;
|
||||
|
||||
uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23;
|
||||
|
||||
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
|
||||
const uint num_iter = 2;
|
||||
|
||||
FLOAT_TYPE sum_sq = 0;
|
||||
|
||||
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
||||
if (idx >= ne) {
|
||||
continue;
|
||||
@@ -61,8 +80,32 @@ void main() {
|
||||
[[unroll]] for (uint s = 0; s < num_srcs; ++s) {
|
||||
sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]);
|
||||
}
|
||||
sum_sq += sum*sum;
|
||||
d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
|
||||
|
||||
idx += num_threads;
|
||||
}
|
||||
|
||||
#if ADD_RMS
|
||||
if (p.rms_partials != 0) {
|
||||
// reduce the sum within each subgroup, then across subgroups
|
||||
const uint NumSubgroups = num_threads / gl_SubgroupSize;
|
||||
sum_sq = subgroupAdd(sum_sq);
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
sumsh[gl_SubgroupID] = sum_sq;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
|
||||
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
|
||||
sum_sq += sumsh[gl_SubgroupID + s];
|
||||
sumsh[gl_SubgroupID] = sum_sq;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
|
||||
partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -10,9 +10,9 @@ layout (constant_id = 1) const bool do_multiply = false;
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
shared FLOAT_TYPE sum[BLOCK_SIZE];
|
||||
shared FLOAT_TYPE sumsh[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
void rms_norm(uint num_iters) {
|
||||
const uint ncols = p.ne00;
|
||||
const uint nrows = gl_NumWorkGroups.x;
|
||||
const uint nchannels = gl_NumWorkGroups.y;
|
||||
@@ -30,38 +30,76 @@ void main() {
|
||||
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
|
||||
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
|
||||
|
||||
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]);
|
||||
sum[tid] += xi * xi;
|
||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||
FLOAT_TYPE xi = FLOAT_TYPE(0);
|
||||
if (col < ncols) {
|
||||
xi = FLOAT_TYPE(data_a[a_offset + col]);
|
||||
}
|
||||
sum += xi * xi;
|
||||
}
|
||||
|
||||
sumsh[tid] = sum;
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
sum[tid] += sum[tid + s];
|
||||
sum += sumsh[tid + s];
|
||||
sumsh[tid] = sum;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
sum = sumsh[0];
|
||||
|
||||
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
|
||||
const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols);
|
||||
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
|
||||
|
||||
if (do_multiply) {
|
||||
if (ncols > p.ne10) {
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||
if (col >= ncols) {
|
||||
continue;
|
||||
}
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||
if (col >= ncols) {
|
||||
continue;
|
||||
}
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||
if (col >= ncols) {
|
||||
continue;
|
||||
}
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void main() {
|
||||
// instantiate the rms_norm function for several different
|
||||
// dimensions, to allow loop unrolling
|
||||
uint num_blocks = (p.ne00 + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
if (num_blocks > 32) {
|
||||
rms_norm(num_blocks);
|
||||
} else if (num_blocks > 16) {
|
||||
rms_norm(32);
|
||||
} else if (num_blocks > 8) {
|
||||
rms_norm(16);
|
||||
} else if (num_blocks > 4) {
|
||||
rms_norm(8);
|
||||
} else if (num_blocks == 4) {
|
||||
rms_norm(4);
|
||||
} else if (num_blocks == 3) {
|
||||
rms_norm(3);
|
||||
} else if (num_blocks == 2) {
|
||||
rms_norm(2);
|
||||
} else if (num_blocks == 1) {
|
||||
rms_norm(1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_binary_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
|
||||
#define BLOCK_SIZE 128
|
||||
|
||||
layout (constant_id = 1) const bool do_multiply = false;
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 3, std430) readonly buffer PartialsBuf {float partial_sums[];};
|
||||
|
||||
shared FLOAT_TYPE sumsh[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint ncols = p.ne00;
|
||||
const uint nrows = gl_NumWorkGroups.x;
|
||||
const uint nchannels = gl_NumWorkGroups.y;
|
||||
|
||||
const uint row = 0;
|
||||
const uint channel = gl_WorkGroupID.y;
|
||||
const uint samp = gl_WorkGroupID.z;
|
||||
// The work is split across multiple workgroups in the x dimension. Each invocation
|
||||
// processes one element
|
||||
const uint tid = gl_GlobalInvocationID.x;
|
||||
|
||||
const uint stride_row = p.nb01;
|
||||
const uint stride_channel = p.nb02;
|
||||
const uint stride_sample = p.nb03;
|
||||
|
||||
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
|
||||
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
|
||||
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
|
||||
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
|
||||
uint32_t num_partials = p.param3;
|
||||
for (uint32_t i = gl_SubgroupInvocationID; i < num_partials; i += gl_SubgroupSize) {
|
||||
sum += partial_sums[i];
|
||||
}
|
||||
sum = subgroupAdd(sum);
|
||||
|
||||
uint col = tid;
|
||||
if (col >= ncols) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols);
|
||||
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
|
||||
|
||||
if (do_multiply) {
|
||||
if (ncols > p.ne10) {
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
|
||||
} else {
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
|
||||
}
|
||||
} else {
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
@@ -11,16 +11,49 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint n_cols;
|
||||
uint ne01, ne02;
|
||||
uint nb01, nb02, nb03;
|
||||
uint nb11, nb12, nb13;
|
||||
float weight;
|
||||
uint misalign_offsets;
|
||||
uint ne0_12mp, ne0_12L;
|
||||
uint ne0_1mp, ne0_1L;
|
||||
} p;
|
||||
|
||||
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
||||
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
||||
|
||||
// see init_fastdiv_values in ggml-vulkan.cpp
|
||||
uint fastdiv(uint n, uint mp, uint L) {
|
||||
uint msbs, lsbs;
|
||||
// msbs = mulhi(n, mp)
|
||||
umulExtended(n, mp, msbs, lsbs);
|
||||
return (msbs + n) >> L;
|
||||
}
|
||||
|
||||
|
||||
shared FLOAT_TYPE tmp[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint col = gl_LocalInvocationID.x;
|
||||
const float weight = p.weight;
|
||||
|
||||
tmp[col] = FLOAT_TYPE(0.0f);
|
||||
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
|
||||
const uint i03_offset = i03 * p.ne01*p.ne02;
|
||||
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
|
||||
const uint i01 = row - i03_offset - i02*p.ne01;
|
||||
|
||||
for (uint i = col; i < p.KX; i += BLOCK_SIZE) {
|
||||
tmp[col] += FLOAT_TYPE(data_a[row*p.KX + i]);
|
||||
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
|
||||
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
|
||||
|
||||
tmp[col] = FLOAT_TYPE(0.0);
|
||||
|
||||
for (uint i = col; i < p.n_cols; i += BLOCK_SIZE) {
|
||||
tmp[col] += FLOAT_TYPE(data_a[src_idx + i]);
|
||||
}
|
||||
|
||||
barrier();
|
||||
@@ -32,6 +65,6 @@ void main() {
|
||||
}
|
||||
|
||||
if (col == 0) {
|
||||
data_d[row] = D_TYPE(tmp[0]);
|
||||
data_d[dst_idx] = D_TYPE(tmp[0] * weight);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,6 +68,12 @@ const std::vector<std::string> type_names = {
|
||||
"bf16",
|
||||
};
|
||||
|
||||
enum MatMulIdType {
|
||||
NONE,
|
||||
DEFAULT,
|
||||
SUBGROUP,
|
||||
};
|
||||
|
||||
namespace {
|
||||
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
|
||||
#ifdef _WIN32
|
||||
@@ -293,7 +299,7 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
|
||||
compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
|
||||
}
|
||||
|
||||
void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) {
|
||||
void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) {
|
||||
std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
|
||||
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
|
||||
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
||||
@@ -303,9 +309,13 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
||||
};
|
||||
std::string shader_name = "matmul";
|
||||
|
||||
if (matmul_id) {
|
||||
if (matmul_id_type == MatMulIdType::DEFAULT) {
|
||||
base_dict["MUL_MAT_ID"] = "1";
|
||||
shader_name = "matmul_id";
|
||||
} else if (matmul_id_type == MatMulIdType::SUBGROUP) {
|
||||
base_dict["MUL_MAT_ID"] = "1";
|
||||
base_dict["MUL_MAT_ID_USE_SUBGROUPS"] = "1";
|
||||
shader_name = "matmul_id_subgroup";
|
||||
}
|
||||
|
||||
if (fp16) {
|
||||
@@ -389,7 +399,7 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
||||
}
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
|
||||
if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
|
||||
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
#endif
|
||||
@@ -401,26 +411,28 @@ void process_shaders() {
|
||||
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
|
||||
|
||||
// matmul
|
||||
for (const auto& matmul_id : {false, true}) {
|
||||
for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
|
||||
// No coopmats
|
||||
// fp32
|
||||
matmul_shaders(false, matmul_id, false, false, false);
|
||||
matmul_shaders(false, matmul_id_type, false, false, false);
|
||||
|
||||
// fp16, fp32acc and fp16acc
|
||||
matmul_shaders(true, matmul_id, false, false, false);
|
||||
matmul_shaders(true, matmul_id, false, false, true);
|
||||
matmul_shaders(true, matmul_id_type, false, false, false);
|
||||
matmul_shaders(true, matmul_id_type, false, false, true);
|
||||
|
||||
if (matmul_id_type != MatMulIdType::DEFAULT) {
|
||||
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
// Coopmat, fp32acc and fp16acc
|
||||
matmul_shaders(true, matmul_id, true, false, false);
|
||||
matmul_shaders(true, matmul_id, true, false, true);
|
||||
// Coopmat, fp32acc and fp16acc
|
||||
matmul_shaders(true, matmul_id_type, true, false, false);
|
||||
matmul_shaders(true, matmul_id_type, true, false, true);
|
||||
#endif
|
||||
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
// Coopmat2, fp32acc and fp16acc
|
||||
matmul_shaders(true, matmul_id, false, true, false);
|
||||
matmul_shaders(true, matmul_id, false, true, true);
|
||||
// Coopmat2, fp32acc and fp16acc
|
||||
matmul_shaders(true, matmul_id_type, false, true, false);
|
||||
matmul_shaders(true, matmul_id_type, false, true, true);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// flash attention
|
||||
@@ -503,6 +515,7 @@ void process_shaders() {
|
||||
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
@@ -538,13 +551,15 @@ void process_shaders() {
|
||||
s += std::string(dst_f16 ? "_f16" : "_f32");
|
||||
return s;
|
||||
};
|
||||
for (std::string op : {"add", "sub", "mul", "div"}) {
|
||||
for (std::string op : {"add", "sub", "mul", "div", "add_rms", }) {
|
||||
for (auto src0_f16 : {false, true}) {
|
||||
for (auto src1_f16 : {false, true}) {
|
||||
for (auto dst_f16 : {false, true}) {
|
||||
for (auto rte : {false, true}) {
|
||||
auto source = op == "add_rms" ? std::string("add") : op;
|
||||
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
|
||||
string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
auto add_rms = op == "add_rms" ? "1" : "0";
|
||||
string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -687,7 +702,8 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
|
||||
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
|
||||
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
@@ -745,7 +761,7 @@ void write_output_files() {
|
||||
}
|
||||
|
||||
std::string suffixes[2] = {"_f32", "_f16"};
|
||||
for (const char *op : {"add", "sub", "mul", "div"}) {
|
||||
for (const char *op : {"add", "sub", "mul", "div", "add_rms"}) {
|
||||
fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op);
|
||||
fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op);
|
||||
std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = ";
|
||||
|
||||
@@ -20,8 +20,8 @@ add_custom_command(
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory ${SHADER_OUTPUT_DIR}
|
||||
COMMAND ${CMAKE_COMMAND} -E env PYTHONIOENCODING=utf-8
|
||||
${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
|
||||
--input "${SHADER_DIR}"
|
||||
--output "${SHADER_HEADER}"
|
||||
--input_dir "${SHADER_DIR}"
|
||||
--output_file "${SHADER_HEADER}"
|
||||
DEPENDS ${WGSL_SHADER_FILES} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
@@ -118,13 +118,11 @@ struct webgpu_context_struct {
|
||||
|
||||
std::recursive_mutex mutex;
|
||||
|
||||
bool device_init = false;
|
||||
|
||||
webgpu_buf_pool param_buf_pool;
|
||||
webgpu_buf_pool set_rows_error_buf_pool;
|
||||
|
||||
wgpu::ComputePipeline memset_pipeline;
|
||||
wgpu::ComputePipeline mul_mat_pipeline;
|
||||
wgpu::ComputePipeline mul_mat_pipeline[30][2];
|
||||
wgpu::ComputePipeline set_rows_pipeline;
|
||||
wgpu::ComputePipeline cpy_pipeline;
|
||||
|
||||
@@ -238,7 +236,7 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
||||
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data);
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
|
||||
}
|
||||
}),
|
||||
UINT64_MAX);
|
||||
@@ -278,7 +276,7 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
||||
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data);
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
|
||||
}
|
||||
// Free the staged buffers
|
||||
ctx->param_buf_pool.free_bufs(staged_param_bufs);
|
||||
@@ -294,7 +292,7 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
||||
if (status != wgpu::MapAsyncStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", message.data);
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
|
||||
} else {
|
||||
const uint32_t * error_data = (const uint32_t *) error_bufs.host_buf.GetConstMappedRange();
|
||||
if (*error_data) {
|
||||
@@ -331,6 +329,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
|
||||
// To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
|
||||
// debug statements in the shader, and then call this function after encoding the commands and submitting them.
|
||||
static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
|
||||
ggml_backend_webgpu_submit_queue(ctx);
|
||||
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
|
||||
encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
|
||||
wgpu::CommandBuffer commands = encoder.Finish();
|
||||
@@ -421,15 +420,6 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, true);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_webgpu_tensor_offset(const ggml_tensor * tensor) {
|
||||
return webgpu_tensor_offset(tensor) + tensor->view_offs;
|
||||
}
|
||||
|
||||
static wgpu::Buffer ggml_backend_webgpu_tensor_buf(const ggml_tensor * tensor) {
|
||||
ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
|
||||
return ctx->buffer;
|
||||
}
|
||||
|
||||
/** End WebGPU Actions */
|
||||
|
||||
/** GGML Backend Interface */
|
||||
@@ -447,19 +437,36 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
|
||||
return webgpu_tensor_offset(tensor) + tensor->view_offs;
|
||||
}
|
||||
|
||||
static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
|
||||
ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
|
||||
return ctx->buffer;
|
||||
}
|
||||
|
||||
static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, ggml_tensor * t) {
|
||||
size_t offset = ggml_webgpu_tensor_offset(t);
|
||||
return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
}
|
||||
|
||||
static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor * t) {
|
||||
size_t offset = ggml_webgpu_tensor_offset(t);
|
||||
return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
}
|
||||
|
||||
static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) {
|
||||
return (ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t) + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
|
||||
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
size_t src_offset = ggml_backend_webgpu_tensor_offset(src);
|
||||
// assumes power of 2 offset alignment
|
||||
size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
// align to minimum offset alignment
|
||||
src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst);
|
||||
size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
|
||||
std::vector<uint32_t> params = { ne,
|
||||
(uint32_t) (src_misalignment / ggml_type_size(src->type)),
|
||||
(uint32_t) (dst_misalignment / ggml_type_size(dst->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
// Convert byte-strides to element-strides
|
||||
(uint32_t) (src->nb[0] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
||||
@@ -477,15 +484,13 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_backend_webgpu_tensor_buf(src),
|
||||
.offset = src_offset,
|
||||
.size = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
|
||||
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) },
|
||||
.buffer = ggml_webgpu_tensor_buf(src),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_backend_webgpu_tensor_buf(dst),
|
||||
.offset = dst_offset,
|
||||
.size = (ggml_nbytes(dst) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
|
||||
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) }
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
|
||||
@@ -504,21 +509,9 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
||||
error_bufs.host_buf.Unmap();
|
||||
}
|
||||
|
||||
size_t src_offset = ggml_backend_webgpu_tensor_offset(src);
|
||||
// assumes power of 2 offset alignment
|
||||
size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
// align to minimum offset alignment
|
||||
src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
size_t idx_offset = ggml_backend_webgpu_tensor_offset(idx);
|
||||
size_t idx_misalignment = idx_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
idx_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst);
|
||||
size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
|
||||
std::vector<uint32_t> params = { (uint32_t) (src_misalignment / ggml_type_size(src->type)),
|
||||
(uint32_t) (idx_misalignment / ggml_type_size(idx->type)),
|
||||
(uint32_t) (dst_misalignment / ggml_type_size(dst->type)),
|
||||
std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
// Convert byte-strides to element-strides
|
||||
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
||||
@@ -540,18 +533,18 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_backend_webgpu_tensor_buf(src),
|
||||
.offset = ggml_backend_webgpu_tensor_offset(src),
|
||||
.size = ggml_nbytes(src) },
|
||||
.buffer = ggml_webgpu_tensor_buf(src),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_backend_webgpu_tensor_buf(idx),
|
||||
.offset = ggml_backend_webgpu_tensor_offset(idx),
|
||||
.size = ggml_nbytes(idx) },
|
||||
.buffer = ggml_webgpu_tensor_buf(idx),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, idx),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, idx) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_backend_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_backend_webgpu_tensor_offset(dst),
|
||||
.size = ggml_nbytes(dst) },
|
||||
{ .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
|
||||
{ .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
|
||||
};
|
||||
|
||||
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
|
||||
@@ -565,15 +558,18 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
||||
|
||||
static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
(uint32_t) dst->ne[1], // number of rows in result (M)
|
||||
(uint32_t) dst->ne[0], // number of columns in result (N)
|
||||
(uint32_t) src0->ne[0], // number of columns in src0/src1 (K)
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 1
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 1
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 2
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 2
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 3
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 3
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 2
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 2
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 3
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 3
|
||||
(uint32_t) src0->ne[2], // batch size in dimension 2
|
||||
(uint32_t) src0->ne[3], // batch size in dimension 3
|
||||
(uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2
|
||||
@@ -582,22 +578,22 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_backend_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_backend_webgpu_tensor_offset(src0),
|
||||
.size = ggml_nbytes(src0) },
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_backend_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_backend_webgpu_tensor_offset(src1),
|
||||
.size = ggml_nbytes(src1) },
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_backend_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_backend_webgpu_tensor_offset(dst),
|
||||
.size = ggml_nbytes(dst) }
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
|
||||
};
|
||||
|
||||
uint32_t wg_x =
|
||||
(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline, params, entries, wg_x);
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x);
|
||||
}
|
||||
|
||||
// Returns true if node has enqueued work into the queue, false otherwise
|
||||
@@ -827,7 +823,7 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b
|
||||
wgpu::Buffer buf;
|
||||
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device,
|
||||
buf,
|
||||
size,
|
||||
(size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1),
|
||||
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
|
||||
"allocated_buffer");
|
||||
|
||||
@@ -907,7 +903,94 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline, wgsl_mul_mat, "mul_mat");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_f32_f32,
|
||||
"mul_mat_f32_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
|
||||
wgsl_mul_mat_f16_f16,
|
||||
"mul_mat_f16_f16");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_f16_f32,
|
||||
"mul_mat_f16_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q4_0_f32,
|
||||
"mul_mat_q4_0_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q4_1_f32,
|
||||
"mul_mat_q4_1_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q5_0_f32,
|
||||
"mul_mat_q5_0_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q5_1_f32,
|
||||
"mul_mat_q5_1_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q8_0_f32,
|
||||
"mul_mat_q8_0_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q2_k_f32,
|
||||
"mul_mat_q2_k_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q3_k_f32,
|
||||
"mul_mat_q3_k_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q4_k_f32,
|
||||
"mul_mat_q4_k_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q5_k_f32,
|
||||
"mul_mat_q5_k_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q6_k_f32,
|
||||
"mul_mat_q6_k_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq2_xxs_f32,
|
||||
"mul_mat_iq2_xxs_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq2_xs_f32,
|
||||
"mul_mat_iq2_xs_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq2_s_f32,
|
||||
"mul_mat_iq2_s_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq3_xxs_f32,
|
||||
"mul_mat_iq3_xxs_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq3_s_f32,
|
||||
"mul_mat_iq3_s_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq1_s_f32,
|
||||
"mul_mat_iq1_s_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq1_m_f32,
|
||||
"mul_mat_iq1_m_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq4_nl_f32,
|
||||
"mul_mat_iq4_nl_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq4_xs_f32,
|
||||
"mul_mat_iq4_xs_f32");
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
|
||||
@@ -933,79 +1016,6 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
|
||||
ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
|
||||
webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;
|
||||
|
||||
// Multiple threads may try to initialize the device
|
||||
std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex);
|
||||
if (!webgpu_ctx->device_init) {
|
||||
// Initialize device
|
||||
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
|
||||
wgpu::FeatureName::ImplicitDeviceSynchronization };
|
||||
wgpu::DeviceDescriptor dev_desc;
|
||||
dev_desc.requiredLimits = &webgpu_ctx->limits;
|
||||
dev_desc.requiredFeatures = required_features.data();
|
||||
dev_desc.requiredFeatureCount = required_features.size();
|
||||
dev_desc.SetDeviceLostCallback(
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
||||
GGML_UNUSED(device);
|
||||
GGML_LOG_ERROR(
|
||||
"ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
|
||||
});
|
||||
dev_desc.SetUncapturedErrorCallback(
|
||||
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
|
||||
GGML_UNUSED(device);
|
||||
GGML_LOG_ERROR(
|
||||
"ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
|
||||
});
|
||||
webgpu_ctx->instance.WaitAny(
|
||||
webgpu_ctx->adapter.RequestDevice(
|
||||
&dev_desc,
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
|
||||
if (status != wgpu::RequestDeviceStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data);
|
||||
return;
|
||||
}
|
||||
webgpu_ctx->device = std::move(device);
|
||||
}),
|
||||
UINT64_MAX);
|
||||
GGML_ASSERT(webgpu_ctx->device != nullptr);
|
||||
|
||||
// Initialize (compute) queue
|
||||
webgpu_ctx->queue = webgpu_ctx->device.GetQueue();
|
||||
|
||||
// Create buffer pool for shader parameters
|
||||
webgpu_ctx->param_buf_pool.init(webgpu_ctx->device,
|
||||
WEBGPU_NUM_PARAM_BUFS,
|
||||
WEBGPU_PARAMS_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
|
||||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
|
||||
webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->device,
|
||||
WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
|
||||
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
|
||||
|
||||
ggml_webgpu_init_memset_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_set_rows_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
|
||||
|
||||
#ifdef GGML_WEBGPU_DEBUG
|
||||
// Initialize debug buffers
|
||||
ggml_webgpu_create_buffer(webgpu_ctx->device,
|
||||
webgpu_ctx->debug_host_buf,
|
||||
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
|
||||
"debug_host_buf");
|
||||
ggml_webgpu_create_buffer(webgpu_ctx->device,
|
||||
webgpu_ctx->debug_dev_buf,
|
||||
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
|
||||
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc,
|
||||
"debug_dev_buf");
|
||||
#endif
|
||||
webgpu_ctx->device_init = true;
|
||||
}
|
||||
|
||||
static ggml_backend_webgpu_context backend_ctx;
|
||||
backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
|
||||
backend_ctx.webgpu_ctx = webgpu_ctx;
|
||||
@@ -1053,10 +1063,45 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
return true;
|
||||
case GGML_OP_CPY | GGML_OP_SET_ROWS:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_SET_ROWS:
|
||||
return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_MUL_MAT:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
||||
{
|
||||
switch (op->src[1]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
return op->src[0]->type == GGML_TYPE_F16;
|
||||
case GGML_TYPE_F32:
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -1123,20 +1168,87 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||
wgpu::AdapterInfo info{};
|
||||
ctx->adapter.GetInfo(&info);
|
||||
|
||||
// Initialize device
|
||||
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
|
||||
wgpu::FeatureName::ImplicitDeviceSynchronization };
|
||||
wgpu::DeviceDescriptor dev_desc;
|
||||
dev_desc.requiredLimits = &ctx->limits;
|
||||
dev_desc.requiredFeatures = required_features.data();
|
||||
dev_desc.requiredFeatureCount = required_features.size();
|
||||
dev_desc.SetDeviceLostCallback(
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
||||
GGML_UNUSED(device);
|
||||
GGML_LOG_ERROR(
|
||||
"ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), std::string(message).c_str());
|
||||
});
|
||||
dev_desc.SetUncapturedErrorCallback(
|
||||
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
|
||||
GGML_UNUSED(device);
|
||||
GGML_LOG_ERROR(
|
||||
"ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), std::string(message).c_str());
|
||||
});
|
||||
ctx->instance.WaitAny(ctx->adapter.RequestDevice(
|
||||
&dev_desc,
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
|
||||
if (status != wgpu::RequestDeviceStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
|
||||
return;
|
||||
}
|
||||
ctx->device = std::move(device);
|
||||
}),
|
||||
UINT64_MAX);
|
||||
GGML_ASSERT(ctx->device != nullptr);
|
||||
|
||||
// Initialize (compute) queue
|
||||
ctx->queue = ctx->device.GetQueue();
|
||||
|
||||
// Create buffer pool for shader parameters
|
||||
ctx->param_buf_pool.init(ctx->device,
|
||||
WEBGPU_NUM_PARAM_BUFS,
|
||||
WEBGPU_PARAMS_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
|
||||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
|
||||
ctx->set_rows_error_buf_pool.init(ctx->device,
|
||||
WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
|
||||
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
|
||||
|
||||
ggml_webgpu_init_memset_pipeline(ctx);
|
||||
ggml_webgpu_init_mul_mat_pipeline(ctx);
|
||||
ggml_webgpu_init_set_rows_pipeline(ctx);
|
||||
ggml_webgpu_init_cpy_pipeline(ctx);
|
||||
|
||||
#ifdef GGML_WEBGPU_DEBUG
|
||||
// Initialize debug buffers
|
||||
ggml_webgpu_create_buffer(ctx->device,
|
||||
ctx->debug_host_buf,
|
||||
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
|
||||
"debug_host_buf");
|
||||
ggml_webgpu_create_buffer(ctx->device,
|
||||
ctx->debug_dev_buf,
|
||||
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
|
||||
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc,
|
||||
"debug_dev_buf");
|
||||
#endif
|
||||
|
||||
static ggml_backend_webgpu_device_context device_ctx;
|
||||
device_ctx.webgpu_ctx = ctx;
|
||||
device_ctx.device_name = GGML_WEBGPU_NAME;
|
||||
device_ctx.device_desc = std::string(info.description.data);
|
||||
device_ctx.device_desc = info.description;
|
||||
|
||||
GGML_LOG_INFO(
|
||||
"ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
|
||||
"device_desc: %s\n",
|
||||
info.vendorID,
|
||||
info.vendor.data,
|
||||
info.architecture.data,
|
||||
std::string(info.vendor).c_str(),
|
||||
std::string(info.architecture).c_str(),
|
||||
info.deviceID,
|
||||
info.device.data,
|
||||
info.description.data);
|
||||
std::string(info.device).c_str(),
|
||||
std::string(info.description).c_str());
|
||||
|
||||
// See GGML Backend Device Interface section
|
||||
static ggml_backend_device device = {
|
||||
|
||||
@@ -1,35 +1,85 @@
|
||||
import os
|
||||
import re
|
||||
import ast
|
||||
import argparse
|
||||
|
||||
|
||||
def escape_triple_quotes(wgsl):
|
||||
# Simple defense in case of embedded """
|
||||
return wgsl.replace('"""', '\\"""')
|
||||
def extract_block(text, name):
|
||||
pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)'
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
if not match:
|
||||
raise ValueError(f"Missing block: {name}")
|
||||
return match.group(1).strip()
|
||||
|
||||
|
||||
def to_cpp_string_literal(varname, content):
|
||||
return f'const char* wgsl_{varname} = R"({content})";\n'
|
||||
def parse_decls(decls_text):
|
||||
decls = {}
|
||||
for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL):
|
||||
decls[name.strip()] = code.strip()
|
||||
return decls
|
||||
|
||||
|
||||
def replace_placeholders(shader_text, replacements):
|
||||
for key, val in replacements.items():
|
||||
# Match {{KEY}} literally, where KEY is escaped
|
||||
pattern = r'{{\s*' + re.escape(key) + r'\s*}}'
|
||||
shader_text = re.sub(pattern, str(val), shader_text)
|
||||
return shader_text
|
||||
|
||||
|
||||
def write_shader(shader_name, shader_code, output_dir, outfile):
|
||||
if output_dir:
|
||||
wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl")
|
||||
with open(wgsl_filename, "w", encoding="utf-8") as f_out:
|
||||
f_out.write(shader_code)
|
||||
outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n')
|
||||
|
||||
|
||||
def generate_variants(shader_path, output_dir, outfile):
|
||||
shader_base_name = shader_path.split("/")[-1].split(".")[0]
|
||||
|
||||
with open(shader_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
|
||||
try:
|
||||
variants = ast.literal_eval(extract_block(text, "VARIANTS"))
|
||||
except ValueError:
|
||||
write_shader(shader_base_name, text, output_dir, outfile)
|
||||
else:
|
||||
decls_map = parse_decls(extract_block(text, "DECLS"))
|
||||
shader_template = extract_block(text, "SHADER")
|
||||
|
||||
for variant in variants:
|
||||
decls = variant["DECLS"]
|
||||
decls_code = ""
|
||||
for key in decls:
|
||||
if key not in decls_map:
|
||||
raise ValueError(f"DECLS key '{key}' not found.")
|
||||
decls_code += decls_map[key] + "\n\n"
|
||||
|
||||
shader_variant = replace_placeholders(shader_template, variant["REPLS"])
|
||||
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant)
|
||||
|
||||
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
|
||||
write_shader(output_name, final_shader, output_dir, outfile)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--input', required=True)
|
||||
parser.add_argument('--output', required=True)
|
||||
parser.add_argument("--input_dir", required=True)
|
||||
parser.add_argument("--output_file", required=True)
|
||||
parser.add_argument("--output_dir")
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.output, 'w', encoding='utf-8') as out:
|
||||
out.write("// Auto-generated shader embedding \n\n")
|
||||
for fname in sorted(os.listdir(args.input)):
|
||||
if not fname.endswith('.wgsl'):
|
||||
continue
|
||||
shader_path = os.path.join(args.input, fname)
|
||||
varname = os.path.splitext(fname)[0]
|
||||
with open(shader_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
content = escape_triple_quotes(content)
|
||||
out.write(to_cpp_string_literal(varname, content))
|
||||
out.write('\n')
|
||||
if args.output_dir:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
with open(args.output_file, "w", encoding="utf-8") as out:
|
||||
out.write("// Auto-generated shader embedding\n\n")
|
||||
for fname in sorted(os.listdir(args.input_dir)):
|
||||
if fname.endswith(".wgsl"):
|
||||
generate_variants(os.path.join(args.input_dir, fname), args.output_dir, out)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -19,20 +19,20 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let start = params.offset;
|
||||
let end = params.offset + params.size;
|
||||
|
||||
for (var j: u32 = 0u; j < bytes_per_thread; j = j + 1u) {
|
||||
for (var j: u32 = 0u; j < bytes_per_thread; j += 4) {
|
||||
let byte_index = start + i + j;
|
||||
if (byte_index + 4u <= end) {
|
||||
output_buffer[(byte_index >> 2u)] = params.value;
|
||||
if (byte_index + 4 <= end) {
|
||||
output_buffer[byte_index >> 2] = params.value;
|
||||
} else {
|
||||
// Handle tail (unaligned)
|
||||
for (var k: u32 = 0u; k < 4u; k = k + 1u) {
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let idx = byte_index + k;
|
||||
if (idx < end) {
|
||||
let word_idx = idx >> 2u;
|
||||
let byte_offset = (idx & 3u) * 8u;
|
||||
let mask = ~(0xffu << byte_offset);
|
||||
let word_idx = idx >> 2;
|
||||
let bit_offset = (idx & 3) * 8u;
|
||||
let mask = ~(0xffu << bit_offset);
|
||||
let existing = output_buffer[word_idx];
|
||||
output_buffer[word_idx] = (existing & mask) | ((params.value & 0xffu) << byte_offset);
|
||||
output_buffer[word_idx] = (existing & mask) | (params.value & (0xffu << bit_offset));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,56 +0,0 @@
|
||||
struct MulMatParams {
|
||||
m: u32,
|
||||
n: u32,
|
||||
k: u32,
|
||||
// all strides are in elements
|
||||
stride_01: u32,
|
||||
stride_11: u32,
|
||||
stride_02: u32,
|
||||
stride_12: u32,
|
||||
stride_03: u32,
|
||||
stride_13: u32,
|
||||
|
||||
bs02: u32,
|
||||
bs03: u32,
|
||||
broadcast2: u32,
|
||||
broadcast3: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<f32>; // N rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<f32>; // M rows, K columns (transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
|
||||
|
||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||
|
||||
@compute @workgroup_size(64)
|
||||
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
|
||||
if (global_id.x >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
let dst2_stride = params.m * params.n;
|
||||
let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
|
||||
|
||||
let dst3_idx = global_id.x / dst3_stride;
|
||||
let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
|
||||
let src13_idx = dst3_idx; // src1 is not broadcast
|
||||
let dst3_rem = global_id.x % dst3_stride;
|
||||
|
||||
let dst2_idx = dst3_rem / dst2_stride;
|
||||
let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
|
||||
let src12_idx = dst2_idx; // src1 is not broadcast
|
||||
|
||||
let dst2_rem = dst3_rem % dst2_stride;
|
||||
|
||||
let row = dst2_rem / params.n; // output row
|
||||
let col = dst2_rem % params.n; // output column
|
||||
|
||||
var sum = 0.0;
|
||||
for (var i: u32 = 0u; i < params.k; i = i + 1u) {
|
||||
let src0_idx = src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01 + i;
|
||||
let src1_idx = src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11 + i;
|
||||
sum = sum + src0[src0_idx] * src1[src1_idx];
|
||||
}
|
||||
dst[dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum;
|
||||
}
|
||||
+54
-2
@@ -975,6 +975,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"IM2COL",
|
||||
"IM2COL_BACK",
|
||||
"CONV_2D",
|
||||
"CONV_3D",
|
||||
"CONV_2D_DW",
|
||||
"CONV_TRANSPOSE_2D",
|
||||
"POOL_1D",
|
||||
@@ -1017,7 +1018,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"GLU",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
|
||||
static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@@ -1077,6 +1078,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"im2col(x)",
|
||||
"im2col_back(x)",
|
||||
"conv_2d(x)",
|
||||
"conv_3d(x)",
|
||||
"conv_2d_dw(x)",
|
||||
"conv_transpose_2d(x)",
|
||||
"pool_1d(x)",
|
||||
@@ -1119,7 +1121,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"glu(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
|
||||
static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -4480,6 +4482,56 @@ struct ggml_tensor * ggml_conv_2d_direct(
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_conv_3d
|
||||
|
||||
struct ggml_tensor * ggml_conv_3d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int s0,
|
||||
int s1,
|
||||
int s2,
|
||||
int p0,
|
||||
int p1,
|
||||
int p2,
|
||||
int d0,
|
||||
int d1,
|
||||
int d2,
|
||||
int c,
|
||||
int n,
|
||||
int oc) {
|
||||
|
||||
GGML_ASSERT(a->ne[3] == (int64_t) c * oc);
|
||||
GGML_ASSERT(b->ne[3] == (int64_t) c * n);
|
||||
|
||||
int64_t ne[4];
|
||||
ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
|
||||
ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
|
||||
ne[2] = ggml_calc_conv_output_size(b->ne[2], a->ne[2], s2, p2, d2);
|
||||
ne[3] = (int64_t) oc * n;
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, s0);
|
||||
ggml_set_op_params_i32(result, 1, s1);
|
||||
ggml_set_op_params_i32(result, 2, s2);
|
||||
ggml_set_op_params_i32(result, 3, p0);
|
||||
ggml_set_op_params_i32(result, 4, p1);
|
||||
ggml_set_op_params_i32(result, 5, p2);
|
||||
ggml_set_op_params_i32(result, 6, d0);
|
||||
ggml_set_op_params_i32(result, 7, d1);
|
||||
ggml_set_op_params_i32(result, 8, d2);
|
||||
ggml_set_op_params_i32(result, 9, c);
|
||||
ggml_set_op_params_i32(result, 10, n);
|
||||
ggml_set_op_params_i32(result, 11, oc);
|
||||
|
||||
result->op = GGML_OP_CONV_3D;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_conv_transpose_2d_p0
|
||||
|
||||
static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
|
||||
|
||||
@@ -385,6 +385,7 @@ class MODEL_ARCH(IntEnum):
|
||||
DREAM = auto()
|
||||
SMALLTHINKER = auto()
|
||||
LLADA = auto()
|
||||
SEED_OSS = auto()
|
||||
|
||||
|
||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
@@ -717,6 +718,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.DREAM: "dream",
|
||||
MODEL_ARCH.SMALLTHINKER: "smallthinker",
|
||||
MODEL_ARCH.LLADA: "llada",
|
||||
MODEL_ARCH.SEED_OSS: "seed_oss",
|
||||
}
|
||||
|
||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||
@@ -1973,6 +1975,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.SEED_OSS: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
],
|
||||
MODEL_ARCH.OLMOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
||||
@@ -37,7 +37,6 @@ LLAMA_BENCH_DB_TYPES = [
|
||||
"TEXT", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
|
||||
"TEXT", "INTEGER", "INTEGER", "TEXT", "TEXT", "INTEGER",
|
||||
"TEXT", "INTEGER", "INTEGER", "INTEGER", "TEXT", "TEXT",
|
||||
"REAL",
|
||||
"INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
|
||||
"TEXT", "INTEGER", "INTEGER", "REAL", "REAL",
|
||||
]
|
||||
|
||||
@@ -93,6 +93,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_DREAM, "dream" },
|
||||
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
|
||||
{ LLM_ARCH_LLADA, "llada" },
|
||||
{ LLM_ARCH_SEED_OSS, "seed_oss" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@@ -2068,6 +2069,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_SEED_OSS,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
|
||||
@@ -97,6 +97,7 @@ enum llm_arch {
|
||||
LLM_ARCH_DREAM,
|
||||
LLM_ARCH_SMALLTHINKER,
|
||||
LLM_ARCH_LLADA,
|
||||
LLM_ARCH_SEED_OSS,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
||||
+13
-2
@@ -16,10 +16,10 @@
|
||||
static std::string trim(const std::string & str) {
|
||||
size_t start = 0;
|
||||
size_t end = str.size();
|
||||
while (start < end && isspace(str[start])) {
|
||||
while (start < end && isspace(static_cast<unsigned char>(str[start]))) {
|
||||
start += 1;
|
||||
}
|
||||
while (end > start && isspace(str[end - 1])) {
|
||||
while (end > start && isspace(static_cast<unsigned char>(str[end - 1]))) {
|
||||
end -= 1;
|
||||
}
|
||||
return str.substr(start, end - start);
|
||||
@@ -69,6 +69,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE },
|
||||
{ "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE },
|
||||
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
|
||||
{ "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
|
||||
};
|
||||
|
||||
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
||||
@@ -201,6 +202,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||
return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE;
|
||||
} else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
|
||||
return LLM_CHAT_TEMPLATE_KIMI_K2;
|
||||
} else if (tmpl_contains("<seed:bos>")) {
|
||||
return LLM_CHAT_TEMPLATE_SEED_OSS;
|
||||
}
|
||||
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
||||
}
|
||||
@@ -752,6 +755,14 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "<|im_assistant|>assistant<|im_middle|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_SEED_OSS) {
|
||||
for (auto message: chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<seed:bos>" << role << "\n" << (role == "assistant" ? trim(message->content) : message->content) << "<seed:eos>";
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<seed:bos>assistant\n";
|
||||
}
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
|
||||
@@ -49,6 +49,7 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_OPENAI_MOE,
|
||||
LLM_CHAT_TEMPLATE_HUNYUAN_DENSE,
|
||||
LLM_CHAT_TEMPLATE_KIMI_K2,
|
||||
LLM_CHAT_TEMPLATE_SEED_OSS,
|
||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||
};
|
||||
|
||||
|
||||
@@ -153,3 +153,28 @@ bool llama_hparams::is_swa(uint32_t il) const {
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
bool llama_hparams::has_kv(uint32_t il) const {
|
||||
if (n_layer_kv_from_start >= 0) {
|
||||
if (il < (uint32_t) n_layer_kv_from_start) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// by default, all layers have kv
|
||||
return true;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_layer_kv() const {
|
||||
uint32_t res = 0;
|
||||
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
if (has_kv(il)) {
|
||||
res++;
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -41,6 +41,7 @@ struct llama_hparams {
|
||||
uint32_t n_embd;
|
||||
uint32_t n_embd_features = 0;
|
||||
uint32_t n_layer;
|
||||
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
|
||||
uint32_t n_rot;
|
||||
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
||||
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
||||
@@ -221,6 +222,11 @@ struct llama_hparams {
|
||||
uint32_t n_pos_per_embd() const;
|
||||
|
||||
bool is_swa(uint32_t il) const;
|
||||
|
||||
bool has_kv(uint32_t il) const;
|
||||
|
||||
// number of layers for which has_kv() returns true
|
||||
uint32_t n_layer_kv() const;
|
||||
};
|
||||
|
||||
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
||||
|
||||
@@ -22,9 +22,26 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad) : hparams(model.hparams), unified(unified) {
|
||||
llama_kv_cache::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
||||
llama_kv_cache::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
||||
uint32_t n_pad,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
|
||||
|
||||
// chain filters
|
||||
const layer_filter_cb filter_base = [&](int32_t il) {
|
||||
if (filter && !filter(il)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return !model.hparams.is_swa(il);
|
||||
};
|
||||
|
||||
const layer_filter_cb filter_swa = [&](int32_t il) {
|
||||
if (filter && !filter(il)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return model.hparams.is_swa(il);
|
||||
};
|
||||
|
||||
const uint32_t size_base = kv_size;
|
||||
|
||||
@@ -41,16 +58,16 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
|
||||
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
|
||||
|
||||
kv_base = std::make_unique<llama_kv_cache>(
|
||||
model, std::move(filter_base), type_k, type_v,
|
||||
model, type_k, type_v,
|
||||
v_trans, offload, unified, size_base, n_seq_max, n_pad,
|
||||
0, LLAMA_SWA_TYPE_NONE);
|
||||
0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
||||
|
||||
kv_swa = std::make_unique<llama_kv_cache>(
|
||||
model, std::move(filter_swa), type_k, type_v,
|
||||
model, type_k, type_v,
|
||||
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
|
||||
hparams.n_swa, hparams.swa_type);
|
||||
hparams.n_swa, hparams.swa_type, filter_swa, reuse);
|
||||
}
|
||||
|
||||
void llama_kv_cache_iswa::clear(bool data) {
|
||||
|
||||
@@ -20,11 +20,13 @@ public:
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool swa_full,
|
||||
bool ,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad);
|
||||
uint32_t n_pad,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse);
|
||||
|
||||
~llama_kv_cache_iswa() = default;
|
||||
|
||||
|
||||
+35
-33
@@ -17,32 +17,25 @@
|
||||
//
|
||||
|
||||
llama_kv_cache::llama_kv_cache(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type) :
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse) :
|
||||
model(model), hparams(model.hparams), v_trans(v_trans),
|
||||
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
|
||||
|
||||
GGML_ASSERT(kv_size % n_pad == 0);
|
||||
|
||||
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
|
||||
auto n_layer_cache = hparams.n_layer;
|
||||
if (model.arch == LLM_ARCH_GEMMA3N) {
|
||||
n_layer_cache = 20;
|
||||
}
|
||||
if (model.arch == LLM_ARCH_GLM4_MOE) {
|
||||
// GLM-4.5: Only process up to last layer, skip final NextN layer
|
||||
n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
}
|
||||
const uint32_t n_layer_kv = hparams.n_layer_kv();
|
||||
|
||||
// create a context for each buffer type
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||
@@ -50,7 +43,7 @@ llama_kv_cache::llama_kv_cache(
|
||||
auto it = ctx_map.find(buft);
|
||||
if (it == ctx_map.end()) {
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()),
|
||||
/*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
@@ -97,9 +90,14 @@ llama_kv_cache::llama_kv_cache(
|
||||
__func__, hparams.n_embd_v_gqa_max());
|
||||
}
|
||||
|
||||
for (uint32_t il = 0; il < n_layer_cache; il++) {
|
||||
for (uint32_t il = 0; il < hparams.n_layer; il++) {
|
||||
if (!hparams.has_kv(il)) {
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (filter && !filter(il)) {
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: filtered\n", __func__, il);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -147,23 +145,27 @@ llama_kv_cache::llama_kv_cache(
|
||||
layers.push_back({ il, k, v, k_stream, v_stream, });
|
||||
}
|
||||
|
||||
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
|
||||
if (model.arch == LLM_ARCH_GEMMA3N) {
|
||||
LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
|
||||
if (reuse) {
|
||||
LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__);
|
||||
|
||||
for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
|
||||
if (filter && !filter(il)) {
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
|
||||
for (uint32_t il = 0; il < hparams.n_layer; il++) {
|
||||
const int32_t il_reuse = reuse(il);
|
||||
|
||||
if (il_reuse < 0) {
|
||||
LLAMA_LOG_DEBUG("%s: - layer %3d: no reuse\n", __func__, il);
|
||||
continue;
|
||||
}
|
||||
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
|
||||
if (filter && !filter(il)) {
|
||||
LLAMA_LOG_DEBUG("%s: - layer %3d: filtered\n", __func__, il);
|
||||
continue;
|
||||
}
|
||||
|
||||
GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
|
||||
|
||||
map_layer_ids[il] = map_layer_ids[il_reuse];
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
|
||||
LLAMA_LOG_DEBUG("%s: - layer %3d: reuse layer %d, is_swa = %d\n", __func__, il, il_reuse, hparams.is_swa(il));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+13
-15
@@ -21,9 +21,6 @@ class llama_kv_cache : public llama_memory_i {
|
||||
public:
|
||||
static uint32_t get_padding(const llama_cparams & cparams);
|
||||
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
struct stream_copy_info {
|
||||
bool empty() const {
|
||||
assert(ssrc.size() == sdst.size());
|
||||
@@ -82,18 +79,19 @@ public:
|
||||
using slot_info_vec_t = std::vector<slot_info>;
|
||||
|
||||
llama_kv_cache(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type);
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse);
|
||||
|
||||
~llama_kv_cache() = default;
|
||||
|
||||
|
||||
+29
-28
@@ -9,32 +9,29 @@
|
||||
//
|
||||
|
||||
llama_memory_hybrid::llama_memory_hybrid(
|
||||
const llama_model & model,
|
||||
/* attn */
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
/* recurrent */
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload,
|
||||
bool unified,
|
||||
/* layer filters */
|
||||
layer_filter_cb && filter_attn,
|
||||
layer_filter_cb && filter_recr) :
|
||||
const llama_model & model,
|
||||
/* attn */
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
/* recurrent */
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload,
|
||||
bool unified,
|
||||
/* layer filters */
|
||||
const layer_filter_cb & filter_attn,
|
||||
const layer_filter_cb & filter_recr) :
|
||||
hparams(model.hparams),
|
||||
mem_attn(new llama_kv_cache(
|
||||
model,
|
||||
filter_attn == nullptr ?
|
||||
[&](int32_t il) { return !hparams.is_recurrent(il); }
|
||||
: filter_attn,
|
||||
type_k,
|
||||
type_v,
|
||||
v_trans,
|
||||
@@ -44,18 +41,22 @@ llama_memory_hybrid::llama_memory_hybrid(
|
||||
n_seq_max,
|
||||
n_pad,
|
||||
n_swa,
|
||||
swa_type
|
||||
swa_type,
|
||||
filter_attn == nullptr ?
|
||||
[&](int32_t il) { return !hparams.is_recurrent(il); }
|
||||
: filter_attn,
|
||||
nullptr
|
||||
)),
|
||||
mem_recr(new llama_memory_recurrent(
|
||||
model,
|
||||
filter_recr == nullptr ?
|
||||
[&](int32_t il) { return hparams.is_recurrent(il); }
|
||||
: filter_recr,
|
||||
type_r,
|
||||
type_s,
|
||||
offload,
|
||||
rs_size,
|
||||
n_seq_max
|
||||
n_seq_max,
|
||||
filter_recr == nullptr ?
|
||||
[&](int32_t il) { return hparams.is_recurrent(il); }
|
||||
: filter_recr
|
||||
)) {}
|
||||
|
||||
llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||
|
||||
+18
-22
@@ -18,31 +18,27 @@
|
||||
|
||||
class llama_memory_hybrid : public llama_memory_i {
|
||||
public:
|
||||
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
llama_memory_hybrid(
|
||||
const llama_model & model,
|
||||
/* attn */
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
/* recurrent */
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload,
|
||||
bool unified,
|
||||
/* layer filters */
|
||||
layer_filter_cb && filter_attn = nullptr,
|
||||
layer_filter_cb && filter_recr = nullptr);
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
/* recurrent */
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload,
|
||||
bool unified,
|
||||
/* layer filters */
|
||||
const layer_filter_cb & filter_attn = nullptr,
|
||||
const layer_filter_cb & filter_recr = nullptr);
|
||||
|
||||
~llama_memory_hybrid() = default;
|
||||
|
||||
|
||||
@@ -16,13 +16,13 @@
|
||||
//
|
||||
|
||||
llama_memory_recurrent::llama_memory_recurrent(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
bool offload,
|
||||
uint32_t mem_size,
|
||||
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
||||
const llama_model & model,
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
bool offload,
|
||||
uint32_t mem_size,
|
||||
uint32_t n_seq_max,
|
||||
const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
||||
const int32_t n_layer = hparams.n_layer;
|
||||
|
||||
head = 0;
|
||||
|
||||
@@ -15,18 +15,14 @@
|
||||
// see the implementation of llama_kv_cache_context_i for an example how to do it
|
||||
class llama_memory_recurrent : public llama_memory_i {
|
||||
public:
|
||||
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
llama_memory_recurrent(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
bool offload,
|
||||
uint32_t mem_size,
|
||||
uint32_t n_seq_max);
|
||||
const llama_model & model,
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
bool offload,
|
||||
uint32_t mem_size,
|
||||
uint32_t n_seq_max,
|
||||
const layer_filter_cb & filter);
|
||||
|
||||
~llama_memory_recurrent() = default;
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "llama.h"
|
||||
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
|
||||
struct llama_ubatch;
|
||||
|
||||
@@ -64,6 +65,13 @@ using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
|
||||
// general concept of LLM memory
|
||||
// the KV cache is a type of LLM memory, but there can be other types
|
||||
struct llama_memory_i {
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
// this callback is used to specify which layers should reuse memory from other layers
|
||||
// return negative value to indicate that the layer il should not reuse memory
|
||||
using layer_reuse_cb = std::function<int32_t(int32_t il)>;
|
||||
|
||||
virtual ~llama_memory_i() = default;
|
||||
|
||||
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
||||
|
||||
+210
-11
@@ -83,6 +83,7 @@ const char * llm_type_name(llm_type type) {
|
||||
case LLM_TYPE_32B: return "32B";
|
||||
case LLM_TYPE_34B: return "34B";
|
||||
case LLM_TYPE_35B: return "35B";
|
||||
case LLM_TYPE_36B: return "36B";
|
||||
case LLM_TYPE_40B: return "40B";
|
||||
case LLM_TYPE_65B: return "65B";
|
||||
case LLM_TYPE_70B: return "70B";
|
||||
@@ -1114,6 +1115,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
hparams.set_swa_pattern(5);
|
||||
|
||||
hparams.n_layer_kv_from_start = 20;
|
||||
hparams.rope_freq_base_train_swa = 10000.0f;
|
||||
hparams.rope_freq_scale_train_swa = 1.0f;
|
||||
hparams.f_attention_scale = 1.0f;
|
||||
@@ -1288,6 +1290,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_SEED_OSS:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
switch (hparams.n_layer) {
|
||||
case 64: type = LLM_TYPE_36B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_OLMOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
@@ -1465,12 +1475,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
// Expert gating function (GLM-4.5 uses sigmoid)
|
||||
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
|
||||
if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
|
||||
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
|
||||
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
|
||||
}
|
||||
|
||||
// NextN/MTP parameters
|
||||
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
|
||||
|
||||
// TODO: when MTP is implemented, this should probably be updated if needed
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer)
|
||||
case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer)
|
||||
@@ -3967,6 +3980,43 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_SEED_OSS:
|
||||
{
|
||||
const uint32_t head_dim = hparams.n_embd_head_k;
|
||||
const int64_t n_qo_dim = n_head * head_dim;
|
||||
const int64_t n_kv_dim = n_head_kv * head_dim;
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, 0);
|
||||
|
||||
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_qo_dim}, TENSOR_NOT_REQUIRED);
|
||||
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED);
|
||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
|
||||
case LLM_ARCH_OLMOE:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
@@ -10478,7 +10528,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
|
||||
const int64_t n_embd_altup;
|
||||
const int64_t n_altup;
|
||||
const int i_altup_act;
|
||||
const int n_layer_kv = 20; // number of layers having KV [KV_REUSE]
|
||||
const int n_layer_sparsity = 10; // number of layers using activation sparsity
|
||||
const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
|
||||
|
||||
@@ -10528,8 +10577,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// this block is made to be closely resemble Gemma3p5DecoderLayer on python code
|
||||
const bool has_kv = (il < n_layer_kv);
|
||||
|
||||
const float freq_base_l = model.get_rope_freq_base (cparams, il);
|
||||
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
||||
|
||||
@@ -10549,7 +10596,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
|
||||
ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens]
|
||||
|
||||
// self-attention
|
||||
if (has_kv) {
|
||||
if (hparams.has_kv(il)) {
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
@@ -10589,7 +10636,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
|
||||
} else {
|
||||
// no KV layers
|
||||
// reuse KV cache of earlier layers
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
@@ -17934,6 +17981,137 @@ struct llm_build_lfm2 : public llm_graph_context {
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_seed_oss : public llm_graph_context {
|
||||
llm_build_seed_oss(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
}
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
}
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
}
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].attn_post_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_post_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool iswa>
|
||||
struct llm_build_smallthinker : public llm_graph_context{
|
||||
llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){
|
||||
@@ -18079,12 +18257,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
if (llm_arch_is_recurrent(arch)) {
|
||||
res = new llama_memory_recurrent(
|
||||
*this,
|
||||
nullptr,
|
||||
GGML_TYPE_F32,
|
||||
GGML_TYPE_F32,
|
||||
cparams.offload_kqv,
|
||||
std::max((uint32_t) 1, cparams.n_seq_max),
|
||||
cparams.n_seq_max);
|
||||
cparams.n_seq_max,
|
||||
nullptr);
|
||||
} else if (llm_arch_is_hybrid(arch)) {
|
||||
const auto padding = llama_kv_cache::get_padding(cparams);
|
||||
|
||||
@@ -18125,6 +18303,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
||||
|
||||
llama_memory_i::layer_reuse_cb reuse = nullptr;
|
||||
|
||||
if (arch == LLM_ARCH_GEMMA3N) {
|
||||
reuse = [&](int32_t il) {
|
||||
if (il >= (int32_t) hparams.n_layer_kv_from_start) {
|
||||
return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1);
|
||||
}
|
||||
|
||||
return -1;
|
||||
};
|
||||
}
|
||||
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
GGML_ASSERT(hparams.is_swa_any());
|
||||
|
||||
@@ -18139,13 +18329,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
n_ctx_per_stream,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_ubatch,
|
||||
padding);
|
||||
padding,
|
||||
nullptr,
|
||||
reuse);
|
||||
} else {
|
||||
GGML_ASSERT(!hparams.is_swa_any());
|
||||
|
||||
res = new llama_kv_cache(
|
||||
*this,
|
||||
nullptr,
|
||||
params.type_k,
|
||||
params.type_v,
|
||||
!cparams.flash_attn,
|
||||
@@ -18155,7 +18346,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
cparams.n_seq_max,
|
||||
padding,
|
||||
hparams.n_swa,
|
||||
hparams.swa_type);
|
||||
hparams.swa_type,
|
||||
nullptr,
|
||||
nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -18472,6 +18665,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||
{
|
||||
llm = std::make_unique<llm_build_bailingmoe>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_SEED_OSS:
|
||||
{
|
||||
llm = std::make_unique<llm_build_seed_oss>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_DOTS1:
|
||||
{
|
||||
llm = std::make_unique<llm_build_dots1>(*this, params);
|
||||
@@ -18530,6 +18727,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||
return llm->res->get_gf();
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// interface implementation
|
||||
//
|
||||
@@ -18724,6 +18922,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_LFM2:
|
||||
case LLM_ARCH_SMALLTHINKER:
|
||||
case LLM_ARCH_GLM4_MOE:
|
||||
case LLM_ARCH_SEED_OSS:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
|
||||
@@ -76,6 +76,7 @@ enum llm_type {
|
||||
LLM_TYPE_32B,
|
||||
LLM_TYPE_34B,
|
||||
LLM_TYPE_35B,
|
||||
LLM_TYPE_36B,
|
||||
LLM_TYPE_40B,
|
||||
LLM_TYPE_65B,
|
||||
LLM_TYPE_70B,
|
||||
|
||||
+177
-8
@@ -2209,6 +2209,26 @@ struct test_count_equal : public test_case {
|
||||
double max_nmse_err() override {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
std::random_device rd;
|
||||
std::default_random_engine rng(rd());
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_F32) {
|
||||
// initialize with unique values to avoid ties
|
||||
for (int64_t r = 0; r < ggml_nrows(t); r++) {
|
||||
std::vector<float> data(t->ne[0]);
|
||||
for (int i = 0; i < t->ne[0]; i++) {
|
||||
data[i] = i;
|
||||
}
|
||||
std::shuffle(data.begin(), data.end(), rng);
|
||||
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
|
||||
}
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_REPEAT
|
||||
@@ -2858,6 +2878,7 @@ struct test_rms_norm_mul_add : public test_case {
|
||||
const std::array<int64_t, 4> ne;
|
||||
const float eps;
|
||||
const bool broadcast;
|
||||
const bool multi_add; // test a sequence of adds feeding into rms_norm
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
@@ -2867,13 +2888,13 @@ struct test_rms_norm_mul_add : public test_case {
|
||||
bool run_whole_graph() override { return true; }
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, ne, eps, broadcast);
|
||||
return VARS_TO_STR5(type, ne, eps, broadcast, multi_add);
|
||||
}
|
||||
|
||||
test_rms_norm_mul_add(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {64, 5, 4, 3},
|
||||
float eps = 1e-6f, bool broadcast = false)
|
||||
: type(type), ne(ne), eps(eps), broadcast(broadcast) {}
|
||||
float eps = 1e-6f, bool broadcast = false, bool multi_add = false)
|
||||
: type(type), ne(ne), eps(eps), broadcast(broadcast), multi_add(multi_add) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
std::array<int64_t, 4> broadcast_dims = {ne[0]*2, ne[1]*3, ne[2]*3, ne[3]*4};
|
||||
@@ -2891,6 +2912,9 @@ struct test_rms_norm_mul_add : public test_case {
|
||||
|
||||
// Use a, b and c early, so we don't end up with an OP_NONE between rms_norm and mul
|
||||
a = ggml_add(ctx, ggml_add(ctx, a, b), c);
|
||||
if (multi_add) {
|
||||
a = ggml_add(ctx, ggml_add(ctx, a, b), c);
|
||||
}
|
||||
ggml_tensor * out = ggml_add(ctx, ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b), c);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
@@ -4091,6 +4115,75 @@ struct test_conv_2d_dw : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_CONV_3D
|
||||
struct test_conv_3d : public test_case {
|
||||
// Logical 5D dimensions
|
||||
const int64_t N, IC, ID, IH, IW;
|
||||
const int64_t OC, KD, KH, KW;
|
||||
// Conv params
|
||||
const int s0, s1, s2;
|
||||
const int p0, p1, p2;
|
||||
const int d0, d1, d2;
|
||||
// Types
|
||||
const ggml_type type_kernel;
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
return "CONV_3D";
|
||||
}
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR11(N, IC, ID, IH, IW, OC, KD, KH, KW, s0, s1) + "," +
|
||||
VARS_TO_STR8(s2, p0, p1, p2, d0, d1, d2, type_kernel);
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
return 5e-4;
|
||||
}
|
||||
|
||||
uint64_t op_flops(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
|
||||
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
|
||||
};
|
||||
const int64_t OD = calc_conv_output_size(ID, KD, s2, p2, d2);
|
||||
const int64_t OH = calc_conv_output_size(IH, KH, s1, p1, d1);
|
||||
const int64_t OW = calc_conv_output_size(IW, KW, s0, p0, d0);
|
||||
|
||||
return (uint64_t)N * OC * OD * OH * OW * (2 * IC * KD * KH * KW - 1);
|
||||
}
|
||||
|
||||
test_conv_3d(
|
||||
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW,
|
||||
int64_t OC, int64_t KD, int64_t KH, int64_t KW,
|
||||
int s0, int s1, int s2,
|
||||
int p0, int p1, int p2,
|
||||
int d0, int d1, int d2,
|
||||
ggml_type type_kernel
|
||||
) : N(N), IC(IC), ID(ID), IH(IH), IW(IW),
|
||||
OC(OC), KD(KD), KH(KH), KW(KW),
|
||||
s0(s0), s1(s1), s2(s2),
|
||||
p0(p0), p1(p1), p2(p2),
|
||||
d0(d0), d1(d1), d2(d2),
|
||||
type_kernel(type_kernel) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
// GGML input tensor is packed as [W, H, D, C*N]
|
||||
const int64_t ne_input[] = {IW, IH, ID, IC * N};
|
||||
ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input);
|
||||
ggml_set_name(input, "input");
|
||||
|
||||
// GGML kernel tensor is packed as [KW, KH, KD, IC*OC]
|
||||
const int64_t ne_kernel[] = {KW, KH, KD, IC * OC};
|
||||
ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel);
|
||||
ggml_set_name(kernel, "kernel");
|
||||
|
||||
ggml_tensor * out = ggml_conv_3d(ctx, kernel, input, s0, s1, s2, p0, p1, p2, d0, d1, d2, (int)IC, (int)N, (int)OC);
|
||||
ggml_set_name(out, "out");
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_CONCAT
|
||||
struct test_concat : public test_case {
|
||||
const ggml_type type;
|
||||
@@ -4231,20 +4324,32 @@ struct test_sum : public test_case {
|
||||
struct test_sum_rows : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const bool permute;
|
||||
const bool slice;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR2(type, ne);
|
||||
return VARS_TO_STR4(type, ne, permute, slice);
|
||||
}
|
||||
|
||||
test_sum_rows(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {10, 5, 4, 3})
|
||||
: type(type), ne(ne) {}
|
||||
std::array<int64_t, 4> ne = {10, 5, 4, 3},
|
||||
bool permute = false, bool slice = false)
|
||||
: type(type), ne(ne), permute(permute), slice(slice) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_param(a);
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
if (slice) {
|
||||
a = ggml_view_4d(ctx, a,
|
||||
ne[0], ne[1], ne[2] / 2, ne[3] - 1,
|
||||
a->nb[1], a->nb[2] * 2, a->nb[3], /*offset=*/a->nb[3]);
|
||||
}
|
||||
if (permute) {
|
||||
a = ggml_permute(ctx, a, 0, 2, 3, 1);
|
||||
}
|
||||
|
||||
ggml_tensor * out = ggml_sum_rows(ctx, a);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
@@ -5528,6 +5633,61 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, false));
|
||||
test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, true));
|
||||
|
||||
// CONV_3D
|
||||
auto calc_conv_output_size_3d = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
|
||||
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
|
||||
};
|
||||
|
||||
for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
for (int N : {1, 2}) {
|
||||
for (int IC : {1, 3}) {
|
||||
for (int OC : {1, 4}) {
|
||||
for (int s0 : {1, 2}) {
|
||||
for (int p1 : {0, 1}) {
|
||||
for (int d2 : {1, 2}) {
|
||||
int64_t IW = 20, IH = 22, ID = 18;
|
||||
int64_t KW = 3, KH = 3, KD = 3;
|
||||
int s1 = s0, s2 = s0;
|
||||
int p0 = p1, p2 = p1;
|
||||
int d0 = d2, d1 = d2;
|
||||
|
||||
if (calc_conv_output_size_3d(IW, KW, s0, p0, d0) <= 0 ||
|
||||
calc_conv_output_size_3d(IH, KH, s1, p1, d1) <= 0 ||
|
||||
calc_conv_output_size_3d(ID, KD, s2, p2, d2) <= 0) {
|
||||
continue;
|
||||
}
|
||||
test_cases.emplace_back(new test_conv_3d(
|
||||
N, IC, ID, IH, IW,
|
||||
OC, KD, KH, KW,
|
||||
s0, s1, s2, p0, p1, p2, d0, d1, d2,
|
||||
kernel_type));
|
||||
|
||||
// Asymmetric kernel and params
|
||||
int64_t asym_KW = 5, asym_KH = 1, asym_KD = 3;
|
||||
int asym_s0 = 2, asym_s1 = 1, asym_s2 = 1;
|
||||
int asym_p0 = 2, asym_p1 = 0, asym_p2 = 1;
|
||||
int asym_d0 = 1, asym_d1 = 1, asym_d2 = 2;
|
||||
|
||||
if (calc_conv_output_size_3d(IW, asym_KW, asym_s0, asym_p0, asym_d0) <= 0 ||
|
||||
calc_conv_output_size_3d(IH, asym_KH, asym_s1, asym_p1, asym_d1) <= 0 ||
|
||||
calc_conv_output_size_3d(ID, asym_KD, asym_s2, asym_p2, asym_d2) <= 0) {
|
||||
continue;
|
||||
}
|
||||
test_cases.emplace_back(new test_conv_3d(
|
||||
N, IC, ID, IH, IW,
|
||||
OC, asym_KD, asym_KH, asym_KW,
|
||||
asym_s0, asym_s1, asym_s2, asym_p0, asym_p1, asym_p2, asym_d0, asym_d1, asym_d2,
|
||||
kernel_type));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Case with kernel size 1
|
||||
test_cases.emplace_back(new test_conv_3d(1, 4, 8, 8, 8, 8, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, kernel_type));
|
||||
}
|
||||
|
||||
for(uint32_t Cout : {1, 9}){
|
||||
for(uint32_t Cin : {1, 7}){
|
||||
for(uint32_t K : {1, 3, 1337}){
|
||||
@@ -5706,6 +5866,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
||||
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
|
||||
}
|
||||
for (uint32_t n : {1, 511, 1025, 8192, 33*512}) {
|
||||
for (bool multi_add : {false, true}) {
|
||||
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f, false, multi_add));
|
||||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
|
||||
|
||||
@@ -5852,6 +6017,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
// test large experts*tokens
|
||||
for (bool b : {false, true}) {
|
||||
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 32, 1024, 16));
|
||||
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 2, 2, b, 32, 8192, 64));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
|
||||
@@ -6071,6 +6237,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
|
||||
test_cases.emplace_back(new test_sum());
|
||||
test_cases.emplace_back(new test_sum_rows());
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));
|
||||
test_cases.emplace_back(new test_mean());
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 }));
|
||||
@@ -6091,8 +6260,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_timestep_embedding());
|
||||
test_cases.emplace_back(new test_leaky_relu());
|
||||
|
||||
for (int hsk : { 64, 80, 128, 192, 256, 576 }) {
|
||||
for (int hsv : { 64, 80, 128, 192, 256, 512 }) {
|
||||
for (int hsk : { 40, 64, 80, 128, 192, 256, 576 }) {
|
||||
for (int hsv : { 40, 64, 80, 128, 192, 256, 512 }) {
|
||||
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
|
||||
if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
|
||||
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
|
||||
|
||||
@@ -290,6 +290,14 @@ int main(void) {
|
||||
/* .bos_token= */ "",
|
||||
/* .eos_token= */ "",
|
||||
},
|
||||
{
|
||||
/* .name= */ "ByteDance-Seed/Seed-OSS-36B-Instruct",
|
||||
/* .template_str */ "{# <seed:bos> #}{%- for message in messages %}{%- if message.role in [\"user\", \"system\"] %}{{ bos_token + message.role + \"\\n\" + message.content + eos_token }}{%- elif message.role == \"assistant\" %}{{ bos_token + message.role }}{%- if message.content is defined and message.content is string and message.content|trim|length > 0 %}{{ \"\\n\" + message.content|trim + eos_token }}{%- endif %}{%- else %}{{ bos_token + message.role + \"\\n\" + message.content + eos_token }}{%- endif %}{%- endfor %}{%- if add_generation_prompt %}{{ bos_token + \"assistant\\n\" }}{%- endif %}",
|
||||
/* .expected_output= */ "<seed:bos>system\nYou are a helpful assistant<seed:eos><seed:bos>user\nHello<seed:eos><seed:bos>assistant\nHi there<seed:eos><seed:bos>user\nWho are you<seed:eos><seed:bos>assistant\nI am an assistant<seed:eos><seed:bos>user\nAnother question<seed:eos><seed:bos>assistant\n",
|
||||
/* .expected_output_jinja= */ "<seed:bos>system\nYou are a helpful assistant<seed:eos><seed:bos>user\nHello<seed:eos><seed:bos>assistant\nHi there<seed:eos><seed:bos>user\nWho are you<seed:eos><seed:bos>assistant\nI am an assistant<seed:eos><seed:bos>user\nAnother question<seed:eos><seed:bos>assistant\n",
|
||||
/* .bos_token= */ "<seed:bos>",
|
||||
/* .eos_token= */ "<seed:eos>",
|
||||
}
|
||||
};
|
||||
std::vector<char> formatted_chat(1024);
|
||||
int32_t res;
|
||||
|
||||
+20
-12
@@ -358,7 +358,7 @@ static std::pair<int, int> test_forward_backward(
|
||||
double accuracy;
|
||||
double accuracy_unc;
|
||||
ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
|
||||
const bool subtest_ok = ndata == 0 && loss == 0.0 && std::isnan(loss_unc) && std::isnan(accuracy) && std::isnan(accuracy_unc);
|
||||
const bool subtest_ok = ndata == 0 && almost_equal(loss, 0.0, 1e-6) && std::isnan(loss_unc) && std::isnan(accuracy) && std::isnan(accuracy_unc);
|
||||
helper_after_test_forward_backward(optim, __func__, high_level, shuffle, "results_initial", subtest_ok, ntest, npass);
|
||||
}
|
||||
|
||||
@@ -381,10 +381,12 @@ static std::pair<int, int> test_forward_backward(
|
||||
{
|
||||
float weights;
|
||||
ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
|
||||
const bool subtest_ok = weights == ndata/2;
|
||||
const bool subtest_ok = almost_equal(weights, ndata/2, 1e-10);
|
||||
helper_after_test_forward_backward(optim, __func__, high_level, shuffle, "weights_after_forward", subtest_ok, ntest, npass);
|
||||
}
|
||||
{
|
||||
constexpr double atol = 1e-10;
|
||||
|
||||
int64_t ndata;
|
||||
ggml_opt_result_ndata(cd.result, &ndata);
|
||||
bool subtest_ok = ndata == 6;
|
||||
@@ -392,7 +394,7 @@ static std::pair<int, int> test_forward_backward(
|
||||
double loss;
|
||||
double loss_unc;
|
||||
ggml_opt_result_loss(cd.result, &loss, &loss_unc);
|
||||
subtest_ok = subtest_ok && loss == 33.0 && almost_equal(loss_unc, sqrt(3.5), 1e-10);
|
||||
subtest_ok = subtest_ok && almost_equal(loss, 33.0, atol) && almost_equal(loss_unc, sqrt(3.5), atol);
|
||||
|
||||
double accuracy;
|
||||
double accuracy_unc;
|
||||
@@ -437,7 +439,7 @@ static std::pair<int, int> test_forward_backward(
|
||||
{
|
||||
float weights;
|
||||
ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
|
||||
const bool subtest_ok = weights == -ndata * .5;
|
||||
const bool subtest_ok = almost_equal(weights, -ndata * 0.5, 1e-10);
|
||||
helper_after_test_forward_backward(optim, __func__, high_level, shuffle, "weights_after_forward_backward", subtest_ok, ntest, npass);
|
||||
}
|
||||
{
|
||||
@@ -448,7 +450,7 @@ static std::pair<int, int> test_forward_backward(
|
||||
double loss;
|
||||
double loss_unc;
|
||||
ggml_opt_result_loss(cd.result, &loss, &loss_unc);
|
||||
subtest_ok = subtest_ok && loss == 18.0 && (shuffle || loss_unc == 0.0);
|
||||
subtest_ok = subtest_ok && almost_equal(loss, 18.0, 1e-10) && (shuffle || loss_unc == 0.0);
|
||||
|
||||
double accuracy;
|
||||
double accuracy_unc;
|
||||
@@ -550,10 +552,12 @@ static std::pair<int, int> test_idata_split(
|
||||
if (adamw) {
|
||||
float weights;
|
||||
ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
|
||||
const bool subtest_ok = weights == ndata/2 - epoch*idata_split;
|
||||
const bool subtest_ok = almost_equal(weights, ndata/2 - epoch*idata_split, 1e-10);
|
||||
helper_after_test_idata_split(optim, __func__, high_level, epoch, "weights", subtest_ok, ntest, npass);
|
||||
}
|
||||
if (adamw) {
|
||||
constexpr double atol = 1e-10;
|
||||
|
||||
int64_t ndata_result;
|
||||
ggml_opt_result_ndata(cd.result, &ndata_result);
|
||||
bool subtest_ok = ndata_result == idata_split;
|
||||
@@ -561,7 +565,7 @@ static std::pair<int, int> test_idata_split(
|
||||
double loss;
|
||||
double loss_unc;
|
||||
ggml_opt_result_loss(cd.result, &loss, &loss_unc);
|
||||
subtest_ok = subtest_ok && loss == 28.0 - epoch*16.0 && loss_unc == 0.0;
|
||||
subtest_ok = subtest_ok && almost_equal(loss, 28.0 - epoch*16.0, atol) && almost_equal(loss_unc, 0.0, atol);
|
||||
|
||||
double accuracy;
|
||||
double accuracy_unc;
|
||||
@@ -571,6 +575,8 @@ static std::pair<int, int> test_idata_split(
|
||||
helper_after_test_idata_split(optim, __func__, high_level, epoch, "results_backward", subtest_ok, ntest, npass);
|
||||
}
|
||||
if (adamw) {
|
||||
constexpr double atol = 1e-10;
|
||||
|
||||
int64_t ndata_result;
|
||||
ggml_opt_result_ndata(cd.result2, &ndata_result);
|
||||
bool subtest_ok = ndata_result == ndata - idata_split;
|
||||
@@ -578,7 +584,7 @@ static std::pair<int, int> test_idata_split(
|
||||
double loss;
|
||||
double loss_unc;
|
||||
ggml_opt_result_loss(cd.result2, &loss, &loss_unc);
|
||||
subtest_ok = subtest_ok && loss == 15.0 - epoch*8 && almost_equal(loss_unc, sqrt(0.5), 1e-10);
|
||||
subtest_ok = subtest_ok && almost_equal(loss, 15.0 - epoch*8, atol) && almost_equal(loss_unc, sqrt(0.5), atol);
|
||||
|
||||
double accuracy;
|
||||
double accuracy_unc;
|
||||
@@ -687,22 +693,24 @@ static std::pair<int, int> test_gradient_accumulation(
|
||||
}
|
||||
bool const adamw = optim == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
||||
if (adamw) {
|
||||
constexpr double atol = 1e-6;
|
||||
float weights;
|
||||
ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
|
||||
const bool subtest_ok = weights == (ndata/2) - epoch;
|
||||
const bool subtest_ok = almost_equal(weights, (ndata/2) - epoch, atol);
|
||||
helper_after_test_gradient_accumulation(optim, __func__, nbatch_physical, loss_type, epoch, "weights", subtest_ok, ntest, npass);
|
||||
}
|
||||
{
|
||||
constexpr double atol = 1e-6;
|
||||
int64_t ndata_result;
|
||||
ggml_opt_result_ndata(cd.result, &ndata_result);
|
||||
bool subtest_ok = ndata_result == ndata/nbatch_physical;
|
||||
bool subtest_ok = almost_equal(ndata_result, ndata/nbatch_physical, atol);
|
||||
|
||||
double loss;
|
||||
ggml_opt_result_loss(cd.result, &loss, /*loss_unc =*/ nullptr);
|
||||
if (loss_type == GGML_OPT_LOSS_TYPE_SUM) {
|
||||
subtest_ok = subtest_ok && loss == (39.0 - epoch*6.0);
|
||||
subtest_ok = subtest_ok && almost_equal(loss, (39.0 - epoch*6.0), atol);
|
||||
} else if (loss_type == GGML_OPT_LOSS_TYPE_MEAN) {
|
||||
subtest_ok = subtest_ok && almost_equal(loss, (39.0 - epoch*6.0) / ndata, 1e-6);
|
||||
subtest_ok = subtest_ok && almost_equal(loss, (39.0 - epoch*6.0) / ndata, atol);
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
@@ -124,7 +124,7 @@ int main(int argc, char ** argv) {
|
||||
const int tg = n_tg[i_tg];
|
||||
const int pl = n_pl[i_pl];
|
||||
|
||||
const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
|
||||
const int n_ctx_req = is_pp_shared ? (params.kv_unified ? pp : pl*pp) + pl*tg : pl*(pp + tg);
|
||||
|
||||
if (n_ctx_req > n_kv_max) {
|
||||
continue;
|
||||
@@ -147,13 +147,24 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
const auto t_pp_end = ggml_time_us();
|
||||
|
||||
if (is_pp_shared) {
|
||||
for (int32_t i = 1; i < pl; ++i) {
|
||||
llama_memory_seq_cp(mem, 0, i, -1, -1);
|
||||
}
|
||||
}
|
||||
|
||||
const auto t_pp_end = ggml_time_us();
|
||||
if (!params.kv_unified) {
|
||||
// run one dummy token to apply the memory copy
|
||||
common_batch_clear(batch);
|
||||
common_batch_add(batch, get_token_rand(), pp + 0, { 0 }, true);
|
||||
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
|
||||
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
llama_memory_seq_rm(mem, 0, pp, -1);
|
||||
}
|
||||
}
|
||||
|
||||
const auto t_tg_start = ggml_time_us();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user