mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-07-01 10:07:44 +02:00
Compare commits
57 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 053b3f9aae | |||
| e2f560175a | |||
| 36ee06dd2d | |||
| 3cd3a39532 | |||
| 2d77d88e70 | |||
| c95fa362b3 | |||
| 2b65ae3029 | |||
| 48d7021c61 | |||
| 3361e2deba | |||
| 00d53800e0 | |||
| 7ea75035b6 | |||
| c54f6b7988 | |||
| 9b169a4d4e | |||
| 77f9c6bbe5 | |||
| 18b663d8e4 | |||
| fbdfefe74e | |||
| ba932dfb50 | |||
| fac63a3d78 | |||
| eddfb43850 | |||
| 4375415b4a | |||
| 30c42ef5cb | |||
| af04481e6b | |||
| 960e726077 | |||
| ea1518e839 | |||
| 1aa87ee53d | |||
| 9ffcc9e374 | |||
| e04643063b | |||
| dbb3a4739e | |||
| 3d82dbcbce | |||
| 732b5fbf5e | |||
| 568013d0cd | |||
| 517b5ddbf0 | |||
| a9b59288e2 | |||
| 0fd8487b14 | |||
| 108e53c2f1 | |||
| a686171ea7 | |||
| c446b2edd2 | |||
| d84635b1b0 | |||
| 75422e8bc4 | |||
| bb115d2bf7 | |||
| 29fff308c7 | |||
| c6af2161b2 | |||
| 99aa304fb9 | |||
| 8551c44d84 | |||
| 35cae5ba05 | |||
| 810e0af3f5 | |||
| eba92d64c3 | |||
| d9a14523bb | |||
| fd123cfead | |||
| a53f7f7b88 | |||
| 7dfad387e3 | |||
| 60c902926c | |||
| b1b132efcb | |||
| 01e8f2138b | |||
| 484a8ab513 | |||
| cf2270e4d3 | |||
| f07690c930 |
@@ -676,6 +676,35 @@ jobs:
|
||||
-DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
|
||||
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
|
||||
|
||||
macOS-latest-cmake-visionos:
|
||||
runs-on: macos-latest
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
run: |
|
||||
brew update
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
sysctl -a
|
||||
cmake -B build -G Xcode \
|
||||
-DGGML_METAL_USE_BF16=ON \
|
||||
-DGGML_METAL_EMBED_LIBRARY=ON \
|
||||
-DLLAMA_BUILD_EXAMPLES=OFF \
|
||||
-DLLAMA_BUILD_TESTS=OFF \
|
||||
-DLLAMA_BUILD_SERVER=OFF \
|
||||
-DCMAKE_SYSTEM_NAME=visionOS \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=1.0 \
|
||||
-DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
|
||||
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
|
||||
|
||||
macOS-latest-swift:
|
||||
runs-on: macos-latest
|
||||
|
||||
|
||||
@@ -432,8 +432,8 @@ cmake -B build-visionos -G Xcode \
|
||||
-DCMAKE_SYSTEM_NAME=visionOS \
|
||||
-DCMAKE_OSX_SYSROOT=xros \
|
||||
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xros \
|
||||
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_CXX_FLAGS}" \
|
||||
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
|
||||
-S .
|
||||
cmake --build build-visionos --config Release -- -quiet
|
||||
|
||||
@@ -445,8 +445,8 @@ cmake -B build-visionos-sim -G Xcode \
|
||||
-DCMAKE_SYSTEM_NAME=visionOS \
|
||||
-DCMAKE_OSX_SYSROOT=xrsimulator \
|
||||
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xrsimulator \
|
||||
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_CXX_FLAGS}" \
|
||||
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
|
||||
-S .
|
||||
cmake --build build-visionos-sim --config Release -- -quiet
|
||||
|
||||
|
||||
@@ -26,4 +26,43 @@ GG_BUILD_CUDA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||
# with SYCL support
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
GG_BUILD_SYCL=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||
|
||||
# with MUSA support
|
||||
GG_BUILD_MUSA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||
```
|
||||
|
||||
## Running MUSA CI in a Docker Container
|
||||
|
||||
Assuming `$PWD` is the root of the `llama.cpp` repository, follow these steps to set up and run MUSA CI in a Docker container:
|
||||
|
||||
### 1. Create a local directory to store cached models, configuration files and venv:
|
||||
|
||||
```bash
|
||||
mkdir -p $HOME/llama.cpp/ci-cache
|
||||
```
|
||||
|
||||
### 2. Create a local directory to store CI run results:
|
||||
|
||||
```bash
|
||||
mkdir -p $HOME/llama.cpp/ci-results
|
||||
```
|
||||
|
||||
### 3. Start a Docker container and run the CI:
|
||||
|
||||
```bash
|
||||
docker run --privileged -it \
|
||||
-v $HOME/llama.cpp/ci-cache:/ci-cache \
|
||||
-v $HOME/llama.cpp/ci-results:/ci-results \
|
||||
-v $PWD:/ws -w /ws \
|
||||
mthreads/musa:rc3.1.1-devel-ubuntu22.04
|
||||
```
|
||||
|
||||
Inside the container, execute the following commands:
|
||||
|
||||
```bash
|
||||
apt update -y && apt install -y cmake git python3.10-venv wget
|
||||
git config --global --add safe.directory /ws
|
||||
GG_BUILD_MUSA=1 bash ./ci/run.sh /ci-results /ci-cache
|
||||
```
|
||||
|
||||
This setup ensures that the CI runs within an isolated Docker environment while maintaining cached files and results across runs.
|
||||
|
||||
@@ -16,6 +16,9 @@
|
||||
# # with VULKAN support
|
||||
# GG_BUILD_VULKAN=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||
#
|
||||
# # with MUSA support
|
||||
# GG_BUILD_MUSA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||
#
|
||||
|
||||
if [ -z "$2" ]; then
|
||||
echo "usage: $0 <output-dir> <mnt-dir>"
|
||||
@@ -52,13 +55,22 @@ if [ ! -z ${GG_BUILD_SYCL} ]; then
|
||||
echo "source /opt/intel/oneapi/setvars.sh"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Use only main GPU
|
||||
export ONEAPI_DEVICE_SELECTOR="level_zero:0"
|
||||
# Enable sysman for correct memory reporting
|
||||
export ZES_ENABLE_SYSMAN=1
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_SYCL=1 -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON"
|
||||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_VULKAN} ]; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_VULKAN=1"
|
||||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_MUSA} ]; then
|
||||
# Use qy1 by default (MTT S80)
|
||||
MUSA_ARCH=${MUSA_ARCH:-21}
|
||||
CMAKE_EXTRA="-DGGML_MUSA=ON -DMUSA_ARCHITECTURES=${MUSA_ARCH}"
|
||||
fi
|
||||
## helpers
|
||||
|
||||
# download a file if it does not exist or if it is outdated
|
||||
@@ -808,7 +820,7 @@ export LLAMA_LOG_PREFIX=1
|
||||
export LLAMA_LOG_TIMESTAMPS=1
|
||||
|
||||
if [ -z ${GG_BUILD_LOW_PERF} ]; then
|
||||
# Create symlink: ./llama.cpp/models-mnt -> $MNT/models/models-mnt
|
||||
# Create symlink: ./llama.cpp/models-mnt -> $MNT/models
|
||||
rm -rf ${SRC}/models-mnt
|
||||
mnt_models=${MNT}/models
|
||||
mkdir -p ${mnt_models}
|
||||
@@ -826,8 +838,10 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then
|
||||
fi
|
||||
|
||||
ret=0
|
||||
|
||||
test $ret -eq 0 && gg_run ctest_debug
|
||||
if [ -z ${GG_BUILD_SYCL} ]; then
|
||||
# SYCL build breaks with debug build flags
|
||||
test $ret -eq 0 && gg_run ctest_debug
|
||||
fi
|
||||
test $ret -eq 0 && gg_run ctest_release
|
||||
|
||||
if [ -z ${GG_BUILD_LOW_PERF} ]; then
|
||||
@@ -835,7 +849,9 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then
|
||||
test $ret -eq 0 && gg_run rerank_tiny
|
||||
|
||||
if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then
|
||||
test $ret -eq 0 && gg_run test_scripts_debug
|
||||
if [ -z ${GG_BUILD_SYCL} ]; then
|
||||
test $ret -eq 0 && gg_run test_scripts_debug
|
||||
fi
|
||||
test $ret -eq 0 && gg_run test_scripts_release
|
||||
fi
|
||||
|
||||
@@ -846,7 +862,9 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then
|
||||
test $ret -eq 0 && gg_run pythia_2_8b
|
||||
#test $ret -eq 0 && gg_run open_llama_7b_v2
|
||||
fi
|
||||
test $ret -eq 0 && gg_run ctest_with_model_debug
|
||||
if [ -z ${GG_BUILD_SYCL} ]; then
|
||||
test $ret -eq 0 && gg_run ctest_with_model_debug
|
||||
fi
|
||||
test $ret -eq 0 && gg_run ctest_with_model_release
|
||||
fi
|
||||
fi
|
||||
|
||||
+239
-55
@@ -180,7 +180,8 @@ class Model:
|
||||
extra = sorted(tensor_names_from_parts.difference(self.tensor_names))
|
||||
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
|
||||
if len(extra) == 0 and len(missing_files) > 0:
|
||||
raise ValueError(f"Missing or incomplete model files: {missing_files}")
|
||||
raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
|
||||
f"Missing tensors: {missing}")
|
||||
else:
|
||||
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
|
||||
f"Missing tensors: {missing}\n"
|
||||
@@ -528,6 +529,8 @@ class Model:
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
|
||||
added_tokens_decoder = tokenizer.added_tokens_decoder
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
tokens.append(f"[PAD{i}]")
|
||||
@@ -537,13 +540,13 @@ class Model:
|
||||
if token in added_vocab:
|
||||
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
|
||||
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
|
||||
if not tokenizer.added_tokens_decoder[i].normalized:
|
||||
if not added_tokens_decoder[i].normalized:
|
||||
previous_token = token
|
||||
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
|
||||
if previous_token != token:
|
||||
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
|
||||
|
||||
if tokenizer.added_tokens_decoder[i].special or self.does_token_look_special(token):
|
||||
if added_tokens_decoder[i].special or self.does_token_look_special(token):
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
# NOTE: this was added for Gemma.
|
||||
@@ -702,6 +705,9 @@ class Model:
|
||||
if chkhsh == "ccc2ef013c104be7bae2965776d611e1d7a8a2a9c547dd93a682c9a9fc80352e":
|
||||
# ref: https://huggingface.co/Xenova/gpt-4o
|
||||
res = "gpt-4o"
|
||||
if chkhsh == "7dec86086fcc38b66b7bc1575a160ae21cf705be7718b9d5598190d7c12db76f":
|
||||
# ref: https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k
|
||||
res = "superbpe"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
@@ -908,6 +914,40 @@ class Model:
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_rwkv_world(self):
|
||||
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
|
||||
vocab_size = self.hparams.get("vocab_size", 65536)
|
||||
|
||||
tokens: list[bytes] = ['<s>'.encode("utf-8")]
|
||||
toktypes: list[int] = [gguf.TokenType.CONTROL]
|
||||
|
||||
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
parts = line.split(' ')
|
||||
assert len(parts) >= 3
|
||||
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
|
||||
token = token.encode("utf-8") if isinstance(token, str) else token
|
||||
assert isinstance(token, bytes)
|
||||
assert len(token) == token_len
|
||||
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
|
||||
tokens.append(token_text.encode("utf-8"))
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
remainder = vocab_size - len(tokens)
|
||||
assert remainder >= 0
|
||||
for i in range(len(tokens), vocab_size):
|
||||
tokens.append(f"[PAD{i}]".encode("utf-8"))
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("rwkv")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
|
||||
special_vocab.chat_template = "rwkv-world"
|
||||
# hack: Add '\n\n' as the EOT token to make it chat normally
|
||||
special_vocab._set_special_token("eot", 261)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int):
|
||||
tokenizer_path = Path(sys.path[0]) / "models" / f"ggml-vocab-{model_name}.gguf"
|
||||
logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
|
||||
@@ -1065,13 +1105,6 @@ class BloomModel(Model):
|
||||
|
||||
tensors.append((self.map_tensor_name(name), data_torch))
|
||||
|
||||
if name == "word_embeddings.weight":
|
||||
assert self.tensor_names is not None
|
||||
|
||||
# TODO: tie them at runtime, don't duplicate in the model file
|
||||
if all(s not in self.tensor_names for s in ("lm_head.weight", "output.weight")):
|
||||
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch))
|
||||
|
||||
return tensors
|
||||
|
||||
|
||||
@@ -1713,6 +1746,25 @@ class LlamaModel(Model):
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@Model.register("Mistral3ForConditionalGeneration")
|
||||
class Mistral3Model(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
|
||||
# we need to merge the text_config into the root level of hparams
|
||||
def __init__(self, *args, **kwargs):
|
||||
hparams = Model.load_hparams(kwargs["dir_model"])
|
||||
if "text_config" in hparams:
|
||||
hparams = {**hparams, **hparams["text_config"]}
|
||||
kwargs["hparams"] = hparams
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
name = name.replace("language_model.", "")
|
||||
if "multi_modal_projector" in name or "vision_tower" in name:
|
||||
return []
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@Model.register("DeciLMForCausalLM")
|
||||
class DeciModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.DECI
|
||||
@@ -2370,10 +2422,6 @@ class GPT2Model(Model):
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
|
||||
# note: GPT2 output is tied to (same as) wte in original model
|
||||
if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
|
||||
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch))
|
||||
|
||||
return tensors
|
||||
|
||||
|
||||
@@ -2703,21 +2751,26 @@ class CodeShellModel(Model):
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||
self.gguf_writer.add_rope_scaling_factor(1.0)
|
||||
|
||||
_has_tok_embd = False
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
|
||||
tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD)
|
||||
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
||||
tensors: list[tuple[str, Tensor]] = [(new_name, data_torch)]
|
||||
# assuming token_embd.weight is seen before output.weight
|
||||
if not self._has_tok_embd and new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
|
||||
# even though the tensor file(s) does not contain the word embeddings they are still in the weight map
|
||||
if self.tensor_names and "transformer.wte.weight" in self.tensor_names:
|
||||
logger.debug(f"{tok_embd_name} not found before {output_name}, assuming they are tied")
|
||||
self.tensor_names.remove("transformer.wte.weight")
|
||||
elif new_name == tok_embd_name:
|
||||
self._has_tok_embd = True
|
||||
|
||||
if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
|
||||
assert self.tensor_names is not None
|
||||
|
||||
if all(s not in self.tensor_names for s in ("lm_head.weight", "output.weight")):
|
||||
# copy tok_embd.weight to output.weight
|
||||
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch))
|
||||
|
||||
return tensors
|
||||
return [(new_name, data_torch)]
|
||||
|
||||
|
||||
@Model.register("InternLM2ForCausalLM")
|
||||
@@ -3412,38 +3465,7 @@ class Rwkv6Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.RWKV6
|
||||
|
||||
def set_vocab(self):
|
||||
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
|
||||
vocab_size = self.hparams.get("vocab_size", 65536)
|
||||
|
||||
tokens: list[bytes] = ['<s>'.encode("utf-8")]
|
||||
toktypes: list[int] = [gguf.TokenType.CONTROL]
|
||||
|
||||
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
parts = line.split(' ')
|
||||
assert len(parts) >= 3
|
||||
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
|
||||
token = token.encode("utf-8") if isinstance(token, str) else token
|
||||
assert isinstance(token, bytes)
|
||||
assert len(token) == token_len
|
||||
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
|
||||
tokens.append(token_text.encode("utf-8"))
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
remainder = vocab_size - len(tokens)
|
||||
assert remainder >= 0
|
||||
for i in range(len(tokens), vocab_size):
|
||||
tokens.append(f"[PAD{i}]".encode("utf-8"))
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("rwkv")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
|
||||
special_vocab.chat_template = "rwkv-world"
|
||||
# hack: Add '\n\n' as the EOT token to make it chat normally
|
||||
special_vocab._set_special_token("eot", 261)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
self._set_vocab_rwkv_world()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
@@ -3565,6 +3587,168 @@ class RWKV6Qwen2Model(Rwkv6Model):
|
||||
yield (new_name, data)
|
||||
|
||||
|
||||
@Model.register("Rwkv7ForCausalLM", "RWKV7ForCausalLM")
|
||||
class Rwkv7Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.RWKV7
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_rwkv_world()
|
||||
|
||||
def calc_lora_rank(self, hidden_size, exponent, multiplier):
|
||||
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
try:
|
||||
head_size = self.hparams["head_size"]
|
||||
layer_norm_eps = self.hparams["layer_norm_epsilon"]
|
||||
except KeyError:
|
||||
head_size = self.hparams["head_dim"]
|
||||
layer_norm_eps = self.hparams["norm_eps"]
|
||||
hidden_size = self.hparams["hidden_size"]
|
||||
intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else (hidden_size * 4)
|
||||
|
||||
# ICLR: In-Context-Learning-Rate
|
||||
try:
|
||||
lora_rank_decay = self.hparams["lora_rank_decay"] if self.hparams["lora_rank_decay"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||
lora_rank_iclr = self.hparams["lora_rank_iclr"] if self.hparams["lora_rank_iclr"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||
lora_rank_value_residual_mix = self.hparams["lora_rank_value_residual_mix"] if self.hparams["lora_rank_value_residual_mix"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
|
||||
lora_rank_gate = self.hparams["lora_rank_gate"] if self.hparams["lora_rank_gate"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
|
||||
except KeyError:
|
||||
lora_rank_decay = self.hparams["decay_low_rank_dim"] if self.hparams["decay_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||
lora_rank_iclr = self.hparams["a_low_rank_dim"] if self.hparams["a_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||
lora_rank_value_residual_mix = self.hparams["v_low_rank_dim"] if self.hparams["v_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
|
||||
lora_rank_gate = self.hparams["gate_low_rank_dim"] if self.hparams["gate_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
|
||||
|
||||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
|
||||
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
|
||||
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
|
||||
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
|
||||
self.gguf_writer.add_feed_forward_length(intermediate_size)
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
# required by llama.cpp, unused
|
||||
self.gguf_writer.add_head_count(0)
|
||||
|
||||
lerp_weights: dict[int, dict[str, Tensor]] = {}
|
||||
lora_needs_transpose: bool = True
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# unify tensor names here to make life easier
|
||||
name = name.replace("blocks", "layers").replace("ffn", "feed_forward")
|
||||
name = name.replace("self_attn", "attention").replace("attn", "attention")
|
||||
name = name.replace("time_mixer.", "")
|
||||
# lora layer names in fla-hub's impl
|
||||
if "_lora.lora" in name:
|
||||
self.lora_needs_transpose = False
|
||||
name = name.replace("_lora.lora.0.weight", "1.weight")
|
||||
name = name.replace("_lora.lora.2.weight", "2.weight")
|
||||
name = name.replace("_lora.lora.2.bias", "0.weight")
|
||||
|
||||
name = name.replace("feed_forward_norm", "ln2")
|
||||
name = name.replace("g_norm", "ln_x")
|
||||
|
||||
if "attention.v" in name and "value" not in self.map_tensor_name(name) and bid == 0:
|
||||
# some models have dummy v0/v1/v2 on first layer while others don't
|
||||
# ignore them all since they are not used
|
||||
return
|
||||
|
||||
wkv_has_gate = self.hparams.get("wkv_has_gate", True)
|
||||
lerp_list = ["r", "w", "k", "v", "a", "g"] if wkv_has_gate else ["r", "w", "k", "v", "a"]
|
||||
|
||||
if bid is not None and "attention.x_" in name:
|
||||
if "attention.x_x" in name:
|
||||
# already concatenated
|
||||
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
|
||||
data = data_torch.reshape(len(lerp_list), 1, 1, -1)
|
||||
yield (new_name, data)
|
||||
else:
|
||||
try:
|
||||
self.lerp_weights[bid][name] = data_torch
|
||||
except KeyError:
|
||||
self.lerp_weights[bid] = {name: data_torch}
|
||||
if all(f"model.layers.{bid}.attention.x_{i}" in self.lerp_weights[bid].keys() for i in lerp_list):
|
||||
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
|
||||
data = torch.stack([self.lerp_weights[bid][f"model.layers.{bid}.attention.x_{i}"] for i in lerp_list], dim=0)
|
||||
yield (new_name, data)
|
||||
return
|
||||
else:
|
||||
data_torch = data_torch.squeeze()
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
||||
if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
|
||||
new_name += ".weight"
|
||||
|
||||
if self.lora_needs_transpose and any(
|
||||
new_name.endswith(t) for t in [
|
||||
"time_mix_w1.weight", "time_mix_w2.weight",
|
||||
"time_mix_a1.weight", "time_mix_a2.weight",
|
||||
"time_mix_v1.weight", "time_mix_v2.weight",
|
||||
"time_mix_g1.weight", "time_mix_g2.weight",
|
||||
]
|
||||
):
|
||||
data_torch = data_torch.transpose(0, 1)
|
||||
|
||||
if 'r_k' in new_name:
|
||||
data_torch = data_torch.flatten()
|
||||
|
||||
if bid == 0 and "time_mix_a" in new_name:
|
||||
# dummy v0/v1/v2 on first layer
|
||||
# easist way to make llama happy
|
||||
yield (new_name.replace("time_mix_a", "time_mix_v"), data_torch)
|
||||
|
||||
yield (new_name, data_torch)
|
||||
|
||||
|
||||
@Model.register("RwkvHybridForCausalLM")
|
||||
class ARwkv7Model(Rwkv7Model):
|
||||
model_arch = gguf.MODEL_ARCH.ARWKV7
|
||||
|
||||
def set_vocab(self):
|
||||
try:
|
||||
self._set_vocab_sentencepiece()
|
||||
except FileNotFoundError:
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
hidden_size = self.hparams["hidden_size"]
|
||||
head_size = self.hparams["head_size"]
|
||||
rms_norm_eps = self.hparams["rms_norm_eps"]
|
||||
intermediate_size = self.hparams["intermediate_size"]
|
||||
wkv_has_gate = self.hparams["wkv_has_gate"]
|
||||
assert self.hparams["wkv_version"] == 7
|
||||
|
||||
# ICLR: In-Context-Learning-Rate
|
||||
lora_rank_decay = 64
|
||||
lora_rank_iclr = 64
|
||||
lora_rank_value_residual_mix = 32
|
||||
lora_rank_gate = 128 if wkv_has_gate else 0
|
||||
|
||||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
|
||||
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
|
||||
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
|
||||
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
|
||||
self.gguf_writer.add_feed_forward_length(intermediate_size)
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
self.gguf_writer.add_token_shift_count(1)
|
||||
|
||||
# required by llama.cpp, unused
|
||||
self.gguf_writer.add_head_count(0)
|
||||
|
||||
|
||||
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
|
||||
class MambaModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.MAMBA
|
||||
|
||||
@@ -110,6 +110,7 @@ models = [
|
||||
{"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"},
|
||||
{"name": "deepseek-r1-qwen", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"},
|
||||
{"name": "gpt-4o", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Xenova/gpt-4o", },
|
||||
{"name": "superbpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k", },
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -14,9 +14,7 @@ In this guide we setup [Nvidia CUDA](https://docs.nvidia.com/cuda/) in a toolbox
|
||||
- [Creating a Fedora Toolbox Environment](#creating-a-fedora-toolbox-environment)
|
||||
- [Installing Essential Development Tools](#installing-essential-development-tools)
|
||||
- [Adding the CUDA Repository](#adding-the-cuda-repository)
|
||||
- [Installing `nvidia-driver-libs`](#installing-nvidia-driver-libs)
|
||||
- [Manually Resolving Package Conflicts](#manually-resolving-package-conflicts)
|
||||
- [Finalizing the Installation of `nvidia-driver-libs`](#finalizing-the-installation-of-nvidia-driver-libs)
|
||||
- [Installing Nvidia Driver Libraries](#installing-nvidia-driver-libraries)
|
||||
- [Installing the CUDA Meta-Package](#installing-the-cuda-meta-package)
|
||||
- [Configuring the Environment](#configuring-the-environment)
|
||||
- [Verifying the Installation](#verifying-the-installation)
|
||||
@@ -67,7 +65,7 @@ This guide focuses on Fedora hosts, but with small adjustments, it can work for
|
||||
sudo dnf distro-sync
|
||||
```
|
||||
|
||||
2. **Install the Default Text Editor (Optional):**
|
||||
2. **Install **Vim** the default text editor (Optional):**
|
||||
|
||||
```bash
|
||||
sudo dnf install vim-default-editor --allowerasing
|
||||
@@ -97,36 +95,48 @@ After adding the repository, synchronize the package manager again:
|
||||
sudo dnf distro-sync
|
||||
```
|
||||
|
||||
## Installing `nvidia-driver-libs` and `nvidia-driver-cuda-libs`
|
||||
## Installing Nvidia Driver Libraries
|
||||
|
||||
We need to detect if the host is supplying the [NVIDIA driver libraries into the toolbox](https://github.com/containers/toolbox/blob/main/src/pkg/nvidia/nvidia.go).
|
||||
First, we need to detect if the host is supplying the [NVIDIA driver libraries into the toolbox](https://github.com/containers/toolbox/blob/main/src/pkg/nvidia/nvidia.go):
|
||||
|
||||
```bash
|
||||
ls -la /usr/lib64/libcuda.so.1
|
||||
```
|
||||
|
||||
### If *`libcuda.so.1`* is missing:
|
||||
|
||||
```
|
||||
ls: cannot access '/usr/lib64/libcuda.so.1': No such file or directory
|
||||
```
|
||||
|
||||
**Explanation:**
|
||||
The host dose not supply the CUDA drivers, **install them now:**
|
||||
|
||||
- `nvidia-driver-libs` and `nvidia-driver-cuda-libs` contains necessary NVIDIA driver libraries required by CUDA,
|
||||
on hosts with NVIDIA drivers installed the Fedora Container will supply the host libraries.
|
||||
|
||||
### Install Nvidia Driver Libraries on Guest (if `libcuda.so.1` was NOT found).
|
||||
#### Install the Nvidia Driver Libraries on Guest:
|
||||
|
||||
```bash
|
||||
sudo dnf install nvidia-driver-libs nvidia-driver-cuda-libs
|
||||
sudo dnf install nvidia-driver-cuda nvidia-driver-libs nvidia-driver-cuda-libs nvidia-persistenced
|
||||
```
|
||||
|
||||
### Manually Updating the RPM database for host-supplied NVIDIA drivers (if `libcuda.so.1` was found).
|
||||
### If *`libcuda.so.1`* exists:
|
||||
```
|
||||
lrwxrwxrwx. 1 root root 21 Mar 24 11:26 /usr/lib64/libcuda.so.1 -> libcuda.so.570.133.07
|
||||
```
|
||||
|
||||
If the installation fails due to conflicts, we'll manually download and install the required packages, excluding conflicting files.
|
||||
**Explanation:**
|
||||
The host is supply the CUDA drivers, **we need to update the guest RPM Database accordingly:**
|
||||
|
||||
#### 1. Download `nvidia-driver-libs` and `nvidia-driver-cuda-libs` RPM's (with dependencies)
|
||||
#### Update the Toolbox RPM Database to include the Host-Supplied Libraries:
|
||||
|
||||
Note: we do not actually install the libraries, we just update the DB so that the guest system knows they are supplied by the host.
|
||||
|
||||
##### 1. Download `nvidia-` parts that are supplied by the host RPM's (with dependencies)
|
||||
|
||||
```bash
|
||||
sudo dnf download --destdir=/tmp/nvidia-driver-libs --resolve --arch x86_64 nvidia-driver-libs nvidia-driver-cuda-libs
|
||||
sudo dnf download --destdir=/tmp/nvidia-driver-libs --resolve --arch x86_64 nvidia-driver-cuda nvidia-driver-libs nvidia-driver-cuda-libs nvidia-persistenced
|
||||
```
|
||||
|
||||
#### 2. Update the RPM database to assume the installation of these packages.
|
||||
##### 2. Update the RPM database to assume the installation of these packages.
|
||||
|
||||
```bash
|
||||
sudo rpm --install --verbose --hash --justdb /tmp/nvidia-driver-libs/*
|
||||
@@ -134,23 +144,26 @@ sudo rpm --install --verbose --hash --justdb /tmp/nvidia-driver-libs/*
|
||||
|
||||
**Note:**
|
||||
|
||||
- The `--justdb` option only updates the RPM database, without touching the filesystem.
|
||||
- The `--justdb` option only updates the RPM database, without touching the filesystem elsewhere.
|
||||
|
||||
#### Finalizing the Installation of `nvidia-driver-libs` and `nvidia-driver-cuda-libs`
|
||||
##### Check that the RPM Database has been correctly updated:
|
||||
|
||||
**Note:** This is the same command as in the *"Install the Nvidia Driver Libraries on Guest"* for if *`libcuda.so.1`* was missing.
|
||||
|
||||
After manually installing the dependencies, run:
|
||||
|
||||
```bash
|
||||
sudo dnf install nvidia-driver-libs nvidia-driver-cuda-libs
|
||||
sudo dnf install nvidia-driver-cuda nvidia-driver-libs nvidia-driver-cuda-libs nvidia-persistenced
|
||||
```
|
||||
|
||||
You should receive a message indicating the package is already installed:
|
||||
*(this time it will not install anything, as the database things that these packages are already installed)*
|
||||
|
||||
```
|
||||
Updating and loading repositories:
|
||||
Repositories loaded.
|
||||
Package "nvidia-driver-libs-3:570.86.10-1.fc41.x86_64" is already installed.
|
||||
Package "nvidia-driver-cuda-libs-3:570.86.10-1.fc41.x86_64" is already installed.
|
||||
Package "nvidia-driver-cuda-3:570.124.06-1.fc41.x86_64" is already installed.
|
||||
Package "nvidia-driver-libs-3:570.124.06-1.fc41.x86_64" is already installed.
|
||||
Package "nvidia-driver-cuda-libs-3:570.124.06-1.fc41.x86_64" is already installed.
|
||||
Package "nvidia-persistenced-3:570.124.06-1.fc41.x86_64" is already installed.
|
||||
|
||||
Nothing to do.
|
||||
```
|
||||
@@ -207,9 +220,9 @@ You should see output similar to:
|
||||
```
|
||||
nvcc: NVIDIA (R) Cuda compiler driver
|
||||
Copyright (c) 2005-2025 NVIDIA Corporation
|
||||
Built on Wed_Jan_15_19:20:09_PST_2025
|
||||
Cuda compilation tools, release 12.8, V12.8.61
|
||||
Build cuda_12.8.r12.8/compiler.35404655_0
|
||||
Built on Fri_Feb_21_20:23:50_PST_2025
|
||||
Cuda compilation tools, release 12.8, V12.8.93
|
||||
Build cuda_12.8.r12.8/compiler.35583870_0
|
||||
```
|
||||
|
||||
This output confirms that the CUDA compiler is accessible and indicates the installed version.
|
||||
+14
-3
@@ -237,6 +237,15 @@ cmake -B buildWithCublas -DCMAKE_CXX_COMPILER=icpx -DCMAKE_C_COMPILER=icx -DENAB
|
||||
cmake --build buildWithCublas --config Release
|
||||
```
|
||||
|
||||
**oneDNN**: The current oneDNN releases *(shipped with the oneAPI base-toolkit)* do not include the NVIDIA backend. Therefore, oneDNN must be compiled from source to enable the NVIDIA target:
|
||||
|
||||
```sh
|
||||
git clone https://github.com/oneapi-src/oneDNN.git
|
||||
cd oneDNN
|
||||
cmake -GNinja -Bbuild-nvidia -DDNNL_CPU_RUNTIME=DPCPP -DDNNL_GPU_RUNTIME=DPCPP -DDNNL_GPU_VENDOR=NVIDIA -DONEDNN_BUILD_GRAPH=OFF -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
|
||||
cmake --build build-nvidia --config Release
|
||||
```
|
||||
|
||||
- **Adding support to AMD GPUs**
|
||||
|
||||
**oneAPI Plugin**: In order to enable SYCL support on AMD GPUs, please install the [Codeplay oneAPI Plugin for AMD GPUs](https://developer.codeplay.com/products/oneapi/amd/download). As with Nvidia GPUs, the user should also make sure the plugin version matches the installed base toolkit.
|
||||
@@ -327,10 +336,10 @@ export CPLUS_INCLUDE_DIR=/path/to/oneMKL/include:$CPLUS_INCLUDE_DIR
|
||||
GGML_SYCL_DEVICE_ARCH=sm_80 # Example architecture
|
||||
|
||||
# Option 1: Use FP32 (recommended for better performance in most cases)
|
||||
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
|
||||
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DDNNL_DIR=/path/to/oneDNN/build-nvidia/install/lib/cmake/dnnl
|
||||
|
||||
# Option 2: Use FP16
|
||||
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON
|
||||
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON -DDNNL_DIR=/path/to/oneDNN/build-nvidia/install/lib/cmake/dnnl
|
||||
|
||||
# build all binary
|
||||
cmake --build build --config Release -j -v
|
||||
@@ -660,8 +669,9 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
||||
|--------------------|---------------------------------------|---------------------------------------------|
|
||||
| GGML_SYCL | ON (mandatory) | Enable build with SYCL code path.<br>FP32 path - recommended for better perforemance than FP16 on quantized model|
|
||||
| GGML_SYCL_TARGET | INTEL *(default)* \| NVIDIA \| AMD | Set the SYCL target device type. |
|
||||
| GGML_SYCL_DEVICE_ARCH | Optional (except for AMD) | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. |
|
||||
| GGML_SYCL_DEVICE_ARCH | Optional (except for AMD) | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. |
|
||||
| GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. |
|
||||
| GGML_SYCL_GRAPH | ON *(default)* \|OFF *(Optional)* | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). |
|
||||
| CMAKE_C_COMPILER | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path. |
|
||||
| CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)* | Set `icpx/icx` compiler for SYCL code path. |
|
||||
|
||||
@@ -671,6 +681,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
||||
|-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------|
|
||||
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
|
||||
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase |
|
||||
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
|
||||
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
|
||||
|
||||
|
||||
|
||||
+26
-4
@@ -132,12 +132,14 @@ You may find the official downloads here: [NVIDIA developer site](https://develo
|
||||
|
||||
|
||||
#### Compile and run inside a Fedora Toolbox Container
|
||||
We also have a [guide](./cuda-fedora.md) for setting up CUDA toolkit in a Fedora [toolbox container](https://containertoolbx.org/).
|
||||
We also have a [guide](./backend/CUDA-FEDORA.md) for setting up CUDA toolkit in a Fedora [toolbox container](https://containertoolbx.org/).
|
||||
|
||||
**Recommended for:**
|
||||
|
||||
- ***Particularly*** *convenient* for users of [Atomic Desktops for Fedora](https://fedoraproject.org/atomic-desktops/); such as: [Silverblue](https://fedoraproject.org/atomic-desktops/silverblue/) and [Kinoite](https://fedoraproject.org/atomic-desktops/kinoite/).
|
||||
- Toolbox is installed by default: [Fedora Workstation](https://fedoraproject.org/workstation/) or [Fedora KDE Plasma Desktop](https://fedoraproject.org/spins/kde).
|
||||
- ***Necessary*** for users of [Atomic Desktops for Fedora](https://fedoraproject.org/atomic-desktops/); such as: [Silverblue](https://fedoraproject.org/atomic-desktops/silverblue/) and [Kinoite](https://fedoraproject.org/atomic-desktops/kinoite/).
|
||||
- (there are no supported CUDA packages for these systems)
|
||||
- ***Necessary*** for users that have a host that is not a: [Supported Nvidia CUDA Release Platform](https://developer.nvidia.com/cuda-downloads).
|
||||
- (for example, you may have [Fedora 42 Beta](https://fedoramagazine.org/announcing-fedora-linux-42-beta/) as your your host operating system)
|
||||
- ***Convenient*** For those running [Fedora Workstation](https://fedoraproject.org/workstation/) or [Fedora KDE Plasma Desktop](https://fedoraproject.org/spins/kde), and want to keep their host system clean.
|
||||
- *Optionally* toolbox packages are available: [Arch Linux](https://archlinux.org/), [Red Hat Enterprise Linux >= 8.5](https://www.redhat.com/en/technologies/linux-platforms/enterprise-linux), or [Ubuntu](https://ubuntu.com/download)
|
||||
|
||||
|
||||
@@ -433,6 +435,26 @@ llama_new_context_with_model: CANN compute buffer size = 1260.81 MiB
|
||||
|
||||
For detailed info, such as model/device supports, CANN install, please refer to [llama.cpp for CANN](./backend/CANN.md).
|
||||
|
||||
## Arm® KleidiAI™
|
||||
KleidiAI is a library of optimized microkernels for AI workloads, specifically designed for Arm CPUs. These microkernels enhance performance and can be enabled for use by the CPU backend.
|
||||
|
||||
To enable KleidiAI, go to the llama.cpp directory and build using CMake
|
||||
```bash
|
||||
cmake -B build -DGGML_CPU_KLEIDIAI=ON
|
||||
cmake --build build --config Release
|
||||
```
|
||||
You can verify that KleidiAI is being used by running
|
||||
```bash
|
||||
./build/bin/llama-cli -m PATH_TO_MODEL -p "What is a car?"
|
||||
```
|
||||
If KleidiAI is enabled, the ouput will contain a line similar to:
|
||||
```
|
||||
load_tensors: CPU_KLEIDIAI model buffer size = 3474.00 MiB
|
||||
```
|
||||
KleidiAI's microkernels implement optimized tensor operations using Arm CPU features such as dotprod, int8mm and SME. llama.cpp selects the most efficient kernel based on runtime CPU feature detection. However, on platforms that support SME, you must manually enable SME microkernels by setting the environment variable `GGML_KLEIDIAI_SME=1`.
|
||||
|
||||
Depending on your build target, other higher priority backends may be enabled by default. To ensure the CPU backend is used, you must disable the higher priority backends either at compile time, e.g. -DGGML_METAL=OFF, or during run-time using the command line option `--device none`.
|
||||
|
||||
## Android
|
||||
|
||||
To read documentation for how to build on Android, [click here](./android.md)
|
||||
|
||||
@@ -9,6 +9,13 @@ brew install llama.cpp
|
||||
```
|
||||
The formula is automatically updated with new `llama.cpp` releases. More info: https://github.com/ggml-org/llama.cpp/discussions/7668
|
||||
|
||||
## MacPorts
|
||||
|
||||
```sh
|
||||
sudo port install llama.cpp
|
||||
```
|
||||
see also: https://ports.macports.org/port/llama.cpp/details/
|
||||
|
||||
## Nix
|
||||
|
||||
On Mac and Linux, the Nix package manager can be used via
|
||||
|
||||
+36
-5
@@ -27,12 +27,24 @@ Once downloaded, place your model in the models folder in llama.cpp.
|
||||
##### Input prompt (One-and-done)
|
||||
|
||||
```bash
|
||||
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --prompt "Once upon a time"
|
||||
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf -no-cnv --prompt "Once upon a time"
|
||||
```
|
||||
##### Conversation mode (Allow for continuous interaction with the model)
|
||||
|
||||
```bash
|
||||
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf -cnv --chat-template gemma
|
||||
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --chat-template gemma
|
||||
```
|
||||
|
||||
##### Conversation mode using built-in jinja chat template
|
||||
|
||||
```bash
|
||||
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --jinja
|
||||
```
|
||||
|
||||
##### One-and-done query using jinja with custom system prompt and a starting prompt
|
||||
|
||||
```bash
|
||||
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --jinja --single-turn -sys "You are a helpful assistant" -p "Hello"
|
||||
```
|
||||
|
||||
##### Infinite text from a starting prompt (you can use `Ctrl-C` to stop it):
|
||||
@@ -44,12 +56,24 @@ Once downloaded, place your model in the models folder in llama.cpp.
|
||||
|
||||
##### Input prompt (One-and-done)
|
||||
```powershell
|
||||
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --prompt "Once upon a time"
|
||||
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf -no-cnv --prompt "Once upon a time"
|
||||
```
|
||||
##### Conversation mode (Allow for continuous interaction with the model)
|
||||
|
||||
```powershell
|
||||
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf -cnv --chat-template gemma
|
||||
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --chat-template gemma
|
||||
```
|
||||
|
||||
##### Conversation mode using built-in jinja chat template
|
||||
|
||||
```powershell
|
||||
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --jinja
|
||||
```
|
||||
|
||||
##### One-and-done query using jinja with custom system prompt and a starting prompt
|
||||
|
||||
```powershell
|
||||
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --jinja --single-turn -sys "You are a helpful assistant" -p "Hello"
|
||||
```
|
||||
|
||||
#### Infinite text from a starting prompt (you can use `Ctrl-C` to stop it):
|
||||
@@ -77,6 +101,8 @@ The `llama-cli` program provides several ways to interact with the LLaMA models
|
||||
|
||||
- `--prompt PROMPT`: Provide a prompt directly as a command-line option.
|
||||
- `--file FNAME`: Provide a file containing a prompt or multiple prompts.
|
||||
- `--system-prompt PROMPT`: Provide a system prompt (will otherwise use the default one in the chat template (if provided)).
|
||||
- `--system-prompt-file FNAME`: Provide a file containing a system prompt.
|
||||
- `--interactive-first`: Run the program in interactive mode and wait for input right away. (More on this below.)
|
||||
|
||||
## Interaction
|
||||
@@ -89,7 +115,10 @@ In interactive mode, users can participate in text generation by injecting their
|
||||
|
||||
- `-i, --interactive`: Run the program in interactive mode, allowing users to engage in real-time conversations or provide specific instructions to the model.
|
||||
- `--interactive-first`: Run the program in interactive mode and immediately wait for user input before starting the text generation.
|
||||
- `-cnv, --conversation`: Run the program in conversation mode (does not print special tokens and suffix/prefix, use default chat template) (default: false)
|
||||
- `-cnv, --conversation`: Run the program in conversation mode (does not print special tokens and suffix/prefix, use default or provided chat template) (default: true if chat template found)
|
||||
- `-no-cnv`: Disable conversation mode (default: false)
|
||||
- `-st, --single-turn`: Only process a single conversation turn (user input) and then exit.
|
||||
- `--jinja`: Enable jinja chat template parser, will use the model's built-in template or a user-provided one (default: false)
|
||||
- `--color`: Enable colorized output to differentiate visually distinguishing between prompts, user input, and generated text.
|
||||
|
||||
By understanding and utilizing these interaction options, you can create engaging and dynamic experiences with the LLaMA models, tailoring the text generation process to your specific needs.
|
||||
@@ -125,6 +154,8 @@ When --in-prefix or --in-suffix options are enabled the chat template ( --chat-t
|
||||
|
||||
Example usage: `--chat-template gemma`
|
||||
|
||||
`--chat-template-file FNAME`: Load a custom jinja chat template from an external file, useful if the model contains outdated or incompatible template, some examples can be found in models/templates. Up-to-date chat templates can be downloaded from Hugging Face using scripts/get_chat_template.py
|
||||
|
||||
## Context Management
|
||||
|
||||
During text generation, LLaMA models have a limited context size, which means they can only consider a certain number of tokens from the input and generated text. When the context fills up, the model resets internally, potentially losing some information from the beginning of the conversation or instructions. Context management options help maintain continuity and coherence in these situations.
|
||||
|
||||
Binary file not shown.
@@ -830,6 +830,11 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
ret.push_back({"timings", timings.to_json()});
|
||||
}
|
||||
|
||||
// extra fields for debugging purposes
|
||||
if (verbose) {
|
||||
ret["__verbose"] = to_json_non_oaicompat();
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
@@ -1872,6 +1877,10 @@ struct server_context {
|
||||
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
||||
params_dft.n_parallel = 1;
|
||||
|
||||
// force F16 KV cache for the draft model for extra performance
|
||||
params_dft.cache_type_k = GGML_TYPE_F16;
|
||||
params_dft.cache_type_v = GGML_TYPE_F16;
|
||||
|
||||
llama_init_dft = common_init_from_params(params_dft);
|
||||
|
||||
model_dft = llama_init_dft.model.get();
|
||||
@@ -1892,10 +1901,6 @@ struct server_context {
|
||||
cparams_dft = common_context_params_to_llama(params_dft);
|
||||
cparams_dft.n_batch = n_ctx_dft;
|
||||
|
||||
// force F16 KV cache for the draft model for extra performance
|
||||
cparams_dft.type_k = GGML_TYPE_F16;
|
||||
cparams_dft.type_v = GGML_TYPE_F16;
|
||||
|
||||
// the context is not needed - we will create one for each slot
|
||||
llama_init_dft.context.reset();
|
||||
}
|
||||
|
||||
@@ -99,13 +99,9 @@ export default function ChatScreen() {
|
||||
canvasData,
|
||||
replaceMessageAndGenerate,
|
||||
} = useAppContext();
|
||||
const [inputMsg, setInputMsg] = useState(prefilledMsg.content());
|
||||
const inputRef = useRef<HTMLTextAreaElement>(null);
|
||||
const textarea = useOptimizedTextarea(prefilledMsg.content());
|
||||
|
||||
const { extraContext, clearExtraContext } = useVSCodeContext(
|
||||
inputRef,
|
||||
setInputMsg
|
||||
);
|
||||
const { extraContext, clearExtraContext } = useVSCodeContext(textarea);
|
||||
// TODO: improve this when we have "upload file" feature
|
||||
const currExtra: Message['extra'] = extraContext ? [extraContext] : undefined;
|
||||
|
||||
@@ -135,9 +131,10 @@ export default function ChatScreen() {
|
||||
};
|
||||
|
||||
const sendNewMessage = async () => {
|
||||
if (inputMsg.trim().length === 0 || isGenerating(currConvId ?? '')) return;
|
||||
const lastInpMsg = inputMsg;
|
||||
setInputMsg('');
|
||||
const lastInpMsg = textarea.value();
|
||||
if (lastInpMsg.trim().length === 0 || isGenerating(currConvId ?? ''))
|
||||
return;
|
||||
textarea.setValue('');
|
||||
scrollToBottom(false);
|
||||
setCurrNodeId(-1);
|
||||
// get the last message node
|
||||
@@ -146,13 +143,13 @@ export default function ChatScreen() {
|
||||
!(await sendMessage(
|
||||
currConvId,
|
||||
lastMsgNodeId,
|
||||
inputMsg,
|
||||
lastInpMsg,
|
||||
currExtra,
|
||||
onChunk
|
||||
))
|
||||
) {
|
||||
// restore the input message if failed
|
||||
setInputMsg(lastInpMsg);
|
||||
textarea.setValue(lastInpMsg);
|
||||
}
|
||||
// OK
|
||||
clearExtraContext();
|
||||
@@ -195,16 +192,13 @@ export default function ChatScreen() {
|
||||
// send the prefilled message if needed
|
||||
sendNewMessage();
|
||||
} else {
|
||||
// otherwise, focus on the input and move the cursor to the end
|
||||
if (inputRef.current) {
|
||||
inputRef.current.focus();
|
||||
inputRef.current.selectionStart = inputRef.current.value.length;
|
||||
}
|
||||
// otherwise, focus on the input
|
||||
textarea.focus();
|
||||
}
|
||||
prefilledMsg.clear();
|
||||
// no need to keep track of sendNewMessage
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [inputRef]);
|
||||
}, [textarea.ref]);
|
||||
|
||||
// due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg)
|
||||
const pendingMsgDisplay: MessageDisplay[] =
|
||||
@@ -258,9 +252,7 @@ export default function ChatScreen() {
|
||||
<textarea
|
||||
className="textarea textarea-bordered w-full"
|
||||
placeholder="Type a message (Shift+Enter to add a new line)"
|
||||
ref={inputRef}
|
||||
value={inputMsg}
|
||||
onChange={(e) => setInputMsg(e.target.value)}
|
||||
ref={textarea.ref}
|
||||
onKeyDown={(e) => {
|
||||
if (e.nativeEvent.isComposing || e.keyCode === 229) return;
|
||||
if (e.key === 'Enter' && e.shiftKey) return;
|
||||
@@ -280,11 +272,7 @@ export default function ChatScreen() {
|
||||
Stop
|
||||
</button>
|
||||
) : (
|
||||
<button
|
||||
className="btn btn-primary ml-2"
|
||||
onClick={sendNewMessage}
|
||||
disabled={inputMsg.trim().length === 0}
|
||||
>
|
||||
<button className="btn btn-primary ml-2" onClick={sendNewMessage}>
|
||||
Send
|
||||
</button>
|
||||
)}
|
||||
@@ -298,3 +286,43 @@ export default function ChatScreen() {
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export interface OptimizedTextareaValue {
|
||||
value: () => string;
|
||||
setValue: (value: string) => void;
|
||||
focus: () => void;
|
||||
ref: React.RefObject<HTMLTextAreaElement>;
|
||||
}
|
||||
|
||||
// This is a workaround to prevent the textarea from re-rendering when the inner content changes
|
||||
// See https://github.com/ggml-org/llama.cpp/pull/12299
|
||||
function useOptimizedTextarea(initValue: string): OptimizedTextareaValue {
|
||||
const [savedInitValue, setSavedInitValue] = useState<string>(initValue);
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (textareaRef.current && savedInitValue) {
|
||||
textareaRef.current.value = savedInitValue;
|
||||
setSavedInitValue('');
|
||||
}
|
||||
}, [textareaRef, savedInitValue, setSavedInitValue]);
|
||||
|
||||
return {
|
||||
value: () => {
|
||||
return textareaRef.current?.value ?? savedInitValue;
|
||||
},
|
||||
setValue: (value: string) => {
|
||||
if (textareaRef.current) {
|
||||
textareaRef.current.value = value;
|
||||
}
|
||||
},
|
||||
focus: () => {
|
||||
if (textareaRef.current) {
|
||||
// focus and move the cursor to the end
|
||||
textareaRef.current.focus();
|
||||
textareaRef.current.selectionStart = textareaRef.current.value.length;
|
||||
}
|
||||
},
|
||||
ref: textareaRef,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { useEffect, useState } from 'react';
|
||||
import { MessageExtraContext } from './types';
|
||||
import { OptimizedTextareaValue } from '../components/ChatScreen';
|
||||
|
||||
// Extra context when using llama.cpp WebUI from llama-vscode, inside an iframe
|
||||
// Ref: https://github.com/ggml-org/llama.cpp/pull/11940
|
||||
@@ -14,10 +15,7 @@ interface SetTextEvData {
|
||||
* window.postMessage({ command: 'setText', text: 'Spot the syntax error', context: 'def test()\n return 123' }, '*');
|
||||
*/
|
||||
|
||||
export const useVSCodeContext = (
|
||||
inputRef: React.RefObject<HTMLTextAreaElement>,
|
||||
setInputMsg: (text: string) => void
|
||||
) => {
|
||||
export const useVSCodeContext = (textarea: OptimizedTextareaValue) => {
|
||||
const [extraContext, setExtraContext] = useState<MessageExtraContext | null>(
|
||||
null
|
||||
);
|
||||
@@ -27,20 +25,20 @@ export const useVSCodeContext = (
|
||||
const handleMessage = (event: MessageEvent) => {
|
||||
if (event.data?.command === 'setText') {
|
||||
const data: SetTextEvData = event.data;
|
||||
setInputMsg(data?.text);
|
||||
textarea.setValue(data?.text);
|
||||
if (data?.context && data.context.length > 0) {
|
||||
setExtraContext({
|
||||
type: 'context',
|
||||
content: data.context,
|
||||
});
|
||||
}
|
||||
inputRef.current?.focus();
|
||||
textarea.focus();
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('message', handleMessage);
|
||||
return () => window.removeEventListener('message', handleMessage);
|
||||
}, [inputRef, setInputMsg]);
|
||||
}, [textarea]);
|
||||
|
||||
// Add a keydown listener that sends the "escapePressed" message to the parent window
|
||||
useEffect(() => {
|
||||
|
||||
@@ -331,11 +331,11 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
active_seqs.erase(s);
|
||||
for(int i = 0; i < n_seq_dft; i++) {
|
||||
for (int i = 0; i < n_seq_dft; i++) {
|
||||
if (i == s) {
|
||||
continue;
|
||||
}
|
||||
if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
|
||||
if (drafts[i].active && drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
|
||||
// synchronize active status for sequences with the same drafted token
|
||||
drafts[i].active = drafts[i].active && accept;
|
||||
if (!drafts[i].active) {
|
||||
|
||||
@@ -571,6 +571,10 @@ int main(int argc, char ** argv) {
|
||||
model_ttc = llama_init_ttc.model.get();
|
||||
ctx_ttc = llama_init_ttc.context.get();
|
||||
|
||||
if (model_ttc == nullptr || ctx_ttc == nullptr) {
|
||||
return ENOENT;
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model_ttc);
|
||||
|
||||
// TODO: refactor in a common struct
|
||||
@@ -586,6 +590,10 @@ int main(int argc, char ** argv) {
|
||||
model_cts = llama_init_cts.model.get();
|
||||
ctx_cts = llama_init_cts.context.get();
|
||||
|
||||
if (model_cts == nullptr || ctx_cts == nullptr) {
|
||||
return ENOENT;
|
||||
}
|
||||
|
||||
std::vector<common_sampler *> smpl(n_parallel);
|
||||
for (int i = 0; i < n_parallel; ++i) {
|
||||
params.sampling.no_perf = (i != 0);
|
||||
|
||||
@@ -186,6 +186,7 @@ option(GGML_OPENMP "ggml: use OpenMP"
|
||||
option(GGML_RPC "ggml: use RPC" OFF)
|
||||
option(GGML_SYCL "ggml: use SYCL" OFF)
|
||||
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
|
||||
option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
|
||||
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
|
||||
"ggml: sycl target device")
|
||||
set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING
|
||||
|
||||
@@ -454,6 +454,7 @@ extern "C" {
|
||||
GGML_OP_RMS_NORM,
|
||||
GGML_OP_RMS_NORM_BACK,
|
||||
GGML_OP_GROUP_NORM,
|
||||
GGML_OP_L2_NORM,
|
||||
|
||||
GGML_OP_MUL_MAT,
|
||||
GGML_OP_MUL_MAT_ID,
|
||||
@@ -502,6 +503,7 @@ extern "C" {
|
||||
GGML_OP_ADD_REL_POS,
|
||||
GGML_OP_RWKV_WKV6,
|
||||
GGML_OP_GATED_LINEAR_ATTN,
|
||||
GGML_OP_RWKV_WKV7,
|
||||
|
||||
GGML_OP_UNARY,
|
||||
|
||||
@@ -1095,6 +1097,18 @@ extern "C" {
|
||||
int n_groups,
|
||||
float eps);
|
||||
|
||||
// l2 normalize along rows
|
||||
// used in rwkv v7
|
||||
GGML_API struct ggml_tensor * ggml_l2_norm(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float eps);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_l2_norm_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float eps);
|
||||
|
||||
// a - x
|
||||
// b - dy
|
||||
GGML_API struct ggml_tensor * ggml_rms_norm_back(
|
||||
@@ -1890,6 +1904,16 @@ extern "C" {
|
||||
struct ggml_tensor * state,
|
||||
float scale);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_rwkv_wkv7(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * r,
|
||||
struct ggml_tensor * w,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * state);
|
||||
|
||||
// custom operators
|
||||
|
||||
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
||||
|
||||
@@ -76,7 +76,11 @@ if (GGML_CCACHE)
|
||||
set(GGML_CCACHE_VARIANT sccache)
|
||||
endif()
|
||||
# TODO: should not be set globally
|
||||
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${GGML_CCACHE_VARIANT}")
|
||||
if (GGML_SYCL AND GGML_CCACHE_FOUND AND WIN32)
|
||||
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "ccache compiler_type=icl")
|
||||
else ()
|
||||
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${GGML_CCACHE_VARIANT}")
|
||||
endif ()
|
||||
set(ENV{CCACHE_SLOPPINESS} time_macros)
|
||||
message(STATUS "${GGML_CCACHE_VARIANT} found, compilation results will be cached. Disable with GGML_CCACHE=OFF.")
|
||||
else()
|
||||
@@ -325,6 +329,10 @@ if (CMAKE_SYSTEM_NAME MATCHES "Android")
|
||||
target_link_libraries(ggml-base PRIVATE dl)
|
||||
endif()
|
||||
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "visionOS")
|
||||
target_compile_definitions(ggml-base PUBLIC _DARWIN_C_SOURCE)
|
||||
endif()
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
foreach (target ggml-base ggml)
|
||||
set_target_properties(${target} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
@@ -287,17 +287,25 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
|
||||
elseif ("${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "ppc64le " OR "${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "powerpc ")
|
||||
message(STATUS "PowerPC detected")
|
||||
execute_process(COMMAND bash -c "grep POWER /proc/cpuinfo | head -n 1" OUTPUT_VARIABLE POWER_M)
|
||||
if (${POWER_M} MATCHES "POWER10")
|
||||
list(APPEND ARCH_FLAGS -mcpu=power10)
|
||||
elseif (${POWER_M} MATCHES "POWER9")
|
||||
list(APPEND ARCH_FLAGS -mcpu=power9)
|
||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
|
||||
file(READ "/proc/cpuinfo" POWER10_M)
|
||||
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "powerpc")
|
||||
execute_process(COMMAND bash -c "prtconf |grep 'Implementation' | head -n 1" OUTPUT_VARIABLE POWER10_M)
|
||||
endif()
|
||||
|
||||
string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M}")
|
||||
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
|
||||
|
||||
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64)
|
||||
elseif (EXTRACTED_NUMBER EQUAL 9)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64)
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
|
||||
list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)
|
||||
else()
|
||||
list(APPEND ARCH_FLAGS -mcpu=powerpc64 -mtune=native)
|
||||
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native -mpowerpc64)
|
||||
endif()
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
|
||||
message(STATUS "loongarch64 detected")
|
||||
@@ -351,9 +359,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
|
||||
# Fetch KleidiAI sources:
|
||||
include(FetchContent)
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.3.0")
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.5.0")
|
||||
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
||||
set(KLEIDIAI_ARCHIVE_MD5 "060bd2dc64642b091f461cc8dd7426d9")
|
||||
set(KLEIDIAI_ARCHIVE_MD5 "ea22e1aefb800e9bc8c74d91633cc58e")
|
||||
|
||||
if (POLICY CMP0135)
|
||||
cmake_policy(SET CMP0135 NEW)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8158,7 +8158,156 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
#ifdef __ARM_NEON
|
||||
#ifdef __ARM_FEATURE_SVE
|
||||
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
||||
float sum = 0;
|
||||
svuint8_t m4b = svdup_n_u8(0xf);
|
||||
svint32_t vzero = svdup_n_s32(0);
|
||||
svuint8_t mone = svdup_n_u8(0x30);
|
||||
svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
|
||||
svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float d_all = GGML_FP16_TO_FP32(x[i].d);
|
||||
|
||||
const uint8_t * GGML_RESTRICT q6 = x[i].ql;
|
||||
const uint8_t * GGML_RESTRICT qh = x[i].qh;
|
||||
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
||||
|
||||
const int8_t * GGML_RESTRICT scale = x[i].scales;
|
||||
|
||||
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
|
||||
const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
|
||||
const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
|
||||
const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
|
||||
const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
|
||||
const svint64_t prod = svdup_n_s64(0);
|
||||
int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
|
||||
svdot_s64(prod, q8sums_2, q6scales_2)));
|
||||
int32_t isum = 0;
|
||||
|
||||
switch (vector_length) {
|
||||
case 128:
|
||||
{
|
||||
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
|
||||
const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
|
||||
svint32_t isum_tmp = svdup_n_s32(0);
|
||||
for (int j = 0; j < QK_K/128; ++j) {
|
||||
svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
|
||||
svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
|
||||
qh += 32;
|
||||
svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
|
||||
svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
|
||||
svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
|
||||
svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
|
||||
q6 += 64;
|
||||
svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
|
||||
svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
|
||||
svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
|
||||
svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
|
||||
q8 += 64;
|
||||
|
||||
q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
|
||||
q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
|
||||
q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
|
||||
q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
|
||||
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
|
||||
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
|
||||
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
|
||||
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
|
||||
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
|
||||
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
|
||||
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
|
||||
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
|
||||
|
||||
scale += 4;
|
||||
q8bytes_1 = svld1_s8(pg8_16, q8);
|
||||
q8bytes_2 = svld1_s8(pg8_16, q8+16);
|
||||
q8bytes_3 = svld1_s8(pg8_16, q8+32);
|
||||
q8bytes_4 = svld1_s8(pg8_16, q8+48);
|
||||
q8 += 64;
|
||||
|
||||
q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
|
||||
q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
|
||||
q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
|
||||
q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
|
||||
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
|
||||
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
|
||||
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
|
||||
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
|
||||
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
|
||||
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
|
||||
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
|
||||
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
|
||||
scale += 4;
|
||||
}
|
||||
isum += svaddv_s32(pg32_4, isum_tmp);
|
||||
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
||||
}
|
||||
break;
|
||||
case 256:
|
||||
case 512:
|
||||
{
|
||||
const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
|
||||
const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
|
||||
const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
|
||||
svint32_t isum_tmp = svdup_n_s32(0);
|
||||
for (int j = 0; j < QK_K/128; j++) {
|
||||
svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
|
||||
qh += 32;
|
||||
svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
|
||||
svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
|
||||
q6 += 64;
|
||||
svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
|
||||
svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
|
||||
svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
|
||||
svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
|
||||
q8 += 128;
|
||||
q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
|
||||
q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
|
||||
q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
|
||||
q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
|
||||
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
|
||||
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
|
||||
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
|
||||
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
|
||||
|
||||
svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
|
||||
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
|
||||
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
|
||||
svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
|
||||
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
|
||||
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
|
||||
svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
|
||||
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
|
||||
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
|
||||
svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
|
||||
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
|
||||
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
|
||||
svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
|
||||
svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
|
||||
svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
|
||||
svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
|
||||
|
||||
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
|
||||
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
|
||||
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
|
||||
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
|
||||
scale += 8;
|
||||
}
|
||||
isum += svaddv_s32(pg32_8, isum_tmp);
|
||||
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
assert(false && "Unsupported vector length");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
*s = sum;
|
||||
|
||||
#elif __ARM_NEON
|
||||
float sum = 0;
|
||||
|
||||
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
||||
|
||||
+284
-29
@@ -3110,17 +3110,17 @@ static void ggml_compute_forward_dup_same_cont(
|
||||
const int ith = params->ith; // thread index
|
||||
const int nth = params->nth; // number of threads
|
||||
|
||||
// parallelize by elements
|
||||
const int ne = ggml_nelements(dst);
|
||||
const int dr = (ne + nth - 1) / nth;
|
||||
const int ie0 = dr * ith;
|
||||
const int ie1 = MIN(ie0 + dr, ne);
|
||||
// parallelize by blocks
|
||||
const int nk = ggml_nelements(src0)/ggml_blck_size(src0->type);
|
||||
const int dr = (nk + nth - 1) / nth;
|
||||
const int k0 = dr * ith;
|
||||
const int k1 = MIN(k0 + dr, nk);
|
||||
|
||||
if (ie0 < ie1) {
|
||||
if (k0 < k1) {
|
||||
memcpy(
|
||||
((char *) dst->data + ie0*nb0),
|
||||
((char *) src0->data + ie0*nb0),
|
||||
(ie1 - ie0) * nb0);
|
||||
((char *) dst->data + k0*nb0),
|
||||
((char *) src0->data + k0*nb0),
|
||||
(k1 - k0) * nb0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4055,7 +4055,6 @@ static void ggml_compute_forward_dup_f32(
|
||||
static void ggml_compute_forward_dup_bytes(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
||||
@@ -4069,10 +4068,10 @@ static void ggml_compute_forward_dup_bytes(
|
||||
}
|
||||
|
||||
const size_t type_size = ggml_type_size(src0->type);
|
||||
|
||||
const int ith = params->ith; // thread index
|
||||
const int nth = params->nth; // number of threads
|
||||
|
||||
|
||||
// parallelize by rows
|
||||
const int nr = ne01;
|
||||
// number of rows per thread
|
||||
@@ -4082,10 +4081,10 @@ static void ggml_compute_forward_dup_bytes(
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
if (src0->type == dst->type &&
|
||||
ne00 == ne0 &&
|
||||
ggml_are_same_shape(src0, dst) &&
|
||||
nb00 == type_size && nb0 == type_size) {
|
||||
// copy by rows
|
||||
const size_t rs = ne00 * type_size;
|
||||
const size_t rs = ggml_row_size(src0->type, ne00);
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
||||
@@ -4140,17 +4139,20 @@ static void ggml_compute_forward_dup_bytes(
|
||||
}
|
||||
|
||||
// dst counters
|
||||
|
||||
int64_t i10 = 0;
|
||||
int64_t k10 = 0;
|
||||
int64_t i11 = 0;
|
||||
int64_t i12 = 0;
|
||||
int64_t i13 = 0;
|
||||
|
||||
// number of blocks in a row
|
||||
const int64_t nk00 = ne00 / ggml_blck_size(src0->type);
|
||||
const int64_t nk0 = ne0 / ggml_blck_size(dst->type);
|
||||
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
i10 += ne00 * ir0;
|
||||
while (i10 >= ne0) {
|
||||
i10 -= ne0;
|
||||
k10 += nk00 * ir0;
|
||||
while (k10 >= nk0) {
|
||||
k10 -= nk0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
@@ -4162,14 +4164,14 @@ static void ggml_compute_forward_dup_bytes(
|
||||
}
|
||||
}
|
||||
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||
for (int64_t k00 = 0; k00 < nk00; k00++) {
|
||||
const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||
|
||||
memcpy(dst_ptr, src0_ptr, type_size);
|
||||
|
||||
if (++i10 == ne0) {
|
||||
i10 = 0;
|
||||
if (++k10 == nk0) {
|
||||
k10 = 0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
@@ -4182,9 +4184,9 @@ static void ggml_compute_forward_dup_bytes(
|
||||
}
|
||||
}
|
||||
}
|
||||
i10 += ne00 * (ne01 - ir1);
|
||||
while (i10 >= ne0) {
|
||||
i10 -= ne0;
|
||||
k10 += nk00 * (ne01 - ir1);
|
||||
while (k10 >= nk0) {
|
||||
k10 -= nk0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
@@ -8548,6 +8550,69 @@ static void ggml_compute_forward_group_norm(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_l2_norm
|
||||
|
||||
static void ggml_compute_forward_l2_norm_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
// TODO: optimize
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
ggml_float sum = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
sum += (ggml_float)(x[i00] * x[i00]);
|
||||
}
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
memcpy(y, x, ne00 * sizeof(float));
|
||||
|
||||
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
|
||||
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_l2_norm(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_l2_norm_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_mul_mat
|
||||
|
||||
static void ggml_compute_forward_mul_mat_one_chunk(
|
||||
@@ -13604,6 +13669,184 @@ static void ggml_compute_forward_gla(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_rwkv_wkv7
|
||||
|
||||
static void ggml_compute_forward_rwkv_wkv7_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
const int64_t T = dst->src[1]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t HEADS = dst->src[1]->ne[1];
|
||||
const int64_t n_seqs = dst->src[6]->ne[1];
|
||||
const int64_t head_size = C / HEADS;
|
||||
|
||||
float * dst_data = (float *) dst->data;
|
||||
float * state = ((float *) dst->data) + C * T;
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
if (ith >= HEADS) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int h_start = (HEADS * ith) / nth;
|
||||
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
||||
(HEADS * (ith + 1)) / nth : HEADS;
|
||||
|
||||
float * r = (float *) dst->src[0]->data;
|
||||
float * w = (float *) dst->src[1]->data;
|
||||
float * k = (float *) dst->src[2]->data;
|
||||
float * v = (float *) dst->src[3]->data;
|
||||
float * a = (float *) dst->src[4]->data;
|
||||
float * b = (float *) dst->src[5]->data;
|
||||
|
||||
int64_t t_stride = HEADS * head_size; // Same to C
|
||||
|
||||
int64_t h_stride = C / HEADS;
|
||||
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
|
||||
int64_t h_stride_2d = head_size * head_size;
|
||||
|
||||
#if defined(GGML_SIMD)
|
||||
for (int64_t t = 0; t < T; t++) {
|
||||
int64_t t_offset = t * t_stride;
|
||||
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
||||
|
||||
for (int64_t h = h_start; h < h_end; h++) {
|
||||
int64_t h_offset = h * h_stride;
|
||||
int64_t t_h_offset = t_offset + h_offset;
|
||||
int64_t h_2d_offset = h * h_stride_2d;
|
||||
|
||||
for (int64_t ii = 0; ii < head_size; ii++) {
|
||||
int64_t t_h_i_offset = t_h_offset + ii;
|
||||
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
|
||||
|
||||
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
|
||||
|
||||
float sa = 0;
|
||||
{
|
||||
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
||||
GGML_F32_VEC ax[GGML_F32_ARR];
|
||||
GGML_F32_VEC ay[GGML_F32_ARR];
|
||||
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
|
||||
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
||||
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
|
||||
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
|
||||
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
|
||||
}
|
||||
}
|
||||
GGML_F32_VEC_REDUCE(sa, sum);
|
||||
}
|
||||
|
||||
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
|
||||
|
||||
int64_t j = 0;
|
||||
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
||||
for (; j < head_size; j += GGML_F32_STEP) {
|
||||
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
||||
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
|
||||
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
|
||||
|
||||
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
|
||||
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
|
||||
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
|
||||
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
|
||||
|
||||
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
|
||||
|
||||
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
|
||||
// kv + s * decay + sa * b
|
||||
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
|
||||
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
|
||||
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
|
||||
|
||||
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
|
||||
}
|
||||
}
|
||||
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
|
||||
|
||||
// There shouldn't be left-overs though.
|
||||
for (; j < head_size; j++) {
|
||||
int64_t t_h_j_offset = t_h_offset + j;
|
||||
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||
|
||||
float r_val = r[t_h_j_offset];
|
||||
float w_val = w[t_h_j_offset];
|
||||
float k_val = k[t_h_j_offset];
|
||||
float b_val = b[t_h_j_offset];
|
||||
float kv_val = v[t_h_i_offset] * k_val;
|
||||
|
||||
float prev_state_val = state_prev[h_2d_i_j_offset];
|
||||
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
||||
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int64_t t = 0; t < T; t++) {
|
||||
int64_t t_offset = t * t_stride;
|
||||
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
||||
|
||||
for (int64_t h = h_start; h < h_end; h++) {
|
||||
int64_t h_offset = h * h_stride;
|
||||
int64_t t_h_offset = t_offset + h_offset;
|
||||
int64_t h_2d_offset = h * h_stride_2d;
|
||||
|
||||
for (int64_t i = 0; i < head_size; i++) {
|
||||
int64_t t_h_i_offset = t_h_offset + i;
|
||||
int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
||||
|
||||
float v_val = v[t_h_i_offset];
|
||||
|
||||
float sa = 0, result = 0;
|
||||
for (int64_t j = 0; j < head_size; j++) {
|
||||
sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
|
||||
}
|
||||
|
||||
for (int64_t j = 0; j < head_size; j++) {
|
||||
int64_t t_h_j_offset = t_h_offset + j;
|
||||
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||
|
||||
float r_val = r[t_h_j_offset];
|
||||
float w_val = w[t_h_j_offset];
|
||||
float k_val = k[t_h_j_offset];
|
||||
float b_val = b[t_h_j_offset];
|
||||
float kv_val = v_val * k_val;
|
||||
float prev_state_val = state_prev[h_2d_i_j_offset];
|
||||
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
||||
result += state_cur[h_2d_i_j_offset] * r_val;
|
||||
}
|
||||
dst_data[t_h_i_offset] = result;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
static void ggml_compute_forward_rwkv_wkv7(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_rwkv_wkv7_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_map_unary
|
||||
|
||||
static void ggml_compute_forward_map_unary_f32(
|
||||
@@ -14067,7 +14310,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
}
|
||||
|
||||
// extra_buffer op?
|
||||
if (ggml_cpu_extra_compute_forward(params, tensor)) return;
|
||||
if (ggml_cpu_extra_compute_forward(params, tensor)) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (tensor->op) {
|
||||
case GGML_OP_DUP:
|
||||
@@ -14170,6 +14415,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_group_norm(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_L2_NORM:
|
||||
{
|
||||
ggml_compute_forward_l2_norm(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
ggml_compute_forward_mul_mat(params, tensor);
|
||||
@@ -14357,6 +14606,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_gla(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
{
|
||||
ggml_compute_forward_rwkv_wkv7(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_MAP_UNARY:
|
||||
{
|
||||
ggml_unary_op_f32_t fun;
|
||||
@@ -14582,6 +14835,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_MUL_MAT:
|
||||
@@ -14648,14 +14902,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_WIN_PART:
|
||||
case GGML_OP_WIN_UNPART:
|
||||
case GGML_OP_GET_REL_POS:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_MAP_UNARY:
|
||||
case GGML_OP_MAP_BINARY:
|
||||
case GGML_OP_MAP_CUSTOM1_F32:
|
||||
|
||||
@@ -51,11 +51,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
},
|
||||
/* .lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .require_aligned_m_idx = */ true,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
||||
@@ -100,7 +99,6 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .require_aligned_m_idx = */ false,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
@@ -144,7 +142,6 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .require_aligned_m_idx = */ false,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
@@ -189,7 +186,6 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .require_aligned_m_idx = */ false,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
@@ -233,7 +229,6 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .require_aligned_m_idx = */ false,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
|
||||
@@ -40,7 +40,6 @@ struct lhs_packing_info {
|
||||
size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
|
||||
void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
|
||||
size_t lhs_stride, void* lhs_packed);
|
||||
bool require_aligned_m_idx;
|
||||
};
|
||||
|
||||
struct rhs_packing_info {
|
||||
|
||||
@@ -124,8 +124,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
size_t sr = kernel->get_sr();
|
||||
|
||||
// Calculate number of columns to be processed per thread
|
||||
const bool use_multithread = lhs_info->require_aligned_m_idx && m <= mr ? false : true;
|
||||
const size_t num_m_per_thread = use_multithread ? kai_roundup(m, nth) / nth : m;
|
||||
const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
|
||||
const size_t m_start = ith * num_m_per_thread;
|
||||
size_t m_to_process = num_m_per_thread;
|
||||
if ((m_start + m_to_process) > m) {
|
||||
@@ -135,11 +134,11 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
if(m_start < m) {
|
||||
// Transform LHS
|
||||
const size_t src_stride = src1->nb[1];
|
||||
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(0, dst->src[1]->nb[1]));
|
||||
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
|
||||
const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr);
|
||||
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
|
||||
|
||||
lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, m_start, src_ptr, src_stride, lhs_packed_ptr);
|
||||
lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
@@ -41,14 +41,17 @@
|
||||
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
|
||||
#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
|
||||
|
||||
#define GGML_CUDA_CC_PASCAL 600
|
||||
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
||||
#define GGML_CUDA_CC_VOLTA 700
|
||||
#define GGML_CUDA_CC_TURING 750
|
||||
#define GGML_CUDA_CC_AMPERE 800
|
||||
#define GGML_CUDA_CC_ADA_LOVELACE 890
|
||||
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
|
||||
#define GGML_CUDA_CC_PASCAL 600
|
||||
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
||||
#define GGML_CUDA_CC_VOLTA 700
|
||||
#define GGML_CUDA_CC_TURING 750
|
||||
#define GGML_CUDA_CC_AMPERE 800
|
||||
#define GGML_CUDA_CC_ADA_LOVELACE 890
|
||||
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
|
||||
#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
|
||||
#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
|
||||
|
||||
// AMD
|
||||
// GCN/CNDA, wave size is 64
|
||||
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
|
||||
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
|
||||
@@ -70,8 +73,17 @@
|
||||
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
|
||||
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
|
||||
|
||||
#define GGML_CUDA_CC_QY1 210
|
||||
#define GGML_CUDA_CC_QY2 220
|
||||
// Moore Threads
|
||||
#define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210)
|
||||
|
||||
#define GGML_CUDA_CC_QY1 (GGML_MUSA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
|
||||
#define GGML_CUDA_CC_QY2 (GGML_MUSA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
|
||||
#define GGML_CUDA_CC_NG (GGML_MUSA_CC_OFFSET_MTHREADS + 0x310) // TBD
|
||||
|
||||
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
|
||||
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
|
||||
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NEXT)
|
||||
#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
|
||||
|
||||
#ifdef __CUDA_ARCH_LIST__
|
||||
constexpr bool ggml_cuda_has_arch_impl(int) {
|
||||
@@ -209,21 +221,21 @@ typedef float2 dfloat2;
|
||||
#define CP_ASYNC_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
|
||||
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
|
||||
#define FLASH_ATTN_AVAILABLE
|
||||
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
|
||||
|
||||
static bool fp16_available(const int cc) {
|
||||
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
|
||||
}
|
||||
|
||||
static bool fast_fp16_available(const int cc) {
|
||||
return fp16_available(cc) && cc != 610;
|
||||
return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
|
||||
}
|
||||
|
||||
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
||||
static bool fast_fp16_hardware_available(const int cc) {
|
||||
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
|
||||
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
|
||||
}
|
||||
|
||||
// Any FP16 tensor core instructions are available for ggml code.
|
||||
@@ -231,20 +243,20 @@ static bool fp16_mma_available(const int cc) {
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
|
||||
return false;
|
||||
#else
|
||||
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ||
|
||||
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
|
||||
return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
|
||||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc);
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
|
||||
}
|
||||
|
||||
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
||||
static bool fp16_mma_hardware_available(const int cc) {
|
||||
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA ||
|
||||
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
|
||||
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
|
||||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc);
|
||||
}
|
||||
|
||||
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
||||
static bool new_mma_available(const int cc) {
|
||||
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
||||
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
||||
}
|
||||
|
||||
static bool cp_async_available(const int cc) {
|
||||
@@ -678,7 +690,7 @@ struct ggml_tensor_extra_gpu {
|
||||
};
|
||||
|
||||
|
||||
#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS)
|
||||
#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
|
||||
#define USE_CUDA_GRAPH
|
||||
#endif
|
||||
|
||||
|
||||
@@ -606,48 +606,47 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
*dst = dst_val / rowsum;
|
||||
}
|
||||
|
||||
template<int D, int parallel_blocks> // D == head size
|
||||
template<int D> // D == head size
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void flash_attn_combine_results(
|
||||
const float * __restrict__ VKQ_parts,
|
||||
const float2 * __restrict__ VKQ_meta,
|
||||
float * __restrict__ dst) {
|
||||
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
|
||||
VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
|
||||
dst += D * gridDim.y*blockIdx.x;
|
||||
float * __restrict__ dst,
|
||||
const int parallel_blocks) {
|
||||
VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
|
||||
VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
|
||||
dst += D * gridDim.z*blockIdx.x;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
__builtin_assume(tid < D);
|
||||
|
||||
__shared__ float2 meta[parallel_blocks];
|
||||
extern __shared__ float2 meta[];
|
||||
if (tid < 2*parallel_blocks) {
|
||||
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
|
||||
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
float kqmax = meta[0].x;
|
||||
#pragma unroll
|
||||
for (int l = 1; l < parallel_blocks; ++l) {
|
||||
kqmax = max(kqmax, meta[l].x);
|
||||
}
|
||||
|
||||
float VKQ_numerator = 0.0f;
|
||||
float VKQ_denominator = 0.0f;
|
||||
#pragma unroll
|
||||
for (int l = 0; l < parallel_blocks; ++l) {
|
||||
const float diff = meta[l].x - kqmax;
|
||||
const float KQ_max_scale = expf(diff);
|
||||
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
||||
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
||||
|
||||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
|
||||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
|
||||
VKQ_denominator += KQ_max_scale * meta[l].y;
|
||||
}
|
||||
|
||||
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
|
||||
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
|
||||
}
|
||||
|
||||
static void on_no_fattn_vec_case(const int D) {
|
||||
@@ -671,12 +670,10 @@ static void on_no_fattn_vec_case(const int D) {
|
||||
}
|
||||
}
|
||||
|
||||
// parallel_blocks == 0 is stream-k decomposition
|
||||
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
|
||||
template <int D, int ncols1, int ncols2, int KQ_stride>
|
||||
void launch_fattn(
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
|
||||
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V,
|
||||
const int warp_size = WARP_SIZE
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
||||
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
|
||||
) {
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
|
||||
@@ -748,12 +745,14 @@ void launch_fattn(
|
||||
nb23 = nb23*bs*sizeof(half)/ts;
|
||||
}
|
||||
|
||||
int parallel_blocks = 1;
|
||||
|
||||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
||||
|
||||
const dim3 block_dim(warp_size, nwarps, 1);
|
||||
dim3 blocks_num;
|
||||
if (parallel_blocks == 0) {
|
||||
if (stream_k) {
|
||||
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
||||
const int max_blocks = 2*nsm;
|
||||
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
|
||||
@@ -769,9 +768,43 @@ void launch_fattn(
|
||||
|
||||
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
|
||||
} else {
|
||||
blocks_num.x = parallel_blocks*ntiles_x;
|
||||
blocks_num.y = Q->ne[2];
|
||||
blocks_num.z = Q->ne[3];
|
||||
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
|
||||
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
||||
|
||||
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
||||
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
||||
|
||||
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
|
||||
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
|
||||
|
||||
// parallel_blocks must not be larger than what the tensor size allows:
|
||||
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
|
||||
|
||||
// If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
|
||||
// Test whether parallel_blocks can be set to a higher value for better efficiency.
|
||||
const int blocks_per_wave = nsm * max_blocks_per_sm;
|
||||
int nwaves_best = 0;
|
||||
int efficiency_percent_best = 0;
|
||||
for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
|
||||
const int nblocks_total = ntiles_total * parallel_blocks_test;
|
||||
const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
|
||||
const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
|
||||
|
||||
// Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
|
||||
if (efficiency_percent_best >= 90 && nwaves > nwaves_best) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (efficiency_percent > efficiency_percent_best) {
|
||||
nwaves_best = nwaves;
|
||||
efficiency_percent_best = efficiency_percent;
|
||||
parallel_blocks = parallel_blocks_test;
|
||||
}
|
||||
}
|
||||
|
||||
blocks_num.x = ntiles_x;
|
||||
blocks_num.y = parallel_blocks;
|
||||
blocks_num.z = Q->ne[2]*Q->ne[3];
|
||||
|
||||
if (parallel_blocks > 1) {
|
||||
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||
@@ -803,7 +836,7 @@ void launch_fattn(
|
||||
K_data,
|
||||
V_data,
|
||||
mask ? ((const char *) mask->data) : nullptr,
|
||||
(parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
@@ -815,7 +848,7 @@ void launch_fattn(
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
if constexpr (parallel_blocks == 0) {
|
||||
if (stream_k) {
|
||||
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
|
||||
@@ -824,13 +857,14 @@ void launch_fattn(
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
||||
}
|
||||
} else if constexpr (parallel_blocks > 1) {
|
||||
} else if (parallel_blocks > 1) {
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
||||
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
|
||||
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
||||
|
||||
flash_attn_combine_results<D, parallel_blocks>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
||||
flash_attn_combine_results<D>
|
||||
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
|
||||
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
@@ -970,7 +970,8 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
|
||||
}
|
||||
|
||||
launch_fattn<D, ncols1, ncols2, 0, KQ_per_iter>(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true);
|
||||
launch_fattn<D, ncols1, ncols2, KQ_per_iter>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
#define FATTN_KQ_STRIDE_TILE_F16 64
|
||||
|
||||
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
|
||||
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
@@ -58,18 +58,17 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
@@ -105,8 +104,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
|
||||
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
|
||||
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
half kqmax_new[ncols/nwarps];
|
||||
@@ -271,16 +269,16 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
const int i0 = i00 + 2*threadIdx.x;
|
||||
|
||||
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
||||
if (parallel_blocks == 1) {
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= __half2half2(kqsum_j);
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
|
||||
}
|
||||
|
||||
if (parallel_blocks != 1 && threadIdx.x == 0) {
|
||||
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
if (gridDim.y != 1 && threadIdx.x == 0) {
|
||||
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
@@ -288,7 +286,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
}
|
||||
|
||||
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
|
||||
template <int cols_per_block, bool use_logit_softcap>
|
||||
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
switch (Q->ne[0]) {
|
||||
@@ -296,15 +294,17 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
constexpr int D = 64;
|
||||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, 1, -1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, 1, -1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||
@@ -324,37 +324,22 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
||||
|
||||
if (Q->ne[1] <= 16) {
|
||||
constexpr int cols_per_block = 16;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32) {
|
||||
constexpr int cols_per_block = 32;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 32;
|
||||
constexpr int parallel_blocks = 1;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
#define FATTN_KQ_STRIDE_TILE_F32 32
|
||||
|
||||
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
|
||||
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
@@ -58,18 +58,17 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
|
||||
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
|
||||
@@ -103,8 +102,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F32;
|
||||
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F32) {
|
||||
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
float kqmax_new[ncols/nwarps];
|
||||
@@ -269,17 +267,17 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
const int i0 = i00 + 2*threadIdx.x;
|
||||
|
||||
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
||||
if (parallel_blocks == 1) {
|
||||
if (gridDim.y == 1) {
|
||||
dst_val.x /= kqsum_j;
|
||||
dst_val.y /= kqsum_j;
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x;
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y;
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
|
||||
}
|
||||
|
||||
if (parallel_blocks != 1 && threadIdx.x == 0) {
|
||||
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
if (gridDim.y != 1 && threadIdx.x == 0) {
|
||||
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
@@ -287,7 +285,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
||||
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
|
||||
template <int cols_per_block, bool use_logit_softcap>
|
||||
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
switch (Q->ne[0]) {
|
||||
@@ -295,15 +293,17 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
constexpr int D = 64;
|
||||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, 1, -1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, 1, -1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||
@@ -320,37 +320,22 @@ void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_ten
|
||||
|
||||
if (Q->ne[1] <= 16) {
|
||||
constexpr int cols_per_block = 16;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32) {
|
||||
constexpr int cols_per_block = 32;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 32;
|
||||
constexpr int parallel_blocks = 1;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
|
||||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
@@ -55,17 +55,16 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
||||
constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
|
||||
|
||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
Q += nb02* blockIdx.y + nb01*ic0;
|
||||
K += nb12*(blockIdx.y / gqa_ratio);
|
||||
V += nb22*(blockIdx.y / gqa_ratio);
|
||||
Q += nb02* blockIdx.z + nb01*ic0;
|
||||
K += nb12*(blockIdx.z / gqa_ratio);
|
||||
V += nb22*(blockIdx.z / gqa_ratio);
|
||||
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
@@ -172,8 +171,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
|
||||
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
||||
|
||||
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
|
||||
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
||||
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
|
||||
@@ -283,29 +281,29 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]);
|
||||
|
||||
half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
|
||||
if (parallel_blocks == 1) {
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= kqsum[j_VKQ];
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
|
||||
}
|
||||
|
||||
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
}
|
||||
|
||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||
template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
constexpr int nwarps = D/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, type_K, type_V, use_logit_softcap>;
|
||||
constexpr bool need_f16_K = D != 128;
|
||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
|
||||
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||
}
|
||||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
@@ -325,65 +323,48 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (Q->ne[1] == 1) {
|
||||
constexpr int cols_per_block = 1;
|
||||
constexpr int parallel_blocks = 4;
|
||||
constexpr int cols_per_block = 1;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] == 2) {
|
||||
constexpr int cols_per_block = 2;
|
||||
constexpr int parallel_blocks = 4;
|
||||
constexpr int cols_per_block = 2;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 4) {
|
||||
constexpr int cols_per_block = 4;
|
||||
constexpr int parallel_blocks = 4;
|
||||
constexpr int cols_per_block = 4;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 8) {
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 1;
|
||||
constexpr int cols_per_block = 8;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
|
||||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
@@ -55,16 +55,15 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
||||
constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
|
||||
|
||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
Q += nb02* blockIdx.y + nb01*ic0;
|
||||
K += nb12*(blockIdx.y / gqa_ratio);
|
||||
V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
|
||||
Q += nb02* blockIdx.z + nb01*ic0;
|
||||
K += nb12*(blockIdx.z / gqa_ratio);
|
||||
V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
constexpr int nwarps = D / WARP_SIZE;
|
||||
@@ -167,8 +166,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
|
||||
float VKQ[ncols] = {0.0f};
|
||||
|
||||
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
|
||||
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
||||
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
float kqmax_new_arr[ncols];
|
||||
@@ -268,29 +266,29 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
|
||||
|
||||
float dst_val = VKQ[j_VKQ];
|
||||
if (parallel_blocks == 1) {
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= kqsum[j_VKQ];
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
|
||||
}
|
||||
|
||||
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||
template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
constexpr int nwarps = D/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, type_K, type_V, use_logit_softcap>;
|
||||
constexpr bool need_f16_K = D != 128;
|
||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
|
||||
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||
}
|
||||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
@@ -307,65 +305,48 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (Q->ne[1] == 1) {
|
||||
constexpr int cols_per_block = 1;
|
||||
constexpr int parallel_blocks = 4;
|
||||
constexpr int cols_per_block = 1;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] == 2) {
|
||||
constexpr int cols_per_block = 2;
|
||||
constexpr int parallel_blocks = 4;
|
||||
constexpr int cols_per_block = 2;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 4) {
|
||||
constexpr int cols_per_block = 4;
|
||||
constexpr int parallel_blocks = 4;
|
||||
constexpr int cols_per_block = 4;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 8) {
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 1;
|
||||
constexpr int cols_per_block = 8;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ namespace wmma = rocwmma;
|
||||
#endif // FP16_MMA_AVAILABLE
|
||||
|
||||
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
||||
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
|
||||
template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
|
||||
__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
|
||||
static __global__ void flash_attn_ext_f16(
|
||||
const char * __restrict__ Q,
|
||||
@@ -67,8 +67,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
|
||||
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||
const int ic0 = ncols*blockIdx.x; // Index of the first Q/QKV column to work on.
|
||||
|
||||
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
|
||||
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
|
||||
@@ -91,16 +90,16 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
|
||||
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||
const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
|
||||
const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
|
||||
const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
|
||||
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
|
||||
|
||||
const int stride_Q = nb01 / sizeof(float);
|
||||
const int stride_KV = nb11 / sizeof(half);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
const half2 slope2 = make_half2(slopef, slopef);
|
||||
|
||||
@@ -176,7 +175,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
__syncthreads();
|
||||
|
||||
// Iterate over ne11 == previous tokens:
|
||||
for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
|
||||
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
|
||||
// Calculate tile of KQ:
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
|
||||
@@ -395,7 +394,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
if (ic0 + j_VKQ >= ne01) {
|
||||
return;
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
|
||||
float KQ_rowsum_j;
|
||||
if (std::is_same<KQ_acc_t, float>::value) {
|
||||
@@ -411,13 +410,13 @@ static __global__ void flash_attn_ext_f16(
|
||||
break;
|
||||
}
|
||||
float dst_val = VKQ[j_VKQ*D_padded + i];
|
||||
if (parallel_blocks == 1) {
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= KQ_rowsum_j;
|
||||
}
|
||||
dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
|
||||
dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val;
|
||||
}
|
||||
|
||||
if (parallel_blocks == 1 || threadIdx.x != 0) {
|
||||
if (gridDim.y == 1 || threadIdx.x != 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -428,7 +427,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
|
||||
}
|
||||
dst_meta_val.y = KQ_rowsum_j;
|
||||
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
|
||||
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
@@ -462,60 +461,26 @@ static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
||||
template <int D, int cols_per_block, typename KQ_acc_t>
|
||||
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
constexpr int nwarps = 4;
|
||||
|
||||
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
|
||||
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
||||
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (4*blocks_num_pb1 < 2*nsm) {
|
||||
constexpr int parallel_blocks = 4;
|
||||
fattn_kernel_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
}
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
|
||||
return;
|
||||
}
|
||||
if (2*blocks_num_pb1 < 2*nsm) {
|
||||
constexpr int parallel_blocks = 2;
|
||||
fattn_kernel_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
}
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
|
||||
return;
|
||||
}
|
||||
constexpr int parallel_blocks = 1;
|
||||
fattn_kernel_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
|
||||
}
|
||||
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
+10
-10
@@ -253,7 +253,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||
|
||||
if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
|
||||
if (GGML_CUDA_CC_IS_AMD(cc)) {
|
||||
#if defined(GGML_HIP_ROCWMMA_FATTN)
|
||||
if (fp16_mma_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||
@@ -281,13 +281,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
|
||||
if (!fp16_mma_available(cc)) {
|
||||
if (prec == GGML_PREC_DEFAULT) {
|
||||
if (Q->ne[1] <= 8) {
|
||||
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
||||
}
|
||||
} else {
|
||||
if (Q->ne[1] <= 8) {
|
||||
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
||||
@@ -296,17 +296,17 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
return;
|
||||
}
|
||||
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
|
||||
K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
|
||||
if (Q->ne[1] == 1 && Q->ne[0] % (2*warp_size) == 0 && !mma_fast_for_bs1) {
|
||||
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
|
||||
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
||||
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
|
||||
const bool can_use_vector_kernel = (Q->ne[0] % (2*warp_size) == 0) && (prec == GGML_PREC_DEFAULT || Q->ne[0] <= 128);
|
||||
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
||||
if (prec == GGML_PREC_DEFAULT) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
return;
|
||||
} else if(Q->ne[0] <= 128) {
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
#include "ggml-cuda/tsembd.cuh"
|
||||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/wkv6.cuh"
|
||||
#include "ggml-cuda/wkv.cuh"
|
||||
#include "ggml-cuda/gla.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
@@ -262,9 +262,11 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
|
||||
device_vmm ? "yes" : "no", prop.warpSize);
|
||||
#elif defined(GGML_USE_MUSA)
|
||||
// TODO: refine the .cc to reflect MUSA's actual CC capabilities
|
||||
// FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
|
||||
info.devices[id].warp_size = 32;
|
||||
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
|
||||
info.devices[id].cc = 100*prop.major + 10*prop.minor;
|
||||
info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100;
|
||||
info.devices[id].cc += prop.minor * 0x10;
|
||||
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
|
||||
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
|
||||
#else
|
||||
@@ -1186,11 +1188,11 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||
// ldc == nrows of the matrix that cuBLAS writes into
|
||||
int64_t ldc = id == ctx.device ? ne0 : row_diff;
|
||||
|
||||
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
|
||||
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
|
||||
|
||||
if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) {
|
||||
if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
|
||||
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
|
||||
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
|
||||
if (src0->type != GGML_TYPE_F16) {
|
||||
@@ -1214,7 +1216,7 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
|
||||
|
||||
if (GGML_CUDA_CC_IS_CDNA(compute_capability)) {
|
||||
if (GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
CUBLAS_CHECK(
|
||||
@@ -2196,6 +2198,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_GROUP_NORM:
|
||||
ggml_cuda_op_group_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_L2_NORM:
|
||||
ggml_cuda_op_l2_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONCAT:
|
||||
ggml_cuda_op_concat(ctx, dst);
|
||||
break;
|
||||
@@ -2304,6 +2309,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
ggml_cuda_op_gated_linear_attn(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
ggml_cuda_op_rwkv_wkv7(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
||||
break;
|
||||
@@ -2610,13 +2618,15 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
|
||||
|
||||
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||
|
||||
#if CUDART_VERSION >= 12000
|
||||
cudaGraphExecUpdateResultInfo result_info;
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
hipGraphNode_t errorNode;
|
||||
hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
|
||||
#else
|
||||
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
||||
#endif
|
||||
#else
|
||||
cudaGraphNode_t errorNode;
|
||||
cudaGraphExecUpdateResult result_info;
|
||||
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
|
||||
#endif // CUDART_VERSION >= 12000
|
||||
|
||||
if (stat == cudaErrorGraphExecUpdateFailure) {
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
|
||||
@@ -3159,6 +3169,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
break;
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return true;
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
|
||||
@@ -3213,11 +3224,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT: {
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
return false;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
if (op->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -27,8 +27,8 @@ void ggml_cuda_op_mul_mat_q(
|
||||
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
|
||||
// Also its fixup needs to allocate a temporary buffer in the memory pool.
|
||||
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
|
||||
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA &&
|
||||
cc < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
|
||||
const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) &&
|
||||
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
|
||||
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
|
||||
|
||||
switch (src0->type) {
|
||||
@@ -145,7 +145,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
return true;
|
||||
#endif //GGML_CUDA_FORCE_MMQ
|
||||
|
||||
if (cc < GGML_CUDA_CC_OFFSET_AMD) {
|
||||
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
||||
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ struct tile_x_sizes {
|
||||
|
||||
static int get_mmq_x_max_host(const int cc) {
|
||||
return new_mma_available(cc) ? 128 :
|
||||
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ?
|
||||
GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
|
||||
#ifdef GGML_CUDA_FORCE_MMQ
|
||||
128 : 64;
|
||||
#else
|
||||
@@ -123,8 +123,8 @@ static constexpr __device__ int get_mmq_x_max_device() {
|
||||
}
|
||||
|
||||
static int get_mmq_y_host(const int cc) {
|
||||
return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
|
||||
(ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? 128 : 64);
|
||||
return GGML_CUDA_CC_IS_AMD(cc) ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
|
||||
((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_mmq_y_device() {
|
||||
@@ -2772,14 +2772,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||
|
||||
const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
|
||||
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||
if (!shmem_limit_raised[id]) {
|
||||
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
|
||||
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
|
||||
shmem_limit_raised[id] = true;
|
||||
}
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
|
||||
const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
|
||||
const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
|
||||
@@ -2832,7 +2832,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
|
||||
const int mmq_x_max = get_mmq_x_max_host(cc);
|
||||
const int mmq_y = get_mmq_y_host(cc);
|
||||
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
|
||||
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
|
||||
const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
|
||||
|
||||
int mmq_x_best = 0;
|
||||
int nparts_best = INT_MAX;
|
||||
|
||||
@@ -201,6 +201,85 @@ static __global__ void rms_norm_back_f32(
|
||||
}
|
||||
}
|
||||
|
||||
// template <int block_size>
|
||||
// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
||||
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
// const int tid = threadIdx.x;
|
||||
|
||||
// float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
// for (int col = tid; col < ncols; col += block_size) {
|
||||
// const float xi = x[row*ncols + col];
|
||||
// tmp += xi * xi;
|
||||
// }
|
||||
|
||||
// // sum up partial sums
|
||||
// tmp = warp_reduce_sum(tmp);
|
||||
// if (block_size > WARP_SIZE) {
|
||||
// __shared__ float s_sum[32];
|
||||
// int warp_id = threadIdx.x / WARP_SIZE;
|
||||
// int lane_id = threadIdx.x % WARP_SIZE;
|
||||
// if (lane_id == 0) {
|
||||
// s_sum[warp_id] = tmp;
|
||||
// }
|
||||
// __syncthreads();
|
||||
// tmp = s_sum[lane_id];
|
||||
// tmp = warp_reduce_sum(tmp);
|
||||
// }
|
||||
|
||||
// // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
|
||||
// const float scale = rsqrtf(fmaxf(tmp, eps * eps));
|
||||
|
||||
// for (int col = tid; col < ncols; col += block_size) {
|
||||
// dst[row*ncols + col] = scale * x[row*ncols + col];
|
||||
// }
|
||||
// }
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void l2_norm_f32(
|
||||
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
||||
const int64_t stride_sample, const float eps) {
|
||||
const int nrows = gridDim.x;
|
||||
const int nchannels = gridDim.y;
|
||||
|
||||
const int row = blockIdx.x;
|
||||
const int channel = blockIdx.y;
|
||||
const int sample = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
x += sample*stride_sample + channel*stride_channel + row*stride_row;
|
||||
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
if constexpr (block_size > WARP_SIZE) {
|
||||
static_assert(block_size == 1024, "unexpected block_size");
|
||||
__shared__ float s_sum[32];
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = tmp;
|
||||
}
|
||||
__syncthreads();
|
||||
tmp = s_sum[lane_id];
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
}
|
||||
|
||||
// from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
|
||||
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[col] = scale * x[col];
|
||||
}
|
||||
}
|
||||
|
||||
static void norm_f32_cuda(
|
||||
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
||||
@@ -248,6 +327,19 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
|
||||
}
|
||||
}
|
||||
|
||||
static void l2_norm_f32_cuda(
|
||||
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
||||
const dim3 blocks_num(nrows, nchannels, nsamples);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
l2_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
@@ -340,3 +432,27 @@ void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * d
|
||||
|
||||
rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
const size_t ts0 = ggml_type_size(src0->type);
|
||||
GGML_ASSERT(nb00 == ts0);
|
||||
const int64_t s01 = nb01 / ts0;
|
||||
const int64_t s02 = nb02 / ts0;
|
||||
const int64_t s03 = nb03 / ts0;
|
||||
|
||||
l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
|
||||
}
|
||||
|
||||
@@ -7,3 +7,5 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
Vendored
+2
-1
@@ -112,7 +112,7 @@
|
||||
#define cudaGraphExecDestroy hipGraphExecDestroy
|
||||
#define cudaGraphLaunch hipGraphLaunch
|
||||
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
|
||||
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
|
||||
#define cudaGraphExecUpdateResult hipGraphExecUpdateResult
|
||||
#define cudaGraphNodeType hipGraphNodeType
|
||||
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
|
||||
#define cudaGraphInstantiate hipGraphInstantiate
|
||||
@@ -129,6 +129,7 @@
|
||||
#define cudaGraph_t hipGraph_t
|
||||
#define cudaStream_t hipStream_t
|
||||
#define cudaSuccess hipSuccess
|
||||
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
|
||||
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
|
||||
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
||||
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
|
||||
|
||||
Vendored
+3
-1
@@ -119,7 +119,7 @@
|
||||
#define cudaGraphExecDestroy musaGraphExecDestroy
|
||||
#define cudaGraphExec_t musaGraphExec_t
|
||||
#define cudaGraphExecUpdate musaGraphExecUpdate
|
||||
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
|
||||
#define cudaGraphExecUpdateResult musaGraphExecUpdateResult
|
||||
#define cudaGraphGetNodes musaGraphGetNodes
|
||||
#define cudaGraphInstantiate musaGraphInstantiate
|
||||
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
||||
@@ -132,6 +132,8 @@
|
||||
#define cudaGraph_t musaGraph_t
|
||||
#define cudaKernelNodeParams musaKernelNodeParams
|
||||
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
||||
#define cudaStreamBeginCapture musaStreamBeginCapture
|
||||
#define cudaStreamEndCapture musaStreamEndCapture
|
||||
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
|
||||
|
||||
typedef mt_bfloat16 nv_bfloat16;
|
||||
|
||||
@@ -0,0 +1,199 @@
|
||||
#include "common.cuh"
|
||||
#include "wkv.cuh"
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = blockIdx.x;
|
||||
|
||||
const int head_size = block_size;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
float state[head_size];
|
||||
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
_tf[tid] = tf[head_i * head_size + tid];
|
||||
__syncthreads();
|
||||
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
||||
__syncthreads();
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
__syncthreads();
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0;
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
const float4& k = (float4&)(_k[j]);
|
||||
const float4& r = (float4&)(_r[j]);
|
||||
const float4& tf = (float4&)(_tf[j]);
|
||||
const float4& td = (float4&)(_td[j]);
|
||||
float4& s = (float4&)(state[j]);
|
||||
float4 kv;
|
||||
|
||||
kv.x = k.x * _v;
|
||||
kv.y = k.y * _v;
|
||||
kv.z = k.z * _v;
|
||||
kv.w = k.w * _v;
|
||||
|
||||
y += r.x * (tf.x * kv.x + s.x);
|
||||
y += r.y * (tf.y * kv.y + s.y);
|
||||
y += r.z * (tf.z * kv.z + s.z);
|
||||
y += r.w * (tf.w * kv.w + s.w);
|
||||
|
||||
s.x = s.x * td.x + kv.x;
|
||||
s.y = s.y * td.y + kv.y;
|
||||
s.z = s.z * td.z + kv.z;
|
||||
s.w = s.w * td.w + kv.w;
|
||||
}
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) {
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = blockIdx.x;
|
||||
|
||||
const int head_size = block_size;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
float state[head_size];
|
||||
__shared__ float _r[head_size], _w[head_size], _k[head_size], _a[head_size], _b[head_size];
|
||||
|
||||
#ifndef GGML_USE_MUSA
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
|
||||
}
|
||||
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
||||
__syncthreads();
|
||||
_r[tid] = r[t];
|
||||
_w[tid] = w[t];
|
||||
_k[tid] = k[t];
|
||||
_a[tid] = a[t];
|
||||
_b[tid] = b[t];
|
||||
__syncthreads();
|
||||
|
||||
float sa = 0;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < head_size; j += 4)
|
||||
{
|
||||
const float4& a = (float4&)(_a[j]);
|
||||
const float4& s = (float4&)(state[j]);
|
||||
sa += a.x * s.x;
|
||||
sa += a.y * s.y;
|
||||
sa += a.z * s.z;
|
||||
sa += a.w * s.w;
|
||||
}
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0;
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
const float4& r = (float4&)(_r[j]);
|
||||
const float4& w = (float4&)(_w[j]);
|
||||
const float4& k = (float4&)(_k[j]);
|
||||
const float4& b = (float4&)(_b[j]);
|
||||
float4& s = (float4&)(state[j]);
|
||||
float4 kv;
|
||||
|
||||
kv.x = k.x * _v;
|
||||
kv.y = k.y * _v;
|
||||
kv.z = k.z * _v;
|
||||
kv.w = k.w * _v;
|
||||
|
||||
s.x = s.x * w.x + kv.x + sa * b.x;
|
||||
s.y = s.y * w.y + kv.y + sa * b.y;
|
||||
s.z = s.z * w.z + kv.z + sa * b.z;
|
||||
s.w = s.w * w.w + kv.w + sa * b.w;
|
||||
|
||||
y += s.x * r.x;
|
||||
y += s.y * r.y;
|
||||
y += s.z * r.z;
|
||||
y += s.w * r.w;
|
||||
}
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const float * k_d = (const float *)dst->src[0]->data;
|
||||
const float * v_d = (const float *)dst->src[1]->data;
|
||||
const float * r_d = (const float *)dst->src[2]->data;
|
||||
const float * tf_d = (const float *)dst->src[3]->data;
|
||||
const float * td_d = (const float *)dst->src[4]->data;
|
||||
const float * s_d = (const float *)dst->src[5]->data;
|
||||
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
|
||||
|
||||
if (C / H == CUDA_WKV_BLOCK_SIZE) {
|
||||
rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
||||
} else {
|
||||
rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const float * r_d = (const float *)dst->src[0]->data;
|
||||
const float * w_d = (const float *)dst->src[1]->data;
|
||||
const float * k_d = (const float *)dst->src[2]->data;
|
||||
const float * v_d = (const float *)dst->src[3]->data;
|
||||
const float * a_d = (const float *)dst->src[4]->data;
|
||||
const float * b_d = (const float *)dst->src[5]->data;
|
||||
const float * s_d = (const float *)dst->src[6]->data;
|
||||
|
||||
const int64_t B = dst->src[6]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
|
||||
|
||||
if (C / H == CUDA_WKV_BLOCK_SIZE) {
|
||||
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
|
||||
} else {
|
||||
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
|
||||
}
|
||||
}
|
||||
@@ -3,3 +3,5 @@
|
||||
#define CUDA_WKV_BLOCK_SIZE 64
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -1,89 +0,0 @@
|
||||
#include "common.cuh"
|
||||
#include "wkv6.cuh"
|
||||
|
||||
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = blockIdx.x;
|
||||
|
||||
const int head_size = CUDA_WKV_BLOCK_SIZE;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
float state[head_size];
|
||||
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
_tf[tid] = tf[head_i * head_size + tid];
|
||||
__syncthreads();
|
||||
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
||||
__syncthreads();
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
__syncthreads();
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0;
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
const float4& k = (float4&)(_k[j]);
|
||||
const float4& r = (float4&)(_r[j]);
|
||||
const float4& tf = (float4&)(_tf[j]);
|
||||
const float4& td = (float4&)(_td[j]);
|
||||
float4& s = (float4&)(state[j]);
|
||||
float4 kv;
|
||||
|
||||
kv.x = k.x * _v;
|
||||
kv.y = k.y * _v;
|
||||
kv.z = k.z * _v;
|
||||
kv.w = k.w * _v;
|
||||
|
||||
y += r.x * (tf.x * kv.x + s.x);
|
||||
y += r.y * (tf.y * kv.y + s.y);
|
||||
y += r.z * (tf.z * kv.z + s.z);
|
||||
y += r.w * (tf.w * kv.w + s.w);
|
||||
|
||||
s.x = s.x * td.x + kv.x;
|
||||
s.y = s.y * td.y + kv.y;
|
||||
s.z = s.z * td.z + kv.z;
|
||||
s.w = s.w * td.w + kv.w;
|
||||
}
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const float * k_d = (const float *)dst->src[0]->data;
|
||||
const float * v_d = (const float *)dst->src[1]->data;
|
||||
const float * r_d = (const float *)dst->src[2]->data;
|
||||
const float * tf_d = (const float *)dst->src[3]->data;
|
||||
const float * td_d = (const float *)dst->src[4]->data;
|
||||
const float * s_d = (const float *)dst->src[5]->data;
|
||||
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64
|
||||
|
||||
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
||||
}
|
||||
@@ -285,6 +285,13 @@ typedef struct {
|
||||
float eps;
|
||||
} ggml_metal_kargs_rms_norm;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne00_4;
|
||||
uint64_t nb01;
|
||||
float eps;
|
||||
} ggml_metal_kargs_l2_norm;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
|
||||
@@ -184,10 +184,13 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
||||
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_L2_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
||||
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
||||
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
||||
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
||||
@@ -810,10 +813,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
|
||||
@@ -1251,6 +1257,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||
case GGML_OP_ARGMAX:
|
||||
return true;
|
||||
@@ -1288,6 +1295,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
return true;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
@@ -2216,6 +2225,83 @@ static void ggml_metal_encode_node(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
{
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == 64);
|
||||
|
||||
size_t offs_src3 = 0;
|
||||
size_t offs_src4 = 0;
|
||||
size_t offs_src5 = 0;
|
||||
|
||||
id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
|
||||
id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
|
||||
id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
||||
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
||||
|
||||
[encoder setBytes:&B length:sizeof(B) atIndex:7];
|
||||
[encoder setBytes:&T length:sizeof(T) atIndex:8];
|
||||
[encoder setBytes:&C length:sizeof(C) atIndex:9];
|
||||
[encoder setBytes:&H length:sizeof(H) atIndex:10];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
{
|
||||
const int64_t B = dst->src[6]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == 64);
|
||||
|
||||
size_t offs_src3 = 0;
|
||||
size_t offs_src4 = 0;
|
||||
size_t offs_src5 = 0;
|
||||
size_t offs_src6 = 0;
|
||||
|
||||
id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
|
||||
id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
|
||||
id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
|
||||
id<MTLBuffer> id_src6 = dst->src[6] ? ggml_metal_get_buffer(dst->src[6], &offs_src6) : nil;
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
||||
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
||||
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
||||
|
||||
[encoder setBytes:&B length:sizeof(B) atIndex:8];
|
||||
[encoder setBytes:&T length:sizeof(T) atIndex:9];
|
||||
[encoder setBytes:&C length:sizeof(C) atIndex:10];
|
||||
[encoder setBytes:&H length:sizeof(H) atIndex:11];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
GGML_ASSERT(ne00 == ne10);
|
||||
@@ -3122,6 +3208,42 @@ static void ggml_metal_encode_node(
|
||||
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_L2_NORM:
|
||||
{
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_L2_NORM].pipeline;
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = MIN(nth, ne00/4);
|
||||
|
||||
ggml_metal_kargs_l2_norm args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne00_4 =*/ ne00/4,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.eps =*/ eps,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_GROUP_NORM:
|
||||
|
||||
@@ -1295,6 +1295,184 @@ kernel void kernel_ssm_scan_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rwkv_wkv6_f32(
|
||||
device const float * k,
|
||||
device const float * v,
|
||||
device const float * r,
|
||||
device const float * tf,
|
||||
device const float * td,
|
||||
device const float * state_in,
|
||||
device float * dst,
|
||||
constant uint & B,
|
||||
constant uint & T,
|
||||
constant uint & C,
|
||||
constant uint & H,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const uint head_size = 64; // TODO: support head_size = 128
|
||||
const uint batch_id = tgpig.x / H;
|
||||
const uint head_id = tgpig.x % H;
|
||||
const uint tid = tpitg.x;
|
||||
|
||||
if (batch_id >= B || head_id >= H) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint state_size = C * head_size;
|
||||
const uint n_seq_tokens = T / B;
|
||||
|
||||
threadgroup float _k[head_size];
|
||||
threadgroup float _r[head_size];
|
||||
threadgroup float _tf[head_size];
|
||||
threadgroup float _td[head_size];
|
||||
|
||||
float state[head_size];
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
||||
+ i * head_size + tid];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
_tf[tid] = tf[head_id * head_size + tid];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
||||
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
||||
|
||||
for (uint t = start_t; t < end_t; t += C) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const float v_val = v[t];
|
||||
float y = 0.0;
|
||||
|
||||
for (uint j = 0; j < head_size; j += 4) {
|
||||
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
||||
float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
||||
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
float4 kv = k_vec * v_val;
|
||||
|
||||
float4 temp = tf_vec * kv + s_vec;
|
||||
y += dot(r_vec, temp);
|
||||
|
||||
s_vec = s_vec * td_vec + kv;
|
||||
state[j] = s_vec[0];
|
||||
state[j+1] = s_vec[1];
|
||||
state[j+2] = s_vec[2];
|
||||
state[j+3] = s_vec[3];
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
||||
+ i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rwkv_wkv7_f32(
|
||||
device const float * r,
|
||||
device const float * w,
|
||||
device const float * k,
|
||||
device const float * v,
|
||||
device const float * a,
|
||||
device const float * b,
|
||||
device const float * state_in,
|
||||
device float * dst,
|
||||
constant uint & B,
|
||||
constant uint & T,
|
||||
constant uint & C,
|
||||
constant uint & H,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const uint head_size = 64; // TODO: support head_size = 128
|
||||
const uint batch_id = tgpig.x / H;
|
||||
const uint head_id = tgpig.x % H;
|
||||
const uint tid = tpitg.x;
|
||||
|
||||
if (batch_id >= B || head_id >= H) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint state_size = C * head_size;
|
||||
const uint n_seq_tokens = T / B;
|
||||
|
||||
threadgroup float _r[head_size];
|
||||
threadgroup float _w[head_size];
|
||||
threadgroup float _k[head_size];
|
||||
threadgroup float _a[head_size];
|
||||
threadgroup float _b[head_size];
|
||||
|
||||
float state[head_size];
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
||||
+ tid * head_size + i];
|
||||
}
|
||||
|
||||
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
||||
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
||||
|
||||
for (uint t = start_t; t < end_t; t += C) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
_r[tid] = r[t];
|
||||
_w[tid] = w[t];
|
||||
_k[tid] = k[t];
|
||||
_a[tid] = a[t];
|
||||
_b[tid] = b[t];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const float v_val = v[t];
|
||||
float y = 0.0, sa = 0.0;
|
||||
|
||||
float4 sa_vec(0.0);
|
||||
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
||||
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
sa_vec += a_vec * s_vec;
|
||||
}
|
||||
sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
|
||||
|
||||
for (uint j = 0; j < head_size; j += 4) {
|
||||
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
||||
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
||||
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
float4 kv = k_vec * v_val;
|
||||
|
||||
s_vec = s_vec * w_vec + kv + sa * b_vec;
|
||||
y += dot(s_vec, r_vec);
|
||||
|
||||
state[j] = s_vec[0];
|
||||
state[j+1] = s_vec[1];
|
||||
state[j+2] = s_vec[2];
|
||||
state[j+3] = s_vec[3];
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
||||
+ tid * head_size + i] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_argmax(
|
||||
device const void * x,
|
||||
device int32_t * dst,
|
||||
@@ -1463,6 +1641,49 @@ kernel void kernel_rms_norm(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_l2_norm(
|
||||
constant ggml_metal_kargs_l2_norm & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort ntg[[threads_per_threadgroup]]) {
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
// parallel sum
|
||||
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
||||
sumf += dot(x[i00], x[i00]);
|
||||
}
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
const float scale = 1.0f/sqrt(max(sumf, args.eps));
|
||||
|
||||
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
|
||||
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
||||
y[i00] = x[i00] * scale;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_group_norm(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
|
||||
@@ -67,10 +67,6 @@ if (MUSAToolkit_FOUND)
|
||||
add_compile_definitions(GGML_USE_MUSA)
|
||||
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
|
||||
|
||||
if (GGML_CUDA_GRAPHS)
|
||||
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_FORCE_MMQ)
|
||||
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
|
||||
endif()
|
||||
|
||||
@@ -25,124 +25,46 @@ endif ()
|
||||
if (GGML_OPENCL_EMBED_KERNELS)
|
||||
add_compile_definitions(GGML_OPENCL_EMBED_KERNELS)
|
||||
|
||||
set(OPENCL_CL_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl.cl.h")
|
||||
set(OPENCL_MM_CL_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_mm.cl.h")
|
||||
set(OPENCL_CVT_CL_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_cvt.cl.h")
|
||||
set(EMBED_KERNEL_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/kernels/embed_kernel.py")
|
||||
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/autogenerated")
|
||||
|
||||
set(OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_gemv_noshuffle.cl.h")
|
||||
set(OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_gemv_noshuffle_general.cl.h")
|
||||
set(OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_mul_mat_Ab_Bi_8x4.cl.h")
|
||||
set(OPENCL_TRANSPOSE_16_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_transpose_16.cl.h")
|
||||
set(OPENCL_TRANSPOSE_32_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_transpose_32.cl.h")
|
||||
set(OPENCL_TRANSPOSE_32_16_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_transpose_32_16.cl.h")
|
||||
|
||||
set(EMBED_KERNEL_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/kernels/embed_kernel.py")
|
||||
file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated")
|
||||
|
||||
include_directories("${CMAKE_BINARY_DIR}/autogenerated")
|
||||
|
||||
# Python must be accessible from command line
|
||||
add_custom_command(
|
||||
OUTPUT ${OPENCL_CL_SOURCE_EMBED}
|
||||
COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl.cl
|
||||
${OPENCL_CL_SOURCE_EMBED}
|
||||
DEPENDS kernels/ggml-opencl.cl ${EMBED_KERNEL_SCRIPT}
|
||||
COMMENT "Generate ggml-opencl.cl.h"
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${OPENCL_MM_CL_SOURCE_EMBED}
|
||||
COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_mm.cl
|
||||
${OPENCL_MM_CL_SOURCE_EMBED}
|
||||
DEPENDS kernels/ggml-opencl_mm.cl ${EMBED_KERNEL_SCRIPT}
|
||||
COMMENT "Generate ggml-opencl_mm.cl.h"
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${OPENCL_CVT_CL_SOURCE_EMBED}
|
||||
COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_cvt.cl
|
||||
${OPENCL_CVT_CL_SOURCE_EMBED}
|
||||
DEPENDS kernels/ggml-opencl_cvt.cl ${EMBED_KERNEL_SCRIPT}
|
||||
COMMENT "Generate ggml-opencl_cvt.cl.h"
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED}
|
||||
COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_gemv_noshuffle.cl
|
||||
${OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED}
|
||||
DEPENDS kernels/ggml-opencl_gemv_noshuffle.cl ${EMBED_KERNEL_SCRIPT}
|
||||
COMMENT "Generate ggml-opencl_gemv_noshuffle.cl.h"
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED}
|
||||
COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_gemv_noshuffle_general.cl
|
||||
${OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED}
|
||||
DEPENDS kernels/ggml-opencl_gemv_noshuffle_general.cl ${EMBED_KERNEL_SCRIPT}
|
||||
COMMENT "Generate ggml-opencl_gemv_noshuffle_general.cl.h"
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED}
|
||||
COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl
|
||||
${OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED}
|
||||
DEPENDS kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl ${EMBED_KERNEL_SCRIPT}
|
||||
COMMENT "Generate ggml-opencl_mul_mat_Ab_Bi_8x4.cl.cl.h"
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${OPENCL_TRANSPOSE_16_SOURCE_EMBED}
|
||||
COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_transpose_16.cl
|
||||
${OPENCL_TRANSPOSE_16_SOURCE_EMBED}
|
||||
DEPENDS kernels/ggml-opencl_transpose_16.cl ${EMBED_KERNEL_SCRIPT}
|
||||
COMMENT "Generate ggml-opencl_transpose_16.cl.h"
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${OPENCL_TRANSPOSE_32_SOURCE_EMBED}
|
||||
COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_transpose_32.cl
|
||||
${OPENCL_TRANSPOSE_32_SOURCE_EMBED}
|
||||
DEPENDS kernels/ggml-opencl_transpose_32.cl ${EMBED_KERNEL_SCRIPT}
|
||||
COMMENT "Generate ggml-opencl_transpose_32.cl.h"
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${OPENCL_TRANSPOSE_32_16_SOURCE_EMBED}
|
||||
COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_transpose_32_16.cl
|
||||
${OPENCL_TRANSPOSE_32_16_SOURCE_EMBED}
|
||||
DEPENDS kernels/ggml-opencl_transpose_32_16.cl ${EMBED_KERNEL_SCRIPT}
|
||||
COMMENT "Generate ggml-opencl_transpose_32_16.cl.h"
|
||||
)
|
||||
|
||||
target_sources(${TARGET_NAME} PRIVATE
|
||||
${OPENCL_CL_SOURCE_EMBED}
|
||||
${OPENCL_MM_CL_SOURCE_EMBED}
|
||||
${OPENCL_CVT_CL_SOURCE_EMBED}
|
||||
${OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED}
|
||||
${OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED}
|
||||
${OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED}
|
||||
${OPENCL_TRANSPOSE_16_SOURCE_EMBED}
|
||||
${OPENCL_TRANSPOSE_32_SOURCE_EMBED}
|
||||
${OPENCL_TRANSPOSE_32_16_SOURCE_EMBED})
|
||||
else ()
|
||||
# copy ggml-opencl.cl to bin directory
|
||||
configure_file(kernels/ggml-opencl.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl.cl COPYONLY)
|
||||
configure_file(kernels/ggml-opencl_mm.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_mm.cl COPYONLY)
|
||||
configure_file(kernels/ggml-opencl_cvt.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_cvt.cl COPYONLY)
|
||||
|
||||
configure_file(kernels/ggml-opencl_gemv_noshuffle.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_gemv_noshuffle.cl COPYONLY)
|
||||
configure_file(kernels/ggml-opencl_gemv_noshuffle_general.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_gemv_noshuffle_general.cl COPYONLY)
|
||||
configure_file(kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_mul_mat_Ab_Bi_8x4.cl COPYONLY)
|
||||
configure_file(kernels/ggml-opencl_transpose_16.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_transpose_16.cl COPYONLY)
|
||||
configure_file(kernels/ggml-opencl_transpose_32.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_transpose_32.cl COPYONLY)
|
||||
configure_file(kernels/ggml-opencl_transpose_32_16.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_transpose_32_16.cl COPYONLY)
|
||||
target_include_directories(${TARGET_NAME} PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/autogenerated")
|
||||
endif ()
|
||||
|
||||
function(ggml_opencl_add_kernel KNAME)
|
||||
set(KERN_HDR ${CMAKE_CURRENT_BINARY_DIR}/autogenerated/${KNAME}.cl.h)
|
||||
set(KERN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/kernels/${KNAME}.cl)
|
||||
|
||||
if (GGML_OPENCL_EMBED_KERNELS)
|
||||
message(STATUS "opencl: embedding kernel ${KNAME}")
|
||||
|
||||
# Python must be accessible from command line
|
||||
add_custom_command(
|
||||
OUTPUT ${KERN_HDR}
|
||||
COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} ${KERN_SRC} ${KERN_HDR}
|
||||
DEPENDS ${KERN_SRC} ${EMBED_KERNEL_SCRIPT}
|
||||
COMMENT "Generate ${KERN_HDR}"
|
||||
)
|
||||
|
||||
target_sources(${TARGET_NAME} PRIVATE ${KERN_HDR})
|
||||
else ()
|
||||
message(STATUS "opencl: adding kernel ${KNAME}")
|
||||
configure_file(${KERN_SRC} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${KNAME}.cl COPYONLY)
|
||||
endif ()
|
||||
endfunction()
|
||||
|
||||
set(GGML_OPENCL_KERNELS
|
||||
ggml-opencl
|
||||
ggml-opencl_mm
|
||||
ggml-opencl_cvt
|
||||
ggml-opencl_gemv_noshuffle
|
||||
ggml-opencl_gemv_noshuffle_general
|
||||
ggml-opencl_mul_mat_Ab_Bi_8x4
|
||||
ggml-opencl_transpose_16
|
||||
ggml-opencl_transpose_32
|
||||
ggml-opencl_transpose_32_16
|
||||
)
|
||||
|
||||
foreach (K ${GGML_OPENCL_KERNELS})
|
||||
ggml_opencl_add_kernel(${K})
|
||||
endforeach()
|
||||
|
||||
@@ -297,8 +297,27 @@ static int ggml_backend_opencl_n_devices = 0;
|
||||
struct ProfilingInfo {
|
||||
std::string op_name;
|
||||
std::string kernel_name;
|
||||
// Kernel execution time in nanoseconds.
|
||||
cl_ulong duration_ns;
|
||||
|
||||
cl_kernel kernel;
|
||||
cl_event evt;
|
||||
|
||||
cl_ulong cmd_queued;
|
||||
cl_ulong cmd_submit;
|
||||
cl_ulong cmd_start;
|
||||
cl_ulong cmd_end;
|
||||
cl_ulong overhead_start;
|
||||
cl_ulong overhead_end;
|
||||
// For the times below, see spec for clGetEventProfilingInfo
|
||||
// The time kernel spent in cmd queue - SUBMIT - QUEUED
|
||||
cl_ulong cmd_queued_duration_ns;
|
||||
// The time kernel spent for submission - START - SUBMIT
|
||||
cl_ulong cmd_submit_duration_ns;
|
||||
// Kernel execution time in nanoseconds - END - START
|
||||
cl_ulong cmd_duration_ns;
|
||||
// The time for the kernel to complete - COMPLETE - END
|
||||
cl_ulong cmd_complete_duration_ns;
|
||||
// Total time to finish the kernel - COMPELTE - QUEUED
|
||||
cl_ulong cmd_total_duration_ns;
|
||||
// Global and local work sizes.
|
||||
size_t global_size[3];
|
||||
size_t local_size[3];
|
||||
@@ -903,12 +922,56 @@ static void ggml_cl2_free(void) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Populate profiling info
|
||||
for (ProfilingInfo & info : g_profiling_info) {
|
||||
cl_ulong cmd_queued;
|
||||
cl_ulong cmd_submit;
|
||||
cl_ulong cmd_start;
|
||||
cl_ulong cmd_end;
|
||||
cl_ulong cmd_complete;
|
||||
|
||||
CL_CHECK(clWaitForEvents(1, &info.evt));
|
||||
CL_CHECK(clGetEventProfilingInfo(
|
||||
info.evt, CL_PROFILING_COMMAND_QUEUED, sizeof(cl_ulong), &cmd_queued, NULL));
|
||||
CL_CHECK(clGetEventProfilingInfo(
|
||||
info.evt, CL_PROFILING_COMMAND_SUBMIT, sizeof(cl_ulong), &cmd_submit, NULL));
|
||||
CL_CHECK(clGetEventProfilingInfo(
|
||||
info.evt, CL_PROFILING_COMMAND_START, sizeof(cl_ulong), &cmd_start, NULL));
|
||||
CL_CHECK(clGetEventProfilingInfo(
|
||||
info.evt, CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &cmd_end, NULL));
|
||||
CL_CHECK(clGetEventProfilingInfo(
|
||||
info.evt, CL_PROFILING_COMMAND_COMPLETE, sizeof(cl_ulong), &cmd_complete, NULL));
|
||||
CL_CHECK(clReleaseEvent(info.evt));
|
||||
|
||||
char kernel_name[512];
|
||||
CL_CHECK(clGetKernelInfo(info.kernel, CL_KERNEL_FUNCTION_NAME,
|
||||
sizeof(kernel_name), kernel_name, NULL));
|
||||
info.kernel_name = kernel_name;
|
||||
|
||||
info.cmd_queued = cmd_queued;
|
||||
info.cmd_submit = cmd_submit;
|
||||
info.cmd_start = cmd_start;
|
||||
info.cmd_end = cmd_end;
|
||||
|
||||
info.cmd_queued_duration_ns = cmd_submit - cmd_queued;
|
||||
info.cmd_submit_duration_ns = cmd_start - cmd_submit;
|
||||
info.cmd_duration_ns = cmd_end - cmd_start;
|
||||
info.cmd_complete_duration_ns = cmd_complete - cmd_end;
|
||||
info.cmd_total_duration_ns = cmd_complete - cmd_queued;
|
||||
}
|
||||
|
||||
// Dump a csv
|
||||
float total_kernel_time = 0;
|
||||
fprintf(fperf, "op name, kernel name, duration (ms), global size, local size, output size\n");
|
||||
fprintf(fperf, "op name, kernel name, queued duration (ms), submit duration(ms), exec duration (ms), complete duration (ms), total duration (ms), global size, local size, output size\n");
|
||||
for (const ProfilingInfo & info : g_profiling_info) {
|
||||
total_kernel_time += info.duration_ns/1.e6f;
|
||||
fprintf(fperf, "%s,%s,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n",
|
||||
info.op_name.c_str(), info.kernel_name.c_str(), info.duration_ns/1.e6f,
|
||||
total_kernel_time += info.cmd_duration_ns/1.e6f;
|
||||
fprintf(fperf, "%s,%s,%f,%f,%f,%f,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n",
|
||||
info.op_name.c_str(), info.kernel_name.c_str(),
|
||||
info.cmd_queued_duration_ns/1.e6f,
|
||||
info.cmd_submit_duration_ns/1.e6f,
|
||||
info.cmd_duration_ns/1.e6f,
|
||||
info.cmd_complete_duration_ns/1.e6f,
|
||||
info.cmd_total_duration_ns/1.e6f,
|
||||
info.global_size[0], info.global_size[1], info.global_size[2],
|
||||
info.local_size[0], info.local_size[2], info.local_size[2],
|
||||
info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]);
|
||||
@@ -916,6 +979,27 @@ static void ggml_cl2_free(void) {
|
||||
fclose(fperf);
|
||||
|
||||
GGML_LOG_INFO("ggml_opencl: total kernel time: %f\n", total_kernel_time);
|
||||
|
||||
// Dump a simple chrome trace
|
||||
FILE* ftrace = fopen("cl_trace.json", "w");
|
||||
if (!ftrace) {
|
||||
GGML_LOG_ERROR("Failed to open cl_trace.json\n");
|
||||
return;
|
||||
}
|
||||
|
||||
fprintf(ftrace, "[\n");
|
||||
for (const ProfilingInfo & info : g_profiling_info) {
|
||||
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
|
||||
info.kernel_name.c_str(), info.cmd_queued/1000);
|
||||
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
|
||||
info.kernel_name.c_str(), info.cmd_submit/1000);
|
||||
|
||||
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
|
||||
info.kernel_name.c_str(), info.cmd_start/1000);
|
||||
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
|
||||
info.kernel_name.c_str(), info.cmd_end/1000);
|
||||
}
|
||||
fclose(ftrace);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -2062,25 +2146,14 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso
|
||||
// Profiling utility
|
||||
//------------------------------------------------------------------------------
|
||||
#ifdef GGML_OPENCL_PROFILING
|
||||
void populateProfilingInfo(
|
||||
static void populateProfilingInfo(
|
||||
ProfilingInfo& info, cl_event evt, cl_kernel kernel,
|
||||
size_t global_size[3], size_t local_size[3],
|
||||
const ggml_tensor * tensor) {
|
||||
cl_ulong start;
|
||||
cl_ulong end;
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
CL_CHECK(clGetEventProfilingInfo(
|
||||
evt, CL_PROFILING_COMMAND_START, sizeof(cl_ulong), &start, NULL));
|
||||
CL_CHECK(clGetEventProfilingInfo(
|
||||
evt, CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &end, NULL));
|
||||
info.op_name = tensor->name;
|
||||
info.kernel = kernel;
|
||||
info.evt = evt;
|
||||
|
||||
char kernel_name[512];
|
||||
CL_CHECK(clGetKernelInfo(kernel, CL_KERNEL_FUNCTION_NAME,
|
||||
sizeof(kernel_name), kernel_name, NULL));
|
||||
|
||||
info.duration_ns = end - start;
|
||||
info.op_name = tensor->name;
|
||||
info.kernel_name = kernel_name;
|
||||
info.local_size[0] = local_size[0];
|
||||
info.local_size[1] = local_size[1];
|
||||
info.local_size[2] = local_size[2];
|
||||
|
||||
@@ -23,6 +23,38 @@ ggml_add_backend_library(ggml-sycl
|
||||
../../include/ggml-sycl.h
|
||||
)
|
||||
|
||||
find_package(DNNL)
|
||||
set(GGML_SYCL_DNNL 0)
|
||||
if(DNNL_FOUND)
|
||||
if (DEFINED ENV{ONEAPI_ROOT} AND NOT DEFINED DNNL_GPU_VENDOR)
|
||||
# Assuming oneDNN packaged with oneapi release is used which
|
||||
# supports only intel target
|
||||
set(DNNL_GPU_VENDOR "INTEL")
|
||||
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
|
||||
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Verify oneDNN was compiled for the same target as llama
|
||||
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
|
||||
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
|
||||
set(GGML_SYCL_DNNL 1)
|
||||
get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
|
||||
foreach(CONFIG ${CONFIGS})
|
||||
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
|
||||
message(STATUS "Found oneDNN: ${DNNL_LIB}")
|
||||
endforeach()
|
||||
else()
|
||||
message(WARNING
|
||||
"oneDNN must be compiled for the same target as llama.cpp.
|
||||
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
|
||||
Disabling oneDNN support.")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "oneDNN not found, disabling oneDNN support")
|
||||
endif()
|
||||
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
|
||||
|
||||
if (GGML_SYCL_F16)
|
||||
if (GGML_SYCL_TARGET STREQUAL "AMD")
|
||||
message(WARNING "AMD target does not entirely support FP16 in the SYCL backend.")
|
||||
@@ -48,24 +80,15 @@ file(GLOB GGML_HEADERS_SYCL "*.hpp")
|
||||
file(GLOB GGML_SOURCES_SYCL "*.cpp")
|
||||
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
|
||||
|
||||
find_package(DNNL)
|
||||
message("-- DNNL found:" ${DNNL_FOUND})
|
||||
|
||||
if (GGML_SYCL_TARGET STREQUAL "INTEL")
|
||||
add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND})
|
||||
else()
|
||||
add_compile_definitions(GGML_SYCL_DNNL=0)
|
||||
endif()
|
||||
|
||||
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
|
||||
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
|
||||
endif()
|
||||
|
||||
if (WIN32)
|
||||
find_package(IntelSYCL REQUIRED)
|
||||
find_package(MKL REQUIRED)
|
||||
target_link_libraries(ggml-sycl PRIVATE IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
|
||||
else()
|
||||
if (GGML_SYCL_GRAPH)
|
||||
add_compile_definitions(GGML_SYCL_GRAPH)
|
||||
endif()
|
||||
if (GGML_SYCL_TARGET STREQUAL "INTEL")
|
||||
target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
|
||||
elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
|
||||
|
||||
@@ -26,7 +26,7 @@
|
||||
#include "softmax.hpp"
|
||||
#include "tsembd.hpp"
|
||||
#include "im2col.hpp"
|
||||
#include "wkv6.hpp"
|
||||
#include "wkv.hpp"
|
||||
#include "outprod.hpp"
|
||||
#include "element_wise.hpp"
|
||||
#include "cpy.hpp"
|
||||
|
||||
@@ -170,7 +170,6 @@ static size_t g_scratch_offset = 0;
|
||||
int get_current_device_id();
|
||||
|
||||
inline dpct::err0 ggml_sycl_set_device(const int device) try {
|
||||
|
||||
int current_device_id;
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id()));
|
||||
|
||||
@@ -242,6 +241,14 @@ struct ggml_sycl_pool_alloc {
|
||||
}
|
||||
}
|
||||
|
||||
T * realloc(size_t size) {
|
||||
GGML_ASSERT(pool != nullptr);
|
||||
if (ptr)
|
||||
pool->free(ptr, actual_size);
|
||||
ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
|
||||
return ptr;
|
||||
}
|
||||
|
||||
// size is in number of elements
|
||||
T * alloc(size_t size) {
|
||||
GGML_ASSERT(pool != nullptr);
|
||||
@@ -301,6 +308,7 @@ inline optimize_feature check_gpu_optimize_feature(syclex::architecture &arch) {
|
||||
return opt;
|
||||
}
|
||||
|
||||
namespace sycl_ex = sycl::ext::oneapi::experimental;
|
||||
struct ggml_backend_sycl_context {
|
||||
int device;
|
||||
std::string name;
|
||||
@@ -370,10 +378,29 @@ struct ggml_backend_sycl_context {
|
||||
dnnl::stream stream_dnnl() {
|
||||
return stream_dnnl(device, 0);
|
||||
}
|
||||
dnnl::memory get_scratchpad_mem(const dnnl::memory::desc & scratchpad_md,
|
||||
const dnnl::engine & eng, const queue_ptr q) {
|
||||
ggml_sycl_pool_alloc<uint8_t> * pool;
|
||||
auto it = scratchpad_map.find(q);
|
||||
if (it == scratchpad_map.end()) {
|
||||
scratchpad_map[q] = std::make_unique<ggml_sycl_pool_alloc<uint8_t>>(this->pool());
|
||||
pool = scratchpad_map[q].get();
|
||||
} else {
|
||||
pool = it->second.get();
|
||||
}
|
||||
|
||||
size_t scratchpad_size = scratchpad_md.get_size();
|
||||
if (scratchpad_size > pool->actual_size) {
|
||||
pool->realloc(scratchpad_size);
|
||||
}
|
||||
void * mem_ptr = pool->get();
|
||||
return dnnl::memory(scratchpad_md, eng, mem_ptr);
|
||||
}
|
||||
#endif
|
||||
|
||||
// pool
|
||||
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
|
||||
std::unordered_map<sycl::queue *, std::unique_ptr<ggml_sycl_pool_alloc<uint8_t>>> scratchpad_map;
|
||||
|
||||
std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
|
||||
|
||||
@@ -392,6 +419,10 @@ struct ggml_backend_sycl_context {
|
||||
return pool(device);
|
||||
}
|
||||
|
||||
#ifdef GGML_SYCL_GRAPH
|
||||
std::unique_ptr<sycl_ex::command_graph<sycl_ex::graph_state::executable>> exec_graph = nullptr;
|
||||
#endif
|
||||
|
||||
ggml_sycl_pool & host_pool(int device) {
|
||||
if (host_pools[device] == nullptr) {
|
||||
host_pools[device] = new_pool_for_host(stream(device, 0), device);
|
||||
|
||||
@@ -138,7 +138,7 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
|
||||
sycl::range<3>(1, 1, WARP_SIZE),
|
||||
sycl::range<3>(1, 1, WARP_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
|
||||
dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
|
||||
});
|
||||
|
||||
|
||||
+12
-13
@@ -210,7 +210,7 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
|
||||
nrows, item_ct1);
|
||||
});
|
||||
@@ -879,7 +879,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloa
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -902,7 +902,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -923,7 +923,7 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -944,7 +944,7 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -965,7 +965,7 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -986,7 +986,7 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -1004,7 +1004,7 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
|
||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -1020,7 +1020,7 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
|
||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -1036,7 +1036,7 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
|
||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -1049,7 +1049,7 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
|
||||
const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -1065,7 +1065,7 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
|
||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -1143,7 +1143,6 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
||||
default:
|
||||
printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "common.hpp"
|
||||
#include "element_wise.hpp"
|
||||
|
||||
void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
||||
static void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
||||
const int ne10, const int ne11, const int ne12,
|
||||
const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
@@ -20,7 +20,7 @@ void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
||||
}
|
||||
}
|
||||
|
||||
void gelu_f32(const float * x, float * dst, const int k,
|
||||
static void gelu_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const float GELU_COEF_A = 0.044715f;
|
||||
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
@@ -37,7 +37,7 @@ void gelu_f32(const float * x, float * dst, const int k,
|
||||
sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi)));
|
||||
}
|
||||
|
||||
void silu_f32(const float * x, float * dst, const int k,
|
||||
static void silu_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -48,7 +48,7 @@ void silu_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i]));
|
||||
}
|
||||
|
||||
void gelu_quick_f32(const float *x, float *dst, int k,
|
||||
static void gelu_quick_f32(const float *x, float *dst, int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const float GELU_QUICK_COEF = -1.702f;
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
@@ -59,7 +59,7 @@ void gelu_quick_f32(const float *x, float *dst, int k,
|
||||
dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i])));
|
||||
}
|
||||
|
||||
void tanh_f32(const float *x, float *dst, int k,
|
||||
static void tanh_f32(const float *x, float *dst, int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -69,7 +69,7 @@ void tanh_f32(const float *x, float *dst, int k,
|
||||
dst[i] = sycl::tanh((float)(x[i]));
|
||||
}
|
||||
|
||||
void relu_f32(const float * x, float * dst, const int k,
|
||||
static void relu_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -80,7 +80,7 @@ void relu_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = sycl::fmax((float)(x[i]), (float)0);
|
||||
}
|
||||
|
||||
void sigmoid_f32(const float * x, float * dst, const int k,
|
||||
static void sigmoid_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -91,7 +91,7 @@ void sigmoid_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = 1.0f / (1.0f + sycl::native::exp(-x[i]));
|
||||
}
|
||||
|
||||
void sqrt_f32(const float * x, float * dst, const int k,
|
||||
static void sqrt_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -102,7 +102,7 @@ void sqrt_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = sycl::sqrt(x[i]);
|
||||
}
|
||||
|
||||
void sin_f32(const float * x, float * dst, const int k,
|
||||
static void sin_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -113,7 +113,7 @@ void sin_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = sycl::sin(x[i]);
|
||||
}
|
||||
|
||||
void cos_f32(const float * x, float * dst, const int k,
|
||||
static void cos_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -124,7 +124,7 @@ void cos_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = sycl::cos(x[i]);
|
||||
}
|
||||
|
||||
void hardsigmoid_f32(const float * x, float * dst, const int k,
|
||||
static void hardsigmoid_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -135,7 +135,7 @@ void hardsigmoid_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
void hardswish_f32(const float * x, float * dst, const int k,
|
||||
static void hardswish_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -146,7 +146,7 @@ void hardswish_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
void exp_f32(const float * x, float * dst, const int k,
|
||||
static void exp_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -157,7 +157,7 @@ void exp_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = sycl::exp(x[i]);
|
||||
}
|
||||
|
||||
void log_f32(const float * x, float * dst, const int k,
|
||||
static void log_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -173,7 +173,7 @@ void log_f32(const float * x, float * dst, const int k,
|
||||
}
|
||||
}
|
||||
|
||||
void neg_f32(const float * x, float * dst, const int k,
|
||||
static void neg_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -184,7 +184,7 @@ void neg_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = -x[i];
|
||||
}
|
||||
|
||||
void step_f32(const float * x, float * dst, const int k,
|
||||
static void step_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -195,7 +195,7 @@ void step_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = x[i] > 0.0f;
|
||||
}
|
||||
|
||||
void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
|
||||
static void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -206,7 +206,7 @@ void leaky_relu_f32(const float *x, float *dst, const int k, const float negativ
|
||||
sycl::fmin((float)(x[i]), 0.0f) * negative_slope;
|
||||
}
|
||||
|
||||
void sqr_f32(const float * x, float * dst, const int k,
|
||||
static void sqr_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -217,7 +217,7 @@ void sqr_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = x[i] * x[i];
|
||||
}
|
||||
|
||||
void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
|
||||
static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
|
||||
const int nb02, const int nb03, const int ne10, const int ne11,
|
||||
const int ne12, const int ne13, const float sf0, const float sf1,
|
||||
const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
|
||||
@@ -240,7 +240,7 @@ void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
|
||||
dst[index] = *(const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
|
||||
}
|
||||
|
||||
void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
|
||||
static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
int nidx = item_ct1.get_local_id(2) +
|
||||
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
||||
@@ -262,7 +262,7 @@ void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const i
|
||||
|
||||
|
||||
|
||||
void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||
static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||
const int n_elements, const int ne10, const int ne11,
|
||||
const int ne12, const int nb1, const int nb2,
|
||||
const int offset, queue_ptr stream) {
|
||||
@@ -277,7 +277,7 @@ void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||
});
|
||||
}
|
||||
|
||||
void gelu_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void gelu_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -289,7 +289,7 @@ void gelu_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void silu_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void silu_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -301,7 +301,7 @@ void silu_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -313,7 +313,7 @@ void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void tanh_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void tanh_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -325,7 +325,7 @@ void tanh_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void relu_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void relu_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -337,7 +337,7 @@ void relu_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -349,7 +349,7 @@ void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void hardswish_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void hardswish_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -361,7 +361,7 @@ void hardswish_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void exp_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void exp_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -373,7 +373,7 @@ void exp_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void log_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void log_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -385,7 +385,7 @@ void log_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void neg_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void neg_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -397,7 +397,7 @@ void neg_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void step_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void step_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -409,7 +409,7 @@ void step_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void sigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void sigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -421,7 +421,7 @@ void sigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void sqrt_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void sqrt_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -433,7 +433,7 @@ void sqrt_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void sin_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void sin_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -445,7 +445,7 @@ void sin_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void cos_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void cos_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -457,7 +457,7 @@ void cos_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
|
||||
const float negative_slope,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
|
||||
@@ -470,7 +470,7 @@ void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void sqr_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void sqr_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -482,7 +482,7 @@ void sqr_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
|
||||
static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
|
||||
const int nb02, const int nb03, const int ne10, const int ne11,
|
||||
const int ne12, const int ne13, const float sf0, const float sf1,
|
||||
const float sf2, const float sf3, queue_ptr stream) {
|
||||
@@ -496,7 +496,7 @@ void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01
|
||||
});
|
||||
}
|
||||
|
||||
void pad_f32_sycl(const float *x, float *dst, const int ne00,
|
||||
static void pad_f32_sycl(const float *x, float *dst, const int ne00,
|
||||
const int ne01, const int ne02, const int ne0,
|
||||
const int ne1, const int ne2, queue_ptr stream) {
|
||||
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
|
||||
|
||||
+14
-45
@@ -13,9 +13,6 @@
|
||||
#ifndef GGML_SYCL_GEMM_HPP
|
||||
#define GGML_SYCL_GEMM_HPP
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
#include "ggml-sycl.h"
|
||||
|
||||
#if GGML_SYCL_DNNL
|
||||
@@ -35,62 +32,34 @@ public:
|
||||
else static_assert(0);
|
||||
}
|
||||
|
||||
static inline void row_gemm(sycl::queue& q, bool a_trans,
|
||||
bool b_trans, int m, int n, int k,
|
||||
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
|
||||
{
|
||||
// Get the device associated with the queue
|
||||
sycl::device dev = q.get_device();
|
||||
// Get the context associated with the queue
|
||||
sycl::context ctx = q.get_context();
|
||||
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
|
||||
const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
|
||||
static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
|
||||
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
|
||||
auto stream = ctx.stream_dnnl(q);
|
||||
auto eng = ctx.engine_dnnl(q);
|
||||
dnnl::memory::dims a_dims = { m, k };
|
||||
dnnl::memory::dims b_dims = { k, n };
|
||||
dnnl::memory::dims c_dims = { m, n };
|
||||
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
||||
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
||||
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
||||
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
||||
|
||||
dnnl::primitive_attr primitive_attr;
|
||||
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
|
||||
auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
|
||||
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
|
||||
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md, primitive_attr);
|
||||
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
|
||||
|
||||
// Create the primitive.
|
||||
auto scratchpad_md = matmul_pd.scratchpad_desc();
|
||||
auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);
|
||||
auto matmul_prim = dnnl::matmul(matmul_pd);
|
||||
// Primitive arguments.
|
||||
std::unordered_map<int, dnnl::memory> matmul_args;
|
||||
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
|
||||
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
|
||||
matmul_args.insert({ DNNL_ARG_DST, c_mem });
|
||||
|
||||
matmul_prim.execute(stream, matmul_args);
|
||||
}
|
||||
|
||||
|
||||
static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
|
||||
bool b_trans, int m, int n, int k,
|
||||
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
|
||||
{
|
||||
auto const eng = stream.get_engine();
|
||||
dnnl::memory::dims a_dims = { m, k };
|
||||
dnnl::memory::dims b_dims = { k, n };
|
||||
dnnl::memory::dims c_dims = { m, n };
|
||||
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
||||
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
||||
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
||||
auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
|
||||
auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
|
||||
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
|
||||
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
|
||||
|
||||
// Create the primitive.
|
||||
auto matmul_prim = dnnl::matmul(matmul_pd);
|
||||
// Primitive arguments.
|
||||
|
||||
std::unordered_map<int, dnnl::memory> matmul_args;
|
||||
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
|
||||
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
|
||||
matmul_args.insert({ DNNL_ARG_DST, c_mem });
|
||||
matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
|
||||
|
||||
matmul_prim.execute(stream, matmul_args);
|
||||
}
|
||||
|
||||
@@ -207,7 +207,7 @@ static void get_rows_sycl_reorder(ggml_backend_sycl_context & ctx, const ggml_te
|
||||
const size_t nrows = ne01;
|
||||
const sycl::half* src0_dq = (const sycl::half*)(src0_q + nrows * ncols / 2);
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
|
||||
k_get_rows_reorder<qk, qr, dq_reorder>(
|
||||
src0_dd, src0_dq, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
||||
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
||||
@@ -302,7 +302,6 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *s
|
||||
// TODO: k-quants
|
||||
GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@
|
||||
static bool g_sycl_loaded = false;
|
||||
int g_ggml_sycl_debug = 0;
|
||||
int g_ggml_sycl_disable_optimize = 0;
|
||||
int g_ggml_sycl_disable_graph = 0;
|
||||
|
||||
static ggml_sycl_device_info ggml_sycl_init() {
|
||||
ggml_sycl_device_info info = {};
|
||||
@@ -95,7 +96,7 @@ const ggml_sycl_device_info & ggml_sycl_info() {
|
||||
return info;
|
||||
}
|
||||
|
||||
void print_device_detail(int id, sycl::device &device, std::string device_type) {
|
||||
static void print_device_detail(int id, sycl::device &device, std::string device_type) {
|
||||
|
||||
dpct::device_info prop;
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(
|
||||
@@ -118,7 +119,7 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
|
||||
global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
|
||||
}
|
||||
|
||||
void print_device_opt_feature(int device_count) {
|
||||
static void print_device_opt_feature(int device_count) {
|
||||
GGML_LOG_INFO("SYCL Optimization Feature:\n");
|
||||
GGML_LOG_INFO(
|
||||
"|ID| Device Type|Reorder|\n");
|
||||
@@ -190,11 +191,13 @@ static void ggml_check_sycl() try {
|
||||
|
||||
if (!initialized) {
|
||||
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
|
||||
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 0);
|
||||
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
|
||||
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
|
||||
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
|
||||
GGML_LOG_INFO("Running with Environment Variables:\n");
|
||||
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
|
||||
GGML_LOG_INFO("Build with Macros:\n");
|
||||
#if defined(GGML_SYCL_FORCE_MMQ)
|
||||
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
|
||||
@@ -401,7 +404,7 @@ catch (sycl::exception const &exc) {
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
|
||||
static void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
|
||||
const void *ptr_src, size_t size) {
|
||||
char *host_buf = (char *)malloc(size);
|
||||
q_src.memcpy(host_buf, (const char *)ptr_src, size).wait();
|
||||
@@ -620,7 +623,7 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
|
||||
return &ggml_backend_sycl_buffer_types[device];
|
||||
}
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
|
||||
static ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
|
||||
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
|
||||
|
||||
int device = ctx->device;
|
||||
@@ -1682,7 +1685,7 @@ static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(num_blocks * block_size, block_size),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -1703,7 +1706,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x,
|
||||
nchannels_y, item_ct1);
|
||||
});
|
||||
@@ -1723,7 +1726,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
|
||||
row_stride_x, channel_stride_x,
|
||||
nchannels_y / nchannels_x, item_ct1);
|
||||
@@ -1764,7 +1767,7 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
||||
const sycl::range<3> block_nums(1, nrows, 1);
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
k_sum_rows_f32(x, dst, ncols, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -2055,9 +2058,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
||||
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
||||
#else
|
||||
auto dnnl_stream = ctx.stream_dnnl(stream);
|
||||
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
||||
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
|
||||
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
|
||||
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
||||
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
|
||||
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
||||
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
|
||||
#endif
|
||||
@@ -2096,9 +2099,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||
dst_dd_i, ldc)));
|
||||
# endif
|
||||
#else
|
||||
auto dnnl_stream = ctx.stream_dnnl(stream);
|
||||
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
|
||||
src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
|
||||
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
|
||||
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
|
||||
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
||||
#endif
|
||||
}
|
||||
GGML_UNUSED(dst);
|
||||
@@ -2696,6 +2699,12 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
|
||||
@@ -2914,7 +2923,7 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ggml_sycl_supports_dmmv(enum ggml_type type) {
|
||||
static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
@@ -3287,7 +3296,7 @@ static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
|
||||
|
||||
void ggml_sycl_set_main_device(const int main_device) try {
|
||||
static void ggml_sycl_set_main_device(const int main_device) try {
|
||||
if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
|
||||
return;
|
||||
}
|
||||
@@ -3308,7 +3317,7 @@ catch (sycl::exception const &exc) {
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) {
|
||||
static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) {
|
||||
if (!g_sycl_loaded) return false;
|
||||
|
||||
if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
|
||||
@@ -3410,6 +3419,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
||||
case GGML_OP_RMS_NORM:
|
||||
ggml_sycl_rms_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_L2_NORM:
|
||||
ggml_sycl_l2_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
|
||||
return false;
|
||||
@@ -3487,6 +3499,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
ggml_sycl_op_rwkv_wkv6(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
ggml_sycl_op_rwkv_wkv7(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
ggml_sycl_op_gated_linear_attn(ctx, dst);
|
||||
break;
|
||||
@@ -3626,7 +3641,7 @@ catch (sycl::exception const &exc) {
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
void reorder_qw(char *data_device, const int ncols, const int nrows,
|
||||
static void reorder_qw(char *data_device, const int ncols, const int nrows,
|
||||
size_t size, size_t offset, dpct::queue_ptr stream) {
|
||||
auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
|
||||
SYCL_CHECK(
|
||||
@@ -3640,7 +3655,7 @@ void reorder_qw(char *data_device, const int ncols, const int nrows,
|
||||
|
||||
stream->parallel_for(
|
||||
size / sizeof(block_q4_0),
|
||||
[=](auto i) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
const block_q4_0* x = (const block_q4_0*)tmp_buf;
|
||||
const int ib = i;
|
||||
|
||||
@@ -3654,7 +3669,7 @@ void reorder_qw(char *data_device, const int ncols, const int nrows,
|
||||
sycl::free(tmp_buf, *stream);
|
||||
}
|
||||
|
||||
void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||
static void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||
char*data_device = (char*)src0->data;
|
||||
size_t ncols = src0->ne[0];
|
||||
size_t nrows = src0->ne[1];
|
||||
@@ -3663,7 +3678,7 @@ void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||
reorder_qw(data_device, ncols, nrows, size, 0, stream);
|
||||
}
|
||||
|
||||
void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
|
||||
static void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
|
||||
ggml_tensor *src0 = dst->src[0];
|
||||
ggml_tensor *src1 = dst->src[1];
|
||||
|
||||
@@ -3676,7 +3691,7 @@ void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
|
||||
}
|
||||
}
|
||||
|
||||
void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx) {
|
||||
static void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx) {
|
||||
dpct::queue_ptr stream = ctx->stream();
|
||||
if (ctx->optimized_graph) {
|
||||
return;
|
||||
@@ -3687,10 +3702,9 @@ void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx)
|
||||
if (ctx->opt_feature.reorder) opt_for_reorder(cgraph->nodes[i], stream);
|
||||
}
|
||||
}
|
||||
static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
|
||||
ggml_sycl_set_main_device(sycl_ctx->device);
|
||||
|
||||
static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * sycl_ctx, ggml_cgraph * cgraph) {
|
||||
ggml_sycl_set_main_device(sycl_ctx->device);
|
||||
if (!g_ggml_sycl_disable_optimize) optimize_graph_once(cgraph, sycl_ctx);
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
@@ -3712,7 +3726,46 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
|
||||
}
|
||||
GGML_ASSERT(ok);
|
||||
}
|
||||
}
|
||||
|
||||
static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
auto * sycl_ctx = static_cast<ggml_backend_sycl_context *>(backend->context);
|
||||
|
||||
#ifdef GGML_SYCL_GRAPH
|
||||
if (!g_ggml_sycl_disable_graph) {
|
||||
if (!sycl_ctx->exec_graph && !dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph)) {
|
||||
GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device);
|
||||
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()));
|
||||
model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
|
||||
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
|
||||
model_sycl_graph.end_recording();
|
||||
|
||||
if (!sycl_ctx->exec_graph) {
|
||||
auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});
|
||||
sycl_ctx->exec_graph = std::make_unique<
|
||||
sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
|
||||
} else {
|
||||
try {
|
||||
sycl_ctx->exec_graph->update(model_sycl_graph);
|
||||
GGML_SYCL_DEBUG("[SYCL-GRAPH] update success\n");
|
||||
} catch (sycl::exception const & e) {
|
||||
GGML_SYCL_DEBUG("[SYCL-GRAPH] Exception when updating graph, %s\n", e.what());
|
||||
auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});
|
||||
sycl_ctx->exec_graph = std::make_unique<
|
||||
sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
|
||||
}
|
||||
}
|
||||
|
||||
sycl_ctx->stream()->ext_oneapi_graph(*(sycl_ctx->exec_graph));
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
|
||||
}
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
@@ -3866,7 +3919,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
} break;
|
||||
}
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
case GGML_UNARY_OP_NEG:
|
||||
@@ -3884,7 +3937,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
@@ -3915,7 +3967,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
} break;
|
||||
}
|
||||
case GGML_OP_OUT_PROD:
|
||||
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
|
||||
case GGML_OP_GET_ROWS:
|
||||
@@ -3932,7 +3984,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
} break;
|
||||
}
|
||||
case GGML_OP_CPY:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
@@ -3983,12 +4035,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
} break;
|
||||
}
|
||||
case GGML_OP_CONCAT:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
||||
} break;
|
||||
}
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_NONE:
|
||||
@@ -4012,6 +4064,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return (op->src[0]->type == GGML_TYPE_F32);
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_SCALE:
|
||||
@@ -4045,6 +4098,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
return true;
|
||||
default:
|
||||
|
||||
@@ -3017,7 +3017,6 @@ void ggml_sycl_op_mul_mat_q(
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
|
||||
+19
-20
@@ -495,7 +495,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
|
||||
VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@@ -519,7 +519,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
|
||||
VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@@ -543,7 +543,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
|
||||
VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@@ -567,7 +567,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
|
||||
VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@@ -591,7 +591,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
|
||||
VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@@ -615,7 +615,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
|
||||
VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@@ -639,7 +639,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
|
||||
VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@@ -663,7 +663,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
|
||||
VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@@ -687,7 +687,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
|
||||
VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@@ -711,7 +711,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
|
||||
VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@@ -734,7 +734,7 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -755,7 +755,7 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -777,7 +777,7 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -799,7 +799,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -821,7 +821,7 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -843,7 +843,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -864,7 +864,7 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -886,7 +886,7 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -908,7 +908,7 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -1003,7 +1003,6 @@ void ggml_sycl_op_mul_mat_vec_q(
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
GGML_UNUSED(src1);
|
||||
|
||||
+114
-6
@@ -180,6 +180,50 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
|
||||
}
|
||||
}
|
||||
|
||||
static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
|
||||
const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
|
||||
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
||||
item_ct1.get_local_id(1);
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int nthreads = item_ct1.get_local_range(2);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[row * ncols + col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||
if (block_size > WARP_SIZE) {
|
||||
|
||||
int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
||||
int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = tmp;
|
||||
}
|
||||
/*
|
||||
DPCT1118:3: SYCL group functions and algorithms must be encountered in
|
||||
converged control flow. You may need to adjust the code.
|
||||
*/
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
size_t nreduce = nwarps / WARP_SIZE;
|
||||
tmp = 0.f;
|
||||
for (size_t i = 0; i < nreduce; i += 1)
|
||||
{
|
||||
tmp += s_sum[lane_id + i * WARP_SIZE];
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||
}
|
||||
|
||||
const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row * ncols + col] = scale * x[row * ncols + col];
|
||||
}
|
||||
}
|
||||
|
||||
static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
const int nrows, const float eps,
|
||||
queue_ptr stream, int device) {
|
||||
@@ -191,7 +235,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
nullptr, WARP_SIZE);
|
||||
});
|
||||
@@ -214,7 +258,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
@@ -233,7 +277,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
group_norm_f32(
|
||||
x, dst, group_size, ne_elements, eps_ct4, item_ct1,
|
||||
nullptr, WARP_SIZE);
|
||||
@@ -260,7 +304,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
group_norm_f32(x, dst, group_size, ne_elements,
|
||||
eps_ct4, item_ct1,
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
@@ -281,7 +325,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
rms_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
nullptr, WARP_SIZE);
|
||||
});
|
||||
@@ -303,7 +347,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
rms_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
@@ -311,6 +355,48 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
}
|
||||
}
|
||||
|
||||
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
const int nrows, const float eps,
|
||||
queue_ptr stream, int device) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
||||
if (ncols < 1024) {
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
nullptr, WARP_SIZE);
|
||||
});
|
||||
});
|
||||
}
|
||||
else {
|
||||
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
||||
const sycl::range<3> block_dims(1, 1, work_group_size);
|
||||
/*
|
||||
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
|
||||
the limit. To get the device limit, query
|
||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||
*/
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
||||
cgh);
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
|
||||
ggml_tensor* dst, const float* src0_dd,
|
||||
const float* src1_dd, float* dst_dd,
|
||||
@@ -376,3 +462,25 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
|
||||
(void)dst;
|
||||
(void)src1_dd;
|
||||
}
|
||||
|
||||
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||
const ggml_tensor* src1, ggml_tensor* dst,
|
||||
const float* src0_dd, const float* src1_dd,
|
||||
float* dst_dd,
|
||||
const queue_ptr& main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
||||
|
||||
(void)src1;
|
||||
(void)dst;
|
||||
(void)src1_dd;
|
||||
}
|
||||
|
||||
@@ -32,4 +32,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
|
||||
float* dst_dd,
|
||||
const queue_ptr& main_stream);
|
||||
|
||||
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||
const ggml_tensor* src1, ggml_tensor* dst,
|
||||
const float* src0_dd, const float* src1_dd,
|
||||
float* dst_dd,
|
||||
const queue_ptr& main_stream);
|
||||
|
||||
#endif // GGML_SYCL_NORM_HPP
|
||||
|
||||
@@ -132,7 +132,7 @@ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst,
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
|
||||
nrows_y, scale, max_bias, m0,
|
||||
m1, n_head_log2, item_ct1,
|
||||
|
||||
@@ -0,0 +1,305 @@
|
||||
#include <sycl/sycl.hpp>
|
||||
#include "wkv.hpp"
|
||||
|
||||
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
|
||||
|
||||
// Helper function for the main kernel
|
||||
template <int block_size>
|
||||
static void rwkv_wkv6_f32_kernel(
|
||||
const int B, const int T, const int C, const int H,
|
||||
const float* k, const float* v, const float* r,
|
||||
const float* tf, const float* td, const float* s,
|
||||
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int bid = item_ct1.get_group(2);
|
||||
|
||||
const int head_size = block_size;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
// Set up shared memory pointers
|
||||
float* _k = shared_mem;
|
||||
float* _r = _k + head_size;
|
||||
float* _tf = _r + head_size;
|
||||
float* _td = _tf + head_size;
|
||||
|
||||
// Local state array
|
||||
float state[block_size];
|
||||
|
||||
// Load initial state
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||
}
|
||||
|
||||
// Sync threads before shared memory operations
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Load time-mixing parameters
|
||||
_tf[tid] = tf[head_i * head_size + tid];
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Main sequence processing loop
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
||||
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
||||
t += C) {
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Load current timestep data to shared memory
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0;
|
||||
|
||||
// Process in chunks of 4 for better vectorization
|
||||
sycl::float4 k4, r4, tf4, td4, s4;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
// Load data in vec4 chunks
|
||||
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
||||
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
||||
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
// Compute key-value product
|
||||
sycl::float4 kv4 = k4 * _v;
|
||||
|
||||
// Accumulate weighted sum
|
||||
y += sycl::dot(r4, tf4 * kv4 + s4);
|
||||
|
||||
// Update state
|
||||
s4 = s4 * td4 + kv4;
|
||||
|
||||
// Store updated state
|
||||
state[j] = s4.x();
|
||||
state[j+1] = s4.y();
|
||||
state[j+2] = s4.z();
|
||||
state[j+3] = s4.w();
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
// Save final state
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
static void rwkv_wkv7_f32_kernel(
|
||||
const int B, const int T, const int C, const int H,
|
||||
const float* r, const float* w, const float* k, const float* v,
|
||||
const float* a, const float* b, const float* s,
|
||||
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int bid = item_ct1.get_group(2);
|
||||
|
||||
const int head_size = block_size;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
float* _r = shared_mem;
|
||||
float* _w = _r + head_size;
|
||||
float* _k = _w + head_size;
|
||||
float* _a = _k + head_size;
|
||||
float* _b = _a + head_size;
|
||||
|
||||
float state[block_size];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
|
||||
}
|
||||
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
||||
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
||||
t += C) {
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
_r[tid] = r[t];
|
||||
_w[tid] = w[t];
|
||||
_k[tid] = k[t];
|
||||
_a[tid] = a[t];
|
||||
_b[tid] = b[t];
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0, sa = 0;
|
||||
sycl::float4 a4, s4;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
a4 = sycl::float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
||||
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
sa += sycl::dot(a4, s4);
|
||||
}
|
||||
|
||||
sycl::float4 r4, w4, k4, b4;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
w4 = sycl::float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
||||
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
b4 = sycl::float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
||||
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
sycl::float4 kv4 = k4 * _v;
|
||||
|
||||
s4 = s4 * w4 + kv4 + sa * b4;
|
||||
y += sycl::dot(r4, s4);
|
||||
|
||||
state[j] = s4.x();
|
||||
state[j+1] = s4.y();
|
||||
state[j+2] = s4.z();
|
||||
state[j+3] = s4.w();
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
const ggml_tensor *src0 = dst->src[0];
|
||||
const ggml_tensor *src1 = dst->src[1];
|
||||
|
||||
const float* k_d = (const float*)dst->src[0]->data;
|
||||
const float* v_d = (const float*)dst->src[1]->data;
|
||||
const float* r_d = (const float*)dst->src[2]->data;
|
||||
const float* tf_d = (const float*)dst->src[3]->data;
|
||||
const float* td_d = (const float*)dst->src[4]->data;
|
||||
const float* s_d = (const float*)dst->src[5]->data;
|
||||
float* dst_d = (float*)dst->data;
|
||||
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
// Calculate execution configuration
|
||||
const size_t shared_mem_size = C / H * 4 * sizeof(float); // For k, r, tf, td
|
||||
sycl::range<3> block_dims(1, 1, C / H);
|
||||
sycl::range<3> grid_dims(1, 1, B * H);
|
||||
|
||||
// Submit kernel
|
||||
if (C / H == WKV_BLOCK_SIZE) {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
|
||||
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
);
|
||||
});
|
||||
});
|
||||
} else {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
|
||||
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
GGML_UNUSED(src0);
|
||||
GGML_UNUSED(src1);
|
||||
}
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
const ggml_tensor *src0 = dst->src[0];
|
||||
const ggml_tensor *src1 = dst->src[1];
|
||||
|
||||
const float* r_d = (const float*)dst->src[0]->data;
|
||||
const float* w_d = (const float*)dst->src[1]->data;
|
||||
const float* k_d = (const float*)dst->src[2]->data;
|
||||
const float* v_d = (const float*)dst->src[3]->data;
|
||||
const float* a_d = (const float*)dst->src[4]->data;
|
||||
const float* b_d = (const float*)dst->src[5]->data;
|
||||
const float* s_d = (const float*)dst->src[6]->data;
|
||||
float* dst_d = (float*)dst->data;
|
||||
|
||||
const int64_t B = dst->src[6]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2);
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
// Calculate execution configuration
|
||||
const size_t shared_mem_size = C / H * 5 * sizeof(float); // For r, w, k, a, b
|
||||
sycl::range<3> block_dims(1, 1, C / H);
|
||||
sycl::range<3> grid_dims(1, 1, B * H);
|
||||
|
||||
// Submit kernel
|
||||
if (C / H == WKV_BLOCK_SIZE) {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
|
||||
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
);
|
||||
});
|
||||
});
|
||||
} else {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
|
||||
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
GGML_UNUSED(src0);
|
||||
GGML_UNUSED(src1);
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
#ifndef GGML_SYCL_WKV_HPP
|
||||
#define GGML_SYCL_WKV_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_WKV_HPP
|
||||
@@ -1,143 +0,0 @@
|
||||
#include <sycl/sycl.hpp>
|
||||
#include "wkv6.hpp"
|
||||
|
||||
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
|
||||
|
||||
// Helper function for the main kernel
|
||||
static void rwkv_wkv_f32_kernel(
|
||||
const int B, const int T, const int C, const int H,
|
||||
const float* k, const float* v, const float* r,
|
||||
const float* tf, const float* td, const float* s,
|
||||
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int bid = item_ct1.get_group(2);
|
||||
|
||||
const int head_size = WKV_BLOCK_SIZE;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
// Set up shared memory pointers
|
||||
float* _k = shared_mem;
|
||||
float* _r = _k + head_size;
|
||||
float* _tf = _r + head_size;
|
||||
float* _td = _tf + head_size;
|
||||
|
||||
// Local state array
|
||||
float state[WKV_BLOCK_SIZE];
|
||||
|
||||
// Load initial state
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||
}
|
||||
|
||||
// Sync threads before shared memory operations
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Load time-mixing parameters
|
||||
_tf[tid] = tf[head_i * head_size + tid];
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Main sequence processing loop
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
||||
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
||||
t += C) {
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Load current timestep data to shared memory
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0;
|
||||
|
||||
// Process in chunks of 4 for better vectorization
|
||||
sycl::float4 k4, r4, tf4, td4, s4;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
// Load data in vec4 chunks
|
||||
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
||||
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
||||
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
// Compute key-value product
|
||||
sycl::float4 kv4 = k4 * _v;
|
||||
|
||||
// Accumulate weighted sum
|
||||
y += sycl::dot(r4, tf4 * kv4 + s4);
|
||||
|
||||
// Update state
|
||||
s4 = s4 * td4 + kv4;
|
||||
|
||||
// Store updated state
|
||||
state[j] = s4.x();
|
||||
state[j+1] = s4.y();
|
||||
state[j+2] = s4.z();
|
||||
state[j+3] = s4.w();
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
// Save final state
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
const ggml_tensor *src0 = dst->src[0];
|
||||
const ggml_tensor *src1 = dst->src[1];
|
||||
|
||||
const float* k_d = (const float*)dst->src[0]->data;
|
||||
const float* v_d = (const float*)dst->src[1]->data;
|
||||
const float* r_d = (const float*)dst->src[2]->data;
|
||||
const float* tf_d = (const float*)dst->src[3]->data;
|
||||
const float* td_d = (const float*)dst->src[4]->data;
|
||||
const float* s_d = (const float*)dst->src[5]->data;
|
||||
float* dst_d = (float*)dst->data;
|
||||
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
// Calculate execution configuration
|
||||
const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
|
||||
sycl::range<3> block_dims(1, 1, C / H);
|
||||
sycl::range<3> grid_dims(1, 1, B * H);
|
||||
|
||||
// Submit kernel
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv_f32_kernel(
|
||||
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
GGML_UNUSED(src0);
|
||||
GGML_UNUSED(src1);
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
#ifndef GGML_SYCL_WKV6_HPP
|
||||
#define GGML_SYCL_WKV6_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
|
||||
#endif // GGML_SYCL_WKV6_HPP
|
||||
@@ -149,6 +149,67 @@ class vk_perf_logger;
|
||||
static void ggml_vk_destroy_buffer(vk_buffer& buf);
|
||||
|
||||
static constexpr uint32_t mul_mat_vec_max_cols = 8;
|
||||
static constexpr uint32_t p021_max_gqa_ratio = 8;
|
||||
|
||||
enum vk_device_architecture {
|
||||
OTHER,
|
||||
AMD_GCN,
|
||||
AMD_RDNA1,
|
||||
AMD_RDNA2,
|
||||
AMD_RDNA3,
|
||||
};
|
||||
|
||||
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
|
||||
vk::PhysicalDeviceProperties props = device.getProperties();
|
||||
|
||||
if (props.vendorID == VK_VENDOR_ID_AMD) {
|
||||
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
|
||||
|
||||
bool amd_shader_core_properties = false;
|
||||
bool integer_dot_product = false;
|
||||
bool subgroup_size_control = false;
|
||||
|
||||
for (const auto& properties : ext_props) {
|
||||
if (strcmp("VK_AMD_shader_core_properties", properties.extensionName) == 0) {
|
||||
amd_shader_core_properties = true;
|
||||
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0) {
|
||||
integer_dot_product = true;
|
||||
} else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
|
||||
subgroup_size_control = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) {
|
||||
return vk_device_architecture::OTHER;
|
||||
}
|
||||
|
||||
vk::PhysicalDeviceProperties2 props2;
|
||||
vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd;
|
||||
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props;
|
||||
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
|
||||
|
||||
props2.pNext = &shader_core_props_amd;
|
||||
shader_core_props_amd.pNext = &integer_dot_props;
|
||||
integer_dot_props.pNext = &subgroup_size_control_props;
|
||||
|
||||
device.getProperties2(&props2);
|
||||
|
||||
if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64) {
|
||||
return vk_device_architecture::AMD_GCN;
|
||||
}
|
||||
if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32) {
|
||||
// RDNA
|
||||
if (shader_core_props_amd.wavefrontsPerSimd == 20) {
|
||||
return vk_device_architecture::AMD_RDNA1;
|
||||
}
|
||||
if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated) {
|
||||
return vk_device_architecture::AMD_RDNA3;
|
||||
}
|
||||
return vk_device_architecture::AMD_RDNA2;
|
||||
}
|
||||
}
|
||||
return vk_device_architecture::OTHER;
|
||||
}
|
||||
|
||||
struct vk_device_struct {
|
||||
std::mutex mutex;
|
||||
@@ -162,6 +223,7 @@ struct vk_device_struct {
|
||||
bool pipeline_robustness;
|
||||
vk::Device device;
|
||||
uint32_t vendor_id;
|
||||
vk_device_architecture architecture;
|
||||
vk_queue compute_queue;
|
||||
vk_queue transfer_queue;
|
||||
bool single_queue;
|
||||
@@ -170,6 +232,7 @@ struct vk_device_struct {
|
||||
bool uma;
|
||||
bool prefer_host_memory;
|
||||
bool float_controls_rte_fp16;
|
||||
bool subgroup_add;
|
||||
|
||||
bool subgroup_size_control;
|
||||
uint32_t subgroup_min_size;
|
||||
@@ -216,7 +279,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
||||
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
|
||||
|
||||
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
|
||||
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
|
||||
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
|
||||
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
|
||||
@@ -243,6 +306,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_group_norm_f32;
|
||||
vk_pipeline pipeline_rms_norm_f32;
|
||||
vk_pipeline pipeline_rms_norm_back_f32;
|
||||
vk_pipeline pipeline_l2_norm_f32;
|
||||
vk_pipeline pipeline_gelu_f32;
|
||||
vk_pipeline pipeline_gelu_quick_f32;
|
||||
vk_pipeline pipeline_silu_f32;
|
||||
@@ -267,6 +331,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_timestep_embedding_f32;
|
||||
vk_pipeline pipeline_pool2d_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv6_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv7_f32;
|
||||
vk_pipeline pipeline_opt_step_adamw_f32;
|
||||
|
||||
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
||||
@@ -568,6 +633,13 @@ struct vk_op_rwkv_wkv6_push_constants {
|
||||
uint32_t H;
|
||||
};
|
||||
|
||||
struct vk_op_rwkv_wkv7_push_constants {
|
||||
uint32_t B;
|
||||
uint32_t T;
|
||||
uint32_t C;
|
||||
uint32_t H;
|
||||
};
|
||||
|
||||
// Allow pre-recording command buffers
|
||||
struct vk_staging_memcpy {
|
||||
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
||||
@@ -1448,6 +1520,73 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
||||
return supported;
|
||||
}
|
||||
|
||||
struct GpuPipelineConfig {
|
||||
// GPU architecture identifier.
|
||||
// Example: vk_device_architecture::AMD_GCN
|
||||
vk_device_architecture arch;
|
||||
|
||||
// Mapping of pipeline names to their specific subgroup sizes.
|
||||
// Example: {"soft_max_f32", 64}
|
||||
std::unordered_map<std::string, uint32_t> pipelines;
|
||||
|
||||
// Default subgroup size for this GPU.
|
||||
// Defaults to 0 if not explicitly provided.
|
||||
uint32_t default_subgroup_size = 0;
|
||||
};
|
||||
|
||||
// Pipeline configuration for RDNA1 GPUs.
|
||||
static const std::unordered_map<std::string, uint32_t> rdna1_pipelines = {
|
||||
{"soft_max", 64}, {"im2col", 64},
|
||||
{"argmax", 64}, {"mul_mat_vec", 64},
|
||||
{"mul_mat_vec_f16", 32}, {"mul_mat_vec_f32_f16", 32}
|
||||
};
|
||||
|
||||
// Pipeline configuration for RDNA2 GPUs.
|
||||
static const std::unordered_map<std::string, uint32_t> rdna2_pipelines = {
|
||||
{"soft_max", 64}, {"im2col", 64},
|
||||
};
|
||||
|
||||
static constexpr uint32_t RDNA_DEFAULT_SUBGROUP_SIZE = 32;
|
||||
|
||||
// Define configurations for different GPUs.
|
||||
static std::vector<GpuPipelineConfig> gpu_pipeline_configs = {
|
||||
{
|
||||
vk_device_architecture::AMD_RDNA1,
|
||||
{
|
||||
rdna1_pipelines,
|
||||
},
|
||||
RDNA_DEFAULT_SUBGROUP_SIZE
|
||||
},
|
||||
{
|
||||
vk_device_architecture::AMD_RDNA2,
|
||||
{
|
||||
rdna2_pipelines,
|
||||
},
|
||||
RDNA_DEFAULT_SUBGROUP_SIZE
|
||||
},
|
||||
};
|
||||
|
||||
static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_device_architecture &arch) {
|
||||
for (const auto &config : gpu_pipeline_configs) {
|
||||
if (config.arch == arch) {
|
||||
auto pipIt = config.pipelines.find(pipeline_name);
|
||||
if (pipIt != config.pipelines.end()) {
|
||||
return pipIt->second;
|
||||
}
|
||||
std::vector<std::pair<std::string, uint32_t>> sorted_pipelines(config.pipelines.begin(), config.pipelines.end());
|
||||
std::sort(sorted_pipelines.begin(), sorted_pipelines.end(),
|
||||
[](const auto &a, const auto &b) { return a.first.size() > b.first.size(); });
|
||||
for (const auto &entry : sorted_pipelines) {
|
||||
if (pipeline_name.find(entry.first) != std::string::npos) {
|
||||
return entry.second;
|
||||
}
|
||||
}
|
||||
return config.default_subgroup_size;
|
||||
}
|
||||
}
|
||||
return 0; // If no matching configuration is found
|
||||
}
|
||||
|
||||
static void ggml_vk_load_shaders(vk_device& device) {
|
||||
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
|
||||
|
||||
@@ -1469,33 +1608,33 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
uint32_t l_align, m_align, s_align;
|
||||
if (device->coopmat2) {
|
||||
// spec constants and tile sizes for non-quant matmul/matmul_id
|
||||
l_warptile = { 256, 128, 256, 64 };
|
||||
m_warptile = { 256, 128, 128, 64 };
|
||||
s_warptile = { 128, 64, 64, 64 };
|
||||
l_warptile = { 256, 128, 256, 64, 1 };
|
||||
m_warptile = { 256, 128, 128, 64, 0 };
|
||||
s_warptile = { 128, 64, 64, 64, 0 };
|
||||
l_wg_denoms = {128, 256, 1 };
|
||||
m_wg_denoms = {128, 128, 1 };
|
||||
s_wg_denoms = { 64, 64, 1 };
|
||||
|
||||
// spec constants and tile sizes for quant matmul (non-Qi_K)
|
||||
l_warptile_mmq = { 256, 128, 256, 64 };
|
||||
m_warptile_mmq = { 256, 128, 128, 64 };
|
||||
s_warptile_mmq = { 256, 32, 64, 128 };
|
||||
l_warptile_mmq = { 256, 128, 256, 64, 1 };
|
||||
m_warptile_mmq = { 256, 128, 128, 64, 1 };
|
||||
s_warptile_mmq = { 256, 32, 64, 128, 0 };
|
||||
l_mmq_wg_denoms = { 128, 256, 1 };
|
||||
m_mmq_wg_denoms = { 128, 128, 1 };
|
||||
s_mmq_wg_denoms = { 32, 64, 1 };
|
||||
|
||||
// spec constants and tile sizes for quant matmul (Qi_K)
|
||||
l_warptile_mmq_k = { 256, 64, 128, 64 };
|
||||
m_warptile_mmq_k = { 256, 32, 64, 64 };
|
||||
s_warptile_mmq_k = { 256, 32, 32, 128 };
|
||||
l_warptile_mmq_k = { 256, 64, 128, 64, 1 };
|
||||
m_warptile_mmq_k = { 256, 32, 64, 64, 0 };
|
||||
s_warptile_mmq_k = { 256, 32, 32, 128, 0 };
|
||||
l_mmq_wg_denoms_k = { 64, 128, 1 };
|
||||
m_mmq_wg_denoms_k = { 32, 64, 1 };
|
||||
s_mmq_wg_denoms_k = { 32, 32, 1 };
|
||||
|
||||
// spec constants and tile sizes for quant matmul_id
|
||||
l_warptile_mmqid = { 256, 128, 64, 16 };
|
||||
m_warptile_mmqid = { 256, 128, 64, 16 };
|
||||
s_warptile_mmqid = { 256, 128, 64, 16 };
|
||||
l_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
||||
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
||||
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
||||
l_mmqid_wg_denoms = { 128, 64, 1 };
|
||||
m_mmqid_wg_denoms = { 128, 64, 1 };
|
||||
s_mmqid_wg_denoms = { 128, 64, 1 };
|
||||
@@ -1574,6 +1713,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
|
||||
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
|
||||
|
||||
if (!require_full_subgroups && required_subgroup_size == 0) {
|
||||
required_subgroup_size = get_subgroup_size(name, device->architecture);
|
||||
}
|
||||
|
||||
if (!pipeline) {
|
||||
pipeline = std::make_shared<vk_pipeline_struct>();
|
||||
pipeline->name = name;
|
||||
@@ -2124,13 +2267,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
|
||||
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
|
||||
if (device->subgroup_add && device->subgroup_require_full_support) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true);
|
||||
} else {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
|
||||
}
|
||||
}
|
||||
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
@@ -2139,13 +2289,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
|
||||
if (device->float_controls_rte_fp16) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
|
||||
} else {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
|
||||
@@ -2242,6 +2400,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
for (auto &c : compiles) {
|
||||
@@ -2250,7 +2410,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
device->need_compiles = false;
|
||||
}
|
||||
|
||||
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props);
|
||||
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);
|
||||
|
||||
static vk_device ggml_vk_get_device(size_t idx) {
|
||||
VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
|
||||
@@ -2279,6 +2439,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
device->physical_device = physical_devices[dev_num];
|
||||
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
|
||||
|
||||
device->architecture = get_device_architecture(device->physical_device);
|
||||
|
||||
const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
|
||||
device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
|
||||
|
||||
@@ -2291,7 +2453,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
bool coopmat2_support = false;
|
||||
device->coopmat_support = false;
|
||||
|
||||
// Check if maintenance4 is supported
|
||||
for (const auto& properties : ext_props) {
|
||||
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
||||
maintenance4_support = true;
|
||||
@@ -2326,13 +2487,15 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
vk::PhysicalDeviceDriverProperties driver_props;
|
||||
vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
|
||||
vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
|
||||
vk::PhysicalDeviceVulkan11Properties vk11_props;
|
||||
vk::PhysicalDeviceVulkan12Properties vk12_props;
|
||||
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
|
||||
|
||||
props2.pNext = &props3;
|
||||
props3.pNext = &subgroup_props;
|
||||
subgroup_props.pNext = &driver_props;
|
||||
driver_props.pNext = &vk12_props;
|
||||
driver_props.pNext = &vk11_props;
|
||||
vk11_props.pNext = &vk12_props;
|
||||
|
||||
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
|
||||
|
||||
@@ -2379,13 +2542,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
|
||||
if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) {
|
||||
device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE);
|
||||
#if defined(_WIN32)
|
||||
} else if (device->vendor_id == VK_VENDOR_ID_NVIDIA) {
|
||||
} else {
|
||||
// Limit batching of allocations to 1GB by default to avoid fragmentation issues
|
||||
device->suballocation_block_size = 1024*1024*1024;
|
||||
#endif
|
||||
} else {
|
||||
device->suballocation_block_size = device->max_memory_allocation_size;
|
||||
}
|
||||
device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
|
||||
|
||||
@@ -2400,11 +2559,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
}
|
||||
device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
|
||||
|
||||
device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
||||
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
|
||||
|
||||
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
|
||||
|
||||
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
||||
|
||||
if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props)) {
|
||||
if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props, device->architecture)) {
|
||||
device->coopmat_support = false;
|
||||
}
|
||||
|
||||
@@ -2782,7 +2944,10 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||
subgroup_props.pNext = &driver_props;
|
||||
physical_device.getProperties2(&props2);
|
||||
|
||||
const size_t subgroup_size = subgroup_props.subgroupSize;
|
||||
vk_device_architecture arch = get_device_architecture(physical_device);
|
||||
uint32_t default_subgroup_size = get_subgroup_size("", arch);
|
||||
const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
|
||||
|
||||
const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
||||
|
||||
bool fp16_storage = false;
|
||||
@@ -2808,7 +2973,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||
}
|
||||
}
|
||||
|
||||
if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props)) {
|
||||
const vk_device_architecture device_architecture = get_device_architecture(physical_device);
|
||||
|
||||
if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture)) {
|
||||
coopmat_support = false;
|
||||
}
|
||||
|
||||
@@ -4481,9 +4648,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
|
||||
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
||||
const uint64_t d_sz = sizeof(float) * d_ne;
|
||||
|
||||
// With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02;
|
||||
if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) {
|
||||
gqa_ratio = 1;
|
||||
}
|
||||
|
||||
if (dryrun) {
|
||||
// Request descriptor sets
|
||||
ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1);
|
||||
ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4507,8 +4680,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
|
||||
|
||||
// compute
|
||||
const std::array<uint32_t, 6> pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
|
||||
|
||||
uint32_t workgroups_z = (uint32_t)ne12;
|
||||
// When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups
|
||||
if (gqa_ratio > 1) {
|
||||
workgroups_z /= gqa_ratio;
|
||||
}
|
||||
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, workgroups_z });
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
@@ -5335,6 +5515,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_rms_norm_back_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_L2_NORM:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_l2_norm_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(dst)) {
|
||||
case GGML_UNARY_OP_SILU:
|
||||
@@ -5474,6 +5659,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_rwkv_wkv6_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_rwkv_wkv7_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_opt_step_adamw_f32;
|
||||
@@ -5721,6 +5911,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
@@ -5970,23 +6161,17 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
||||
}, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) {
|
||||
const ggml_tensor * k = dst->src[0];
|
||||
const ggml_tensor * v = dst->src[1];
|
||||
const ggml_tensor * r = dst->src[2];
|
||||
const ggml_tensor * tf = dst->src[3];
|
||||
const ggml_tensor * td = dst->src[4];
|
||||
const ggml_tensor * state = dst->src[5];
|
||||
static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
|
||||
GGML_ASSERT(version == 6 || version == 7);
|
||||
int num_srcs = version == 6 ? 6 : 7;
|
||||
|
||||
for (int i = 0; i < num_srcs; i++) {
|
||||
GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
|
||||
}
|
||||
|
||||
GGML_ASSERT(!ggml_is_quantized(k->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(v->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(r->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(tf->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(td->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(state->type));
|
||||
GGML_ASSERT(dst->buffer != nullptr);
|
||||
|
||||
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
|
||||
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
|
||||
GGML_ASSERT(pipeline != nullptr);
|
||||
|
||||
if (dryrun) {
|
||||
@@ -5995,89 +6180,73 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
|
||||
}
|
||||
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
|
||||
ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
|
||||
ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context;
|
||||
ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context;
|
||||
ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
|
||||
ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
|
||||
for (int i = 0; i < num_srcs; i++) {
|
||||
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
|
||||
}
|
||||
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
|
||||
vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr;
|
||||
size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0;
|
||||
bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
|
||||
vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
|
||||
size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
|
||||
bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
|
||||
|
||||
if (ctx->device->uma) {
|
||||
ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
|
||||
ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
|
||||
ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
|
||||
ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
|
||||
ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
|
||||
ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
|
||||
for (int i = 0; i < num_srcs; i++) {
|
||||
ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
|
||||
srcs_uma[i] = d_srcs[i] != nullptr;
|
||||
}
|
||||
|
||||
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
|
||||
|
||||
K_uma = d_K != nullptr;
|
||||
V_uma = d_V != nullptr;
|
||||
R_uma = d_R != nullptr;
|
||||
TF_uma = d_TF != nullptr;
|
||||
TD_uma = d_TD != nullptr;
|
||||
STATE_uma = d_State != nullptr;
|
||||
DST_uma = d_D != nullptr;
|
||||
dst_uma = d_D != nullptr;
|
||||
}
|
||||
|
||||
if (!K_uma) {
|
||||
d_K = k_buf_ctx->dev_buffer;
|
||||
k_offset = vk_tensor_offset(k) + k->view_offs;
|
||||
uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 };
|
||||
for (int i = 0; i < num_srcs; i++) {
|
||||
src_sizes[i] = ggml_nbytes(dst->src[i]);
|
||||
if (!srcs_uma[i]) {
|
||||
d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
|
||||
src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
|
||||
}
|
||||
}
|
||||
if (!V_uma) {
|
||||
d_V = v_buf_ctx->dev_buffer;
|
||||
v_offset = vk_tensor_offset(v) + v->view_offs;
|
||||
}
|
||||
if (!R_uma) {
|
||||
d_R = r_buf_ctx->dev_buffer;
|
||||
r_offset = vk_tensor_offset(r) + r->view_offs;
|
||||
}
|
||||
if (!TF_uma) {
|
||||
d_TF = tf_buf_ctx->dev_buffer;
|
||||
tf_offset = vk_tensor_offset(tf) + tf->view_offs;
|
||||
}
|
||||
if (!TD_uma) {
|
||||
d_TD = td_buf_ctx->dev_buffer;
|
||||
td_offset = vk_tensor_offset(td) + td->view_offs;
|
||||
}
|
||||
if (!STATE_uma) {
|
||||
d_State = state_buf_ctx->dev_buffer;
|
||||
state_offset = vk_tensor_offset(state) + state->view_offs;
|
||||
}
|
||||
if (!DST_uma) {
|
||||
|
||||
const uint64_t dst_size = ggml_nbytes(dst);
|
||||
if (!dst_uma) {
|
||||
d_D = dst_buf_ctx->dev_buffer;
|
||||
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||
}
|
||||
|
||||
const uint64_t k_size = ggml_nbytes(k);
|
||||
const uint64_t v_size = ggml_nbytes(v);
|
||||
const uint64_t r_size = ggml_nbytes(r);
|
||||
const uint64_t tf_size = ggml_nbytes(tf);
|
||||
const uint64_t td_size = ggml_nbytes(td);
|
||||
const uint64_t state_size = ggml_nbytes(state);
|
||||
const uint64_t dst_size = ggml_nbytes(dst);
|
||||
|
||||
std::array<uint32_t, 3> elements = {
|
||||
(uint32_t)(pc.B * pc.H),
|
||||
1,
|
||||
1
|
||||
};
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||
vk_subbuffer{ d_K, k_offset, k_size },
|
||||
vk_subbuffer{ d_V, v_offset, v_size },
|
||||
vk_subbuffer{ d_R, r_offset, r_size },
|
||||
vk_subbuffer{ d_TF, tf_offset, tf_size },
|
||||
vk_subbuffer{ d_TD, td_offset, td_size },
|
||||
vk_subbuffer{ d_State, state_offset, state_size },
|
||||
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
|
||||
if (version == 6) {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
|
||||
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
|
||||
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
|
||||
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
|
||||
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
|
||||
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
||||
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
|
||||
} else if (version == 7) {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
|
||||
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
|
||||
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
|
||||
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
|
||||
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
|
||||
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
||||
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
|
||||
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||
}, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements);
|
||||
} else {
|
||||
// shouldn't happen
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
||||
@@ -6086,7 +6255,7 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
const size_t n_heads = dst->src[0]->ne[1];
|
||||
const size_t n_seqs = dst->src[5]->ne[1];
|
||||
|
||||
ggml_vk_op_f32_rwkv6(
|
||||
ggml_vk_op_f32_wkv(
|
||||
ctx, subctx, dst,
|
||||
{
|
||||
(uint32_t)n_seqs,
|
||||
@@ -6094,6 +6263,26 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
(uint32_t)n_embed,
|
||||
(uint32_t)n_heads,
|
||||
},
|
||||
6,
|
||||
dryrun
|
||||
);
|
||||
}
|
||||
|
||||
static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
||||
const size_t seq_length = dst->src[0]->ne[2];
|
||||
const size_t n_embed = dst->ne[0];
|
||||
const size_t n_heads = dst->src[0]->ne[1];
|
||||
const size_t n_seqs = dst->src[6]->ne[1];
|
||||
|
||||
ggml_vk_op_f32_wkv(
|
||||
ctx, subctx, dst,
|
||||
{
|
||||
(uint32_t)n_seqs,
|
||||
(uint32_t)seq_length,
|
||||
(uint32_t)n_embed,
|
||||
(uint32_t)n_heads,
|
||||
},
|
||||
7,
|
||||
dryrun
|
||||
);
|
||||
}
|
||||
@@ -6395,6 +6584,11 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
float * op_params = (float *)dst->op_params;
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
||||
}
|
||||
@@ -7390,6 +7584,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
@@ -7406,6 +7601,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
@@ -7452,6 +7648,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_UNARY:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
@@ -7569,6 +7766,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_L2_NORM:
|
||||
ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(node)) {
|
||||
@@ -7659,6 +7860,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
|
||||
break;
|
||||
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun);
|
||||
|
||||
break;
|
||||
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
||||
|
||||
@@ -7732,6 +7938,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
@@ -7751,6 +7958,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
@@ -8262,8 +8470,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||
|
||||
uint64_t total_mat_mul_bytes = 0;
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false);
|
||||
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
|
||||
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
||||
}
|
||||
}
|
||||
if (ctx->device->need_compiles) {
|
||||
ggml_vk_load_shaders(ctx->device);
|
||||
@@ -8284,17 +8496,27 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
bool first_node_in_batch = true; // true if next node will be first node in a batch
|
||||
int submit_node_idx = 0; // index to first node in a batch
|
||||
|
||||
// Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution.
|
||||
// Start with a smaller count to get work submitted right away, and increase it after each submit.
|
||||
int nodes_per_submit = 20;
|
||||
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
|
||||
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
|
||||
// (and scaled down based on model size, so smaller models submit earlier).
|
||||
// Also submit at least every 100 nodes, in case there are workloads without as much matmul.
|
||||
int nodes_per_submit = 100;
|
||||
int submitted_nodes = 0;
|
||||
int submit_count = 0;
|
||||
uint64_t mul_mat_bytes = 0;
|
||||
uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), total_mat_mul_bytes / 40u);
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (first_node_in_batch) {
|
||||
submit_node_idx = i;
|
||||
}
|
||||
|
||||
bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node);
|
||||
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
|
||||
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
||||
}
|
||||
|
||||
bool submit = (submitted_nodes >= nodes_per_submit) ||
|
||||
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
|
||||
(i == last_node);
|
||||
|
||||
bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
|
||||
|
||||
@@ -8311,13 +8533,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
if (submit) {
|
||||
first_node_in_batch = true;
|
||||
submitted_nodes = 0;
|
||||
switch (submit_count) {
|
||||
case 0:
|
||||
nodes_per_submit = 50;
|
||||
break;
|
||||
default:
|
||||
nodes_per_submit = 100;
|
||||
break;
|
||||
mul_mat_bytes = 0;
|
||||
if (submit_count < 3) {
|
||||
mul_mat_bytes_per_submit *= 2;
|
||||
}
|
||||
submit_count++;
|
||||
}
|
||||
@@ -8668,6 +8886,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
@@ -8697,6 +8916,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
return true;
|
||||
@@ -8843,7 +9063,7 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
|
||||
UNUSED(instance_extensions);
|
||||
}
|
||||
|
||||
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) {
|
||||
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
|
||||
switch (props.vendorID) {
|
||||
case VK_VENDOR_ID_INTEL:
|
||||
// Intel drivers don't support coopmat properly yet
|
||||
@@ -8851,10 +9071,7 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope
|
||||
case VK_VENDOR_ID_AMD:
|
||||
if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
|
||||
// Workaround for AMD proprietary driver reporting support on all GPUs
|
||||
const std::string name = props.deviceName;
|
||||
return name.rfind("AMD Radeon RX 7", 0) == 0 || name.rfind("AMD Radeon(TM) RX 7", 0) == 0 || // RDNA 3 consumer GPUs
|
||||
name.rfind("AMD Radeon PRO W7", 0) == 0 || name.rfind("AMD Radeon(TM) PRO W7", 0) == 0 || // RDNA 3 workstation GPUs
|
||||
name.rfind("AMD Radeon 7", 0) == 0 || name.rfind("AMD Radeon(TM) 7", 0) == 0; // RDNA 3 APUs
|
||||
return arch == vk_device_architecture::AMD_RDNA3;
|
||||
}
|
||||
return true;
|
||||
default:
|
||||
@@ -9084,6 +9301,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
|
||||
} else if (tensor->op == GGML_OP_SILU_BACK) {
|
||||
tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
|
||||
} else if (tensor->op == GGML_OP_L2_NORM) {
|
||||
const float eps = ((float *) tensor->op_params)[0];
|
||||
tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
|
||||
} else if (tensor->op == GGML_OP_SOFT_MAX) {
|
||||
if (src1 != nullptr) {
|
||||
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
||||
@@ -9203,6 +9423,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
|
||||
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
|
||||
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
|
||||
} else if (tensor->op == GGML_OP_RWKV_WKV7) {
|
||||
tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
|
||||
src_clone[4], src_clone[5], src_clone[6]);
|
||||
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
|
||||
src_clone[0]->flags = src0->flags;
|
||||
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
find_package (Threads REQUIRED)
|
||||
find_program(GLSLC_EXECUTABLE glslc)
|
||||
if(NOT GLSLC_EXECUTABLE)
|
||||
message(FATAL_ERROR "glslc not found.")
|
||||
endif()
|
||||
|
||||
set(TARGET vulkan-shaders-gen)
|
||||
add_executable(${TARGET} vulkan-shaders-gen.cpp)
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
#version 450
|
||||
|
||||
#if RTE16
|
||||
#extension GL_EXT_spirv_intrinsics : enable
|
||||
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
|
||||
#endif // RTE16
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_unary_head.comp"
|
||||
|
||||
|
||||
@@ -82,8 +82,8 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1]));
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
const i8vec2 v0 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2]);
|
||||
const i8vec2 v1 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2 + 1]);
|
||||
const i8vec2 v0 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2 + 1])).xy;
|
||||
return vec4(v0.x, v0.y, v1.x, v1.y);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -178,7 +178,7 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
|
||||
|
||||
uvec4 v = bl128.block.q4k[0];
|
||||
|
||||
const f16vec2 loadd = unpackFloat2x16(v.x);
|
||||
const vec2 loadd = vec2(unpackFloat2x16(v.x));
|
||||
|
||||
uint32_t sc;
|
||||
uint32_t mbyte;
|
||||
@@ -199,15 +199,15 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
|
||||
sc &= 0x3F;
|
||||
mbyte &= 0x3F;
|
||||
|
||||
const float16_t d = loadd.x * float16_t(sc);
|
||||
const float16_t m = loadd.y * float16_t(mbyte);
|
||||
const float d = loadd.x * float(sc);
|
||||
const float m = loadd.y * float(mbyte);
|
||||
|
||||
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
|
||||
qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
|
||||
|
||||
float16_t ret = d * float16_t(qs) - m;
|
||||
float ret = d * float(qs) - m;
|
||||
|
||||
return ret;
|
||||
return float16_t(ret);
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {
|
||||
@@ -311,8 +311,8 @@ float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords
|
||||
const float16_t d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint ib32 = idx / 32;
|
||||
const uint ib8 = idx / 8;
|
||||
const uint ib32 = (idx & 0xE0) >> 5;
|
||||
const uint ib8 = (idx & 0xF8) >> 3;
|
||||
|
||||
const uint qh = bl.block.qh[ib32];
|
||||
const uint qs = bl.block.qs[ib8];
|
||||
@@ -330,14 +330,20 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1
|
||||
block_iq1_m block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufIQ1_M_packed64 {
|
||||
block_iq1_m_packed64 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const u16vec4 scales = u16vec4(bl.block.scales[0], bl.block.scales[1], bl.block.scales[2], bl.block.scales[3]) >> 12;
|
||||
const float16_t d = uint16BitsToHalf(scales.x | (scales.y << 4) | (scales.z << 8) | (scales.w << 12));
|
||||
decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl);
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint ib8 = idx / 8;
|
||||
const uint ib16 = idx / 16;
|
||||
uvec2 scales = unpack32(bl64.block.scales);
|
||||
const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16)));
|
||||
|
||||
const uint ib8 = (idx & 0xF8) >> 3;
|
||||
const uint ib16 = (idx & 0xF0) >> 4;
|
||||
const int i8 = int(idx % 8);
|
||||
const uint sc = bl.block.scales[ib8 / 8];
|
||||
const uint qs = bl.block.qs[ib8];
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#define BLOCK_SIZE 512
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
shared FLOAT_TYPE sum[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
|
||||
sum[tid] += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
sum[tid] += sum[tid + s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
|
||||
}
|
||||
}
|
||||
@@ -105,6 +105,16 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
int unroll_count = 4;
|
||||
uint unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||
|
||||
#if K_PER_ITER == 2
|
||||
// If the K dimension is odd, we need lastiter==true on the last iteration
|
||||
// so OOB is computed correctly. Skip some unrolling to make that happen.
|
||||
if ((p.ncols & 1) != 0 &&
|
||||
unrolled_iters == num_iters &&
|
||||
unrolled_iters > 0) {
|
||||
unrolled_iters -= unroll_count;
|
||||
}
|
||||
#endif
|
||||
|
||||
uint i = 0;
|
||||
while (i < unrolled_iters) {
|
||||
// Manually partially unroll the loop
|
||||
@@ -113,8 +123,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
unroll_count = 2;
|
||||
unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||
|
||||
#if K_PER_ITER == 2
|
||||
if ((p.ncols & 1) != 0 &&
|
||||
unrolled_iters == num_iters &&
|
||||
unrolled_iters > 0) {
|
||||
unrolled_iters -= unroll_count;
|
||||
}
|
||||
#endif
|
||||
|
||||
while (i < unrolled_iters) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
|
||||
@@ -19,8 +19,8 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
|
||||
const float db = d * (0.5 + scale) * 0.25;
|
||||
|
||||
const uint qh = data_a[ibi].qh[ib32];
|
||||
const u8vec2 qs16 = unpack8(data_a_packed16[ibi].qs[itid]);
|
||||
const u8vec2 sign16 = unpack8(data_a_packed16[ibi].qs[QUANT_K / 16 + itid]);
|
||||
const u8vec2 qs16 = unpack8(uint32_t(data_a_packed16[ibi].qs[itid])).xy; // vec4 used due to #12147
|
||||
const u8vec2 sign16 = unpack8(uint32_t(data_a_packed16[ibi].qs[QUANT_K / 16 + itid])).xy;
|
||||
[[unroll]] for (uint l = 0; l < 2; ++l) {
|
||||
const uint8_t sign = sign16[l];
|
||||
const uint qs = qs16[l] | ((qh << (8 - nibble_shift - 2 * l)) & 0x300);
|
||||
|
||||
@@ -21,7 +21,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32,
|
||||
sum[j] = 0.0;
|
||||
}
|
||||
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
||||
const u8vec2 qs = unpack8(data_a_packed16[ibi].qs[4 * ib32 + l]);
|
||||
const u8vec2 qs = unpack8(uint32_t(data_a_packed16[ibi].qs[4 * ib32 + l])).xy; // vec4 used due to #12147
|
||||
const uint sign = data_a[ibi].signs[4 * ib32 + l];
|
||||
const vec4 grid0 = vec4(unpack8(iq3s_grid[qs.x | ((qh << (8 - 2*l)) & 0x100)]));
|
||||
const vec4 grid1 = vec4(unpack8(iq3s_grid[qs.y | ((qh << (7 - 2*l)) & 0x100)]));
|
||||
|
||||
@@ -12,6 +12,9 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
|
||||
|
||||
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
|
||||
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint ncols_x;
|
||||
@@ -37,25 +40,66 @@ void main() {
|
||||
|
||||
const uint idst = channel*nrows_dst + row_dst;
|
||||
|
||||
tmp[tid] = 0.0f;
|
||||
FLOAT_TYPE temp = 0.0f;
|
||||
|
||||
for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
|
||||
const uint col_x = col_x0 + tid;
|
||||
// Detect alignment for vector loads
|
||||
bool is_aligned = (p.ncols_x % 4) == 0 && (p.row_stride_x % 4) == 0 && (p.channel_stride_x % 4) == 0;
|
||||
|
||||
if (col_x >= p.ncols_x) {
|
||||
break;
|
||||
for (uint col_x0 = 0; col_x0 < p.ncols_x;) {
|
||||
|
||||
// Unroll 2x and do vec4 loads if aligned
|
||||
const uint unroll_count = 2;
|
||||
if (col_x0 + unroll_count * 4 * BLOCK_SIZE <= p.ncols_x && is_aligned) {
|
||||
[[unroll]] for (uint i = 0; i < unroll_count; ++i) {
|
||||
const uint col_x = col_x0 + 4*tid;
|
||||
|
||||
const uint row_y = col_x;
|
||||
|
||||
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
||||
const uint iy = channel*nrows_y + row_y;
|
||||
|
||||
const vec4 av4 = vec4(data_a_v4[ix / 4]);
|
||||
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
|
||||
|
||||
temp += dot(av4, bv4);
|
||||
|
||||
col_x0 += 4*BLOCK_SIZE;
|
||||
}
|
||||
// do vec4 loads if aligned
|
||||
} else if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
|
||||
const uint col_x = col_x0 + 4*tid;
|
||||
|
||||
const uint row_y = col_x;
|
||||
|
||||
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
||||
const uint iy = channel*nrows_y + row_y;
|
||||
|
||||
const vec4 av4 = vec4(data_a_v4[ix / 4]);
|
||||
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
|
||||
|
||||
temp += dot(av4, bv4);
|
||||
|
||||
col_x0 += 4*BLOCK_SIZE;
|
||||
} else {
|
||||
const uint col_x = col_x0 + tid;
|
||||
if (col_x >= p.ncols_x) {
|
||||
break;
|
||||
}
|
||||
|
||||
const uint row_y = col_x;
|
||||
|
||||
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
||||
const uint iy = channel*nrows_y + row_y;
|
||||
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
|
||||
|
||||
temp = fma(xi, FLOAT_TYPE(data_b[iy]), temp);
|
||||
col_x0 += BLOCK_SIZE;
|
||||
}
|
||||
|
||||
const uint row_y = col_x;
|
||||
|
||||
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
||||
const uint iy = channel*nrows_y + row_y;
|
||||
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
|
||||
|
||||
tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
|
||||
}
|
||||
|
||||
tmp[tid] = temp;
|
||||
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
|
||||
@@ -2,16 +2,25 @@
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#if USE_SUBGROUP_ADD
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#endif
|
||||
|
||||
#define BLOCK_SIZE 32
|
||||
#define FLOAT_TYPE float
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
|
||||
|
||||
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
|
||||
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
|
||||
|
||||
layout(constant_id = 0) const int BLOCK_SIZE = 32;
|
||||
// gqa_ratio is in the range [1,8]
|
||||
layout(constant_id = 1) const uint gqa_ratio = 1;
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint ncols_x;
|
||||
@@ -22,52 +31,124 @@ layout (push_constant) uniform parameter
|
||||
uint d_offset;
|
||||
} p;
|
||||
|
||||
shared FLOAT_TYPE tmp[BLOCK_SIZE];
|
||||
#if !USE_SUBGROUP_ADD
|
||||
shared FLOAT_TYPE tmp[8][BLOCK_SIZE];
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint row_x = gl_GlobalInvocationID.y;
|
||||
const uint channel = gl_GlobalInvocationID.z;
|
||||
const uint channel_x = channel / (p.nchannels_y / p.nchannels_x);
|
||||
|
||||
uint channel, channel_x;
|
||||
|
||||
// When gqa_ratio > 1, each invocation does multiple rows.
|
||||
// The row in the A matrix is starting from channel / gqa_ratio and the
|
||||
// rows in the B matrix are [channel, channel+gqa_ratio).
|
||||
// When gpa_ratio is 1, each invocation does one row.
|
||||
if (gqa_ratio > 1) {
|
||||
channel_x = gl_GlobalInvocationID.z;
|
||||
channel = channel_x * gqa_ratio;
|
||||
} else {
|
||||
channel = gl_GlobalInvocationID.z;
|
||||
channel_x = channel / (p.nchannels_y / p.nchannels_x);;
|
||||
}
|
||||
|
||||
const uint nrows_y = p.ncols_x;
|
||||
const uint nrows_dst = p.nrows_x;
|
||||
const uint row_dst = row_x;
|
||||
|
||||
tmp[tid] = FLOAT_TYPE(0.0f);
|
||||
|
||||
for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
|
||||
const uint col_x = col_x0 + tid;
|
||||
|
||||
if (col_x >= p.ncols_x) {
|
||||
break;
|
||||
}
|
||||
|
||||
// x is transposed and permuted
|
||||
const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
|
||||
|
||||
const uint row_y = col_x;
|
||||
|
||||
// y is not transposed but permuted
|
||||
const uint iy = channel*nrows_y + row_y;
|
||||
|
||||
tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
|
||||
FLOAT_TYPE temp[8];
|
||||
[[unroll]] for (uint i = 0; i < 8; ++i) {
|
||||
temp[i] = FLOAT_TYPE(0.0f);
|
||||
}
|
||||
|
||||
// dst is not transposed and not permuted
|
||||
const uint idst = channel*nrows_dst + row_dst;
|
||||
// Detect alignment for vector loads
|
||||
bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0;
|
||||
|
||||
for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
|
||||
|
||||
// Use vec4 loads if aligned
|
||||
if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
|
||||
|
||||
uint col_x = col_x0 + 4*tid;
|
||||
const uint row_y = col_x;
|
||||
|
||||
// x is transposed and permuted
|
||||
const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
|
||||
const vec4 av4 = vec4(data_a_v4[ix / 4]);
|
||||
|
||||
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
|
||||
// y is not transposed but permuted
|
||||
const uint iy = (channel + c)*nrows_y + row_y;
|
||||
|
||||
vec4 bv4 = data_b_v4[iy / 4];
|
||||
temp[c] += dot(av4, bv4);
|
||||
}
|
||||
|
||||
col_x0 += 3*BLOCK_SIZE;
|
||||
} else {
|
||||
const uint col_x = col_x0 + tid;
|
||||
|
||||
if (col_x >= p.ncols_x) {
|
||||
break;
|
||||
}
|
||||
|
||||
// x is transposed and permuted
|
||||
const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
|
||||
|
||||
const uint row_y = col_x;
|
||||
|
||||
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
|
||||
// y is not transposed but permuted
|
||||
const uint iy = (channel + c)*nrows_y + row_y;
|
||||
|
||||
temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if USE_SUBGROUP_ADD
|
||||
// reduce vec4 at a time
|
||||
vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]);
|
||||
t = subgroupAdd(t);
|
||||
temp[0] = t[0];
|
||||
temp[1] = t[1];
|
||||
temp[2] = t[2];
|
||||
temp[3] = t[3];
|
||||
if (gqa_ratio > 4) {
|
||||
t = vec4(temp[4], temp[5], temp[6], temp[7]);
|
||||
t = subgroupAdd(t);
|
||||
temp[4] = t[0];
|
||||
temp[5] = t[1];
|
||||
temp[6] = t[2];
|
||||
temp[7] = t[3];
|
||||
}
|
||||
#else
|
||||
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
|
||||
tmp[c][tid] = temp[c];
|
||||
}
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
tmp[tid] += tmp[tid + s];
|
||||
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
|
||||
temp[c] += tmp[c][tid + s];
|
||||
tmp[c][tid] = temp[c];
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
|
||||
temp[c] = tmp[c][tid];
|
||||
}
|
||||
#endif
|
||||
|
||||
if (tid == 0) {
|
||||
dst[idst] = tmp[0];
|
||||
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
|
||||
// dst is not transposed and not permuted
|
||||
const uint idst = (channel + c)*nrows_dst + row_dst;
|
||||
dst[idst] = temp[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -336,8 +336,8 @@ void main() {
|
||||
const uint iqs = idx & 0x07;
|
||||
|
||||
const float d = float(data_a_packed16[ib].d);
|
||||
const i8vec2 v0 = unpack8(data_a_packed16[ib].qs[2*iqs]);
|
||||
const i8vec2 v1 = unpack8(data_a_packed16[ib].qs[2*iqs + 1]);
|
||||
const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
|
||||
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
@@ -544,7 +544,7 @@ void main() {
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
@@ -564,7 +564,7 @@ void main() {
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
@@ -586,7 +586,7 @@ void main() {
|
||||
const float db = d * 0.25 * (0.5 + scale);
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
@@ -611,7 +611,7 @@ void main() {
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
@@ -631,7 +631,7 @@ void main() {
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
|
||||
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
|
||||
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
|
||||
@@ -23,6 +23,10 @@ layout (constant_id = 1) const uint BM = 64;
|
||||
layout (constant_id = 2) const uint BN = 64;
|
||||
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
|
||||
|
||||
layout (constant_id = 4) const bool enable_smaller_matrices = false;
|
||||
const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
|
||||
const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint M;
|
||||
@@ -168,15 +172,13 @@ void main() {
|
||||
const uint end_k = min(p.K, (ik + 1) * p.k_split);
|
||||
#endif
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
|
||||
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K;
|
||||
uint pos_b = 0;
|
||||
#else
|
||||
uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K;
|
||||
uint pos_b = batch_idx * p.batch_stride_b;
|
||||
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
||||
#endif
|
||||
|
||||
uint stride_a = p.stride_a / QUANT_K;
|
||||
@@ -197,6 +199,7 @@ void main() {
|
||||
tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
|
||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
|
||||
|
||||
#if QUANT_K > 1
|
||||
tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
|
||||
@@ -232,16 +235,54 @@ void main() {
|
||||
tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
|
||||
|
||||
uint k_iters = (end_k - start_k + BK - 1) / BK;
|
||||
if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) {
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
|
||||
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
||||
|
||||
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
}
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
|
||||
return;
|
||||
} else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) {
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
|
||||
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
||||
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
}
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
|
||||
|
||||
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
|
||||
return;
|
||||
} else {
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
|
||||
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
||||
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
}
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
|
||||
|
||||
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
|
||||
return;
|
||||
}
|
||||
} else
|
||||
#endif // !defined(MUL_MAT_ID)
|
||||
@@ -254,6 +295,9 @@ void main() {
|
||||
|
||||
tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
|
||||
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint block_k = start_k; block_k < end_k; block_k += BK) {
|
||||
|
||||
@@ -296,19 +340,16 @@ void main() {
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert from ACC_TYPE to D_TYPE
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
|
||||
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
|
||||
// Convert from ACC_TYPE to D_TYPE
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
|
||||
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
// Call callback to store each element, remapping row through shared memory
|
||||
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
|
||||
// Call callback to store each element, remapping row through shared memory
|
||||
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
|
||||
#else
|
||||
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
|
||||
|
||||
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
||||
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
|
||||
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#if !defined(GGML_TYPES_COMP)
|
||||
#define GGML_TYPES_COMP
|
||||
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||
@@ -312,6 +313,12 @@ struct block_iq1_m {
|
||||
uint16_t scales[QUANT_K_IQ1_M/64];
|
||||
};
|
||||
|
||||
struct block_iq1_m_packed64 {
|
||||
uint64_t qs[QUANT_K_IQ1_M/8/8];
|
||||
uint64_t qh[QUANT_K_IQ1_M/16/8];
|
||||
uint64_t scales;
|
||||
};
|
||||
|
||||
#if defined(DATA_A_IQ1_S)
|
||||
#define QUANT_K QUANT_K_IQ1_S
|
||||
#define QUANT_R QUANT_R_IQ1_S
|
||||
|
||||
@@ -426,14 +426,16 @@ void process_shaders() {
|
||||
}
|
||||
}
|
||||
|
||||
string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
|
||||
string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
|
||||
|
||||
// Norms
|
||||
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
||||
@@ -444,6 +446,7 @@ void process_shaders() {
|
||||
|
||||
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
|
||||
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
|
||||
string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
}
|
||||
|
||||
@@ -528,6 +531,8 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
for (auto &c : compiles) {
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
#define BLOCK_SIZE 64
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout(push_constant) uniform Parameters {
|
||||
uint B;
|
||||
uint T;
|
||||
uint C;
|
||||
uint H;
|
||||
};
|
||||
|
||||
layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; };
|
||||
layout(binding = 1) readonly buffer WBuf { A_TYPE w[]; };
|
||||
layout(binding = 2) readonly buffer KBuf { A_TYPE k[]; };
|
||||
layout(binding = 3) readonly buffer VBuf { A_TYPE v[]; };
|
||||
layout(binding = 4) readonly buffer ABuf { A_TYPE a[]; };
|
||||
layout(binding = 5) readonly buffer BBuf { A_TYPE b[]; };
|
||||
layout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; };
|
||||
layout(binding = 7) buffer DstBuf { A_TYPE dst[]; };
|
||||
|
||||
shared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint head_size = BLOCK_SIZE;
|
||||
const uint batch_id = gl_WorkGroupID.x / H;
|
||||
const uint head_id = gl_WorkGroupID.x % H;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
const uint state_size = C * head_size;
|
||||
const uint n_seq_tokens = T / B;
|
||||
|
||||
if (batch_id >= B || head_id >= H) {
|
||||
return;
|
||||
}
|
||||
|
||||
A_TYPE state[BLOCK_SIZE];
|
||||
[[unroll]] for (uint i = 0; i < head_size; i++) {
|
||||
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
||||
+ tid * head_size + i];
|
||||
}
|
||||
|
||||
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
||||
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
||||
|
||||
for (uint t = start_t; t < end_t; t += C) {
|
||||
barrier();
|
||||
_r[tid] = r[t];
|
||||
_w[tid] = w[t];
|
||||
_k[tid] = k[t];
|
||||
_a[tid] = a[t];
|
||||
_b[tid] = b[t];
|
||||
barrier();
|
||||
|
||||
A_TYPE sa = 0.0;
|
||||
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
|
||||
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
vec4 a_vec = vec4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
||||
sa += dot(s_vec, a_vec);
|
||||
}
|
||||
|
||||
const A_TYPE v_val = v[t];
|
||||
A_TYPE y = 0.0;
|
||||
|
||||
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
|
||||
vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
||||
vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
||||
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
vec4 kv = k_vec * v_val;
|
||||
s_vec = s_vec * w_vec + kv + sa * b_vec;
|
||||
y += dot(r_vec, s_vec);
|
||||
|
||||
state[j] = s_vec.x;
|
||||
state[j+1] = s_vec.y;
|
||||
state[j+2] = s_vec.z;
|
||||
state[j+3] = s_vec.w;
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
[[unroll]] for (uint i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
||||
+ tid * head_size + i] = state[i];
|
||||
}
|
||||
}
|
||||
+85
-2
@@ -929,6 +929,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"RMS_NORM",
|
||||
"RMS_NORM_BACK",
|
||||
"GROUP_NORM",
|
||||
"L2_NORM",
|
||||
|
||||
"MUL_MAT",
|
||||
"MUL_MAT_ID",
|
||||
@@ -977,6 +978,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"ADD_REL_POS",
|
||||
"RWKV_WKV6",
|
||||
"GATED_LINEAR_ATTN",
|
||||
"RWKV_WKV7",
|
||||
|
||||
"UNARY",
|
||||
|
||||
@@ -996,7 +998,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"OPT_STEP_ADAMW",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
||||
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@@ -1026,6 +1028,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"rms_norm(x)",
|
||||
"rms_norm_back(x)",
|
||||
"group_norm(x)",
|
||||
"l2_norm(x)",
|
||||
|
||||
"X*Y",
|
||||
"X[i]*Y",
|
||||
@@ -1074,6 +1077,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"add_rel_pos(x)",
|
||||
"rwkv_wkv6(k, v, r, tf, td, s)",
|
||||
"gated_linear_attn(k, v, q, gate, s)",
|
||||
"rwkv_wkv7(r, w, k, v, a, b, s)",
|
||||
|
||||
"unary(x)",
|
||||
|
||||
@@ -1093,7 +1097,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"adamw(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
||||
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -2686,6 +2690,37 @@ struct ggml_tensor * ggml_group_norm_inplace(
|
||||
return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
|
||||
}
|
||||
|
||||
// ggml_l2_norm
|
||||
|
||||
static struct ggml_tensor * ggml_l2_norm_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float eps,
|
||||
bool inplace) {
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params_f32(result, 0, eps);
|
||||
|
||||
result->op = GGML_OP_L2_NORM;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_l2_norm(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float eps) {
|
||||
return ggml_l2_norm_impl(ctx, a, eps, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_l2_norm_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float eps) {
|
||||
return ggml_l2_norm_impl(ctx, a, eps, true);
|
||||
}
|
||||
|
||||
// ggml_mul_mat
|
||||
|
||||
static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
||||
@@ -4720,6 +4755,54 @@ struct ggml_tensor * ggml_gated_linear_attn(
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_rwkv_wkv7
|
||||
|
||||
struct ggml_tensor * ggml_rwkv_wkv7(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * r,
|
||||
struct ggml_tensor * w,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * state) {
|
||||
GGML_ASSERT(ggml_is_contiguous(r));
|
||||
GGML_ASSERT(ggml_is_contiguous(w));
|
||||
GGML_ASSERT(ggml_is_contiguous(k));
|
||||
GGML_ASSERT(ggml_is_contiguous(v));
|
||||
GGML_ASSERT(ggml_is_contiguous(a));
|
||||
GGML_ASSERT(ggml_is_contiguous(b));
|
||||
GGML_ASSERT(ggml_is_contiguous(state));
|
||||
|
||||
const int64_t S = k->ne[0];
|
||||
const int64_t H = k->ne[1];
|
||||
const int64_t n_tokens = k->ne[2];
|
||||
const int64_t n_seqs = state->ne[1];
|
||||
{
|
||||
GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens);
|
||||
GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens);
|
||||
GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
|
||||
GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens);
|
||||
GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens);
|
||||
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
|
||||
}
|
||||
|
||||
// concat output and new_state
|
||||
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
result->op = GGML_OP_RWKV_WKV7;
|
||||
result->src[0] = r;
|
||||
result->src[1] = w;
|
||||
result->src[2] = k;
|
||||
result->src[3] = v;
|
||||
result->src[4] = a;
|
||||
result->src[5] = b;
|
||||
result->src[6] = state;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_unary
|
||||
|
||||
static struct ggml_tensor * ggml_unary_impl(
|
||||
|
||||
+111
-16
@@ -118,22 +118,26 @@ class Keys:
|
||||
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "{arch}.attention.head_count"
|
||||
HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
|
||||
MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
|
||||
CLAMP_KQV = "{arch}.attention.clamp_kqv"
|
||||
KEY_LENGTH = "{arch}.attention.key_length"
|
||||
VALUE_LENGTH = "{arch}.attention.value_length"
|
||||
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
|
||||
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
|
||||
GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon"
|
||||
GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups"
|
||||
CAUSAL = "{arch}.attention.causal"
|
||||
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
|
||||
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
|
||||
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
|
||||
SLIDING_WINDOW = "{arch}.attention.sliding_window"
|
||||
SCALE = "{arch}.attention.scale"
|
||||
HEAD_COUNT = "{arch}.attention.head_count"
|
||||
HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
|
||||
MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
|
||||
CLAMP_KQV = "{arch}.attention.clamp_kqv"
|
||||
KEY_LENGTH = "{arch}.attention.key_length"
|
||||
VALUE_LENGTH = "{arch}.attention.value_length"
|
||||
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
|
||||
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
|
||||
GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon"
|
||||
GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups"
|
||||
CAUSAL = "{arch}.attention.causal"
|
||||
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
|
||||
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
|
||||
DECAY_LORA_RANK = "{arch}.attention.decay_lora_rank"
|
||||
ICLR_LORA_RANK = "{arch}.attention.iclr_lora_rank"
|
||||
VALUE_RESIDUAL_MIX_LORA_RANK = "{arch}.attention.value_residual_mix_lora_rank"
|
||||
GATE_LORA_RANK = "{arch}.attention.gate_lora_rank"
|
||||
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
|
||||
SLIDING_WINDOW = "{arch}.attention.sliding_window"
|
||||
SCALE = "{arch}.attention.scale"
|
||||
|
||||
class Rope:
|
||||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
@@ -257,6 +261,8 @@ class MODEL_ARCH(IntEnum):
|
||||
STARCODER2 = auto()
|
||||
RWKV6 = auto()
|
||||
RWKV6QWEN2 = auto()
|
||||
RWKV7 = auto()
|
||||
ARWKV7 = auto()
|
||||
MAMBA = auto()
|
||||
XVERSE = auto()
|
||||
COMMAND_R = auto()
|
||||
@@ -329,8 +335,20 @@ class MODEL_TENSOR(IntEnum):
|
||||
SSM_A = auto()
|
||||
SSM_D = auto()
|
||||
SSM_OUT = auto()
|
||||
TIME_MIX_W0 = auto()
|
||||
TIME_MIX_W1 = auto()
|
||||
TIME_MIX_W2 = auto()
|
||||
TIME_MIX_A0 = auto()
|
||||
TIME_MIX_A1 = auto()
|
||||
TIME_MIX_A2 = auto()
|
||||
TIME_MIX_V0 = auto()
|
||||
TIME_MIX_V1 = auto()
|
||||
TIME_MIX_V2 = auto()
|
||||
TIME_MIX_G1 = auto()
|
||||
TIME_MIX_G2 = auto()
|
||||
TIME_MIX_K_K = auto()
|
||||
TIME_MIX_K_A = auto()
|
||||
TIME_MIX_R_K = auto()
|
||||
TIME_MIX_LERP_X = auto()
|
||||
TIME_MIX_LERP_K = auto()
|
||||
TIME_MIX_LERP_V = auto()
|
||||
@@ -445,6 +463,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||
MODEL_ARCH.RWKV6: "rwkv6",
|
||||
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
|
||||
MODEL_ARCH.RWKV7: "rwkv7",
|
||||
MODEL_ARCH.ARWKV7: "arwkv7",
|
||||
MODEL_ARCH.MAMBA: "mamba",
|
||||
MODEL_ARCH.XVERSE: "xverse",
|
||||
MODEL_ARCH.COMMAND_R: "command-r",
|
||||
@@ -517,8 +537,20 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
|
||||
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
||||
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
||||
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
|
||||
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
|
||||
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
|
||||
MODEL_TENSOR.TIME_MIX_A0: "blk.{bid}.time_mix_a0",
|
||||
MODEL_TENSOR.TIME_MIX_A1: "blk.{bid}.time_mix_a1",
|
||||
MODEL_TENSOR.TIME_MIX_A2: "blk.{bid}.time_mix_a2",
|
||||
MODEL_TENSOR.TIME_MIX_V0: "blk.{bid}.time_mix_v0",
|
||||
MODEL_TENSOR.TIME_MIX_V1: "blk.{bid}.time_mix_v1",
|
||||
MODEL_TENSOR.TIME_MIX_V2: "blk.{bid}.time_mix_v2",
|
||||
MODEL_TENSOR.TIME_MIX_G1: "blk.{bid}.time_mix_g1",
|
||||
MODEL_TENSOR.TIME_MIX_G2: "blk.{bid}.time_mix_g2",
|
||||
MODEL_TENSOR.TIME_MIX_K_K: "blk.{bid}.time_mix_k_k",
|
||||
MODEL_TENSOR.TIME_MIX_K_A: "blk.{bid}.time_mix_k_a",
|
||||
MODEL_TENSOR.TIME_MIX_R_K: "blk.{bid}.time_mix_r_k",
|
||||
MODEL_TENSOR.TIME_MIX_LERP_X: "blk.{bid}.time_mix_lerp_x",
|
||||
MODEL_TENSOR.TIME_MIX_LERP_K: "blk.{bid}.time_mix_lerp_k",
|
||||
MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v",
|
||||
@@ -1081,6 +1113,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
],
|
||||
MODEL_ARCH.GEMMA3: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
@@ -1172,6 +1205,68 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.RWKV7: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_NORM_2,
|
||||
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
|
||||
MODEL_TENSOR.TIME_MIX_W0,
|
||||
MODEL_TENSOR.TIME_MIX_W1,
|
||||
MODEL_TENSOR.TIME_MIX_W2,
|
||||
MODEL_TENSOR.TIME_MIX_A0,
|
||||
MODEL_TENSOR.TIME_MIX_A1,
|
||||
MODEL_TENSOR.TIME_MIX_A2,
|
||||
MODEL_TENSOR.TIME_MIX_V0,
|
||||
MODEL_TENSOR.TIME_MIX_V1,
|
||||
MODEL_TENSOR.TIME_MIX_V2,
|
||||
MODEL_TENSOR.TIME_MIX_G1,
|
||||
MODEL_TENSOR.TIME_MIX_G2,
|
||||
MODEL_TENSOR.TIME_MIX_K_K,
|
||||
MODEL_TENSOR.TIME_MIX_K_A,
|
||||
MODEL_TENSOR.TIME_MIX_R_K,
|
||||
MODEL_TENSOR.TIME_MIX_KEY,
|
||||
MODEL_TENSOR.TIME_MIX_VALUE,
|
||||
MODEL_TENSOR.TIME_MIX_RECEPTANCE,
|
||||
MODEL_TENSOR.TIME_MIX_LN,
|
||||
MODEL_TENSOR.TIME_MIX_OUTPUT,
|
||||
MODEL_TENSOR.CHANNEL_MIX_LERP_K,
|
||||
MODEL_TENSOR.CHANNEL_MIX_KEY,
|
||||
MODEL_TENSOR.CHANNEL_MIX_VALUE,
|
||||
],
|
||||
MODEL_ARCH.ARWKV7: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
|
||||
MODEL_TENSOR.TIME_MIX_W0,
|
||||
MODEL_TENSOR.TIME_MIX_W1,
|
||||
MODEL_TENSOR.TIME_MIX_W2,
|
||||
MODEL_TENSOR.TIME_MIX_A0,
|
||||
MODEL_TENSOR.TIME_MIX_A1,
|
||||
MODEL_TENSOR.TIME_MIX_A2,
|
||||
MODEL_TENSOR.TIME_MIX_V0,
|
||||
MODEL_TENSOR.TIME_MIX_V1,
|
||||
MODEL_TENSOR.TIME_MIX_V2,
|
||||
MODEL_TENSOR.TIME_MIX_G1,
|
||||
MODEL_TENSOR.TIME_MIX_G2,
|
||||
MODEL_TENSOR.TIME_MIX_K_K,
|
||||
MODEL_TENSOR.TIME_MIX_K_A,
|
||||
MODEL_TENSOR.TIME_MIX_R_K,
|
||||
MODEL_TENSOR.TIME_MIX_KEY,
|
||||
MODEL_TENSOR.TIME_MIX_VALUE,
|
||||
MODEL_TENSOR.TIME_MIX_RECEPTANCE,
|
||||
MODEL_TENSOR.TIME_MIX_LN,
|
||||
MODEL_TENSOR.TIME_MIX_OUTPUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.MAMBA: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
||||
@@ -767,6 +767,18 @@ class GGUFWriter:
|
||||
def add_kv_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_decay_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.DECAY_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_iclr_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.ICLR_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_value_residual_mix_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_gate_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_relative_attn_buckets_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)
|
||||
|
||||
|
||||
+100
-31
@@ -27,7 +27,8 @@ class TensorNameMap:
|
||||
"embedding.word_embeddings", # chatglm
|
||||
"transformer.token_embeddings", # openelm
|
||||
"shared", # t5
|
||||
"rwkv.embeddings", # rwkv
|
||||
"rwkv.embeddings", # rwkv6
|
||||
"model.embeddings", # rwkv7
|
||||
),
|
||||
|
||||
# Token type embeddings
|
||||
@@ -42,6 +43,9 @@ class TensorNameMap:
|
||||
"emb_ln", # nomic-bert
|
||||
"transformer.norm", # openelm
|
||||
"rwkv.blocks.0.pre_ln", # rwkv
|
||||
"rwkv.blocks.0.pre_ln", # rwkv6
|
||||
"model.pre_ln", # rwkv7
|
||||
"model.layers.0.pre_norm", # rwkv7
|
||||
"backbone.norm", # wavtokenizer
|
||||
),
|
||||
|
||||
@@ -81,7 +85,8 @@ class TensorNameMap:
|
||||
"encoder.final_layernorm", # chatglm
|
||||
"transformer.norm", # openelm
|
||||
"model.norm", # nemotron
|
||||
"rwkv.ln_out", # rwkv
|
||||
"rwkv.ln_out", # rwkv6
|
||||
"model.ln_out", # rwkv7
|
||||
"backbone.final_layer_norm", # wavtokenizer
|
||||
),
|
||||
|
||||
@@ -122,14 +127,16 @@ class TensorNameMap:
|
||||
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
|
||||
"encoder.layers.{bid}.input_layernorm", # chatglm
|
||||
"transformer.layers.{bid}.attn_norm", # openelm
|
||||
"rwkv.blocks.{bid}.ln1", # rwkv
|
||||
"rwkv.blocks.{bid}.ln1", # rwkv6
|
||||
"model.layers.{bid}.ln1", # rwkv7
|
||||
),
|
||||
|
||||
# Attention norm 2
|
||||
MODEL_TENSOR.ATTN_NORM_2: (
|
||||
"transformer.h.{bid}.ln_attn", # falcon40b
|
||||
"encoder.layer.{bid}.layer_norm_1", # jina-v2-code
|
||||
"rwkv.blocks.{bid}.ln2", # rwkv
|
||||
"rwkv.blocks.{bid}.ln2", # rwkv6
|
||||
"model.layers.{bid}.ln2", # rwkv7
|
||||
),
|
||||
|
||||
# Attention query-key-value
|
||||
@@ -462,112 +469,174 @@ class TensorNameMap:
|
||||
"backbone.layers.{bid}.mixer.out_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_W0: (
|
||||
"model.layers.{bid}.attention.w0", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_W1: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2
|
||||
"model.layers.{bid}.attention.w1", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_W2: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2
|
||||
"model.layers.{bid}.attention.w2", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_A0: (
|
||||
"model.layers.{bid}.attention.a0", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_A1: (
|
||||
"model.layers.{bid}.attention.a1", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_A2: (
|
||||
"model.layers.{bid}.attention.a2", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_V0: (
|
||||
"model.layers.{bid}.attention.v0", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_V1: (
|
||||
"model.layers.{bid}.attention.v1", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_V2: (
|
||||
"model.layers.{bid}.attention.v2", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_G1: (
|
||||
"model.layers.{bid}.attention.g1", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_G2: (
|
||||
"model.layers.{bid}.attention.g2", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_K_K: (
|
||||
"model.layers.{bid}.attention.k_k", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_K_A: (
|
||||
"model.layers.{bid}.attention.k_a", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_R_K: (
|
||||
"model.layers.{bid}.attention.r_k", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_X: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_x", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_maa_x", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_x", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_K: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_k", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_maa_k", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_k", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_V: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_v", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_maa_v", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_v", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_R: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_r", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_maa_r", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_r", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_G: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_g", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_maa_g", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_g", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_W: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_w", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_FIRST: (
|
||||
"rwkv.blocks.{bid}.attention.time_faaaa", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_faaaa", # rwkv6
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_DECAY: (
|
||||
"rwkv.blocks.{bid}.attention.time_decay", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_decay", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_decay", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_DECAY_W1: (
|
||||
"rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_decay_w1", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_DECAY_W2: (
|
||||
"rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_decay_w2", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_KEY: (
|
||||
"rwkv.blocks.{bid}.attention.key", # rwkv
|
||||
"rwkv.blocks.{bid}.attention.key", # rwkv6
|
||||
"model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2
|
||||
"model.layers.{bid}.attention.key", # rwkv7
|
||||
"model.layers.{bid}.attention.k_proj", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_VALUE: (
|
||||
"rwkv.blocks.{bid}.attention.value", # rwkv
|
||||
"rwkv.blocks.{bid}.attention.value", # rwkv6
|
||||
"model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2
|
||||
"model.layers.{bid}.attention.value", # rwkv7
|
||||
"model.layers.{bid}.attention.v_proj", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_RECEPTANCE: (
|
||||
"rwkv.blocks.{bid}.attention.receptance", # rwkv
|
||||
"model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2
|
||||
"rwkv.blocks.{bid}.attention.receptance", # rwkv6
|
||||
"model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2
|
||||
"model.layers.{bid}.attention.receptance", # rwkv7
|
||||
"model.layers.{bid}.attention.r_proj", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_GATE: (
|
||||
"rwkv.blocks.{bid}.attention.gate", # rwkv
|
||||
"model.layers.{bid}.self_attn.gate", # rwkv6qwen2
|
||||
"rwkv.blocks.{bid}.attention.gate", # rwkv6
|
||||
"model.layers.{bid}.self_attn.gate", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LN: (
|
||||
"rwkv.blocks.{bid}.attention.ln_x", # rwkv
|
||||
"rwkv.blocks.{bid}.attention.ln_x", # rwkv6
|
||||
"model.layers.{bid}.attention.ln_x" # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_OUTPUT: (
|
||||
"rwkv.blocks.{bid}.attention.output", # rwkv
|
||||
"rwkv.blocks.{bid}.attention.output", # rwkv6
|
||||
"model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2
|
||||
"model.layers.{bid}.attention.output", # rwkv7
|
||||
"model.layers.{bid}.attention.o_proj", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CHANNEL_MIX_LERP_K: (
|
||||
"rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv v6
|
||||
"rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv6
|
||||
"model.layers.{bid}.feed_forward.x_k", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CHANNEL_MIX_LERP_R: (
|
||||
"rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv v6
|
||||
"rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv6
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CHANNEL_MIX_KEY: (
|
||||
"rwkv.blocks.{bid}.feed_forward.key", # rwkv
|
||||
"rwkv.blocks.{bid}.feed_forward.key", # rwkv6
|
||||
"model.layers.{bid}.feed_forward.key", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: (
|
||||
"rwkv.blocks.{bid}.feed_forward.receptance", # rwkv
|
||||
"rwkv.blocks.{bid}.feed_forward.receptance", # rwkv6
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CHANNEL_MIX_VALUE: (
|
||||
"rwkv.blocks.{bid}.feed_forward.value", # rwkv
|
||||
"rwkv.blocks.{bid}.feed_forward.value", # rwkv6
|
||||
"model.layers.{bid}.feed_forward.value", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_A: (
|
||||
|
||||
@@ -154,7 +154,12 @@ class SpecialVocab:
|
||||
return True
|
||||
with open(tokenizer_config_file, encoding = 'utf-8') as f:
|
||||
tokenizer_config = json.load(f)
|
||||
chat_template = tokenizer_config.get('chat_template')
|
||||
chat_template_alt = None
|
||||
chat_template_file = path / 'chat_template.json'
|
||||
if chat_template_file.is_file():
|
||||
with open(chat_template_file, encoding = 'utf-8') as f:
|
||||
chat_template_alt = json.load(f).get('chat_template')
|
||||
chat_template = tokenizer_config.get('chat_template', chat_template_alt)
|
||||
if chat_template is None or isinstance(chat_template, (str, list)):
|
||||
self.chat_template = chat_template
|
||||
else:
|
||||
|
||||
@@ -107,6 +107,7 @@ extern "C" {
|
||||
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
|
||||
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
|
||||
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
|
||||
};
|
||||
|
||||
enum llama_rope_type {
|
||||
|
||||
+103
-16
@@ -59,6 +59,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_EXAONE, "exaone" },
|
||||
{ LLM_ARCH_RWKV6, "rwkv6" },
|
||||
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
|
||||
{ LLM_ARCH_RWKV7, "rwkv7" },
|
||||
{ LLM_ARCH_ARWKV7, "arwkv7" },
|
||||
{ LLM_ARCH_GRANITE, "granite" },
|
||||
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
||||
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
||||
@@ -110,22 +112,26 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
|
||||
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
|
||||
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
||||
{ LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
|
||||
{ LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
|
||||
{ LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
|
||||
{ LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
|
||||
{ LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
|
||||
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
|
||||
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
||||
{ LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
|
||||
{ LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
|
||||
{ LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
|
||||
{ LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
|
||||
{ LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
|
||||
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
|
||||
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_DECAY_LORA_RANK, "%s.attention.decay_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_ICLR_LORA_RANK, "%s.attention.iclr_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, "%s.attention.value_residual_mix_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
||||
@@ -772,6 +778,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
@@ -1238,6 +1245,74 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_RWKV7,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
|
||||
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
|
||||
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
|
||||
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
|
||||
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
|
||||
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
|
||||
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
|
||||
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
|
||||
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
|
||||
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
|
||||
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
|
||||
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
|
||||
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
|
||||
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
|
||||
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
|
||||
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
|
||||
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
|
||||
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
|
||||
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
|
||||
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
|
||||
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
|
||||
{ LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" },
|
||||
{ LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" },
|
||||
{ LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_ARWKV7,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
|
||||
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
|
||||
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
|
||||
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
|
||||
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
|
||||
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
|
||||
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
|
||||
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
|
||||
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
|
||||
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
|
||||
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
|
||||
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
|
||||
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
|
||||
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
|
||||
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
|
||||
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
|
||||
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
|
||||
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
|
||||
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
|
||||
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GRANITE,
|
||||
{
|
||||
@@ -1397,6 +1472,12 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_A2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_V1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_V2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_G1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_G2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
@@ -1415,6 +1496,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_TIME_MIX_K_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_TIME_MIX_K_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_TIME_MIX_R_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
@@ -1422,6 +1506,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_W0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_A0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_V0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
|
||||
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
|
||||
@@ -63,6 +63,8 @@ enum llm_arch {
|
||||
LLM_ARCH_EXAONE,
|
||||
LLM_ARCH_RWKV6,
|
||||
LLM_ARCH_RWKV6QWEN2,
|
||||
LLM_ARCH_RWKV7,
|
||||
LLM_ARCH_ARWKV7,
|
||||
LLM_ARCH_GRANITE,
|
||||
LLM_ARCH_GRANITE_MOE,
|
||||
LLM_ARCH_CHAMELEON,
|
||||
@@ -127,6 +129,10 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_CAUSAL,
|
||||
LLM_KV_ATTENTION_Q_LORA_RANK,
|
||||
LLM_KV_ATTENTION_KV_LORA_RANK,
|
||||
LLM_KV_ATTENTION_DECAY_LORA_RANK,
|
||||
LLM_KV_ATTENTION_ICLR_LORA_RANK,
|
||||
LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK,
|
||||
LLM_KV_ATTENTION_GATE_LORA_RANK,
|
||||
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
||||
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
||||
LLM_KV_ATTENTION_SCALE,
|
||||
@@ -250,8 +256,20 @@ enum llm_tensor {
|
||||
LLM_TENSOR_SSM_A,
|
||||
LLM_TENSOR_SSM_D,
|
||||
LLM_TENSOR_SSM_OUT,
|
||||
LLM_TENSOR_TIME_MIX_W0,
|
||||
LLM_TENSOR_TIME_MIX_W1,
|
||||
LLM_TENSOR_TIME_MIX_W2,
|
||||
LLM_TENSOR_TIME_MIX_A0,
|
||||
LLM_TENSOR_TIME_MIX_A1,
|
||||
LLM_TENSOR_TIME_MIX_A2,
|
||||
LLM_TENSOR_TIME_MIX_V0,
|
||||
LLM_TENSOR_TIME_MIX_V1,
|
||||
LLM_TENSOR_TIME_MIX_V2,
|
||||
LLM_TENSOR_TIME_MIX_G1,
|
||||
LLM_TENSOR_TIME_MIX_G2,
|
||||
LLM_TENSOR_TIME_MIX_K_K,
|
||||
LLM_TENSOR_TIME_MIX_K_A,
|
||||
LLM_TENSOR_TIME_MIX_R_K,
|
||||
LLM_TENSOR_TIME_MIX_LERP_X,
|
||||
LLM_TENSOR_TIME_MIX_LERP_W,
|
||||
LLM_TENSOR_TIME_MIX_LERP_K,
|
||||
|
||||
+33
-4
@@ -294,10 +294,7 @@ llama_context::llama_context(
|
||||
// TODO: something cleaner
|
||||
const auto n_outputs_save = n_outputs;
|
||||
|
||||
// max number of outputs
|
||||
n_outputs = n_tokens;
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||
|
||||
int n_splits_pp = -1;
|
||||
int n_nodes_pp = -1;
|
||||
@@ -313,8 +310,15 @@ llama_context::llama_context(
|
||||
// reserve pp graph first so that buffers are only allocated once
|
||||
{
|
||||
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
// max number of outputs
|
||||
n_outputs = ubatch_pp.n_tokens;
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
||||
|
||||
auto * gf = graph_init();
|
||||
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
||||
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
}
|
||||
@@ -326,11 +330,18 @@ llama_context::llama_context(
|
||||
// reserve with tg graph to get the number of splits and nodes
|
||||
{
|
||||
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
n_outputs = ubatch_tg.n_tokens;
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
|
||||
|
||||
auto * gf = graph_init();
|
||||
graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
|
||||
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||
throw std::runtime_error("failed to allocate compute tg buffers");
|
||||
}
|
||||
|
||||
n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
|
||||
n_nodes_tg = ggml_graph_n_nodes(gf);
|
||||
}
|
||||
@@ -338,8 +349,14 @@ llama_context::llama_context(
|
||||
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
||||
{
|
||||
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
n_outputs = ubatch_pp.n_tokens;
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
||||
|
||||
auto * gf = graph_init();
|
||||
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
||||
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
}
|
||||
@@ -1057,6 +1074,13 @@ int llama_context::encode(llama_batch & inp_batch) {
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
|
||||
const auto causal_attn_org = cparams.causal_attn;
|
||||
|
||||
// always use non-causal attention for encoder graphs
|
||||
// TODO: this is a tmp solution until we have a proper way to support enc-dec models
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
|
||||
cparams.causal_attn = false;
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
|
||||
|
||||
@@ -1064,6 +1088,8 @@ int llama_context::encode(llama_batch & inp_batch) {
|
||||
|
||||
res->set_inputs(&ubatch);
|
||||
|
||||
cparams.causal_attn = causal_attn_org;
|
||||
|
||||
const auto compute_status = graph_compute(gf, n_tokens > 1);
|
||||
switch (compute_status) {
|
||||
case GGML_STATUS_SUCCESS:
|
||||
@@ -1134,6 +1160,8 @@ int llama_context::encode(llama_batch & inp_batch) {
|
||||
if (model.arch == LLM_ARCH_T5 && t_embd) {
|
||||
//cross.t_embd = t_embd;
|
||||
|
||||
synchronize();
|
||||
|
||||
cross.n_embd = t_embd->ne[0];
|
||||
cross.n_enc = t_embd->ne[1];
|
||||
cross.v_embd.resize(cross.n_embd*cross.n_enc);
|
||||
@@ -1142,6 +1170,7 @@ int llama_context::encode(llama_batch & inp_batch) {
|
||||
// remember the sequence ids used during the encoding - needed for cross attention later
|
||||
cross.seq_ids_enc.resize(n_tokens);
|
||||
for (int32_t i = 0; i < n_tokens; i++) {
|
||||
cross.seq_ids_enc[i].clear();
|
||||
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
|
||||
llama_seq_id seq_id = ubatch.seq_id[i][s];
|
||||
cross.seq_ids_enc[i].insert(seq_id);
|
||||
|
||||
+1
-1
@@ -1378,7 +1378,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
// note: storing RoPE-ed version of K in the KV cache
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
|
||||
|
||||
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
|
||||
v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
|
||||
|
||||
ggml_tensor * v_cache_view = nullptr;
|
||||
|
||||
|
||||
+12
-12
@@ -487,9 +487,9 @@ struct llm_graph_context {
|
||||
|
||||
ggml_tensor * build_attn_mha(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
|
||||
ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
|
||||
ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * kq_mask,
|
||||
bool v_trans,
|
||||
@@ -502,9 +502,9 @@ struct llm_graph_context {
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
@@ -516,9 +516,9 @@ struct llm_graph_context {
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
@@ -530,9 +530,9 @@ struct llm_graph_context {
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
@@ -76,6 +76,10 @@ struct llama_hparams {
|
||||
uint32_t time_decay_extra_dim = 0;
|
||||
uint32_t wkv_head_size = 0;
|
||||
uint32_t token_shift_count = 2;
|
||||
uint32_t n_lora_decay = 0;
|
||||
uint32_t n_lora_iclr = 0;
|
||||
uint32_t n_lora_value_res_mix = 0;
|
||||
uint32_t n_lora_gate = 0;
|
||||
|
||||
float rope_attn_factor = 1.0f;
|
||||
float rope_freq_base_train;
|
||||
|
||||
+1
-1
@@ -476,7 +476,7 @@ struct llama_mlock::impl {
|
||||
|
||||
char* errmsg = std::strerror(errno);
|
||||
bool suggest = (errno == ENOMEM);
|
||||
#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV)
|
||||
#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX)
|
||||
// visionOS/tvOS dont't support RLIMIT_MEMLOCK
|
||||
// Skip resource limit checks on visionOS/tvOS
|
||||
suggest = false;
|
||||
|
||||
+1068
-318
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user