forked from wylab/llama.cpp
Compare commits
47 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7a3c178d78 | |||
| dc4bb64290 | |||
| 8551c44d84 | |||
| 35cae5ba05 | |||
| 810e0af3f5 | |||
| eba92d64c3 | |||
| d9a14523bb | |||
| fd123cfead | |||
| a53f7f7b88 | |||
| 7dfad387e3 | |||
| 60c902926c | |||
| b1b132efcb | |||
| 01e8f2138b | |||
| eab5606d7b | |||
| de788e071b | |||
| 624a683c6f | |||
| 116b9a1662 | |||
| eaffba0f2e | |||
| 8e7714fa77 | |||
| a363251fac | |||
| ba79369615 | |||
| 07d84fa3c2 | |||
| 32940369d3 | |||
| 5e6a6d4e1c | |||
| bfdddbc150 | |||
| 54566ad95d | |||
| 04f8641815 | |||
| c3dd79007b | |||
| 65f0184517 | |||
| 9fb2d81eab | |||
| 47086fa82d | |||
| 4aabf4e8f4 | |||
| 86973cb14a | |||
| 17f954c8e2 | |||
| 46596caf6d | |||
| 1d6ba97789 | |||
| 1170135dfb | |||
| 40989f4116 | |||
| 9e75c49d35 | |||
| f0ffd81130 | |||
| a1b1dea33b | |||
| 4bf7ca3943 | |||
| aed4a8e980 | |||
| 85ef80cbe9 | |||
| 17d3658b5f | |||
| f2e59a8eb9 | |||
| 4ed4fe75ed |
+6
-37
@@ -582,41 +582,6 @@ std::string string_from(const struct llama_context * ctx, const std::vector<llam
|
||||
return buf.str();
|
||||
}
|
||||
|
||||
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) {
|
||||
std::stringstream buf;
|
||||
|
||||
buf << "[ ";
|
||||
|
||||
bool first = true;
|
||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||
if (!first) {
|
||||
buf << ", ";
|
||||
} else {
|
||||
first = false;
|
||||
}
|
||||
|
||||
auto detokenized = common_token_to_piece(ctx, batch.token[i]);
|
||||
|
||||
detokenized.erase(
|
||||
std::remove_if(
|
||||
detokenized.begin(),
|
||||
detokenized.end(),
|
||||
[](const unsigned char c) { return !std::isprint(c); }),
|
||||
detokenized.end());
|
||||
|
||||
buf << "\n" << std::to_string(i)
|
||||
<< ", token '" << detokenized << "'"
|
||||
<< ", pos " << std::to_string(batch.pos[i])
|
||||
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
|
||||
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
|
||||
<< ", logits " << std::to_string(batch.logits[i]);
|
||||
}
|
||||
|
||||
buf << " ]";
|
||||
|
||||
return buf.str();
|
||||
}
|
||||
|
||||
void string_process_escapes(std::string & input) {
|
||||
std::size_t input_len = input.length();
|
||||
std::size_t output_idx = 0;
|
||||
@@ -1051,7 +1016,8 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
}
|
||||
|
||||
if (llama_model_has_encoder(model)) {
|
||||
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), tmp.size(), 0, 0, true));
|
||||
llama_encode_ext(lctx, batch.get());
|
||||
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
|
||||
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
|
||||
decoder_start_token_id = bos;
|
||||
@@ -1060,7 +1026,8 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
tmp.push_back(decoder_start_token_id);
|
||||
}
|
||||
if (llama_model_has_decoder(model)) {
|
||||
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0, true));
|
||||
llama_decode_ext(lctx, batch.get());
|
||||
}
|
||||
llama_kv_self_clear(lctx);
|
||||
llama_synchronize(lctx);
|
||||
@@ -1613,10 +1580,12 @@ std::pair<std::string, std::string> common_get_hf_file(const std::string &, cons
|
||||
// Batch utils
|
||||
//
|
||||
|
||||
// DEPRECATED
|
||||
void common_batch_clear(struct llama_batch & batch) {
|
||||
batch.n_tokens = 0;
|
||||
}
|
||||
|
||||
// DEPRECATED
|
||||
void common_batch_add(
|
||||
struct llama_batch & batch,
|
||||
llama_token id,
|
||||
|
||||
+62
-1
@@ -516,7 +516,6 @@ void string_process_escapes(std::string & input);
|
||||
std::string string_from(bool value);
|
||||
std::string string_from(const std::vector<int> & values);
|
||||
std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens);
|
||||
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch);
|
||||
|
||||
//
|
||||
// Filesystem utils
|
||||
@@ -570,8 +569,10 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adap
|
||||
// Batch utils
|
||||
//
|
||||
|
||||
// DEPRECATED
|
||||
void common_batch_clear(struct llama_batch & batch);
|
||||
|
||||
// DEPRECATED
|
||||
void common_batch_add(
|
||||
struct llama_batch & batch,
|
||||
llama_token id,
|
||||
@@ -579,6 +580,66 @@ void common_batch_add(
|
||||
const std::vector<llama_seq_id> & seq_ids,
|
||||
bool logits);
|
||||
|
||||
// convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions
|
||||
// this is meant to be temporary
|
||||
struct common_batch {
|
||||
llama_batch_ext_ptr batch;
|
||||
struct batch_token {
|
||||
llama_token token;
|
||||
llama_seq_id seq_id; // only support single seq for now
|
||||
bool logits;
|
||||
};
|
||||
std::vector<batch_token> tokens;
|
||||
int n_outputs = 0;
|
||||
common_batch() = default;
|
||||
common_batch(int32_t n_tokens, int32_t n_seq_max) {
|
||||
batch.reset(llama_batch_ext_init(n_tokens, n_seq_max));
|
||||
tokens.reserve(n_tokens);
|
||||
}
|
||||
void clear() {
|
||||
llama_batch_ext_clear(batch.get());
|
||||
tokens.clear();
|
||||
}
|
||||
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
|
||||
llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits);
|
||||
tokens.push_back({token, seq_id, logits});
|
||||
if (logits) {
|
||||
n_outputs++;
|
||||
}
|
||||
}
|
||||
void add_text_multi_seq(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) {
|
||||
llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits);
|
||||
tokens.push_back({token, seq_ids[0], logits});
|
||||
if (logits) {
|
||||
n_outputs++;
|
||||
}
|
||||
}
|
||||
void set_logits_last() {
|
||||
if (!tokens.empty()) {
|
||||
llama_batch_ext_set_output_last(batch.get());
|
||||
tokens.back().logits = true;
|
||||
}
|
||||
}
|
||||
int32_t get_n_tokens() const {
|
||||
return (int32_t)tokens.size();
|
||||
}
|
||||
llama_batch_ext * get() {
|
||||
return batch.get();
|
||||
}
|
||||
common_batch get_view(int32_t offset, int32_t n_tokens) {
|
||||
common_batch view;
|
||||
view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens));
|
||||
view.tokens.reserve(n_tokens);
|
||||
for (int32_t i = 0; i < n_tokens; i++) {
|
||||
view.tokens.push_back(tokens[offset + i]);
|
||||
if (tokens[offset + i].logits) {
|
||||
view.n_outputs++;
|
||||
}
|
||||
}
|
||||
return view;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Token utils
|
||||
//
|
||||
|
||||
+14
-14
@@ -14,7 +14,7 @@ struct common_speculative {
|
||||
struct llama_context * ctx;
|
||||
struct common_sampler * smpl;
|
||||
|
||||
llama_batch batch;
|
||||
llama_batch_ext_ptr batch;
|
||||
llama_tokens prompt;
|
||||
};
|
||||
|
||||
@@ -23,7 +23,7 @@ struct common_speculative * common_speculative_init(
|
||||
auto * result = new common_speculative {
|
||||
/* .ctx = */ ctx_dft,
|
||||
/* .smpl = */ nullptr,
|
||||
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
|
||||
/* .batch = */ llama_batch_ext_ptr(llama_batch_ext_init(llama_n_batch(ctx_dft), 1)),
|
||||
/* .prompt = */ {},
|
||||
};
|
||||
|
||||
@@ -69,8 +69,6 @@ void common_speculative_free(struct common_speculative * spec) {
|
||||
|
||||
common_sampler_free(spec->smpl);
|
||||
|
||||
llama_batch_free(spec->batch);
|
||||
|
||||
delete spec;
|
||||
}
|
||||
|
||||
@@ -151,6 +149,8 @@ llama_tokens common_speculative_gen_draft(
|
||||
|
||||
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
|
||||
|
||||
const llama_seq_id seq_id = 0;
|
||||
|
||||
// reuse as much as possible from the old draft context
|
||||
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
|
||||
for (int i = 0; i < (int) prompt.size(); ++i) {
|
||||
@@ -206,40 +206,40 @@ llama_tokens common_speculative_gen_draft(
|
||||
}
|
||||
|
||||
// prepare a batch to evaluate any new tokens in the prompt
|
||||
common_batch_clear(batch);
|
||||
llama_batch_ext_clear(batch.get());
|
||||
|
||||
for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
|
||||
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
|
||||
common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
|
||||
llama_batch_ext_add_text(batch.get(), prompt_tgt[i], i - i_start, &seq_id, 1, false);
|
||||
|
||||
prompt.push_back(prompt_tgt[i]);
|
||||
}
|
||||
|
||||
// we should rarely end-up here during normal decoding
|
||||
if (batch.n_tokens > 0) {
|
||||
if (llama_batch_ext_get_n_tokens(batch.get()) > 0) {
|
||||
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
|
||||
|
||||
llama_decode(ctx, batch);
|
||||
llama_decode_ext(ctx, batch.get());
|
||||
}
|
||||
|
||||
const llama_pos n_past = prompt.size();
|
||||
|
||||
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
|
||||
|
||||
common_batch_clear(batch);
|
||||
common_batch_add (batch, id_last, n_past, { 0 }, true);
|
||||
llama_batch_ext_clear(batch.get());
|
||||
llama_batch_ext_add_text(batch.get(), id_last, n_past, &seq_id, 1, true);
|
||||
|
||||
prompt.push_back(id_last);
|
||||
|
||||
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
|
||||
|
||||
llama_decode(ctx, batch);
|
||||
llama_decode_ext(ctx, batch.get());
|
||||
|
||||
common_sampler_reset(smpl);
|
||||
|
||||
// sample n_draft tokens from the draft model
|
||||
for (int i = 0; i < params.n_draft; ++i) {
|
||||
common_batch_clear(batch);
|
||||
llama_batch_ext_clear(batch.get());
|
||||
|
||||
common_sampler_sample(smpl, ctx, 0, true);
|
||||
|
||||
@@ -266,10 +266,10 @@ llama_tokens common_speculative_gen_draft(
|
||||
break;
|
||||
}
|
||||
|
||||
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
|
||||
llama_batch_ext_add_text(batch.get(), id, n_past + i + 1, &seq_id, 1, true);
|
||||
|
||||
// evaluate the drafted tokens on the draft model
|
||||
llama_decode(ctx, batch);
|
||||
llama_decode_ext(ctx, batch.get());
|
||||
|
||||
prompt.push_back(id);
|
||||
}
|
||||
|
||||
+197
-32
@@ -908,6 +908,40 @@ class Model:
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_rwkv_world(self):
|
||||
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
|
||||
vocab_size = self.hparams.get("vocab_size", 65536)
|
||||
|
||||
tokens: list[bytes] = ['<s>'.encode("utf-8")]
|
||||
toktypes: list[int] = [gguf.TokenType.CONTROL]
|
||||
|
||||
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
parts = line.split(' ')
|
||||
assert len(parts) >= 3
|
||||
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
|
||||
token = token.encode("utf-8") if isinstance(token, str) else token
|
||||
assert isinstance(token, bytes)
|
||||
assert len(token) == token_len
|
||||
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
|
||||
tokens.append(token_text.encode("utf-8"))
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
remainder = vocab_size - len(tokens)
|
||||
assert remainder >= 0
|
||||
for i in range(len(tokens), vocab_size):
|
||||
tokens.append(f"[PAD{i}]".encode("utf-8"))
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("rwkv")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
|
||||
special_vocab.chat_template = "rwkv-world"
|
||||
# hack: Add '\n\n' as the EOT token to make it chat normally
|
||||
special_vocab._set_special_token("eot", 261)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int):
|
||||
tokenizer_path = Path(sys.path[0]) / "models" / f"ggml-vocab-{model_name}.gguf"
|
||||
logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
|
||||
@@ -3412,38 +3446,7 @@ class Rwkv6Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.RWKV6
|
||||
|
||||
def set_vocab(self):
|
||||
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
|
||||
vocab_size = self.hparams.get("vocab_size", 65536)
|
||||
|
||||
tokens: list[bytes] = ['<s>'.encode("utf-8")]
|
||||
toktypes: list[int] = [gguf.TokenType.CONTROL]
|
||||
|
||||
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
parts = line.split(' ')
|
||||
assert len(parts) >= 3
|
||||
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
|
||||
token = token.encode("utf-8") if isinstance(token, str) else token
|
||||
assert isinstance(token, bytes)
|
||||
assert len(token) == token_len
|
||||
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
|
||||
tokens.append(token_text.encode("utf-8"))
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
remainder = vocab_size - len(tokens)
|
||||
assert remainder >= 0
|
||||
for i in range(len(tokens), vocab_size):
|
||||
tokens.append(f"[PAD{i}]".encode("utf-8"))
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("rwkv")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
|
||||
special_vocab.chat_template = "rwkv-world"
|
||||
# hack: Add '\n\n' as the EOT token to make it chat normally
|
||||
special_vocab._set_special_token("eot", 261)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
self._set_vocab_rwkv_world()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
@@ -3565,6 +3568,168 @@ class RWKV6Qwen2Model(Rwkv6Model):
|
||||
yield (new_name, data)
|
||||
|
||||
|
||||
@Model.register("Rwkv7ForCausalLM", "RWKV7ForCausalLM")
|
||||
class Rwkv7Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.RWKV7
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_rwkv_world()
|
||||
|
||||
def calc_lora_rank(self, hidden_size, exponent, multiplier):
|
||||
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
try:
|
||||
head_size = self.hparams["head_size"]
|
||||
layer_norm_eps = self.hparams["layer_norm_epsilon"]
|
||||
except KeyError:
|
||||
head_size = self.hparams["head_dim"]
|
||||
layer_norm_eps = self.hparams["norm_eps"]
|
||||
hidden_size = self.hparams["hidden_size"]
|
||||
intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else (hidden_size * 4)
|
||||
|
||||
# ICLR: In-Context-Learning-Rate
|
||||
try:
|
||||
lora_rank_decay = self.hparams["lora_rank_decay"] if self.hparams["lora_rank_decay"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||
lora_rank_iclr = self.hparams["lora_rank_iclr"] if self.hparams["lora_rank_iclr"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||
lora_rank_value_residual_mix = self.hparams["lora_rank_value_residual_mix"] if self.hparams["lora_rank_value_residual_mix"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
|
||||
lora_rank_gate = self.hparams["lora_rank_gate"] if self.hparams["lora_rank_gate"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
|
||||
except KeyError:
|
||||
lora_rank_decay = self.hparams["decay_low_rank_dim"] if self.hparams["decay_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||
lora_rank_iclr = self.hparams["a_low_rank_dim"] if self.hparams["a_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||
lora_rank_value_residual_mix = self.hparams["v_low_rank_dim"] if self.hparams["v_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
|
||||
lora_rank_gate = self.hparams["gate_low_rank_dim"] if self.hparams["gate_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
|
||||
|
||||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
|
||||
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
|
||||
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
|
||||
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
|
||||
self.gguf_writer.add_feed_forward_length(intermediate_size)
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
# required by llama.cpp, unused
|
||||
self.gguf_writer.add_head_count(0)
|
||||
|
||||
lerp_weights: dict[int, dict[str, Tensor]] = {}
|
||||
lora_needs_transpose: bool = True
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# unify tensor names here to make life easier
|
||||
name = name.replace("blocks", "layers").replace("ffn", "feed_forward")
|
||||
name = name.replace("self_attn", "attention").replace("attn", "attention")
|
||||
name = name.replace("time_mixer.", "")
|
||||
# lora layer names in fla-hub's impl
|
||||
if "_lora.lora" in name:
|
||||
self.lora_needs_transpose = False
|
||||
name = name.replace("_lora.lora.0.weight", "1.weight")
|
||||
name = name.replace("_lora.lora.2.weight", "2.weight")
|
||||
name = name.replace("_lora.lora.2.bias", "0.weight")
|
||||
|
||||
name = name.replace("feed_forward_norm", "ln2")
|
||||
name = name.replace("g_norm", "ln_x")
|
||||
|
||||
if "attention.v" in name and "value" not in self.map_tensor_name(name) and bid == 0:
|
||||
# some models have dummy v0/v1/v2 on first layer while others don't
|
||||
# ignore them all since they are not used
|
||||
return
|
||||
|
||||
wkv_has_gate = self.hparams.get("wkv_has_gate", True)
|
||||
lerp_list = ["r", "w", "k", "v", "a", "g"] if wkv_has_gate else ["r", "w", "k", "v", "a"]
|
||||
|
||||
if bid is not None and "attention.x_" in name:
|
||||
if "attention.x_x" in name:
|
||||
# already concatenated
|
||||
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
|
||||
data = data_torch.reshape(len(lerp_list), 1, 1, -1)
|
||||
yield (new_name, data)
|
||||
else:
|
||||
try:
|
||||
self.lerp_weights[bid][name] = data_torch
|
||||
except KeyError:
|
||||
self.lerp_weights[bid] = {name: data_torch}
|
||||
if all(f"model.layers.{bid}.attention.x_{i}" in self.lerp_weights[bid].keys() for i in lerp_list):
|
||||
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
|
||||
data = torch.stack([self.lerp_weights[bid][f"model.layers.{bid}.attention.x_{i}"] for i in lerp_list], dim=0)
|
||||
yield (new_name, data)
|
||||
return
|
||||
else:
|
||||
data_torch = data_torch.squeeze()
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
||||
if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
|
||||
new_name += ".weight"
|
||||
|
||||
if self.lora_needs_transpose and any(
|
||||
new_name.endswith(t) for t in [
|
||||
"time_mix_w1.weight", "time_mix_w2.weight",
|
||||
"time_mix_a1.weight", "time_mix_a2.weight",
|
||||
"time_mix_v1.weight", "time_mix_v2.weight",
|
||||
"time_mix_g1.weight", "time_mix_g2.weight",
|
||||
]
|
||||
):
|
||||
data_torch = data_torch.transpose(0, 1)
|
||||
|
||||
if 'r_k' in new_name:
|
||||
data_torch = data_torch.flatten()
|
||||
|
||||
if bid == 0 and "time_mix_a" in new_name:
|
||||
# dummy v0/v1/v2 on first layer
|
||||
# easist way to make llama happy
|
||||
yield (new_name.replace("time_mix_a", "time_mix_v"), data_torch)
|
||||
|
||||
yield (new_name, data_torch)
|
||||
|
||||
|
||||
@Model.register("RwkvHybridForCausalLM")
|
||||
class ARwkv7Model(Rwkv7Model):
|
||||
model_arch = gguf.MODEL_ARCH.ARWKV7
|
||||
|
||||
def set_vocab(self):
|
||||
try:
|
||||
self._set_vocab_sentencepiece()
|
||||
except FileNotFoundError:
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
hidden_size = self.hparams["hidden_size"]
|
||||
head_size = self.hparams["head_size"]
|
||||
rms_norm_eps = self.hparams["rms_norm_eps"]
|
||||
intermediate_size = self.hparams["intermediate_size"]
|
||||
wkv_has_gate = self.hparams["wkv_has_gate"]
|
||||
assert self.hparams["wkv_version"] == 7
|
||||
|
||||
# ICLR: In-Context-Learning-Rate
|
||||
lora_rank_decay = 64
|
||||
lora_rank_iclr = 64
|
||||
lora_rank_value_residual_mix = 32
|
||||
lora_rank_gate = 128 if wkv_has_gate else 0
|
||||
|
||||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
|
||||
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
|
||||
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
|
||||
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
|
||||
self.gguf_writer.add_feed_forward_length(intermediate_size)
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
self.gguf_writer.add_token_shift_count(1)
|
||||
|
||||
# required by llama.cpp, unused
|
||||
self.gguf_writer.add_head_count(0)
|
||||
|
||||
|
||||
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
|
||||
class MambaModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.MAMBA
|
||||
|
||||
@@ -660,8 +660,9 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
||||
|--------------------|---------------------------------------|---------------------------------------------|
|
||||
| GGML_SYCL | ON (mandatory) | Enable build with SYCL code path.<br>FP32 path - recommended for better perforemance than FP16 on quantized model|
|
||||
| GGML_SYCL_TARGET | INTEL *(default)* \| NVIDIA \| AMD | Set the SYCL target device type. |
|
||||
| GGML_SYCL_DEVICE_ARCH | Optional (except for AMD) | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. |
|
||||
| GGML_SYCL_DEVICE_ARCH | Optional (except for AMD) | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. |
|
||||
| GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. |
|
||||
| GGML_SYCL_GRAPH | ON *(default)* \|OFF *(Optional)* | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). |
|
||||
| CMAKE_C_COMPILER | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path. |
|
||||
| CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)* | Set `icpx/icx` compiler for SYCL code path. |
|
||||
|
||||
@@ -671,6 +672,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
||||
|-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------|
|
||||
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
|
||||
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase |
|
||||
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
|
||||
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
|
||||
|
||||
|
||||
|
||||
@@ -59,24 +59,17 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const int32_t n_kv_max = llama_n_ctx(ctx);
|
||||
|
||||
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
|
||||
llama_batch_ext * batch = llama_batch_ext_init(n_kv_max, 1);
|
||||
|
||||
// decode in batches of ctx_params.n_batch tokens
|
||||
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
|
||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||
auto decode_helper = [](llama_context * ctx, llama_batch_ext * batch, int32_t n_batch) {
|
||||
const int32_t n_batch_tokens = llama_batch_ext_get_n_tokens(batch);
|
||||
for (int32_t i = 0; i < (int32_t) n_batch_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (n_batch_tokens - i));
|
||||
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch.token + i,
|
||||
nullptr,
|
||||
batch.pos + i,
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
};
|
||||
llama_batch_ext_ptr batch_view = llama_batch_ext_ptr(llama_batch_ext_get_view(batch, i, n_tokens));
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
const int ret = llama_decode_ext(ctx, batch_view.get());
|
||||
if (ret != 0) {
|
||||
LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
|
||||
return false;
|
||||
@@ -91,7 +84,8 @@ int main(int argc, char ** argv) {
|
||||
// warm up
|
||||
{
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
common_batch_add(batch, 0, i, { 0 }, false);
|
||||
const llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, false);
|
||||
}
|
||||
|
||||
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
|
||||
@@ -121,14 +115,14 @@ int main(int argc, char ** argv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
common_batch_clear(batch);
|
||||
llama_batch_ext_clear(batch);
|
||||
|
||||
for (int i = 0; i < pp; ++i) {
|
||||
for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
|
||||
common_batch_add(batch, 0, i, { j }, false);
|
||||
llama_batch_ext_add_text(batch, 0, i, &j, 1, false);
|
||||
}
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
llama_batch_ext_set_output_last(batch);
|
||||
|
||||
const auto t_pp_start = ggml_time_us();
|
||||
|
||||
@@ -150,10 +144,10 @@ int main(int argc, char ** argv) {
|
||||
const auto t_tg_start = ggml_time_us();
|
||||
|
||||
for (int i = 0; i < tg; ++i) {
|
||||
common_batch_clear(batch);
|
||||
llama_batch_ext_clear(batch);
|
||||
|
||||
for (int j = 0; j < pl; ++j) {
|
||||
common_batch_add(batch, 0, pp + i, { j }, true);
|
||||
llama_batch_ext_add_text(batch, 0, pp + i, &j, 1, true);
|
||||
}
|
||||
|
||||
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
|
||||
@@ -191,7 +185,7 @@ int main(int argc, char ** argv) {
|
||||
LOG("\n");
|
||||
llama_perf_context_print(ctx);
|
||||
|
||||
llama_batch_free(batch);
|
||||
llama_batch_ext_free(batch);
|
||||
|
||||
llama_free(ctx);
|
||||
llama_model_free(model);
|
||||
|
||||
@@ -102,7 +102,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// create a llama_batch
|
||||
// we use this object to submit token data for decoding
|
||||
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel);
|
||||
llama_batch_ext * batch = llama_batch_ext_init(std::max(tokens_list.size(), (size_t) n_parallel), n_parallel);
|
||||
|
||||
std::vector<llama_seq_id> seq_ids(n_parallel, 0);
|
||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||
@@ -111,12 +111,12 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// evaluate the initial prompt
|
||||
for (size_t i = 0; i < tokens_list.size(); ++i) {
|
||||
common_batch_add(batch, tokens_list[i], i, seq_ids, false);
|
||||
llama_batch_ext_add_text(batch, tokens_list[i], i, seq_ids.data(), seq_ids.size(), false);
|
||||
}
|
||||
GGML_ASSERT(batch.n_tokens == (int) tokens_list.size());
|
||||
GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == (int) tokens_list.size());
|
||||
|
||||
if (llama_model_has_encoder(model)) {
|
||||
if (llama_encode(ctx, batch)) {
|
||||
if (llama_encode_ext(ctx, batch)) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@@ -126,14 +126,14 @@ int main(int argc, char ** argv) {
|
||||
decoder_start_token_id = llama_vocab_bos(vocab);
|
||||
}
|
||||
|
||||
common_batch_clear(batch);
|
||||
common_batch_add(batch, decoder_start_token_id, 0, seq_ids, false);
|
||||
llama_batch_ext_clear(batch);
|
||||
llama_batch_ext_add_text(batch, decoder_start_token_id, 0, seq_ids.data(), seq_ids.size(), false);
|
||||
}
|
||||
|
||||
// llama_decode will output logits only for the last token of the prompt
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
llama_batch_ext_set_output_last(batch);
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
if (llama_decode_ext(ctx, batch) != 0) {
|
||||
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@@ -155,16 +155,16 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// remember the batch index of the last token for each parallel sequence
|
||||
// we need this to determine which logits to sample from
|
||||
std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
|
||||
std::vector<int32_t> i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch) - 1);
|
||||
|
||||
int n_cur = batch.n_tokens;
|
||||
int n_cur = llama_batch_ext_get_n_tokens(batch);
|
||||
int n_decode = 0;
|
||||
|
||||
const auto t_main_start = ggml_time_us();
|
||||
|
||||
while (n_cur <= n_predict) {
|
||||
// prepare the next batch
|
||||
common_batch_clear(batch);
|
||||
llama_batch_ext_clear(batch);
|
||||
|
||||
// sample the next token for each parallel sequence / stream
|
||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||
@@ -193,23 +193,23 @@ int main(int argc, char ** argv) {
|
||||
|
||||
streams[i] += common_token_to_piece(ctx, new_token_id);
|
||||
|
||||
i_batch[i] = batch.n_tokens;
|
||||
i_batch[i] = llama_batch_ext_get_n_tokens(batch);
|
||||
|
||||
// push this new token for next evaluation
|
||||
common_batch_add(batch, new_token_id, n_cur, { i }, true);
|
||||
llama_batch_ext_add_text(batch, new_token_id, n_cur, &i, 1, true);
|
||||
|
||||
n_decode += 1;
|
||||
}
|
||||
|
||||
// all streams are finished
|
||||
if (batch.n_tokens == 0) {
|
||||
if (llama_batch_ext_get_n_tokens(batch) == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
n_cur += 1;
|
||||
|
||||
// evaluate the current batch with the transformer model
|
||||
if (llama_decode(ctx, batch)) {
|
||||
if (llama_decode_ext(ctx, batch)) {
|
||||
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return 1;
|
||||
}
|
||||
@@ -234,7 +234,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
llama_batch_free(batch);
|
||||
llama_batch_ext_free(batch);
|
||||
|
||||
llama_sampler_free(smpl);
|
||||
llama_free(ctx);
|
||||
|
||||
@@ -343,7 +343,8 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
|
||||
|
||||
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
|
||||
llama_kv_self_clear(ctx);
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
|
||||
auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), 0, 0, true);
|
||||
if (llama_decode_ext(ctx, batch.get())) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -26,14 +26,14 @@ static std::vector<std::string> split_lines(const std::string & s, const std::st
|
||||
return lines;
|
||||
}
|
||||
|
||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
|
||||
static void batch_add_seq(common_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
|
||||
size_t n_tokens = tokens.size();
|
||||
for (size_t i = 0; i < n_tokens; i++) {
|
||||
common_batch_add(batch, tokens[i], i, { seq_id }, true);
|
||||
batch.add_text(tokens[i], i, seq_id, true);
|
||||
}
|
||||
}
|
||||
|
||||
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
|
||||
static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
|
||||
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
|
||||
const struct llama_model * model = llama_get_model(ctx);
|
||||
|
||||
@@ -41,21 +41,21 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
||||
llama_kv_self_clear(ctx);
|
||||
|
||||
// run model
|
||||
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
||||
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq);
|
||||
if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
|
||||
// encoder-only model
|
||||
if (llama_encode(ctx, batch) < 0) {
|
||||
if (llama_encode_ext(ctx, batch.get()) < 0) {
|
||||
LOG_ERR("%s : failed to encode\n", __func__);
|
||||
}
|
||||
} else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
|
||||
// decoder-only model
|
||||
if (llama_decode(ctx, batch) < 0) {
|
||||
if (llama_decode_ext(ctx, batch.get()) < 0) {
|
||||
LOG_ERR("%s : failed to decode\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; i++) {
|
||||
if (!batch.logits[i]) {
|
||||
for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) {
|
||||
if (!batch.tokens[i].logits) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -69,8 +69,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
||||
GGML_ASSERT(embd != NULL && "failed to get token embeddings");
|
||||
} else {
|
||||
// try to get sequence embeddings - supported only when pooling_type is not NONE
|
||||
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||
embd_pos = batch.seq_id[i][0];
|
||||
embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id);
|
||||
embd_pos = batch.tokens[i].seq_id;
|
||||
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
|
||||
}
|
||||
|
||||
@@ -171,7 +171,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// initialize batch
|
||||
const int n_prompts = prompts.size();
|
||||
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||
struct common_batch batch = common_batch(n_batch, 1);
|
||||
|
||||
// count number of embeddings
|
||||
int n_embd_count = 0;
|
||||
@@ -198,12 +198,12 @@ int main(int argc, char ** argv) {
|
||||
const uint64_t n_toks = inp.size();
|
||||
|
||||
// encode if at capacity
|
||||
if (batch.n_tokens + n_toks > n_batch) {
|
||||
if (batch.get_n_tokens() + n_toks > n_batch) {
|
||||
float * out = emb + e * n_embd;
|
||||
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
|
||||
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
|
||||
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.get_n_tokens() : s;
|
||||
s = 0;
|
||||
common_batch_clear(batch);
|
||||
batch.clear();
|
||||
}
|
||||
|
||||
// add to batch
|
||||
@@ -319,7 +319,6 @@ int main(int argc, char ** argv) {
|
||||
llama_perf_context_print(ctx);
|
||||
|
||||
// clean up
|
||||
llama_batch_free(batch);
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -134,7 +134,8 @@ static bool run(llama_context * ctx, const common_params & params) {
|
||||
|
||||
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
|
||||
auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), 0, 0, true);
|
||||
if (llama_decode_ext(ctx, batch.get())) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
+13
-11
@@ -13,10 +13,10 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
||||
llama_batch_ext * batch = llama_batch_ext_init(llama_n_batch(ctx), 1);
|
||||
|
||||
for (uint64_t i = 0; i < sentences.size(); i++) {
|
||||
common_batch_clear(batch);
|
||||
llama_batch_ext_clear(batch);
|
||||
|
||||
const std::string input_string = instruction + sentences[i];
|
||||
|
||||
@@ -41,7 +41,8 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
||||
|
||||
// add input to batch (this increments n_tokens)
|
||||
for (int32_t j = 0; j < n_toks; j++) {
|
||||
common_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
|
||||
const llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch, inputs[j], j, &seq_id, 1 , j >= n_inst);
|
||||
}
|
||||
|
||||
// clear previous kv_cache values (irrelevant for embeddings)
|
||||
@@ -50,7 +51,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
||||
llama_set_causal_attn(ctx, false);
|
||||
|
||||
// run model
|
||||
llama_decode(ctx, batch);
|
||||
llama_decode_ext(ctx, batch);
|
||||
|
||||
// get embedding dimensions
|
||||
uint64_t n_embd = llama_model_n_embd(model);
|
||||
@@ -89,7 +90,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
||||
#endif
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
llama_batch_ext_free(batch);
|
||||
|
||||
return result;
|
||||
}
|
||||
@@ -106,25 +107,26 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
|
||||
llama_set_embeddings(ctx, false);
|
||||
llama_set_causal_attn(ctx, true);
|
||||
|
||||
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
||||
llama_batch_ext * bat = llama_batch_ext_init(llama_n_batch(ctx), 1);
|
||||
|
||||
std::vector<llama_token> inputs = common_tokenize(vocab, prompt, false, true);
|
||||
int32_t i_current_token = 0;
|
||||
|
||||
while (true) {
|
||||
common_batch_clear(bat);
|
||||
llama_batch_ext_clear(bat);
|
||||
{
|
||||
const int32_t n_inputs = inputs.size();
|
||||
|
||||
for (int32_t i = 0; i < n_inputs; i++) {
|
||||
common_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
|
||||
const llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(bat, inputs[i], i_current_token++, &seq_id, 1, i == n_inputs - 1);
|
||||
}
|
||||
}
|
||||
inputs.clear();
|
||||
|
||||
llama_decode(ctx, bat);
|
||||
llama_decode_ext(ctx, bat);
|
||||
|
||||
llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1);
|
||||
llama_token token = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(bat) - 1);
|
||||
|
||||
if (token == eos_token) {
|
||||
break;
|
||||
@@ -145,7 +147,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
|
||||
std::printf("\n");
|
||||
}
|
||||
|
||||
llama_batch_free(bat);
|
||||
llama_batch_ext_free(bat);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -497,7 +497,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
|
||||
// clear the KV cache
|
||||
llama_kv_self_clear(ctx);
|
||||
|
||||
llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||
llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1);
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
@@ -511,14 +511,15 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
|
||||
tokens[batch_start] = llama_vocab_bos(vocab);
|
||||
}
|
||||
|
||||
common_batch_clear(batch);
|
||||
llama_batch_ext_clear(batch);
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
|
||||
const llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch, tokens[batch_start + i], j*n_batch + i, &seq_id, 1, true);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
if (llama_decode_ext(ctx, batch)) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
llama_batch_free(batch);
|
||||
llama_batch_ext_free(batch);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -531,7 +532,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
|
||||
}
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
llama_batch_ext_free(batch);
|
||||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
|
||||
@@ -353,7 +353,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
|
||||
auto batch = llama_batch_ext_ptr::init_from_text(&embd[i], n_eval, n_past, 0, true);
|
||||
if (llama_decode_ext(ctx, batch.get())) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -1427,7 +1427,7 @@ struct sql_printer : public printer {
|
||||
}
|
||||
};
|
||||
|
||||
static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
|
||||
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
|
||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
@@ -1444,14 +1444,15 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
|
||||
for (int i = 1; i < n_tokens; i++) {
|
||||
tokens[i] = std::rand() % n_vocab;
|
||||
}
|
||||
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
|
||||
auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), n_tokens, n_past + n_processed, 0, true);
|
||||
llama_decode_ext(ctx, batch.get());
|
||||
n_processed += n_tokens;
|
||||
}
|
||||
|
||||
llama_synchronize(ctx);
|
||||
}
|
||||
|
||||
static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
|
||||
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
|
||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
@@ -1461,7 +1462,8 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
|
||||
llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
|
||||
|
||||
for (int i = 0; i < n_gen; i++) {
|
||||
llama_decode(ctx, llama_batch_get_one(&token, 1));
|
||||
auto batch = llama_batch_ext_ptr::init_from_text(&token, 1, n_past + i, 0, true);
|
||||
llama_decode_ext(ctx, batch.get());
|
||||
llama_synchronize(ctx);
|
||||
token = std::rand() % n_vocab;
|
||||
}
|
||||
@@ -1608,13 +1610,13 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count);
|
||||
}
|
||||
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
|
||||
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
|
||||
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
|
||||
}
|
||||
if (t.n_gen > 0) {
|
||||
if (params.progress) {
|
||||
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup generation run\n", params_idx, params_count);
|
||||
}
|
||||
test_gen(ctx, 1, t.n_threads);
|
||||
test_gen(ctx, 1, 0, t.n_threads);
|
||||
}
|
||||
|
||||
for (int i = 0; i < params.reps; i++) {
|
||||
@@ -1627,14 +1629,14 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "llama-bench: benchmark %d/%zu: prompt run %d/%d\n", params_idx, params_count,
|
||||
i + 1, params.reps);
|
||||
}
|
||||
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
|
||||
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
|
||||
}
|
||||
if (t.n_gen > 0) {
|
||||
if (params.progress) {
|
||||
fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count,
|
||||
i + 1, params.reps);
|
||||
}
|
||||
test_gen(ctx, t.n_gen, t.n_threads);
|
||||
test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
|
||||
}
|
||||
|
||||
uint64_t t_ns = get_time_ns() - t_start;
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "clip.h"
|
||||
#include "stb_image.h"
|
||||
#include "llama.h"
|
||||
#include "llama-cpp.h"
|
||||
#include "ggml.h"
|
||||
#include "console.h"
|
||||
|
||||
@@ -63,7 +64,7 @@ struct gemma3_context {
|
||||
llama_model * model;
|
||||
llama_context * lctx;
|
||||
const llama_vocab * vocab;
|
||||
llama_batch batch;
|
||||
llama_batch_ext_ptr batch;
|
||||
|
||||
int n_threads = 1;
|
||||
llama_pos n_past = 0;
|
||||
@@ -73,7 +74,7 @@ struct gemma3_context {
|
||||
lctx = llama_init.context.get();
|
||||
vocab = llama_model_get_vocab(model);
|
||||
n_threads = params.cpuparams.n_threads;
|
||||
batch = llama_batch_init(params.n_batch, 0, 1);
|
||||
batch.reset(llama_batch_ext_init(params.n_batch, 1));
|
||||
init_clip_model(params);
|
||||
}
|
||||
|
||||
@@ -87,50 +88,18 @@ struct gemma3_context {
|
||||
}
|
||||
};
|
||||
|
||||
struct decode_embd_batch {
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id> seq_id_0;
|
||||
std::vector<llama_seq_id *> seq_ids;
|
||||
std::vector<int8_t> logits;
|
||||
llama_batch batch;
|
||||
decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
|
||||
pos .resize(n_tokens);
|
||||
n_seq_id.resize(n_tokens);
|
||||
seq_ids .resize(n_tokens + 1);
|
||||
logits .resize(n_tokens);
|
||||
seq_id_0.resize(1);
|
||||
seq_id_0[0] = seq_id;
|
||||
seq_ids [n_tokens] = nullptr;
|
||||
batch = {
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ embd,
|
||||
/*pos =*/ pos.data(),
|
||||
/*n_seq_id =*/ n_seq_id.data(),
|
||||
/*seq_id =*/ seq_ids.data(),
|
||||
/*logits =*/ logits.data(),
|
||||
};
|
||||
for (int i = 0; i < n_tokens; i++) {
|
||||
batch.pos [i] = pos_0 + i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id [i] = seq_id_0.data();
|
||||
batch.logits [i] = false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) {
|
||||
llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true);
|
||||
common_batch_clear(ctx.batch);
|
||||
llama_batch_ext_clear(ctx.batch.get());
|
||||
for (llama_token & t : tokens) {
|
||||
common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false);
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(ctx.batch.get(), t, ctx.n_past++, &seq_id, 1, false);
|
||||
}
|
||||
if (logits_last) {
|
||||
ctx.batch.logits[ctx.batch.n_tokens - 1] = true;
|
||||
llama_batch_ext_set_output_last(ctx.batch.get());
|
||||
}
|
||||
// LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str());
|
||||
if (llama_decode(ctx.lctx, ctx.batch)) {
|
||||
if (llama_decode_ext(ctx.lctx, ctx.batch.get())) {
|
||||
LOG_ERR("Failed to decode text\n");
|
||||
return 1;
|
||||
}
|
||||
@@ -179,8 +148,8 @@ static int eval_image(gemma3_context & ctx, std::string & fname) {
|
||||
int64_t t1 = ggml_time_ms();
|
||||
eval_text(ctx, "<start_of_image>");
|
||||
llama_set_causal_attn(ctx.lctx, false);
|
||||
decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0);
|
||||
if (llama_decode(ctx.lctx, batch_img.batch)) {
|
||||
llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, n_embd, ctx.n_past, 0));
|
||||
if (llama_decode_ext(ctx.lctx, batch_img.get())) {
|
||||
LOG_ERR("failed to decode image\n");
|
||||
return 1;
|
||||
}
|
||||
@@ -210,9 +179,10 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_
|
||||
fflush(stdout);
|
||||
|
||||
// eval the token
|
||||
common_batch_clear(ctx.batch);
|
||||
common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
|
||||
if (llama_decode(ctx.lctx, ctx.batch)) {
|
||||
llama_batch_ext_clear(ctx.batch.get());
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(ctx.batch.get(), token_id, ctx.n_past++, &seq_id, 1, true);
|
||||
if (llama_decode_ext(ctx.lctx, ctx.batch.get())) {
|
||||
LOG_ERR("failed to decode token\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -20,7 +20,8 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
|
||||
auto batch = llama_batch_ext_ptr::init_from_text(&tokens[i], n_eval, *n_past, 0, true);
|
||||
if (llama_decode_ext(ctx_llama, batch.get())) {
|
||||
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "llava.h"
|
||||
|
||||
#include "llama.h"
|
||||
#include "llama-cpp.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cerrno>
|
||||
@@ -438,39 +439,6 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
|
||||
return true;
|
||||
}
|
||||
|
||||
struct llava_embd_batch {
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id> seq_id_0;
|
||||
std::vector<llama_seq_id *> seq_ids;
|
||||
std::vector<int8_t> logits;
|
||||
llama_batch batch;
|
||||
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
|
||||
pos .resize(n_tokens);
|
||||
n_seq_id.resize(n_tokens);
|
||||
seq_ids .resize(n_tokens + 1);
|
||||
logits .resize(n_tokens);
|
||||
seq_id_0.resize(1);
|
||||
seq_id_0[0] = seq_id;
|
||||
seq_ids [n_tokens] = nullptr;
|
||||
batch = {
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ embd,
|
||||
/*pos =*/ pos.data(),
|
||||
/*n_seq_id =*/ n_seq_id.data(),
|
||||
/*seq_id =*/ seq_ids.data(),
|
||||
/*logits =*/ logits.data(),
|
||||
};
|
||||
for (int i = 0; i < n_tokens; i++) {
|
||||
batch.pos [i] = pos_0 + i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id [i] = seq_id_0.data();
|
||||
batch.logits [i] = false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
|
||||
int n_embd = llama_model_n_embd(llama_get_model(ctx_llama));
|
||||
|
||||
@@ -480,8 +448,8 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
|
||||
n_eval = n_batch;
|
||||
}
|
||||
float * embd = image_embed->embed+i*n_embd;
|
||||
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
|
||||
if (llama_decode(ctx_llama, llava_batch.batch)) {
|
||||
auto batch = llama_batch_ext_ptr::init_from_embd(embd, n_eval, n_embd, 0, 0);
|
||||
if (llama_decode_ext(ctx_llama, batch.get())) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -101,7 +101,8 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
|
||||
auto batch = llama_batch_ext_ptr::init_from_text(&tokens[i], n_eval, *n_past, 0, true);
|
||||
if (llama_decode_ext(ctx_llama, batch.get())) {
|
||||
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -66,17 +66,11 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
|
||||
memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos));
|
||||
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
|
||||
|
||||
llama_batch batch = {
|
||||
int32_t(n_eval), // n_tokens
|
||||
nullptr, // token
|
||||
(image_embed->embed+i*n_embd), // embed
|
||||
batch_mrope_pos.data(), // pos
|
||||
nullptr, // n_seq_id
|
||||
nullptr, // seq_id
|
||||
nullptr, // logits
|
||||
};
|
||||
float * batch_embd = image_embed->embed+i*n_embd;
|
||||
auto batch = llama_batch_ext_ptr::init_from_embd(batch_embd, n_eval, n_embd, 0, 0);
|
||||
llama_batch_ext_set_pos(batch.get(), batch_mrope_pos.data(), n_eval);
|
||||
|
||||
if (llama_decode(ctx_llama, batch)) {
|
||||
if (llama_decode_ext(ctx_llama, batch.get())) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
@@ -95,16 +89,24 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
auto batch = llama_batch_get_one(&tokens[i], n_eval);
|
||||
// TODO: add mrope pos ids somewhere else
|
||||
pos.resize(batch.n_tokens * 4);
|
||||
std::fill(pos.begin(), pos.end(), 0);
|
||||
for (int j = 0; j < batch.n_tokens * 3; j ++) {
|
||||
pos[j] = *st_pos_id + (j % batch.n_tokens);
|
||||
}
|
||||
batch.pos = pos.data();
|
||||
|
||||
if (llama_decode(ctx_llama, batch)) {
|
||||
// TODO: add mrope pos ids somewhere else
|
||||
int n_tokens = n_eval;
|
||||
pos.resize(n_tokens * 4);
|
||||
std::fill(pos.begin(), pos.end(), 0);
|
||||
for (int j = 0; j < n_tokens * 3; j ++) {
|
||||
pos[j] = *st_pos_id + (j % n_tokens);
|
||||
}
|
||||
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init(n_eval, 1));
|
||||
for (int j = 0; j < n_eval; j++) {
|
||||
llama_token token = tokens[i + j];
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch.get(), token, pos[j], &seq_id, 1, false);
|
||||
}
|
||||
llama_batch_ext_set_output_last(batch.get());
|
||||
|
||||
if (llama_decode_ext(ctx_llama, batch.get())) {
|
||||
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -92,8 +92,10 @@ int main(int argc, char ** argv) {
|
||||
const auto t_enc_start = ggml_time_us();
|
||||
|
||||
// eval the prompt
|
||||
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1));
|
||||
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
|
||||
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
|
||||
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
|
||||
llama_decode_ext(ctx, batch0.get());
|
||||
llama_decode_ext(ctx, batch1.get());
|
||||
|
||||
for (int s = 1; s < W + G + 1; ++s) {
|
||||
llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
|
||||
@@ -115,7 +117,7 @@ int main(int argc, char ** argv) {
|
||||
// seq_id == 0 : the current input token
|
||||
// seq_id [1, W] : tokens from the past N - 1 Jacobi iterations
|
||||
// seq_id [W + 1, W + G] : verification n-grams
|
||||
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
|
||||
llama_batch_ext * batch = llama_batch_ext_init(params.n_ctx, W + G + 1);
|
||||
|
||||
// target model sampling context
|
||||
struct common_sampler * smpl = common_sampler_init(model, params.sampling);
|
||||
@@ -204,10 +206,10 @@ int main(int argc, char ** argv) {
|
||||
// V V V V V V
|
||||
// id
|
||||
{
|
||||
common_batch_clear(batch);
|
||||
llama_batch_ext_clear(batch);
|
||||
|
||||
// current token - first token of the first level
|
||||
common_batch_add(batch, id, n_past, seq_id_all, true);
|
||||
llama_batch_ext_add_text(batch, id, n_past, seq_id_all.data(), seq_id_all.size(), true);
|
||||
|
||||
// verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation
|
||||
{
|
||||
@@ -230,9 +232,10 @@ int main(int argc, char ** argv) {
|
||||
const llama_token t = ngrams_observed.tokens[idx + j];
|
||||
|
||||
ngrams_cur[g].tokens [j + 1] = t;
|
||||
ngrams_cur[g].i_batch[j + 1] = batch.n_tokens;
|
||||
ngrams_cur[g].i_batch[j + 1] = llama_batch_ext_get_n_tokens(batch);
|
||||
|
||||
common_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true);
|
||||
llama_seq_id seq_id = W + 1 + g;
|
||||
llama_batch_ext_add_text(batch, t, n_past + j + 1, &seq_id, 1, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -244,18 +247,20 @@ int main(int argc, char ** argv) {
|
||||
seq_id_look[j] = i + j + 1;
|
||||
}
|
||||
|
||||
common_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false);
|
||||
llama_batch_ext_add_text(batch, tokens_j[0][i], n_past + i,
|
||||
seq_id_look.data(), seq_id_look.size(), false);
|
||||
}
|
||||
|
||||
// fill the rest of the levels
|
||||
for (int j = 1; j < N - 1; j++) {
|
||||
for (int i = 0; i < W; i++) {
|
||||
common_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2);
|
||||
llama_seq_id seq_id = i + 1;
|
||||
llama_batch_ext_add_text(batch, tokens_j[j][i], n_past + j + i, &seq_id, 1, j == N - 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
if (llama_decode_ext(ctx, batch) != 0) {
|
||||
LOG_ERR("\n\n%s: llama_decode failed - increase KV cache size\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@@ -475,7 +480,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
llama_kv_cache_view_free(&kvc_view);
|
||||
|
||||
llama_batch_free(batch);
|
||||
llama_batch_ext_free(batch);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
|
||||
@@ -91,8 +91,10 @@ int main(int argc, char ** argv){
|
||||
|
||||
const auto t_enc_start = ggml_time_us();
|
||||
|
||||
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1));
|
||||
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
|
||||
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
|
||||
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
|
||||
llama_decode_ext(ctx, batch0.get());
|
||||
llama_decode_ext(ctx, batch1.get());
|
||||
|
||||
const auto t_enc_end = ggml_time_us();
|
||||
|
||||
@@ -108,7 +110,7 @@ int main(int argc, char ** argv){
|
||||
|
||||
std::vector<llama_token> draft;
|
||||
|
||||
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1);
|
||||
llama_batch_ext * batch_tgt = llama_batch_ext_init(params.n_ctx, 1);
|
||||
|
||||
// debug
|
||||
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1);
|
||||
@@ -194,8 +196,9 @@ int main(int argc, char ** argv){
|
||||
// clean the cache of draft tokens that weren't accepted
|
||||
llama_kv_self_seq_rm(ctx, 0, n_past, -1);
|
||||
|
||||
common_batch_clear(batch_tgt);
|
||||
common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
|
||||
const llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_clear(batch_tgt);
|
||||
llama_batch_ext_add_text(batch_tgt, draft[0], n_past, &seq_id, 1, true);
|
||||
|
||||
// Draft already contains a single token sampled from the model:
|
||||
GGML_ASSERT(draft.size() == 1);
|
||||
@@ -205,13 +208,13 @@ int main(int argc, char ** argv){
|
||||
common_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
|
||||
|
||||
for (size_t i = 1; i < draft.size(); ++i) {
|
||||
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
|
||||
llama_batch_ext_add_text(batch_tgt, draft[i], n_past + i, &seq_id, 1, true);
|
||||
}
|
||||
|
||||
t_draft_us += ggml_time_us() - t_start_draft_us;
|
||||
n_drafted += draft.size() - 1;
|
||||
|
||||
llama_decode(ctx, batch_tgt);
|
||||
llama_decode_ext(ctx, batch_tgt);
|
||||
++n_past;
|
||||
|
||||
draft.erase(draft.begin());
|
||||
@@ -243,7 +246,7 @@ int main(int argc, char ** argv){
|
||||
|
||||
common_sampler_free(smpl);
|
||||
|
||||
llama_batch_free(batch_tgt);
|
||||
llama_batch_ext_free(batch_tgt);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
|
||||
+36
-5
@@ -27,12 +27,24 @@ Once downloaded, place your model in the models folder in llama.cpp.
|
||||
##### Input prompt (One-and-done)
|
||||
|
||||
```bash
|
||||
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --prompt "Once upon a time"
|
||||
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf -no-cnv --prompt "Once upon a time"
|
||||
```
|
||||
##### Conversation mode (Allow for continuous interaction with the model)
|
||||
|
||||
```bash
|
||||
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf -cnv --chat-template gemma
|
||||
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --chat-template gemma
|
||||
```
|
||||
|
||||
##### Conversation mode using built-in jinja chat template
|
||||
|
||||
```bash
|
||||
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --jinja
|
||||
```
|
||||
|
||||
##### One-and-done query using jinja with custom system prompt and a starting prompt
|
||||
|
||||
```bash
|
||||
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --jinja --single-turn -sys "You are a helpful assistant" -p "Hello"
|
||||
```
|
||||
|
||||
##### Infinite text from a starting prompt (you can use `Ctrl-C` to stop it):
|
||||
@@ -44,12 +56,24 @@ Once downloaded, place your model in the models folder in llama.cpp.
|
||||
|
||||
##### Input prompt (One-and-done)
|
||||
```powershell
|
||||
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --prompt "Once upon a time"
|
||||
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf -no-cnv --prompt "Once upon a time"
|
||||
```
|
||||
##### Conversation mode (Allow for continuous interaction with the model)
|
||||
|
||||
```powershell
|
||||
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf -cnv --chat-template gemma
|
||||
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --chat-template gemma
|
||||
```
|
||||
|
||||
##### Conversation mode using built-in jinja chat template
|
||||
|
||||
```powershell
|
||||
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --jinja
|
||||
```
|
||||
|
||||
##### One-and-done query using jinja with custom system prompt and a starting prompt
|
||||
|
||||
```powershell
|
||||
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --jinja --single-turn -sys "You are a helpful assistant" -p "Hello"
|
||||
```
|
||||
|
||||
#### Infinite text from a starting prompt (you can use `Ctrl-C` to stop it):
|
||||
@@ -77,6 +101,8 @@ The `llama-cli` program provides several ways to interact with the LLaMA models
|
||||
|
||||
- `--prompt PROMPT`: Provide a prompt directly as a command-line option.
|
||||
- `--file FNAME`: Provide a file containing a prompt or multiple prompts.
|
||||
- `--system-prompt PROMPT`: Provide a system prompt (will otherwise use the default one in the chat template (if provided)).
|
||||
- `--system-prompt-file FNAME`: Provide a file containing a system prompt.
|
||||
- `--interactive-first`: Run the program in interactive mode and wait for input right away. (More on this below.)
|
||||
|
||||
## Interaction
|
||||
@@ -89,7 +115,10 @@ In interactive mode, users can participate in text generation by injecting their
|
||||
|
||||
- `-i, --interactive`: Run the program in interactive mode, allowing users to engage in real-time conversations or provide specific instructions to the model.
|
||||
- `--interactive-first`: Run the program in interactive mode and immediately wait for user input before starting the text generation.
|
||||
- `-cnv, --conversation`: Run the program in conversation mode (does not print special tokens and suffix/prefix, use default chat template) (default: false)
|
||||
- `-cnv, --conversation`: Run the program in conversation mode (does not print special tokens and suffix/prefix, use default or provided chat template) (default: true if chat template found)
|
||||
- `-no-cnv`: Disable conversation mode (default: false)
|
||||
- `-st, --single-turn`: Only process a single conversation turn (user input) and then exit.
|
||||
- `--jinja`: Enable jinja chat template parser, will use the model's built-in template or a user-provided one (default: false)
|
||||
- `--color`: Enable colorized output to differentiate visually distinguishing between prompts, user input, and generated text.
|
||||
|
||||
By understanding and utilizing these interaction options, you can create engaging and dynamic experiences with the LLaMA models, tailoring the text generation process to your specific needs.
|
||||
@@ -125,6 +154,8 @@ When --in-prefix or --in-suffix options are enabled the chat template ( --chat-t
|
||||
|
||||
Example usage: `--chat-template gemma`
|
||||
|
||||
`--chat-template-file FNAME`: Load a custom jinja chat template from an external file, useful if the model contains outdated or incompatible template, some examples can be found in models/templates. Up-to-date chat templates can be downloaded from Hugging Face using scripts/get_chat_template.py
|
||||
|
||||
## Context Management
|
||||
|
||||
During text generation, LLaMA models have a limited context size, which means they can only consider a certain number of tokens from the input and generated text. When the context fills up, the model resets internally, potentially losing some information from the beginning of the conversation or instructions. Context management options help maintain continuity and coherence in these situations.
|
||||
|
||||
@@ -548,7 +548,8 @@ int main(int argc, char ** argv) {
|
||||
int enc_input_size = embd_inp.size();
|
||||
llama_token * enc_input_buf = embd_inp.data();
|
||||
|
||||
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size))) {
|
||||
auto batch = llama_batch_ext_ptr::init_from_text(enc_input_buf, enc_input_size, 0, 0, true);
|
||||
if (llama_decode_ext(ctx, batch.get())) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@@ -668,7 +669,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
|
||||
auto batch = llama_batch_ext_ptr::init_from_text(&embd[i], n_eval, n_past, 0, true);
|
||||
if (llama_decode_ext(ctx, batch.get())) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -174,7 +174,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
|
||||
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
|
||||
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
|
||||
llama_batch_ext * batch = llama_batch_ext_init(n_ctx, 1);
|
||||
|
||||
int32_t n_total_prompt = 0;
|
||||
int32_t n_total_gen = 0;
|
||||
@@ -192,10 +192,11 @@ int main(int argc, char ** argv) {
|
||||
LOG_INF("%s: Evaluating the system prompt ...\n", __func__);
|
||||
|
||||
for (int32_t i = 0; i < n_tokens_system; ++i) {
|
||||
common_batch_add(batch, tokens_system[i], i, { 0 }, false);
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch, tokens_system[i], i, &seq_id, 1, false);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
if (llama_decode_ext(ctx, batch) != 0) {
|
||||
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@@ -216,7 +217,7 @@ int main(int argc, char ** argv) {
|
||||
common_kv_cache_dump_view_seqs(kvc_view, 40);
|
||||
}
|
||||
|
||||
common_batch_clear(batch);
|
||||
llama_batch_ext_clear(batch);
|
||||
|
||||
// decode any currently ongoing sequences
|
||||
for (auto & client : clients) {
|
||||
@@ -224,14 +225,15 @@ int main(int argc, char ** argv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
client.i_batch = batch.n_tokens;
|
||||
client.i_batch = llama_batch_ext_get_n_tokens(batch);
|
||||
|
||||
common_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true);
|
||||
llama_seq_id seq_id = client.id + 1;
|
||||
llama_batch_ext_add_text(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, &seq_id, 1, true);
|
||||
|
||||
client.n_decoded += 1;
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
if (llama_batch_ext_get_n_tokens(batch) == 0) {
|
||||
// all sequences have ended - clear the entire KV cache
|
||||
for (int i = 1; i <= n_clients; ++i) {
|
||||
llama_kv_self_seq_rm(ctx, i, -1, -1);
|
||||
@@ -243,7 +245,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// insert new sequences for decoding
|
||||
if (cont_batching || batch.n_tokens == 0) {
|
||||
if (cont_batching || llama_batch_ext_get_n_tokens(batch) == 0) {
|
||||
for (auto & client : clients) {
|
||||
if (client.seq_id == -1 && g_seq_id < n_seq) {
|
||||
client.seq_id = g_seq_id;
|
||||
@@ -262,17 +264,18 @@ int main(int argc, char ** argv) {
|
||||
tokens_prompt = common_tokenize(ctx, client.prompt, false);
|
||||
|
||||
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
||||
common_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
|
||||
llama_seq_id seq_id = client.id + 1;
|
||||
llama_batch_ext_add_text(batch, tokens_prompt[i], i + n_tokens_system, &seq_id, 1, false);
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
if (batch.n_tokens > 0) {
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
if (llama_batch_ext_get_n_tokens(batch) > 0) {
|
||||
llama_batch_ext_set_output_last(batch);
|
||||
}
|
||||
|
||||
client.n_prompt = tokens_prompt.size();
|
||||
client.n_decoded = 0;
|
||||
client.i_batch = batch.n_tokens - 1;
|
||||
client.i_batch = llama_batch_ext_get_n_tokens(batch) - 1;
|
||||
|
||||
LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id);
|
||||
|
||||
@@ -286,14 +289,15 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
if (llama_batch_ext_get_n_tokens(batch) == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
// process in chunks of params.n_batch
|
||||
int32_t n_batch = params.n_batch;
|
||||
|
||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||
int32_t n_tokens_in_batch = llama_batch_ext_get_n_tokens(batch);
|
||||
for (int32_t i = 0; i < (int32_t) n_tokens_in_batch; i += n_batch) {
|
||||
// experiment: process in powers of 2
|
||||
//if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
|
||||
// n_batch /= 2;
|
||||
@@ -301,19 +305,11 @@ int main(int argc, char ** argv) {
|
||||
// continue;
|
||||
//}
|
||||
|
||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (n_tokens_in_batch - i));
|
||||
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch.token + i,
|
||||
nullptr,
|
||||
batch.pos + i,
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
};
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
llama_batch_ext * batch_view = llama_batch_ext_get_view(batch, i, n_tokens);
|
||||
const int ret = llama_decode_ext(ctx, batch_view);
|
||||
llama_batch_ext_free(batch_view);
|
||||
if (ret != 0) {
|
||||
if (n_batch == 1 || ret < 0) {
|
||||
// if you get here, it means the KV cache is full - try increasing it via the context size
|
||||
@@ -417,7 +413,7 @@ int main(int argc, char ** argv) {
|
||||
// TODO: print sampling/grammar timings for all clients
|
||||
llama_perf_context_print(ctx);
|
||||
|
||||
llama_batch_free(batch);
|
||||
llama_batch_ext_free(batch);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include "llama-cpp.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
@@ -122,7 +123,7 @@ int main(int argc, char ** argv) {
|
||||
LOG_INF("prompt tokens: %d\n", n_tokens_all);
|
||||
//LOG_INF("prompt: %s\n", params.prompt.c_str());
|
||||
|
||||
llama_batch batch = llama_batch_init(params.n_batch, 0, 1);
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init(params.n_batch, 1));
|
||||
|
||||
int n_past = 0;
|
||||
|
||||
@@ -140,17 +141,18 @@ int main(int argc, char ** argv) {
|
||||
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
|
||||
}
|
||||
|
||||
common_batch_clear(batch);
|
||||
llama_batch_ext_clear(batch.get());
|
||||
|
||||
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
|
||||
common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false);
|
||||
}
|
||||
|
||||
if (i + n_batch >= n_tokens_all) {
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
llama_batch_ext_set_output_last(batch.get());
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
if (llama_decode_ext(ctx, batch.get()) != 0) {
|
||||
LOG_INF("%s: llama_decode() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@@ -174,17 +176,18 @@ int main(int argc, char ** argv) {
|
||||
|
||||
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
|
||||
|
||||
common_batch_clear(batch);
|
||||
llama_batch_ext_clear(batch.get());
|
||||
|
||||
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
|
||||
common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false);
|
||||
}
|
||||
|
||||
if (i + n_batch >= n_tokens_all) {
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
llama_batch_ext_set_output_last(batch.get());
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
if (llama_decode_ext(ctx, batch.get()) != 0) {
|
||||
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@@ -223,7 +226,7 @@ int main(int argc, char ** argv) {
|
||||
while (n_cur <= n_len) {
|
||||
// sample the next token
|
||||
{
|
||||
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
|
||||
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(batch.get()) - 1);
|
||||
|
||||
// is it an end of generation?
|
||||
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {
|
||||
@@ -237,16 +240,17 @@ int main(int argc, char ** argv) {
|
||||
n_decode += 1;
|
||||
|
||||
// prepare the next batch
|
||||
common_batch_clear(batch);
|
||||
llama_batch_ext_clear(batch.get());
|
||||
|
||||
// push this new token for next evaluation
|
||||
common_batch_add(batch, new_token_id, n_past++, { 0 }, true);
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch.get(), new_token_id, n_past++, &seq_id, 1, true);
|
||||
}
|
||||
|
||||
n_cur += 1;
|
||||
|
||||
// evaluate the current batch with the transformer model
|
||||
if (llama_decode(ctx, batch)) {
|
||||
if (llama_decode_ext(ctx, batch.get())) {
|
||||
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return 1;
|
||||
}
|
||||
@@ -266,8 +270,6 @@ int main(int argc, char ** argv) {
|
||||
|
||||
llama_sampler_free(smpl);
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
llama_free(ctx);
|
||||
llama_model_free(model);
|
||||
|
||||
|
||||
@@ -363,21 +363,20 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
|
||||
// clear the KV cache
|
||||
llama_kv_self_clear(ctx);
|
||||
|
||||
llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||
common_batch batch(n_batch, 1);
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
const int batch_size = std::min(end - batch_start, n_batch);
|
||||
|
||||
common_batch_clear(batch);
|
||||
batch.clear();
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
|
||||
batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true);
|
||||
}
|
||||
|
||||
//LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
||||
if (llama_decode(ctx, batch)) {
|
||||
if (llama_decode_ext(ctx, batch.get())) {
|
||||
//LOG_ERR("%s : failed to eval\n", __func__);
|
||||
llama_batch_free(batch);
|
||||
return {tokens, -1, logit_history, prob_history};
|
||||
}
|
||||
|
||||
@@ -397,8 +396,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
|
||||
}
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
if (i == 0) {
|
||||
@@ -504,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
|
||||
GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
|
||||
GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
|
||||
|
||||
llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);
|
||||
common_batch batch(std::min(n_batch, n_ctx*n_seq), 1);
|
||||
|
||||
std::vector<float> logits;
|
||||
if (num_batches > 1) {
|
||||
@@ -555,7 +552,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
|
||||
|
||||
int n_outputs = 0;
|
||||
|
||||
batch.n_tokens = 0;
|
||||
batch.clear();
|
||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||
int seq_start = batch_start + seq*n_ctx;
|
||||
|
||||
@@ -568,22 +565,18 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
|
||||
}
|
||||
|
||||
for (int k = 0; k < batch_size; ++k) {
|
||||
const int idx = seq*n_ctx + k;
|
||||
batch.token [idx] = tokens[seq_start + k];
|
||||
batch.pos [idx] = j*n_batch + k;
|
||||
batch.n_seq_id[idx] = 1;
|
||||
batch.seq_id [idx][0] = seq;
|
||||
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
|
||||
const llama_pos pos = j*n_batch + k;
|
||||
bool output = pos >= first;
|
||||
batch.add_text(tokens[seq_start + k], pos, seq, output);
|
||||
|
||||
n_outputs += batch.logits[idx] != 0;
|
||||
n_outputs += output ? 1 : 0;
|
||||
}
|
||||
batch.n_tokens += batch_size;
|
||||
|
||||
// restore the original token in case it was set to BOS
|
||||
tokens[seq_start] = token_org;
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
if (llama_decode_ext(ctx, batch.get())) {
|
||||
LOG_INF("%s : failed to eval\n", __func__);
|
||||
return {tokens, -1, logit_history, prob_history};
|
||||
}
|
||||
@@ -653,36 +646,23 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
|
||||
LOG_ERR("Unexpected negative standard deviation of log(prob)\n");
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
return {tokens, ppl, logit_history, prob_history};
|
||||
}
|
||||
|
||||
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
|
||||
static bool decode_helper(llama_context * ctx, common_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
|
||||
int prev_outputs = 0;
|
||||
for (int i = 0; i < (int) batch.n_tokens; i += n_batch) {
|
||||
const int n_tokens = std::min<int>(n_batch, batch.n_tokens - i);
|
||||
for (int i = 0; i < (int) batch.get_n_tokens(); i += n_batch) {
|
||||
const int n_tokens = std::min<int>(n_batch, batch.get_n_tokens() - i);
|
||||
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch.token + i,
|
||||
nullptr,
|
||||
batch.pos + i,
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
};
|
||||
common_batch batch_view = batch.get_view(i, n_tokens);
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
const int ret = llama_decode_ext(ctx, batch_view.get());
|
||||
if (ret != 0) {
|
||||
LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
|
||||
return false;
|
||||
}
|
||||
|
||||
int n_outputs = 0;
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
n_outputs += batch_view.logits[i] != 0;
|
||||
}
|
||||
int n_outputs = batch_view.n_outputs;
|
||||
|
||||
memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
|
||||
|
||||
@@ -863,7 +843,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
|
||||
const int max_tasks_per_batch = 32;
|
||||
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
||||
|
||||
llama_batch batch = llama_batch_init(n_ctx, 0, 4);
|
||||
common_batch batch(n_ctx, 4);
|
||||
|
||||
std::vector<float> tok_logits(n_vocab);
|
||||
// TODO: this could be made smaller; it's currently the worst-case size
|
||||
@@ -879,7 +859,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
|
||||
size_t i1 = i0;
|
||||
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
|
||||
|
||||
common_batch_clear(batch);
|
||||
batch.clear();
|
||||
|
||||
// batch as much tasks as possible into the available context
|
||||
// each task has 4 unique sequence ids - one for each ending
|
||||
@@ -895,9 +875,9 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
|
||||
common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
|
||||
batch.add_text_multi_seq(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
||||
llama_batch_ext_set_output_last(batch.get());
|
||||
n_logits += 1;
|
||||
|
||||
for (int s = 0; s < 4; ++s) {
|
||||
@@ -905,7 +885,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
|
||||
// TODO: don't evaluate the last token of each sequence
|
||||
for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
|
||||
const bool needs_logits = i < seq_tokens_size - 1;
|
||||
common_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
|
||||
batch.add_text_multi_seq(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
|
||||
n_logits += needs_logits;
|
||||
}
|
||||
}
|
||||
@@ -992,8 +972,6 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
|
||||
i0 = i1 - 1;
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
LOG("\n");
|
||||
}
|
||||
|
||||
@@ -1147,7 +1125,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
|
||||
const int max_tasks_per_batch = 128;
|
||||
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
||||
|
||||
llama_batch batch = llama_batch_init(n_ctx, 0, 2);
|
||||
common_batch batch(n_ctx, 2);
|
||||
|
||||
std::vector<float> tok_logits(n_vocab);
|
||||
// TODO: this could be made smaller; it's currently the worst-case size
|
||||
@@ -1166,7 +1144,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
|
||||
size_t i1 = i0;
|
||||
size_t i_logits = 0;
|
||||
|
||||
common_batch_clear(batch);
|
||||
batch.clear();
|
||||
|
||||
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
|
||||
int n_logits = 0;
|
||||
@@ -1176,15 +1154,15 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
|
||||
common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
|
||||
batch.add_text_multi_seq(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
llama_batch_ext_set_output_last(batch.get());
|
||||
n_logits += 1;
|
||||
|
||||
for (int s = 0; s < 2; ++s) {
|
||||
// TODO: end before the last token, no need to predict past the end of the sequences
|
||||
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
|
||||
common_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
|
||||
batch.add_text_multi_seq(data[i1].seq_tokens[s][i], i, { s0 + s }, true);
|
||||
n_logits += 1;
|
||||
}
|
||||
}
|
||||
@@ -1501,7 +1479,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
|
||||
const int max_tasks_per_batch = 32;
|
||||
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
||||
|
||||
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
||||
common_batch batch(n_ctx, max_seq);
|
||||
|
||||
std::vector<float> tok_logits(n_vocab);
|
||||
std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
|
||||
@@ -1521,7 +1499,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
|
||||
size_t i1 = i0;
|
||||
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
|
||||
|
||||
common_batch_clear(batch);
|
||||
batch.clear();
|
||||
|
||||
// batch as much tasks as possible into the available context
|
||||
// each task has 4 unique sequence ids - one for each ending
|
||||
@@ -1544,9 +1522,9 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
|
||||
|
||||
for (size_t i = 0; i < cur_task.common_prefix; ++i) {
|
||||
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
|
||||
common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
|
||||
batch.add_text_multi_seq(cur_task.seq_tokens[0][i], i, batch_indeces, false);
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
||||
llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix
|
||||
n_logits += 1;
|
||||
|
||||
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
||||
@@ -1554,7 +1532,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
|
||||
// TODO: don't evaluate the last token of each sequence
|
||||
for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
|
||||
const bool needs_logits = i < seq_tokens_size - 1;
|
||||
common_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
|
||||
batch.add_text_multi_seq(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
|
||||
n_logits += needs_logits;
|
||||
}
|
||||
}
|
||||
@@ -1653,8 +1631,6 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
|
||||
i0 = i1 - 1;
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
if (n_done < 100 && (params.multiple_choice_tasks != 0 && params.multiple_choice_tasks < (size_t)n_task)) return;
|
||||
|
||||
float p = 1.f*n_correct/n_done;
|
||||
@@ -1767,7 +1743,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
||||
// clear the KV cache
|
||||
llama_kv_self_clear(ctx);
|
||||
|
||||
llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||
common_batch batch(n_batch, 1);
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
@@ -1781,14 +1757,13 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
||||
tokens[batch_start] = llama_vocab_bos(vocab);
|
||||
}
|
||||
|
||||
common_batch_clear(batch);
|
||||
batch.clear();
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
|
||||
batch.add_text_multi_seq(tokens[batch_start + i], j*n_batch + i, {0}, true);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
if (llama_decode_ext(ctx, batch.get())) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
llama_batch_free(batch);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1801,8 +1776,6 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
||||
}
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
if (i == 0) {
|
||||
|
||||
@@ -74,40 +74,56 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
|
||||
return chunks;
|
||||
}
|
||||
|
||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
|
||||
static void batch_add_seq(common_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
|
||||
size_t n_tokens = tokens.size();
|
||||
for (size_t i = 0; i < n_tokens; i++) {
|
||||
common_batch_add(batch, tokens[i], i, { seq_id }, true);
|
||||
batch.add_text(tokens[i], i, seq_id, true);
|
||||
}
|
||||
}
|
||||
|
||||
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
|
||||
static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm = 2) {
|
||||
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
|
||||
const struct llama_model * model = llama_get_model(ctx);
|
||||
|
||||
// clear previous kv_cache values (irrelevant for embeddings)
|
||||
llama_kv_self_clear(ctx);
|
||||
|
||||
// run model
|
||||
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
||||
if (llama_decode(ctx, batch) < 0) {
|
||||
LOG_ERR("%s : failed to decode\n", __func__);
|
||||
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq);
|
||||
if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
|
||||
// encoder-only model
|
||||
if (llama_encode_ext(ctx, batch.get()) < 0) {
|
||||
LOG_ERR("%s : failed to encode\n", __func__);
|
||||
}
|
||||
} else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
|
||||
// decoder-only model
|
||||
if (llama_decode_ext(ctx, batch.get()) < 0) {
|
||||
LOG_ERR("%s : failed to decode\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; i++) {
|
||||
if (!batch.logits[i]) {
|
||||
for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) {
|
||||
if (!batch.tokens[i].logits) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// try to get sequence embeddings - supported only when pooling_type is not NONE
|
||||
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||
if (embd == NULL) {
|
||||
const float * embd = nullptr;
|
||||
int embd_pos = 0;
|
||||
|
||||
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
// try to get token embeddings
|
||||
embd = llama_get_embeddings_ith(ctx, i);
|
||||
if (embd == NULL) {
|
||||
LOG_ERR("%s: failed to get embeddings for token %d\n", __func__, i);
|
||||
continue;
|
||||
}
|
||||
embd_pos = i;
|
||||
GGML_ASSERT(embd != NULL && "failed to get token embeddings");
|
||||
} else {
|
||||
// try to get sequence embeddings - supported only when pooling_type is not NONE
|
||||
embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id);
|
||||
embd_pos = batch.tokens[i].seq_id;
|
||||
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
|
||||
}
|
||||
|
||||
float * out = output + batch.seq_id[i][0] * n_embd;
|
||||
common_embd_normalize(embd, out, n_embd, 2);
|
||||
float * out = output + embd_pos * n_embd;
|
||||
common_embd_normalize(embd, out, n_embd, embd_norm);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -214,7 +230,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// initialize batch
|
||||
const int n_chunks = chunks.size();
|
||||
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||
struct common_batch batch = common_batch(n_batch, 1);
|
||||
|
||||
// allocate output
|
||||
const int n_embd = llama_model_n_embd(model);
|
||||
@@ -231,10 +247,10 @@ int main(int argc, char ** argv) {
|
||||
const uint64_t n_toks = inp.size();
|
||||
|
||||
// encode if at capacity
|
||||
if (batch.n_tokens + n_toks > n_batch) {
|
||||
if (llama_batch_ext_get_n_tokens(batch.get()) + n_toks > n_batch) {
|
||||
float * out = emb + p * n_embd;
|
||||
batch_decode(ctx, batch, out, s, n_embd);
|
||||
common_batch_clear(batch);
|
||||
batch.clear();
|
||||
p += s;
|
||||
s = 0;
|
||||
}
|
||||
@@ -255,7 +271,7 @@ int main(int argc, char ** argv) {
|
||||
chunks[i].tokens.clear();
|
||||
}
|
||||
|
||||
struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1);
|
||||
struct common_batch query_batch = common_batch(n_batch, 1);
|
||||
|
||||
// start loop, receive query and return top k similar chunks based on cosine similarity
|
||||
std::string query;
|
||||
@@ -269,7 +285,7 @@ int main(int argc, char ** argv) {
|
||||
std::vector<float> query_emb(n_embd, 0);
|
||||
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);
|
||||
|
||||
common_batch_clear(query_batch);
|
||||
query_batch.clear();
|
||||
|
||||
// compute cosine similarities
|
||||
{
|
||||
@@ -299,6 +315,5 @@ int main(int argc, char ** argv) {
|
||||
llama_perf_context_print(ctx);
|
||||
|
||||
// clean up
|
||||
llama_batch_free(query_batch);
|
||||
llama_backend_free();
|
||||
}
|
||||
|
||||
@@ -640,6 +640,7 @@ class LlamaData {
|
||||
std::vector<llama_chat_message> messages; // TODO: switch to common_chat_msg
|
||||
std::list<std::string> msg_strs;
|
||||
std::vector<char> fmtted;
|
||||
llama_pos n_past = 0;
|
||||
|
||||
int init(Opt & opt) {
|
||||
model = initialize_model(opt);
|
||||
@@ -950,10 +951,10 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt
|
||||
}
|
||||
|
||||
// Check if we have enough space in the context to evaluate this batch
|
||||
static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) {
|
||||
static int check_context_size(const llama_context_ptr & ctx, const llama_batch_ext_ptr & batch) {
|
||||
const int n_ctx = llama_n_ctx(ctx.get());
|
||||
const int n_ctx_used = llama_kv_self_used_cells(ctx.get());
|
||||
if (n_ctx_used + batch.n_tokens > n_ctx) {
|
||||
if (n_ctx_used + llama_batch_ext_get_n_tokens(batch.get()) > n_ctx) {
|
||||
printf(LOG_COL_DEFAULT "\n");
|
||||
printe("context size exceeded\n");
|
||||
return 1;
|
||||
@@ -991,15 +992,17 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
|
||||
}
|
||||
|
||||
// prepare a batch for the prompt
|
||||
llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size());
|
||||
auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), llama_data.n_past, 0, true);
|
||||
llama_token new_token_id;
|
||||
while (true) {
|
||||
check_context_size(llama_data.context, batch);
|
||||
if (llama_decode(llama_data.context.get(), batch)) {
|
||||
if (llama_decode_ext(llama_data.context.get(), batch.get())) {
|
||||
printe("failed to decode\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_data.n_past += llama_batch_ext_get_n_tokens(batch.get());
|
||||
|
||||
// sample the next token, check is it an end of generation?
|
||||
new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1);
|
||||
if (llama_vocab_is_eog(vocab, new_token_id)) {
|
||||
@@ -1014,7 +1017,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
|
||||
print_word_and_concatenate_to_response(piece, response);
|
||||
|
||||
// prepare the next batch with the sampled token
|
||||
batch = llama_batch_get_one(&new_token_id, 1);
|
||||
batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, llama_data.n_past, 0, true));
|
||||
}
|
||||
|
||||
printf(LOG_COL_DEFAULT);
|
||||
|
||||
@@ -48,15 +48,11 @@ int main(int argc, char ** argv) {
|
||||
auto tokens = common_tokenize(ctx, params.prompt, true);
|
||||
|
||||
// prepare the batch
|
||||
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
common_batch_add(batch, tokens[i], i, {0}, false);
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true; // generate next token
|
||||
llama_batch_ext * batch = llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true);
|
||||
|
||||
// evaluate prompt
|
||||
llama_decode(ctx, batch);
|
||||
n_past += batch.n_tokens;
|
||||
llama_decode_ext(ctx, batch);
|
||||
n_past += llama_batch_ext_get_n_tokens(batch);
|
||||
|
||||
// save state (rng, logits, embedding and kv_cache) to file
|
||||
{
|
||||
@@ -83,12 +79,13 @@ int main(int argc, char ** argv) {
|
||||
printf("%s", next_token_str.c_str());
|
||||
result0 += next_token_str;
|
||||
|
||||
common_batch_clear(batch);
|
||||
common_batch_add(batch, next_token, n_past, {0}, true);
|
||||
llama_batch_ext_clear(batch);
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
if (llama_decode_ext(ctx, batch)) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_batch_free(batch);
|
||||
llama_batch_ext_free(batch);
|
||||
return 1;
|
||||
}
|
||||
n_past += 1;
|
||||
@@ -135,12 +132,13 @@ int main(int argc, char ** argv) {
|
||||
printf("%s", next_token_str.c_str());
|
||||
result1 += next_token_str;
|
||||
|
||||
common_batch_clear(batch);
|
||||
common_batch_add(batch, next_token, n_past, {0}, true);
|
||||
llama_batch_ext_clear(batch);
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
|
||||
|
||||
if (llama_decode(ctx2, batch)) {
|
||||
if (llama_decode_ext(ctx2, batch)) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_batch_free(batch);
|
||||
llama_batch_ext_free(batch);
|
||||
return 1;
|
||||
}
|
||||
n_past += 1;
|
||||
@@ -216,12 +214,13 @@ int main(int argc, char ** argv) {
|
||||
printf("%s", next_token_str.c_str());
|
||||
result2 += next_token_str;
|
||||
|
||||
common_batch_clear(batch);
|
||||
common_batch_add(batch, next_token, n_past, {1}, true);
|
||||
llama_batch_ext_clear(batch);
|
||||
llama_seq_id seq_id = 1; // seq 1 instead of 0
|
||||
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
|
||||
|
||||
if (llama_decode(ctx3, batch)) {
|
||||
if (llama_decode_ext(ctx3, batch)) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_batch_free(batch);
|
||||
llama_batch_ext_free(batch);
|
||||
return 1;
|
||||
}
|
||||
n_past += 1;
|
||||
@@ -233,7 +232,7 @@ int main(int argc, char ** argv) {
|
||||
llama_sampler_free(smpl2);
|
||||
llama_sampler_free(smpl3);
|
||||
|
||||
llama_batch_free(batch);
|
||||
llama_batch_ext_free(batch);
|
||||
|
||||
if (result0 != result2) {
|
||||
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
|
||||
|
||||
+46
-58
@@ -1224,7 +1224,7 @@ struct server_slot {
|
||||
// only used for completion/embedding/infill/rerank
|
||||
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
|
||||
|
||||
llama_batch batch_spec = {};
|
||||
common_batch batch_spec;
|
||||
|
||||
llama_context * ctx = nullptr;
|
||||
llama_context * ctx_dft = nullptr;
|
||||
@@ -1796,7 +1796,7 @@ struct server_context {
|
||||
|
||||
llama_context_params cparams_dft;
|
||||
|
||||
llama_batch batch = {};
|
||||
common_batch batch;
|
||||
|
||||
bool clean_kv_cache = true;
|
||||
bool add_bos_token = true;
|
||||
@@ -1829,11 +1829,7 @@ struct server_context {
|
||||
|
||||
common_speculative_free(slot.spec);
|
||||
slot.spec = nullptr;
|
||||
|
||||
llama_batch_free(slot.batch_spec);
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
|
||||
bool load_model(const common_params & params) {
|
||||
@@ -1872,6 +1868,10 @@ struct server_context {
|
||||
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
||||
params_dft.n_parallel = 1;
|
||||
|
||||
// force F16 KV cache for the draft model for extra performance
|
||||
params_dft.cache_type_k = GGML_TYPE_F16;
|
||||
params_dft.cache_type_v = GGML_TYPE_F16;
|
||||
|
||||
llama_init_dft = common_init_from_params(params_dft);
|
||||
|
||||
model_dft = llama_init_dft.model.get();
|
||||
@@ -1892,10 +1892,6 @@ struct server_context {
|
||||
cparams_dft = common_context_params_to_llama(params_dft);
|
||||
cparams_dft.n_batch = n_ctx_dft;
|
||||
|
||||
// force F16 KV cache for the draft model for extra performance
|
||||
cparams_dft.type_k = GGML_TYPE_F16;
|
||||
cparams_dft.type_v = GGML_TYPE_F16;
|
||||
|
||||
// the context is not needed - we will create one for each slot
|
||||
llama_init_dft.context.reset();
|
||||
}
|
||||
@@ -1926,7 +1922,7 @@ struct server_context {
|
||||
slot.n_predict = params_base.n_predict;
|
||||
|
||||
if (model_dft) {
|
||||
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
|
||||
slot.batch_spec = common_batch(params_base.speculative.n_max + 1, 1);
|
||||
|
||||
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
|
||||
if (slot.ctx_dft == nullptr) {
|
||||
@@ -1951,7 +1947,7 @@ struct server_context {
|
||||
|
||||
slot.reset();
|
||||
|
||||
slots.push_back(slot);
|
||||
slots.push_back(std::move(slot));
|
||||
}
|
||||
|
||||
default_generation_settings_for_props = slots[0].to_json();
|
||||
@@ -1962,7 +1958,7 @@ struct server_context {
|
||||
const int32_t n_batch = llama_n_batch(ctx);
|
||||
|
||||
// only a single seq_id per token is needed
|
||||
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
|
||||
batch = common_batch(std::max(n_batch, params_base.n_parallel), 1);
|
||||
}
|
||||
|
||||
metrics.init();
|
||||
@@ -2097,9 +2093,7 @@ struct server_context {
|
||||
}
|
||||
|
||||
if (slot.ctx_dft) {
|
||||
llama_batch_free(slot.batch_spec);
|
||||
|
||||
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
|
||||
slot.batch_spec = common_batch(slot.params.speculative.n_max + 1, 1);
|
||||
}
|
||||
|
||||
slot.state = SLOT_STATE_STARTED;
|
||||
@@ -2407,7 +2401,7 @@ struct server_context {
|
||||
queue_results.send(std::move(res));
|
||||
}
|
||||
|
||||
void send_embedding(const server_slot & slot, const llama_batch & batch) {
|
||||
void send_embedding(const server_slot & slot, common_batch & batch) {
|
||||
auto res = std::make_unique<server_task_result_embd>();
|
||||
res->id = slot.id_task;
|
||||
res->index = slot.index;
|
||||
@@ -2418,18 +2412,19 @@ struct server_context {
|
||||
|
||||
std::vector<float> embd_res(n_embd, 0.0f);
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
||||
for (int i = 0; i < batch.get_n_tokens(); ++i) {
|
||||
auto tok = batch.tokens[i];
|
||||
if (!tok.logits || tok.seq_id != slot.id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id);
|
||||
if (embd == NULL) {
|
||||
embd = llama_get_embeddings_ith(ctx, i);
|
||||
}
|
||||
|
||||
if (embd == NULL) {
|
||||
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
||||
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id);
|
||||
|
||||
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
||||
continue;
|
||||
@@ -2450,24 +2445,25 @@ struct server_context {
|
||||
queue_results.send(std::move(res));
|
||||
}
|
||||
|
||||
void send_rerank(const server_slot & slot, const llama_batch & batch) {
|
||||
void send_rerank(const server_slot & slot, common_batch & batch) {
|
||||
auto res = std::make_unique<server_task_result_rerank>();
|
||||
res->id = slot.id_task;
|
||||
res->index = slot.index;
|
||||
res->n_tokens = slot.n_prompt_tokens;
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
||||
for (int i = 0; i < batch.get_n_tokens(); ++i) {
|
||||
auto tok = batch.tokens[i];
|
||||
if (!tok.logits || tok.seq_id != slot.id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id);
|
||||
if (embd == NULL) {
|
||||
embd = llama_get_embeddings_ith(ctx, i);
|
||||
}
|
||||
|
||||
if (embd == NULL) {
|
||||
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
||||
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id);
|
||||
|
||||
res->score = -1e6;
|
||||
continue;
|
||||
@@ -2858,7 +2854,7 @@ struct server_context {
|
||||
}
|
||||
|
||||
// start populating the batch for this iteration
|
||||
common_batch_clear(batch);
|
||||
batch.clear();
|
||||
|
||||
// track if given slot can be batched with slots already in the batch
|
||||
server_slot * slot_batched = nullptr;
|
||||
@@ -2880,9 +2876,9 @@ struct server_context {
|
||||
continue;
|
||||
}
|
||||
|
||||
slot.i_batch = batch.n_tokens;
|
||||
slot.i_batch = batch.get_n_tokens();
|
||||
|
||||
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
|
||||
batch.add_text(slot.sampled, slot.n_past, slot.id, true);
|
||||
|
||||
slot.n_past += 1;
|
||||
|
||||
@@ -2899,7 +2895,7 @@ struct server_context {
|
||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||
|
||||
// next, batch any pending prompts without exceeding n_batch
|
||||
if (params_base.cont_batching || batch.n_tokens == 0) {
|
||||
if (params_base.cont_batching || batch.get_n_tokens() == 0) {
|
||||
for (auto & slot : slots) {
|
||||
// check if we can batch this slot with the previous one
|
||||
if (slot.is_processing()) {
|
||||
@@ -3065,7 +3061,7 @@ struct server_context {
|
||||
// non-causal tasks require to fit the entire prompt in the physical batch
|
||||
if (slot.is_non_causal()) {
|
||||
// cannot fit the prompt in the current batch - will try next iter
|
||||
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
||||
if (batch.get_n_tokens() + slot.n_prompt_tokens > n_batch) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -3085,11 +3081,11 @@ struct server_context {
|
||||
slot.cache_tokens.resize(slot.n_past);
|
||||
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
||||
while (slot.n_past < slot.n_prompt_tokens && batch.get_n_tokens() < n_batch) {
|
||||
// without pooling, we want to output the embeddings for all the tokens in the batch
|
||||
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
|
||||
batch.add_text(prompt_tokens[slot.n_past], slot.n_past, slot.id, need_embd);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
||||
@@ -3099,13 +3095,13 @@ struct server_context {
|
||||
slot.n_past++;
|
||||
}
|
||||
|
||||
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
||||
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.get_n_tokens(), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
||||
|
||||
// entire prompt has been processed
|
||||
if (slot.n_past == slot.n_prompt_tokens) {
|
||||
slot.state = SLOT_STATE_DONE_PROMPT;
|
||||
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
GGML_ASSERT(batch.get_n_tokens() > 0);
|
||||
|
||||
common_sampler_reset(slot.smpl);
|
||||
|
||||
@@ -3115,27 +3111,27 @@ struct server_context {
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
batch.set_logits_last();
|
||||
|
||||
slot.n_decoded = 0;
|
||||
slot.i_batch = batch.n_tokens - 1;
|
||||
slot.i_batch = batch.get_n_tokens() - 1;
|
||||
|
||||
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
|
||||
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.get_n_tokens());
|
||||
}
|
||||
}
|
||||
|
||||
if (batch.n_tokens >= n_batch) {
|
||||
if (batch.get_n_tokens() >= n_batch) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
if (batch.get_n_tokens() == 0) {
|
||||
SRV_WRN("%s", "no tokens to decode\n");
|
||||
return;
|
||||
}
|
||||
|
||||
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
|
||||
SRV_DBG("decoding batch, n_tokens = %d\n", batch.get_n_tokens());
|
||||
|
||||
if (slot_batched) {
|
||||
// make sure we're in the right embedding mode
|
||||
@@ -3145,20 +3141,12 @@ struct server_context {
|
||||
}
|
||||
|
||||
// process the created batch of tokens
|
||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
||||
for (int32_t i = 0; i < batch.get_n_tokens(); i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, batch.get_n_tokens() - i);
|
||||
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch.token + i,
|
||||
nullptr,
|
||||
batch.pos + i,
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
};
|
||||
common_batch batch_view = batch.get_view(i, n_tokens);
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
const int ret = llama_decode_ext(ctx, batch_view.get());
|
||||
metrics.on_decoded(slots);
|
||||
|
||||
if (ret != 0) {
|
||||
@@ -3293,16 +3281,16 @@ struct server_context {
|
||||
}
|
||||
|
||||
// construct the speculation batch
|
||||
common_batch_clear(slot.batch_spec);
|
||||
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
|
||||
slot.batch_spec.clear();
|
||||
slot.batch_spec.add_text(id, slot.n_past, slot.id, true);
|
||||
|
||||
for (size_t i = 0; i < draft.size(); ++i) {
|
||||
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
|
||||
slot.batch_spec.add_text(draft[i], slot.n_past + 1 + i, slot.id, true);
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
|
||||
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.get_n_tokens());
|
||||
|
||||
llama_decode(ctx, slot.batch_spec);
|
||||
llama_decode_ext(ctx, slot.batch_spec.get());
|
||||
|
||||
// the accepted tokens from the speculation
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
|
||||
|
||||
@@ -108,19 +108,22 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// prepare a batch for the prompt
|
||||
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
|
||||
llama_pos n_past = 0;
|
||||
llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), n_past, 0, true);
|
||||
n_past += llama_batch_ext_get_n_tokens(batch);
|
||||
|
||||
llama_token new_token_id;
|
||||
while (true) {
|
||||
// check if we have enough space in the context to evaluate this batch
|
||||
int n_ctx = llama_n_ctx(ctx);
|
||||
int n_ctx_used = llama_kv_self_used_cells(ctx);
|
||||
if (n_ctx_used + batch.n_tokens > n_ctx) {
|
||||
if (n_ctx_used + llama_batch_ext_get_n_tokens(batch) > n_ctx) {
|
||||
printf("\033[0m\n");
|
||||
fprintf(stderr, "context size exceeded\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
if (llama_decode_ext(ctx, batch)) {
|
||||
GGML_ABORT("failed to decode\n");
|
||||
}
|
||||
|
||||
@@ -144,9 +147,14 @@ int main(int argc, char ** argv) {
|
||||
response += piece;
|
||||
|
||||
// prepare the next batch with the sampled token
|
||||
batch = llama_batch_get_one(&new_token_id, 1);
|
||||
llama_batch_ext_clear(batch);
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch, new_token_id, n_past, &seq_id, 1, true);
|
||||
n_past++;
|
||||
}
|
||||
|
||||
llama_batch_ext_free(batch);
|
||||
|
||||
return response;
|
||||
};
|
||||
|
||||
|
||||
@@ -143,7 +143,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// prepare a batch for the prompt
|
||||
|
||||
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
|
||||
llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0, true);
|
||||
|
||||
// main loop
|
||||
|
||||
@@ -151,14 +151,14 @@ int main(int argc, char ** argv) {
|
||||
int n_decode = 0;
|
||||
llama_token new_token_id;
|
||||
|
||||
for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + n_predict; ) {
|
||||
for (int n_pos = 0; n_pos + llama_batch_ext_get_n_tokens(batch) < n_prompt + n_predict; ) {
|
||||
// evaluate the current batch with the transformer model
|
||||
if (llama_decode(ctx, batch)) {
|
||||
if (llama_decode_ext(ctx, batch)) {
|
||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
n_pos += batch.n_tokens;
|
||||
n_pos += llama_batch_ext_get_n_tokens(batch);
|
||||
|
||||
// sample the next token
|
||||
{
|
||||
@@ -180,7 +180,9 @@ int main(int argc, char ** argv) {
|
||||
fflush(stdout);
|
||||
|
||||
// prepare the next batch with the sampled token
|
||||
batch = llama_batch_get_one(&new_token_id, 1);
|
||||
llama_batch_ext_clear(batch);
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch, new_token_id, n_pos, &seq_id, 1, true);
|
||||
|
||||
n_decode += 1;
|
||||
}
|
||||
@@ -198,6 +200,7 @@ int main(int argc, char ** argv) {
|
||||
llama_perf_context_print(ctx);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
llama_batch_ext_free(batch);
|
||||
llama_sampler_free(smpl);
|
||||
llama_free(ctx);
|
||||
llama_model_free(model);
|
||||
|
||||
@@ -113,7 +113,8 @@ int main(int argc, char ** argv) {
|
||||
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
|
||||
|
||||
// eval the prompt
|
||||
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
|
||||
auto batch = llama_batch_ext_ptr::init_from_text(inp.data(), inp.size() - 1, 0, 0, true);
|
||||
llama_decode_ext(ctx_tgt, batch.get());
|
||||
|
||||
// note: keep the last token separate!
|
||||
llama_token id_last = inp.back();
|
||||
@@ -132,7 +133,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
struct common_speculative * spec = common_speculative_init(ctx_dft);
|
||||
|
||||
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
|
||||
llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), 1);
|
||||
|
||||
const auto t_enc_end = ggml_time_us();
|
||||
|
||||
@@ -151,8 +152,9 @@ int main(int argc, char ** argv) {
|
||||
//LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
|
||||
|
||||
// always have a token to evaluate from before - id_last
|
||||
common_batch_clear(batch_tgt);
|
||||
common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true);
|
||||
llama_batch_ext_clear(batch_tgt);
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch_tgt, id_last, n_past++, &seq_id, 1, true);
|
||||
|
||||
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
|
||||
{
|
||||
@@ -162,12 +164,12 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < draft.size(); ++i) {
|
||||
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
|
||||
llama_batch_ext_add_text(batch_tgt, draft[i], n_past + i, &seq_id, 1, true);
|
||||
}
|
||||
|
||||
//LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
|
||||
|
||||
llama_decode(ctx_tgt, batch_tgt);
|
||||
llama_decode_ext(ctx_tgt, batch_tgt);
|
||||
}
|
||||
|
||||
// sample from the full target batch and return the accepted tokens based on the target sampler
|
||||
@@ -253,6 +255,7 @@ int main(int argc, char ** argv) {
|
||||
common_sampler_free(smpl);
|
||||
common_speculative_free(spec);
|
||||
|
||||
llama_batch_ext_free(batch_tgt);
|
||||
llama_backend_free();
|
||||
|
||||
LOG("\n\n");
|
||||
|
||||
@@ -45,7 +45,6 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
if (params.speculative.model.empty()) {
|
||||
LOG_ERR("%s: --model-draft is required\n", __func__);
|
||||
return 1;
|
||||
@@ -166,9 +165,12 @@ int main(int argc, char ** argv) {
|
||||
const auto t_enc_start = ggml_time_us();
|
||||
|
||||
// eval the prompt with both models
|
||||
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1));
|
||||
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1));
|
||||
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input));
|
||||
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
|
||||
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
|
||||
llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true));
|
||||
llama_decode_ext(ctx_tgt, batch0.get());
|
||||
llama_decode_ext(ctx_tgt, batch1.get());
|
||||
llama_decode_ext(ctx_dft, batch2.get());
|
||||
|
||||
const auto t_enc_end = ggml_time_us();
|
||||
|
||||
@@ -199,8 +201,8 @@ int main(int argc, char ** argv) {
|
||||
drafts[s].smpl = common_sampler_init(model_dft, params.sampling);
|
||||
}
|
||||
|
||||
llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
|
||||
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft);
|
||||
llama_batch_ext * batch_dft = llama_batch_ext_init(llama_n_batch(ctx_dft), 1);
|
||||
llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), n_seq_dft);
|
||||
|
||||
const auto t_dec_start = ggml_time_us();
|
||||
|
||||
@@ -335,7 +337,7 @@ int main(int argc, char ** argv) {
|
||||
if (i == s) {
|
||||
continue;
|
||||
}
|
||||
if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
|
||||
if (drafts[i].active && drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
|
||||
// synchronize active status for sequences with the same drafted token
|
||||
drafts[i].active = drafts[i].active && accept;
|
||||
if (!drafts[i].active) {
|
||||
@@ -441,12 +443,13 @@ int main(int argc, char ** argv) {
|
||||
drafts[0].dists.push_back(std::vector<llama_token_data>());
|
||||
drafts[0].i_batch_tgt.push_back(0);
|
||||
|
||||
common_batch_clear(batch_dft);
|
||||
common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
|
||||
llama_batch_ext_clear(batch_dft);
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch_dft, token_id, n_past_dft, &seq_id, 1, true);
|
||||
|
||||
llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
||||
llama_decode(ctx_dft, batch_dft);
|
||||
llama_decode_ext(ctx_dft, batch_dft);
|
||||
|
||||
++n_past_dft;
|
||||
}
|
||||
@@ -471,12 +474,19 @@ int main(int argc, char ** argv) {
|
||||
drafts[0].drafting = true;
|
||||
drafts[0].i_batch_dft = 0;
|
||||
|
||||
common_batch_clear(batch_tgt);
|
||||
common_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
|
||||
struct batch_info {
|
||||
llama_token id;
|
||||
llama_pos pos;
|
||||
std::vector<llama_seq_id> seq_id;
|
||||
};
|
||||
|
||||
std::vector<batch_info> batch_tgt_data;
|
||||
|
||||
batch_tgt_data.push_back({ drafts[0].tokens[0], n_past_tgt, {0} });
|
||||
|
||||
// sample n_draft tokens from the draft model using tree-based sampling
|
||||
for (int i = 0; i < n_draft; ++i) {
|
||||
batch_dft.n_tokens = 0;
|
||||
llama_batch_ext_clear(batch_dft);
|
||||
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
drafts[s].skip = false;
|
||||
@@ -507,11 +517,10 @@ int main(int argc, char ** argv) {
|
||||
llama_kv_self_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
|
||||
|
||||
// all previous tokens from this branch are now also part of the new branch
|
||||
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
|
||||
for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
|
||||
if (batch_tgt.seq_id[t][p] == s) {
|
||||
batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur;
|
||||
batch_tgt.n_seq_id[t]++;
|
||||
for (int t = 0; t < (int) batch_tgt_data.size(); ++t) {
|
||||
for (int p = 0; p < (int) batch_tgt_data[t].seq_id.size(); ++p) {
|
||||
if (batch_tgt_data[t].seq_id[p] == s) {
|
||||
batch_tgt_data[t].seq_id.push_back(n_seq_cur);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -553,32 +562,30 @@ int main(int argc, char ** argv) {
|
||||
drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size});
|
||||
|
||||
// add unique drafted tokens to the target batch
|
||||
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
||||
drafts[s].i_batch_tgt.push_back(batch_tgt_data.size());
|
||||
|
||||
common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
|
||||
batch_tgt_data.push_back({ id, n_past_tgt + i + 1, { s }});
|
||||
|
||||
// add the token to the batch for batched decoding with the draft model
|
||||
drafts[s].i_batch_dft = batch_dft.n_tokens;
|
||||
drafts[s].i_batch_dft = llama_batch_ext_add_text(batch_dft, id, n_past_cur, &s, 1, true);
|
||||
|
||||
common_batch_add(batch_dft, id, n_past_cur, { s }, true);
|
||||
|
||||
if (batch_tgt.n_tokens > n_draft) {
|
||||
if (batch_tgt_data.size() > (size_t) n_draft) {
|
||||
drafts[s].drafting = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// no sequence is drafting anymore
|
||||
if (batch_dft.n_tokens == 0) {
|
||||
if (llama_batch_ext_get_n_tokens(batch_dft) == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
// evaluate the drafted tokens on the draft model
|
||||
llama_decode(ctx_dft, batch_dft);
|
||||
llama_decode_ext(ctx_dft, batch_dft);
|
||||
++n_past_cur;
|
||||
++n_drafted;
|
||||
|
||||
if (batch_tgt.n_tokens > n_draft) {
|
||||
if (batch_tgt_data.size() > (size_t) n_draft) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -590,8 +597,15 @@ int main(int argc, char ** argv) {
|
||||
llama_kv_self_seq_cp(ctx_tgt, 0, s, -1, -1);
|
||||
}
|
||||
|
||||
llama_batch_ext_clear(batch_tgt);
|
||||
for (int i = 0; i < (int) batch_tgt_data.size(); ++i) {
|
||||
const auto & data = batch_tgt_data[i];
|
||||
|
||||
llama_batch_ext_add_text(batch_tgt, data.id, data.pos, data.seq_id.data(), data.seq_id.size(), true);
|
||||
}
|
||||
|
||||
// LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
|
||||
llama_decode(ctx_tgt, batch_tgt);
|
||||
llama_decode_ext(ctx_tgt, batch_tgt);
|
||||
++n_past_tgt;
|
||||
}
|
||||
|
||||
@@ -634,7 +648,8 @@ int main(int argc, char ** argv) {
|
||||
common_sampler_free(drafts[s].smpl);
|
||||
}
|
||||
|
||||
llama_batch_free(batch_dft);
|
||||
llama_batch_ext_free(batch_dft);
|
||||
llama_batch_ext_free(batch_tgt);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
|
||||
+19
-17
@@ -818,7 +818,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
||||
|
||||
// create a llama_batch
|
||||
// we use this object to submit token data for decoding
|
||||
llama_batch batch = llama_batch_init(std::max(prompt_inp.size(), (size_t) n_parallel), 0, n_parallel);
|
||||
llama_batch_ext * batch = llama_batch_ext_init(std::max(prompt_inp.size(), (size_t) n_parallel), n_parallel);
|
||||
|
||||
std::vector<llama_seq_id> seq_ids(n_parallel, 0);
|
||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||
@@ -827,14 +827,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
||||
|
||||
// evaluate the initial prompt
|
||||
for (size_t i = 0; i < prompt_inp.size(); ++i) {
|
||||
common_batch_add(batch, prompt_inp[i], i, seq_ids, false);
|
||||
llama_batch_ext_add_text(batch, prompt_inp[i], i, seq_ids.data(), seq_ids.size(), false);
|
||||
}
|
||||
GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size());
|
||||
GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == (int) prompt_inp.size());
|
||||
|
||||
// llama_decode will output logits only for the last token of the prompt
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
llama_batch_ext_set_output_last(batch);
|
||||
|
||||
if (llama_decode(ctx_ttc, batch) != 0) {
|
||||
if (llama_decode_ext(ctx_ttc, batch) != 0) {
|
||||
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@@ -853,16 +853,16 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
||||
|
||||
// remember the batch index of the last token for each parallel sequence
|
||||
// we need this to determine which logits to sample from
|
||||
std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
|
||||
std::vector<int32_t> i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch) - 1);
|
||||
|
||||
int n_past = batch.n_tokens;
|
||||
int n_past = llama_batch_ext_get_n_tokens(batch);
|
||||
int n_decode = 0;
|
||||
|
||||
bool next_token_uses_guide_token = true;
|
||||
|
||||
while (n_decode <= n_predict) {
|
||||
// prepare the next batch
|
||||
common_batch_clear(batch);
|
||||
llama_batch_ext_clear(batch);
|
||||
|
||||
// sample the next token for each parallel sequence / stream
|
||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||
@@ -918,14 +918,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
||||
//LOG_CNT("%d", i);
|
||||
}
|
||||
|
||||
i_batch[i] = batch.n_tokens;
|
||||
i_batch[i] = llama_batch_ext_get_n_tokens(batch);
|
||||
|
||||
// push this new token for next evaluation
|
||||
common_batch_add(batch, new_token_id, n_past, { i }, true);
|
||||
llama_batch_ext_add_text(batch, new_token_id, n_past, &i, 1, true);
|
||||
}
|
||||
|
||||
// all streams are finished
|
||||
if (batch.n_tokens == 0) {
|
||||
if (llama_batch_ext_get_n_tokens(batch) == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -933,13 +933,13 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
||||
n_past += 1;
|
||||
|
||||
// evaluate the current batch with the transformer model
|
||||
if (llama_decode(ctx_ttc, batch)) {
|
||||
if (llama_decode_ext(ctx_ttc, batch)) {
|
||||
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
llama_batch_ext_free(batch);
|
||||
|
||||
LOG("\n");
|
||||
LOG_INF("%s: time for decoder: %.3f ms\n", __func__, (ggml_time_us() - t_dec_start) / 1000.0f);
|
||||
@@ -1008,14 +1008,15 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
||||
|
||||
const int n_codes = codes.size();
|
||||
|
||||
llama_batch batch = llama_batch_init(n_codes, 0, 1);
|
||||
llama_batch_ext * batch = llama_batch_ext_init(n_codes, 1);
|
||||
|
||||
for (size_t i = 0; i < codes.size(); ++i) {
|
||||
common_batch_add(batch, codes[i], i, { 0 }, true); // TODO: all logits?
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch, codes[i], i, &seq_id, 1, true); // TODO: all logits?
|
||||
}
|
||||
GGML_ASSERT(batch.n_tokens == n_codes);
|
||||
GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == n_codes);
|
||||
|
||||
if (llama_decode(ctx_cts, batch) != 0) {
|
||||
if (llama_decode_ext(ctx_cts, batch) != 0) {
|
||||
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@@ -1079,6 +1080,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
||||
retval = ENOENT;
|
||||
}
|
||||
|
||||
llama_batch_ext_free(batch);
|
||||
llama_backend_free();
|
||||
|
||||
return retval;
|
||||
|
||||
@@ -186,6 +186,7 @@ option(GGML_OPENMP "ggml: use OpenMP"
|
||||
option(GGML_RPC "ggml: use RPC" OFF)
|
||||
option(GGML_SYCL "ggml: use SYCL" OFF)
|
||||
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
|
||||
option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
|
||||
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
|
||||
"ggml: sycl target device")
|
||||
set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING
|
||||
|
||||
@@ -454,6 +454,7 @@ extern "C" {
|
||||
GGML_OP_RMS_NORM,
|
||||
GGML_OP_RMS_NORM_BACK,
|
||||
GGML_OP_GROUP_NORM,
|
||||
GGML_OP_L2_NORM,
|
||||
|
||||
GGML_OP_MUL_MAT,
|
||||
GGML_OP_MUL_MAT_ID,
|
||||
@@ -502,6 +503,7 @@ extern "C" {
|
||||
GGML_OP_ADD_REL_POS,
|
||||
GGML_OP_RWKV_WKV6,
|
||||
GGML_OP_GATED_LINEAR_ATTN,
|
||||
GGML_OP_RWKV_WKV7,
|
||||
|
||||
GGML_OP_UNARY,
|
||||
|
||||
@@ -1095,6 +1097,18 @@ extern "C" {
|
||||
int n_groups,
|
||||
float eps);
|
||||
|
||||
// l2 normalize along rows
|
||||
// used in rwkv v7
|
||||
GGML_API struct ggml_tensor * ggml_l2_norm(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float eps);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_l2_norm_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float eps);
|
||||
|
||||
// a - x
|
||||
// b - dy
|
||||
GGML_API struct ggml_tensor * ggml_rms_norm_back(
|
||||
@@ -1890,6 +1904,16 @@ extern "C" {
|
||||
struct ggml_tensor * state,
|
||||
float scale);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_rwkv_wkv7(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * r,
|
||||
struct ggml_tensor * w,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * state);
|
||||
|
||||
// custom operators
|
||||
|
||||
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
||||
|
||||
@@ -287,17 +287,25 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
|
||||
elseif ("${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "ppc64le " OR "${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "powerpc ")
|
||||
message(STATUS "PowerPC detected")
|
||||
execute_process(COMMAND bash -c "grep POWER /proc/cpuinfo | head -n 1" OUTPUT_VARIABLE POWER_M)
|
||||
if (${POWER_M} MATCHES "POWER10")
|
||||
list(APPEND ARCH_FLAGS -mcpu=power10)
|
||||
elseif (${POWER_M} MATCHES "POWER9")
|
||||
list(APPEND ARCH_FLAGS -mcpu=power9)
|
||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
|
||||
file(READ "/proc/cpuinfo" POWER10_M)
|
||||
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "powerpc")
|
||||
execute_process(COMMAND bash -c "prtconf |grep 'Implementation' | head -n 1" OUTPUT_VARIABLE POWER10_M)
|
||||
endif()
|
||||
|
||||
string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M}")
|
||||
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
|
||||
|
||||
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64)
|
||||
elseif (EXTRACTED_NUMBER EQUAL 9)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64)
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
|
||||
list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)
|
||||
else()
|
||||
list(APPEND ARCH_FLAGS -mcpu=powerpc64 -mtune=native)
|
||||
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native -mpowerpc64)
|
||||
endif()
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
|
||||
message(STATUS "loongarch64 detected")
|
||||
|
||||
@@ -8158,7 +8158,156 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
#ifdef __ARM_NEON
|
||||
#ifdef __ARM_FEATURE_SVE
|
||||
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
||||
float sum = 0;
|
||||
svuint8_t m4b = svdup_n_u8(0xf);
|
||||
svint32_t vzero = svdup_n_s32(0);
|
||||
svuint8_t mone = svdup_n_u8(0x30);
|
||||
svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
|
||||
svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float d_all = GGML_FP16_TO_FP32(x[i].d);
|
||||
|
||||
const uint8_t * GGML_RESTRICT q6 = x[i].ql;
|
||||
const uint8_t * GGML_RESTRICT qh = x[i].qh;
|
||||
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
||||
|
||||
const int8_t * GGML_RESTRICT scale = x[i].scales;
|
||||
|
||||
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
|
||||
const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
|
||||
const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
|
||||
const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
|
||||
const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
|
||||
const svint64_t prod = svdup_n_s64(0);
|
||||
int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
|
||||
svdot_s64(prod, q8sums_2, q6scales_2)));
|
||||
int32_t isum = 0;
|
||||
|
||||
switch (vector_length) {
|
||||
case 128:
|
||||
{
|
||||
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
|
||||
const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
|
||||
svint32_t isum_tmp = svdup_n_s32(0);
|
||||
for (int j = 0; j < QK_K/128; ++j) {
|
||||
svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
|
||||
svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
|
||||
qh += 32;
|
||||
svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
|
||||
svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
|
||||
svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
|
||||
svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
|
||||
q6 += 64;
|
||||
svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
|
||||
svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
|
||||
svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
|
||||
svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
|
||||
q8 += 64;
|
||||
|
||||
q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
|
||||
q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
|
||||
q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
|
||||
q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
|
||||
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
|
||||
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
|
||||
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
|
||||
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
|
||||
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
|
||||
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
|
||||
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
|
||||
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
|
||||
|
||||
scale += 4;
|
||||
q8bytes_1 = svld1_s8(pg8_16, q8);
|
||||
q8bytes_2 = svld1_s8(pg8_16, q8+16);
|
||||
q8bytes_3 = svld1_s8(pg8_16, q8+32);
|
||||
q8bytes_4 = svld1_s8(pg8_16, q8+48);
|
||||
q8 += 64;
|
||||
|
||||
q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
|
||||
q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
|
||||
q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
|
||||
q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
|
||||
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
|
||||
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
|
||||
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
|
||||
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
|
||||
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
|
||||
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
|
||||
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
|
||||
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
|
||||
scale += 4;
|
||||
}
|
||||
isum += svaddv_s32(pg32_4, isum_tmp);
|
||||
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
||||
}
|
||||
break;
|
||||
case 256:
|
||||
case 512:
|
||||
{
|
||||
const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
|
||||
const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
|
||||
const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
|
||||
svint32_t isum_tmp = svdup_n_s32(0);
|
||||
for (int j = 0; j < QK_K/128; j++) {
|
||||
svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
|
||||
qh += 32;
|
||||
svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
|
||||
svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
|
||||
q6 += 64;
|
||||
svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
|
||||
svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
|
||||
svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
|
||||
svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
|
||||
q8 += 128;
|
||||
q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
|
||||
q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
|
||||
q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
|
||||
q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
|
||||
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
|
||||
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
|
||||
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
|
||||
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
|
||||
|
||||
svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
|
||||
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
|
||||
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
|
||||
svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
|
||||
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
|
||||
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
|
||||
svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
|
||||
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
|
||||
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
|
||||
svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
|
||||
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
|
||||
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
|
||||
svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
|
||||
svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
|
||||
svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
|
||||
svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
|
||||
|
||||
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
|
||||
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
|
||||
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
|
||||
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
|
||||
scale += 8;
|
||||
}
|
||||
isum += svaddv_s32(pg32_8, isum_tmp);
|
||||
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
assert(false && "Unsupported vector length");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
*s = sum;
|
||||
|
||||
#elif __ARM_NEON
|
||||
float sum = 0;
|
||||
|
||||
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
||||
|
||||
@@ -8548,6 +8548,69 @@ static void ggml_compute_forward_group_norm(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_l2_norm
|
||||
|
||||
static void ggml_compute_forward_l2_norm_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
// TODO: optimize
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
ggml_float sum = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
sum += (ggml_float)(x[i00] * x[i00]);
|
||||
}
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
memcpy(y, x, ne00 * sizeof(float));
|
||||
|
||||
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
|
||||
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_l2_norm(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_l2_norm_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_mul_mat
|
||||
|
||||
static void ggml_compute_forward_mul_mat_one_chunk(
|
||||
@@ -13604,6 +13667,184 @@ static void ggml_compute_forward_gla(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_rwkv_wkv7
|
||||
|
||||
static void ggml_compute_forward_rwkv_wkv7_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
const int64_t T = dst->src[1]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t HEADS = dst->src[1]->ne[1];
|
||||
const int64_t n_seqs = dst->src[6]->ne[1];
|
||||
const int64_t head_size = C / HEADS;
|
||||
|
||||
float * dst_data = (float *) dst->data;
|
||||
float * state = ((float *) dst->data) + C * T;
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
if (ith >= HEADS) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int h_start = (HEADS * ith) / nth;
|
||||
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
||||
(HEADS * (ith + 1)) / nth : HEADS;
|
||||
|
||||
float * r = (float *) dst->src[0]->data;
|
||||
float * w = (float *) dst->src[1]->data;
|
||||
float * k = (float *) dst->src[2]->data;
|
||||
float * v = (float *) dst->src[3]->data;
|
||||
float * a = (float *) dst->src[4]->data;
|
||||
float * b = (float *) dst->src[5]->data;
|
||||
|
||||
int64_t t_stride = HEADS * head_size; // Same to C
|
||||
|
||||
int64_t h_stride = C / HEADS;
|
||||
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
|
||||
int64_t h_stride_2d = head_size * head_size;
|
||||
|
||||
#if defined(GGML_SIMD)
|
||||
for (int64_t t = 0; t < T; t++) {
|
||||
int64_t t_offset = t * t_stride;
|
||||
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
||||
|
||||
for (int64_t h = h_start; h < h_end; h++) {
|
||||
int64_t h_offset = h * h_stride;
|
||||
int64_t t_h_offset = t_offset + h_offset;
|
||||
int64_t h_2d_offset = h * h_stride_2d;
|
||||
|
||||
for (int64_t ii = 0; ii < head_size; ii++) {
|
||||
int64_t t_h_i_offset = t_h_offset + ii;
|
||||
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
|
||||
|
||||
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
|
||||
|
||||
float sa = 0;
|
||||
{
|
||||
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
||||
GGML_F32_VEC ax[GGML_F32_ARR];
|
||||
GGML_F32_VEC ay[GGML_F32_ARR];
|
||||
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
|
||||
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
||||
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
|
||||
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
|
||||
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
|
||||
}
|
||||
}
|
||||
GGML_F32_VEC_REDUCE(sa, sum);
|
||||
}
|
||||
|
||||
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
|
||||
|
||||
int64_t j = 0;
|
||||
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
||||
for (; j < head_size; j += GGML_F32_STEP) {
|
||||
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
|
||||
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
|
||||
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
|
||||
|
||||
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
|
||||
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
|
||||
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
|
||||
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
|
||||
|
||||
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
|
||||
|
||||
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
|
||||
// kv + s * decay + sa * b
|
||||
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
|
||||
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
|
||||
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
|
||||
|
||||
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
|
||||
}
|
||||
}
|
||||
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
|
||||
|
||||
// There shouldn't be left-overs though.
|
||||
for (; j < head_size; j++) {
|
||||
int64_t t_h_j_offset = t_h_offset + j;
|
||||
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||
|
||||
float r_val = r[t_h_j_offset];
|
||||
float w_val = w[t_h_j_offset];
|
||||
float k_val = k[t_h_j_offset];
|
||||
float b_val = b[t_h_j_offset];
|
||||
float kv_val = v[t_h_i_offset] * k_val;
|
||||
|
||||
float prev_state_val = state_prev[h_2d_i_j_offset];
|
||||
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
||||
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int64_t t = 0; t < T; t++) {
|
||||
int64_t t_offset = t * t_stride;
|
||||
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
|
||||
|
||||
for (int64_t h = h_start; h < h_end; h++) {
|
||||
int64_t h_offset = h * h_stride;
|
||||
int64_t t_h_offset = t_offset + h_offset;
|
||||
int64_t h_2d_offset = h * h_stride_2d;
|
||||
|
||||
for (int64_t i = 0; i < head_size; i++) {
|
||||
int64_t t_h_i_offset = t_h_offset + i;
|
||||
int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
||||
|
||||
float v_val = v[t_h_i_offset];
|
||||
|
||||
float sa = 0, result = 0;
|
||||
for (int64_t j = 0; j < head_size; j++) {
|
||||
sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
|
||||
}
|
||||
|
||||
for (int64_t j = 0; j < head_size; j++) {
|
||||
int64_t t_h_j_offset = t_h_offset + j;
|
||||
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||
|
||||
float r_val = r[t_h_j_offset];
|
||||
float w_val = w[t_h_j_offset];
|
||||
float k_val = k[t_h_j_offset];
|
||||
float b_val = b[t_h_j_offset];
|
||||
float kv_val = v_val * k_val;
|
||||
float prev_state_val = state_prev[h_2d_i_j_offset];
|
||||
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
|
||||
result += state_cur[h_2d_i_j_offset] * r_val;
|
||||
}
|
||||
dst_data[t_h_i_offset] = result;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
static void ggml_compute_forward_rwkv_wkv7(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_rwkv_wkv7_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_map_unary
|
||||
|
||||
static void ggml_compute_forward_map_unary_f32(
|
||||
@@ -14170,6 +14411,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_group_norm(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_L2_NORM:
|
||||
{
|
||||
ggml_compute_forward_l2_norm(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
ggml_compute_forward_mul_mat(params, tensor);
|
||||
@@ -14357,6 +14602,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_gla(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
{
|
||||
ggml_compute_forward_rwkv_wkv7(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_MAP_UNARY:
|
||||
{
|
||||
ggml_unary_op_f32_t fun;
|
||||
@@ -14582,6 +14831,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_MUL_MAT:
|
||||
@@ -14648,14 +14898,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_WIN_PART:
|
||||
case GGML_OP_WIN_UNPART:
|
||||
case GGML_OP_GET_REL_POS:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_MAP_UNARY:
|
||||
case GGML_OP_MAP_BINARY:
|
||||
case GGML_OP_MAP_CUSTOM1_F32:
|
||||
|
||||
@@ -678,7 +678,7 @@ struct ggml_tensor_extra_gpu {
|
||||
};
|
||||
|
||||
|
||||
#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS)
|
||||
#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
|
||||
#define USE_CUDA_GRAPH
|
||||
#endif
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
#include "ggml-cuda/tsembd.cuh"
|
||||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/wkv6.cuh"
|
||||
#include "ggml-cuda/wkv.cuh"
|
||||
#include "ggml-cuda/gla.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
@@ -2196,6 +2196,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_GROUP_NORM:
|
||||
ggml_cuda_op_group_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_L2_NORM:
|
||||
ggml_cuda_op_l2_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONCAT:
|
||||
ggml_cuda_op_concat(ctx, dst);
|
||||
break;
|
||||
@@ -2304,6 +2307,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
ggml_cuda_op_gated_linear_attn(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
ggml_cuda_op_rwkv_wkv7(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
||||
break;
|
||||
@@ -2610,13 +2616,15 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
|
||||
|
||||
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||
|
||||
#if CUDART_VERSION >= 12000
|
||||
cudaGraphExecUpdateResultInfo result_info;
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
hipGraphNode_t errorNode;
|
||||
hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
|
||||
#else
|
||||
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
||||
#endif
|
||||
#else
|
||||
cudaGraphNode_t errorNode;
|
||||
cudaGraphExecUpdateResult result_info;
|
||||
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
|
||||
#endif // CUDART_VERSION >= 12000
|
||||
|
||||
if (stat == cudaErrorGraphExecUpdateFailure) {
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
|
||||
@@ -3159,6 +3167,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
break;
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return true;
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
|
||||
@@ -3213,6 +3222,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT: {
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
|
||||
@@ -201,6 +201,85 @@ static __global__ void rms_norm_back_f32(
|
||||
}
|
||||
}
|
||||
|
||||
// template <int block_size>
|
||||
// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
||||
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
// const int tid = threadIdx.x;
|
||||
|
||||
// float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
// for (int col = tid; col < ncols; col += block_size) {
|
||||
// const float xi = x[row*ncols + col];
|
||||
// tmp += xi * xi;
|
||||
// }
|
||||
|
||||
// // sum up partial sums
|
||||
// tmp = warp_reduce_sum(tmp);
|
||||
// if (block_size > WARP_SIZE) {
|
||||
// __shared__ float s_sum[32];
|
||||
// int warp_id = threadIdx.x / WARP_SIZE;
|
||||
// int lane_id = threadIdx.x % WARP_SIZE;
|
||||
// if (lane_id == 0) {
|
||||
// s_sum[warp_id] = tmp;
|
||||
// }
|
||||
// __syncthreads();
|
||||
// tmp = s_sum[lane_id];
|
||||
// tmp = warp_reduce_sum(tmp);
|
||||
// }
|
||||
|
||||
// // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
|
||||
// const float scale = rsqrtf(fmaxf(tmp, eps * eps));
|
||||
|
||||
// for (int col = tid; col < ncols; col += block_size) {
|
||||
// dst[row*ncols + col] = scale * x[row*ncols + col];
|
||||
// }
|
||||
// }
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void l2_norm_f32(
|
||||
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
||||
const int64_t stride_sample, const float eps) {
|
||||
const int nrows = gridDim.x;
|
||||
const int nchannels = gridDim.y;
|
||||
|
||||
const int row = blockIdx.x;
|
||||
const int channel = blockIdx.y;
|
||||
const int sample = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
x += sample*stride_sample + channel*stride_channel + row*stride_row;
|
||||
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
if constexpr (block_size > WARP_SIZE) {
|
||||
static_assert(block_size == 1024, "unexpected block_size");
|
||||
__shared__ float s_sum[32];
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = tmp;
|
||||
}
|
||||
__syncthreads();
|
||||
tmp = s_sum[lane_id];
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
}
|
||||
|
||||
// from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
|
||||
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[col] = scale * x[col];
|
||||
}
|
||||
}
|
||||
|
||||
static void norm_f32_cuda(
|
||||
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
||||
@@ -248,6 +327,19 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
|
||||
}
|
||||
}
|
||||
|
||||
static void l2_norm_f32_cuda(
|
||||
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
||||
const dim3 blocks_num(nrows, nchannels, nsamples);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
l2_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
@@ -340,3 +432,27 @@ void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * d
|
||||
|
||||
rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
const size_t ts0 = ggml_type_size(src0->type);
|
||||
GGML_ASSERT(nb00 == ts0);
|
||||
const int64_t s01 = nb01 / ts0;
|
||||
const int64_t s02 = nb02 / ts0;
|
||||
const int64_t s03 = nb03 / ts0;
|
||||
|
||||
l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
|
||||
}
|
||||
|
||||
@@ -7,3 +7,5 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
Vendored
+1
-1
@@ -112,7 +112,7 @@
|
||||
#define cudaGraphExecDestroy hipGraphExecDestroy
|
||||
#define cudaGraphLaunch hipGraphLaunch
|
||||
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
|
||||
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
|
||||
#define cudaGraphExecUpdateResult hipGraphExecUpdateResult
|
||||
#define cudaGraphNodeType hipGraphNodeType
|
||||
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
|
||||
#define cudaGraphInstantiate hipGraphInstantiate
|
||||
|
||||
Vendored
+2
-1
@@ -119,7 +119,7 @@
|
||||
#define cudaGraphExecDestroy musaGraphExecDestroy
|
||||
#define cudaGraphExec_t musaGraphExec_t
|
||||
#define cudaGraphExecUpdate musaGraphExecUpdate
|
||||
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
|
||||
#define cudaGraphExecUpdateResult musaGraphExecUpdateResult
|
||||
#define cudaGraphGetNodes musaGraphGetNodes
|
||||
#define cudaGraphInstantiate musaGraphInstantiate
|
||||
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
||||
@@ -132,6 +132,7 @@
|
||||
#define cudaGraph_t musaGraph_t
|
||||
#define cudaKernelNodeParams musaKernelNodeParams
|
||||
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
||||
#define cudaStreamBeginCapture musaStreamBeginCapture
|
||||
#define cudaStreamEndCapture musaStreamEndCapture
|
||||
|
||||
typedef mt_bfloat16 nv_bfloat16;
|
||||
|
||||
@@ -0,0 +1,199 @@
|
||||
#include "common.cuh"
|
||||
#include "wkv.cuh"
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = blockIdx.x;
|
||||
|
||||
const int head_size = block_size;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
float state[head_size];
|
||||
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
_tf[tid] = tf[head_i * head_size + tid];
|
||||
__syncthreads();
|
||||
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
||||
__syncthreads();
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
__syncthreads();
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0;
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
const float4& k = (float4&)(_k[j]);
|
||||
const float4& r = (float4&)(_r[j]);
|
||||
const float4& tf = (float4&)(_tf[j]);
|
||||
const float4& td = (float4&)(_td[j]);
|
||||
float4& s = (float4&)(state[j]);
|
||||
float4 kv;
|
||||
|
||||
kv.x = k.x * _v;
|
||||
kv.y = k.y * _v;
|
||||
kv.z = k.z * _v;
|
||||
kv.w = k.w * _v;
|
||||
|
||||
y += r.x * (tf.x * kv.x + s.x);
|
||||
y += r.y * (tf.y * kv.y + s.y);
|
||||
y += r.z * (tf.z * kv.z + s.z);
|
||||
y += r.w * (tf.w * kv.w + s.w);
|
||||
|
||||
s.x = s.x * td.x + kv.x;
|
||||
s.y = s.y * td.y + kv.y;
|
||||
s.z = s.z * td.z + kv.z;
|
||||
s.w = s.w * td.w + kv.w;
|
||||
}
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) {
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = blockIdx.x;
|
||||
|
||||
const int head_size = block_size;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
float state[head_size];
|
||||
__shared__ float _r[head_size], _w[head_size], _k[head_size], _a[head_size], _b[head_size];
|
||||
|
||||
#ifndef GGML_USE_MUSA
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
|
||||
}
|
||||
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
||||
__syncthreads();
|
||||
_r[tid] = r[t];
|
||||
_w[tid] = w[t];
|
||||
_k[tid] = k[t];
|
||||
_a[tid] = a[t];
|
||||
_b[tid] = b[t];
|
||||
__syncthreads();
|
||||
|
||||
float sa = 0;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < head_size; j += 4)
|
||||
{
|
||||
const float4& a = (float4&)(_a[j]);
|
||||
const float4& s = (float4&)(state[j]);
|
||||
sa += a.x * s.x;
|
||||
sa += a.y * s.y;
|
||||
sa += a.z * s.z;
|
||||
sa += a.w * s.w;
|
||||
}
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0;
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
const float4& r = (float4&)(_r[j]);
|
||||
const float4& w = (float4&)(_w[j]);
|
||||
const float4& k = (float4&)(_k[j]);
|
||||
const float4& b = (float4&)(_b[j]);
|
||||
float4& s = (float4&)(state[j]);
|
||||
float4 kv;
|
||||
|
||||
kv.x = k.x * _v;
|
||||
kv.y = k.y * _v;
|
||||
kv.z = k.z * _v;
|
||||
kv.w = k.w * _v;
|
||||
|
||||
s.x = s.x * w.x + kv.x + sa * b.x;
|
||||
s.y = s.y * w.y + kv.y + sa * b.y;
|
||||
s.z = s.z * w.z + kv.z + sa * b.z;
|
||||
s.w = s.w * w.w + kv.w + sa * b.w;
|
||||
|
||||
y += s.x * r.x;
|
||||
y += s.y * r.y;
|
||||
y += s.z * r.z;
|
||||
y += s.w * r.w;
|
||||
}
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const float * k_d = (const float *)dst->src[0]->data;
|
||||
const float * v_d = (const float *)dst->src[1]->data;
|
||||
const float * r_d = (const float *)dst->src[2]->data;
|
||||
const float * tf_d = (const float *)dst->src[3]->data;
|
||||
const float * td_d = (const float *)dst->src[4]->data;
|
||||
const float * s_d = (const float *)dst->src[5]->data;
|
||||
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
|
||||
|
||||
if (C / H == CUDA_WKV_BLOCK_SIZE) {
|
||||
rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
||||
} else {
|
||||
rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const float * r_d = (const float *)dst->src[0]->data;
|
||||
const float * w_d = (const float *)dst->src[1]->data;
|
||||
const float * k_d = (const float *)dst->src[2]->data;
|
||||
const float * v_d = (const float *)dst->src[3]->data;
|
||||
const float * a_d = (const float *)dst->src[4]->data;
|
||||
const float * b_d = (const float *)dst->src[5]->data;
|
||||
const float * s_d = (const float *)dst->src[6]->data;
|
||||
|
||||
const int64_t B = dst->src[6]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
|
||||
|
||||
if (C / H == CUDA_WKV_BLOCK_SIZE) {
|
||||
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
|
||||
} else {
|
||||
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
|
||||
}
|
||||
}
|
||||
@@ -3,3 +3,5 @@
|
||||
#define CUDA_WKV_BLOCK_SIZE 64
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -1,89 +0,0 @@
|
||||
#include "common.cuh"
|
||||
#include "wkv6.cuh"
|
||||
|
||||
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = blockIdx.x;
|
||||
|
||||
const int head_size = CUDA_WKV_BLOCK_SIZE;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
float state[head_size];
|
||||
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
_tf[tid] = tf[head_i * head_size + tid];
|
||||
__syncthreads();
|
||||
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
||||
__syncthreads();
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
__syncthreads();
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0;
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
const float4& k = (float4&)(_k[j]);
|
||||
const float4& r = (float4&)(_r[j]);
|
||||
const float4& tf = (float4&)(_tf[j]);
|
||||
const float4& td = (float4&)(_td[j]);
|
||||
float4& s = (float4&)(state[j]);
|
||||
float4 kv;
|
||||
|
||||
kv.x = k.x * _v;
|
||||
kv.y = k.y * _v;
|
||||
kv.z = k.z * _v;
|
||||
kv.w = k.w * _v;
|
||||
|
||||
y += r.x * (tf.x * kv.x + s.x);
|
||||
y += r.y * (tf.y * kv.y + s.y);
|
||||
y += r.z * (tf.z * kv.z + s.z);
|
||||
y += r.w * (tf.w * kv.w + s.w);
|
||||
|
||||
s.x = s.x * td.x + kv.x;
|
||||
s.y = s.y * td.y + kv.y;
|
||||
s.z = s.z * td.z + kv.z;
|
||||
s.w = s.w * td.w + kv.w;
|
||||
}
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const float * k_d = (const float *)dst->src[0]->data;
|
||||
const float * v_d = (const float *)dst->src[1]->data;
|
||||
const float * r_d = (const float *)dst->src[2]->data;
|
||||
const float * tf_d = (const float *)dst->src[3]->data;
|
||||
const float * td_d = (const float *)dst->src[4]->data;
|
||||
const float * s_d = (const float *)dst->src[5]->data;
|
||||
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64
|
||||
|
||||
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
||||
}
|
||||
@@ -285,6 +285,13 @@ typedef struct {
|
||||
float eps;
|
||||
} ggml_metal_kargs_rms_norm;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne00_4;
|
||||
uint64_t nb01;
|
||||
float eps;
|
||||
} ggml_metal_kargs_l2_norm;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
|
||||
@@ -184,10 +184,13 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
||||
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_L2_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
||||
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
||||
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
||||
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
||||
@@ -810,10 +813,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
|
||||
@@ -1251,6 +1257,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||
case GGML_OP_ARGMAX:
|
||||
return true;
|
||||
@@ -1288,6 +1295,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
return true;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
@@ -2216,6 +2225,83 @@ static void ggml_metal_encode_node(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
{
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == 64);
|
||||
|
||||
size_t offs_src3 = 0;
|
||||
size_t offs_src4 = 0;
|
||||
size_t offs_src5 = 0;
|
||||
|
||||
id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
|
||||
id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
|
||||
id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
||||
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
||||
|
||||
[encoder setBytes:&B length:sizeof(B) atIndex:7];
|
||||
[encoder setBytes:&T length:sizeof(T) atIndex:8];
|
||||
[encoder setBytes:&C length:sizeof(C) atIndex:9];
|
||||
[encoder setBytes:&H length:sizeof(H) atIndex:10];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
{
|
||||
const int64_t B = dst->src[6]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == 64);
|
||||
|
||||
size_t offs_src3 = 0;
|
||||
size_t offs_src4 = 0;
|
||||
size_t offs_src5 = 0;
|
||||
size_t offs_src6 = 0;
|
||||
|
||||
id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
|
||||
id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
|
||||
id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
|
||||
id<MTLBuffer> id_src6 = dst->src[6] ? ggml_metal_get_buffer(dst->src[6], &offs_src6) : nil;
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
||||
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
||||
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
||||
|
||||
[encoder setBytes:&B length:sizeof(B) atIndex:8];
|
||||
[encoder setBytes:&T length:sizeof(T) atIndex:9];
|
||||
[encoder setBytes:&C length:sizeof(C) atIndex:10];
|
||||
[encoder setBytes:&H length:sizeof(H) atIndex:11];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
GGML_ASSERT(ne00 == ne10);
|
||||
@@ -3122,6 +3208,42 @@ static void ggml_metal_encode_node(
|
||||
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_L2_NORM:
|
||||
{
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_L2_NORM].pipeline;
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = MIN(nth, ne00/4);
|
||||
|
||||
ggml_metal_kargs_l2_norm args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne00_4 =*/ ne00/4,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.eps =*/ eps,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_GROUP_NORM:
|
||||
|
||||
@@ -1295,6 +1295,184 @@ kernel void kernel_ssm_scan_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rwkv_wkv6_f32(
|
||||
device const float * k,
|
||||
device const float * v,
|
||||
device const float * r,
|
||||
device const float * tf,
|
||||
device const float * td,
|
||||
device const float * state_in,
|
||||
device float * dst,
|
||||
constant uint & B,
|
||||
constant uint & T,
|
||||
constant uint & C,
|
||||
constant uint & H,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const uint head_size = 64; // TODO: support head_size = 128
|
||||
const uint batch_id = tgpig.x / H;
|
||||
const uint head_id = tgpig.x % H;
|
||||
const uint tid = tpitg.x;
|
||||
|
||||
if (batch_id >= B || head_id >= H) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint state_size = C * head_size;
|
||||
const uint n_seq_tokens = T / B;
|
||||
|
||||
threadgroup float _k[head_size];
|
||||
threadgroup float _r[head_size];
|
||||
threadgroup float _tf[head_size];
|
||||
threadgroup float _td[head_size];
|
||||
|
||||
float state[head_size];
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
||||
+ i * head_size + tid];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
_tf[tid] = tf[head_id * head_size + tid];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
||||
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
||||
|
||||
for (uint t = start_t; t < end_t; t += C) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const float v_val = v[t];
|
||||
float y = 0.0;
|
||||
|
||||
for (uint j = 0; j < head_size; j += 4) {
|
||||
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
||||
float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
||||
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
float4 kv = k_vec * v_val;
|
||||
|
||||
float4 temp = tf_vec * kv + s_vec;
|
||||
y += dot(r_vec, temp);
|
||||
|
||||
s_vec = s_vec * td_vec + kv;
|
||||
state[j] = s_vec[0];
|
||||
state[j+1] = s_vec[1];
|
||||
state[j+2] = s_vec[2];
|
||||
state[j+3] = s_vec[3];
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
||||
+ i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rwkv_wkv7_f32(
|
||||
device const float * r,
|
||||
device const float * w,
|
||||
device const float * k,
|
||||
device const float * v,
|
||||
device const float * a,
|
||||
device const float * b,
|
||||
device const float * state_in,
|
||||
device float * dst,
|
||||
constant uint & B,
|
||||
constant uint & T,
|
||||
constant uint & C,
|
||||
constant uint & H,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const uint head_size = 64; // TODO: support head_size = 128
|
||||
const uint batch_id = tgpig.x / H;
|
||||
const uint head_id = tgpig.x % H;
|
||||
const uint tid = tpitg.x;
|
||||
|
||||
if (batch_id >= B || head_id >= H) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint state_size = C * head_size;
|
||||
const uint n_seq_tokens = T / B;
|
||||
|
||||
threadgroup float _r[head_size];
|
||||
threadgroup float _w[head_size];
|
||||
threadgroup float _k[head_size];
|
||||
threadgroup float _a[head_size];
|
||||
threadgroup float _b[head_size];
|
||||
|
||||
float state[head_size];
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
||||
+ tid * head_size + i];
|
||||
}
|
||||
|
||||
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
||||
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
||||
|
||||
for (uint t = start_t; t < end_t; t += C) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
_r[tid] = r[t];
|
||||
_w[tid] = w[t];
|
||||
_k[tid] = k[t];
|
||||
_a[tid] = a[t];
|
||||
_b[tid] = b[t];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const float v_val = v[t];
|
||||
float y = 0.0, sa = 0.0;
|
||||
|
||||
float4 sa_vec(0.0);
|
||||
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
||||
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
sa_vec += a_vec * s_vec;
|
||||
}
|
||||
sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
|
||||
|
||||
for (uint j = 0; j < head_size; j += 4) {
|
||||
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
||||
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
||||
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
float4 kv = k_vec * v_val;
|
||||
|
||||
s_vec = s_vec * w_vec + kv + sa * b_vec;
|
||||
y += dot(s_vec, r_vec);
|
||||
|
||||
state[j] = s_vec[0];
|
||||
state[j+1] = s_vec[1];
|
||||
state[j+2] = s_vec[2];
|
||||
state[j+3] = s_vec[3];
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
||||
+ tid * head_size + i] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_argmax(
|
||||
device const void * x,
|
||||
device int32_t * dst,
|
||||
@@ -1463,6 +1641,49 @@ kernel void kernel_rms_norm(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_l2_norm(
|
||||
constant ggml_metal_kargs_l2_norm & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort ntg[[threads_per_threadgroup]]) {
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
// parallel sum
|
||||
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
||||
sumf += dot(x[i00], x[i00]);
|
||||
}
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
const float scale = 1.0f/sqrt(max(sumf, args.eps));
|
||||
|
||||
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
|
||||
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
||||
y[i00] = x[i00] * scale;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_group_norm(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
|
||||
@@ -67,10 +67,6 @@ if (MUSAToolkit_FOUND)
|
||||
add_compile_definitions(GGML_USE_MUSA)
|
||||
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
|
||||
|
||||
if (GGML_CUDA_GRAPHS)
|
||||
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_FORCE_MMQ)
|
||||
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
|
||||
endif()
|
||||
|
||||
@@ -66,6 +66,9 @@ if (WIN32)
|
||||
find_package(MKL REQUIRED)
|
||||
target_link_libraries(ggml-sycl PRIVATE IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
|
||||
else()
|
||||
if (GGML_SYCL_GRAPH)
|
||||
add_compile_definitions(GGML_SYCL_GRAPH)
|
||||
endif()
|
||||
if (GGML_SYCL_TARGET STREQUAL "INTEL")
|
||||
target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
|
||||
elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
|
||||
|
||||
@@ -26,7 +26,7 @@
|
||||
#include "softmax.hpp"
|
||||
#include "tsembd.hpp"
|
||||
#include "im2col.hpp"
|
||||
#include "wkv6.hpp"
|
||||
#include "wkv.hpp"
|
||||
#include "outprod.hpp"
|
||||
#include "element_wise.hpp"
|
||||
#include "cpy.hpp"
|
||||
|
||||
@@ -301,6 +301,7 @@ inline optimize_feature check_gpu_optimize_feature(syclex::architecture &arch) {
|
||||
return opt;
|
||||
}
|
||||
|
||||
namespace sycl_ex = sycl::ext::oneapi::experimental;
|
||||
struct ggml_backend_sycl_context {
|
||||
int device;
|
||||
std::string name;
|
||||
@@ -392,6 +393,10 @@ struct ggml_backend_sycl_context {
|
||||
return pool(device);
|
||||
}
|
||||
|
||||
#ifdef GGML_SYCL_GRAPH
|
||||
std::unique_ptr<sycl_ex::command_graph<sycl_ex::graph_state::executable>> exec_graph = nullptr;
|
||||
#endif
|
||||
|
||||
ggml_sycl_pool & host_pool(int device) {
|
||||
if (host_pools[device] == nullptr) {
|
||||
host_pools[device] = new_pool_for_host(stream(device, 0), device);
|
||||
|
||||
@@ -138,7 +138,7 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
|
||||
sycl::range<3>(1, 1, WARP_SIZE),
|
||||
sycl::range<3>(1, 1, WARP_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
|
||||
dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
|
||||
});
|
||||
|
||||
|
||||
+12
-13
@@ -210,7 +210,7 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
|
||||
nrows, item_ct1);
|
||||
});
|
||||
@@ -879,7 +879,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloa
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -902,7 +902,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -923,7 +923,7 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -944,7 +944,7 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -965,7 +965,7 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -986,7 +986,7 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -1004,7 +1004,7 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
|
||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -1020,7 +1020,7 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
|
||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -1036,7 +1036,7 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
|
||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -1049,7 +1049,7 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
|
||||
const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -1065,7 +1065,7 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
|
||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -1143,7 +1143,6 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
||||
default:
|
||||
printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "common.hpp"
|
||||
#include "element_wise.hpp"
|
||||
|
||||
void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
||||
static void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
||||
const int ne10, const int ne11, const int ne12,
|
||||
const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
@@ -20,7 +20,7 @@ void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
||||
}
|
||||
}
|
||||
|
||||
void gelu_f32(const float * x, float * dst, const int k,
|
||||
static void gelu_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const float GELU_COEF_A = 0.044715f;
|
||||
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
@@ -37,7 +37,7 @@ void gelu_f32(const float * x, float * dst, const int k,
|
||||
sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi)));
|
||||
}
|
||||
|
||||
void silu_f32(const float * x, float * dst, const int k,
|
||||
static void silu_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -48,7 +48,7 @@ void silu_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i]));
|
||||
}
|
||||
|
||||
void gelu_quick_f32(const float *x, float *dst, int k,
|
||||
static void gelu_quick_f32(const float *x, float *dst, int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const float GELU_QUICK_COEF = -1.702f;
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
@@ -59,7 +59,7 @@ void gelu_quick_f32(const float *x, float *dst, int k,
|
||||
dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i])));
|
||||
}
|
||||
|
||||
void tanh_f32(const float *x, float *dst, int k,
|
||||
static void tanh_f32(const float *x, float *dst, int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -69,7 +69,7 @@ void tanh_f32(const float *x, float *dst, int k,
|
||||
dst[i] = sycl::tanh((float)(x[i]));
|
||||
}
|
||||
|
||||
void relu_f32(const float * x, float * dst, const int k,
|
||||
static void relu_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -80,7 +80,7 @@ void relu_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = sycl::fmax((float)(x[i]), (float)0);
|
||||
}
|
||||
|
||||
void sigmoid_f32(const float * x, float * dst, const int k,
|
||||
static void sigmoid_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -91,7 +91,7 @@ void sigmoid_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = 1.0f / (1.0f + sycl::native::exp(-x[i]));
|
||||
}
|
||||
|
||||
void sqrt_f32(const float * x, float * dst, const int k,
|
||||
static void sqrt_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -102,7 +102,7 @@ void sqrt_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = sycl::sqrt(x[i]);
|
||||
}
|
||||
|
||||
void sin_f32(const float * x, float * dst, const int k,
|
||||
static void sin_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -113,7 +113,7 @@ void sin_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = sycl::sin(x[i]);
|
||||
}
|
||||
|
||||
void cos_f32(const float * x, float * dst, const int k,
|
||||
static void cos_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -124,7 +124,7 @@ void cos_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = sycl::cos(x[i]);
|
||||
}
|
||||
|
||||
void hardsigmoid_f32(const float * x, float * dst, const int k,
|
||||
static void hardsigmoid_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -135,7 +135,7 @@ void hardsigmoid_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
void hardswish_f32(const float * x, float * dst, const int k,
|
||||
static void hardswish_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -146,7 +146,7 @@ void hardswish_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
void exp_f32(const float * x, float * dst, const int k,
|
||||
static void exp_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -157,7 +157,7 @@ void exp_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = sycl::exp(x[i]);
|
||||
}
|
||||
|
||||
void log_f32(const float * x, float * dst, const int k,
|
||||
static void log_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -173,7 +173,7 @@ void log_f32(const float * x, float * dst, const int k,
|
||||
}
|
||||
}
|
||||
|
||||
void neg_f32(const float * x, float * dst, const int k,
|
||||
static void neg_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -184,7 +184,7 @@ void neg_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = -x[i];
|
||||
}
|
||||
|
||||
void step_f32(const float * x, float * dst, const int k,
|
||||
static void step_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -195,7 +195,7 @@ void step_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = x[i] > 0.0f;
|
||||
}
|
||||
|
||||
void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
|
||||
static void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -206,7 +206,7 @@ void leaky_relu_f32(const float *x, float *dst, const int k, const float negativ
|
||||
sycl::fmin((float)(x[i]), 0.0f) * negative_slope;
|
||||
}
|
||||
|
||||
void sqr_f32(const float * x, float * dst, const int k,
|
||||
static void sqr_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -217,7 +217,7 @@ void sqr_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = x[i] * x[i];
|
||||
}
|
||||
|
||||
void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
|
||||
static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
|
||||
const int nb02, const int nb03, const int ne10, const int ne11,
|
||||
const int ne12, const int ne13, const float sf0, const float sf1,
|
||||
const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
|
||||
@@ -240,7 +240,7 @@ void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
|
||||
dst[index] = *(const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
|
||||
}
|
||||
|
||||
void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
|
||||
static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
int nidx = item_ct1.get_local_id(2) +
|
||||
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
||||
@@ -262,7 +262,7 @@ void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const i
|
||||
|
||||
|
||||
|
||||
void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||
static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||
const int n_elements, const int ne10, const int ne11,
|
||||
const int ne12, const int nb1, const int nb2,
|
||||
const int offset, queue_ptr stream) {
|
||||
@@ -277,7 +277,7 @@ void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||
});
|
||||
}
|
||||
|
||||
void gelu_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void gelu_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -289,7 +289,7 @@ void gelu_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void silu_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void silu_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -301,7 +301,7 @@ void silu_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -313,7 +313,7 @@ void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void tanh_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void tanh_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -325,7 +325,7 @@ void tanh_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void relu_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void relu_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -337,7 +337,7 @@ void relu_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -349,7 +349,7 @@ void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void hardswish_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void hardswish_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -361,7 +361,7 @@ void hardswish_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void exp_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void exp_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -373,7 +373,7 @@ void exp_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void log_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void log_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -385,7 +385,7 @@ void log_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void neg_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void neg_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -397,7 +397,7 @@ void neg_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void step_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void step_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -409,7 +409,7 @@ void step_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void sigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void sigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -421,7 +421,7 @@ void sigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void sqrt_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void sqrt_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -433,7 +433,7 @@ void sqrt_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void sin_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void sin_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -445,7 +445,7 @@ void sin_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void cos_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void cos_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -457,7 +457,7 @@ void cos_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
|
||||
const float negative_slope,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
|
||||
@@ -470,7 +470,7 @@ void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void sqr_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void sqr_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -482,7 +482,7 @@ void sqr_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
|
||||
static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
|
||||
const int nb02, const int nb03, const int ne10, const int ne11,
|
||||
const int ne12, const int ne13, const float sf0, const float sf1,
|
||||
const float sf2, const float sf3, queue_ptr stream) {
|
||||
@@ -496,7 +496,7 @@ void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01
|
||||
});
|
||||
}
|
||||
|
||||
void pad_f32_sycl(const float *x, float *dst, const int ne00,
|
||||
static void pad_f32_sycl(const float *x, float *dst, const int ne00,
|
||||
const int ne01, const int ne02, const int ne0,
|
||||
const int ne1, const int ne2, queue_ptr stream) {
|
||||
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
|
||||
|
||||
@@ -207,7 +207,7 @@ static void get_rows_sycl_reorder(ggml_backend_sycl_context & ctx, const ggml_te
|
||||
const size_t nrows = ne01;
|
||||
const sycl::half* src0_dq = (const sycl::half*)(src0_q + nrows * ncols / 2);
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
|
||||
k_get_rows_reorder<qk, qr, dq_reorder>(
|
||||
src0_dd, src0_dq, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
||||
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
||||
@@ -302,7 +302,6 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *s
|
||||
// TODO: k-quants
|
||||
GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@
|
||||
static bool g_sycl_loaded = false;
|
||||
int g_ggml_sycl_debug = 0;
|
||||
int g_ggml_sycl_disable_optimize = 0;
|
||||
int g_ggml_sycl_disable_graph = 0;
|
||||
|
||||
static ggml_sycl_device_info ggml_sycl_init() {
|
||||
ggml_sycl_device_info info = {};
|
||||
@@ -95,7 +96,7 @@ const ggml_sycl_device_info & ggml_sycl_info() {
|
||||
return info;
|
||||
}
|
||||
|
||||
void print_device_detail(int id, sycl::device &device, std::string device_type) {
|
||||
static void print_device_detail(int id, sycl::device &device, std::string device_type) {
|
||||
|
||||
dpct::device_info prop;
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(
|
||||
@@ -118,7 +119,7 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
|
||||
global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
|
||||
}
|
||||
|
||||
void print_device_opt_feature(int device_count) {
|
||||
static void print_device_opt_feature(int device_count) {
|
||||
GGML_LOG_INFO("SYCL Optimization Feature:\n");
|
||||
GGML_LOG_INFO(
|
||||
"|ID| Device Type|Reorder|\n");
|
||||
@@ -191,10 +192,12 @@ static void ggml_check_sycl() try {
|
||||
if (!initialized) {
|
||||
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
|
||||
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 0);
|
||||
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
|
||||
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
|
||||
GGML_LOG_INFO("Running with Environment Variables:\n");
|
||||
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
|
||||
GGML_LOG_INFO("Build with Macros:\n");
|
||||
#if defined(GGML_SYCL_FORCE_MMQ)
|
||||
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
|
||||
@@ -401,7 +404,7 @@ catch (sycl::exception const &exc) {
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
|
||||
static void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
|
||||
const void *ptr_src, size_t size) {
|
||||
char *host_buf = (char *)malloc(size);
|
||||
q_src.memcpy(host_buf, (const char *)ptr_src, size).wait();
|
||||
@@ -620,7 +623,7 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
|
||||
return &ggml_backend_sycl_buffer_types[device];
|
||||
}
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
|
||||
static ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
|
||||
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
|
||||
|
||||
int device = ctx->device;
|
||||
@@ -1682,7 +1685,7 @@ static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(num_blocks * block_size, block_size),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -1703,7 +1706,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x,
|
||||
nchannels_y, item_ct1);
|
||||
});
|
||||
@@ -1723,7 +1726,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
|
||||
row_stride_x, channel_stride_x,
|
||||
nchannels_y / nchannels_x, item_ct1);
|
||||
@@ -1764,7 +1767,7 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
||||
const sycl::range<3> block_nums(1, nrows, 1);
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
k_sum_rows_f32(x, dst, ncols, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -2696,6 +2699,12 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
|
||||
@@ -2914,7 +2923,7 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ggml_sycl_supports_dmmv(enum ggml_type type) {
|
||||
static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
@@ -3287,7 +3296,7 @@ static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
|
||||
|
||||
void ggml_sycl_set_main_device(const int main_device) try {
|
||||
static void ggml_sycl_set_main_device(const int main_device) try {
|
||||
if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
|
||||
return;
|
||||
}
|
||||
@@ -3308,7 +3317,7 @@ catch (sycl::exception const &exc) {
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) {
|
||||
static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) {
|
||||
if (!g_sycl_loaded) return false;
|
||||
|
||||
if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
|
||||
@@ -3410,6 +3419,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
||||
case GGML_OP_RMS_NORM:
|
||||
ggml_sycl_rms_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_L2_NORM:
|
||||
ggml_sycl_l2_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
|
||||
return false;
|
||||
@@ -3487,6 +3499,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
ggml_sycl_op_rwkv_wkv6(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
ggml_sycl_op_rwkv_wkv7(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
ggml_sycl_op_gated_linear_attn(ctx, dst);
|
||||
break;
|
||||
@@ -3626,7 +3641,7 @@ catch (sycl::exception const &exc) {
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
void reorder_qw(char *data_device, const int ncols, const int nrows,
|
||||
static void reorder_qw(char *data_device, const int ncols, const int nrows,
|
||||
size_t size, size_t offset, dpct::queue_ptr stream) {
|
||||
auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
|
||||
SYCL_CHECK(
|
||||
@@ -3640,7 +3655,7 @@ void reorder_qw(char *data_device, const int ncols, const int nrows,
|
||||
|
||||
stream->parallel_for(
|
||||
size / sizeof(block_q4_0),
|
||||
[=](auto i) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
const block_q4_0* x = (const block_q4_0*)tmp_buf;
|
||||
const int ib = i;
|
||||
|
||||
@@ -3654,7 +3669,7 @@ void reorder_qw(char *data_device, const int ncols, const int nrows,
|
||||
sycl::free(tmp_buf, *stream);
|
||||
}
|
||||
|
||||
void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||
static void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||
char*data_device = (char*)src0->data;
|
||||
size_t ncols = src0->ne[0];
|
||||
size_t nrows = src0->ne[1];
|
||||
@@ -3663,7 +3678,7 @@ void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||
reorder_qw(data_device, ncols, nrows, size, 0, stream);
|
||||
}
|
||||
|
||||
void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
|
||||
static void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
|
||||
ggml_tensor *src0 = dst->src[0];
|
||||
ggml_tensor *src1 = dst->src[1];
|
||||
|
||||
@@ -3676,7 +3691,7 @@ void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
|
||||
}
|
||||
}
|
||||
|
||||
void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx) {
|
||||
static void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx) {
|
||||
dpct::queue_ptr stream = ctx->stream();
|
||||
if (ctx->optimized_graph) {
|
||||
return;
|
||||
@@ -3687,10 +3702,9 @@ void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx)
|
||||
if (ctx->opt_feature.reorder) opt_for_reorder(cgraph->nodes[i], stream);
|
||||
}
|
||||
}
|
||||
static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
|
||||
ggml_sycl_set_main_device(sycl_ctx->device);
|
||||
|
||||
static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * sycl_ctx, ggml_cgraph * cgraph) {
|
||||
ggml_sycl_set_main_device(sycl_ctx->device);
|
||||
if (!g_ggml_sycl_disable_optimize) optimize_graph_once(cgraph, sycl_ctx);
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
@@ -3712,7 +3726,46 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
|
||||
}
|
||||
GGML_ASSERT(ok);
|
||||
}
|
||||
}
|
||||
|
||||
static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
auto * sycl_ctx = static_cast<ggml_backend_sycl_context *>(backend->context);
|
||||
|
||||
#ifdef GGML_SYCL_GRAPH
|
||||
if (!g_ggml_sycl_disable_graph) {
|
||||
if (!sycl_ctx->exec_graph && !dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph)) {
|
||||
GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device);
|
||||
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()));
|
||||
model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
|
||||
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
|
||||
model_sycl_graph.end_recording();
|
||||
|
||||
if (!sycl_ctx->exec_graph) {
|
||||
auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});
|
||||
sycl_ctx->exec_graph = std::make_unique<
|
||||
sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
|
||||
} else {
|
||||
try {
|
||||
sycl_ctx->exec_graph->update(model_sycl_graph);
|
||||
GGML_SYCL_DEBUG("[SYCL-GRAPH] update success\n");
|
||||
} catch (sycl::exception const & e) {
|
||||
GGML_SYCL_DEBUG("[SYCL-GRAPH] Exception when updating graph, %s\n", e.what());
|
||||
auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});
|
||||
sycl_ctx->exec_graph = std::make_unique<
|
||||
sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
|
||||
}
|
||||
}
|
||||
|
||||
sycl_ctx->stream()->ext_oneapi_graph(*(sycl_ctx->exec_graph));
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
|
||||
}
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
@@ -3866,7 +3919,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
} break;
|
||||
}
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
case GGML_UNARY_OP_NEG:
|
||||
@@ -3884,7 +3937,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
@@ -3915,7 +3967,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
} break;
|
||||
}
|
||||
case GGML_OP_OUT_PROD:
|
||||
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
|
||||
case GGML_OP_GET_ROWS:
|
||||
@@ -3932,7 +3984,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
} break;
|
||||
}
|
||||
case GGML_OP_CPY:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
@@ -3983,12 +4035,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
} break;
|
||||
}
|
||||
case GGML_OP_CONCAT:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
||||
} break;
|
||||
}
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_NONE:
|
||||
@@ -4012,6 +4064,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return (op->src[0]->type == GGML_TYPE_F32);
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_SCALE:
|
||||
@@ -4045,6 +4098,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
return true;
|
||||
default:
|
||||
|
||||
@@ -3017,7 +3017,6 @@ void ggml_sycl_op_mul_mat_q(
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
|
||||
+19
-20
@@ -495,7 +495,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
|
||||
VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@@ -519,7 +519,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
|
||||
VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@@ -543,7 +543,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
|
||||
VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@@ -567,7 +567,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
|
||||
VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@@ -591,7 +591,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
|
||||
VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@@ -615,7 +615,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
|
||||
VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@@ -639,7 +639,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
|
||||
VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@@ -663,7 +663,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
|
||||
VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@@ -687,7 +687,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
|
||||
VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@@ -711,7 +711,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
|
||||
VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@@ -734,7 +734,7 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -755,7 +755,7 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -777,7 +777,7 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -799,7 +799,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -821,7 +821,7 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -843,7 +843,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -864,7 +864,7 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -886,7 +886,7 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -908,7 +908,7 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@@ -1003,7 +1003,6 @@ void ggml_sycl_op_mul_mat_vec_q(
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
GGML_UNUSED(src1);
|
||||
|
||||
+114
-6
@@ -180,6 +180,50 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
|
||||
}
|
||||
}
|
||||
|
||||
static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
|
||||
const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
|
||||
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
||||
item_ct1.get_local_id(1);
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int nthreads = item_ct1.get_local_range(2);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[row * ncols + col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||
if (block_size > WARP_SIZE) {
|
||||
|
||||
int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
||||
int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = tmp;
|
||||
}
|
||||
/*
|
||||
DPCT1118:3: SYCL group functions and algorithms must be encountered in
|
||||
converged control flow. You may need to adjust the code.
|
||||
*/
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
size_t nreduce = nwarps / WARP_SIZE;
|
||||
tmp = 0.f;
|
||||
for (size_t i = 0; i < nreduce; i += 1)
|
||||
{
|
||||
tmp += s_sum[lane_id + i * WARP_SIZE];
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||
}
|
||||
|
||||
const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row * ncols + col] = scale * x[row * ncols + col];
|
||||
}
|
||||
}
|
||||
|
||||
static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
const int nrows, const float eps,
|
||||
queue_ptr stream, int device) {
|
||||
@@ -191,7 +235,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
nullptr, WARP_SIZE);
|
||||
});
|
||||
@@ -214,7 +258,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
@@ -233,7 +277,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
group_norm_f32(
|
||||
x, dst, group_size, ne_elements, eps_ct4, item_ct1,
|
||||
nullptr, WARP_SIZE);
|
||||
@@ -260,7 +304,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
group_norm_f32(x, dst, group_size, ne_elements,
|
||||
eps_ct4, item_ct1,
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
@@ -281,7 +325,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
rms_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
nullptr, WARP_SIZE);
|
||||
});
|
||||
@@ -303,7 +347,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
rms_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
@@ -311,6 +355,48 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
}
|
||||
}
|
||||
|
||||
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
const int nrows, const float eps,
|
||||
queue_ptr stream, int device) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
||||
if (ncols < 1024) {
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
nullptr, WARP_SIZE);
|
||||
});
|
||||
});
|
||||
}
|
||||
else {
|
||||
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
||||
const sycl::range<3> block_dims(1, 1, work_group_size);
|
||||
/*
|
||||
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
|
||||
the limit. To get the device limit, query
|
||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||
*/
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
||||
cgh);
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
|
||||
ggml_tensor* dst, const float* src0_dd,
|
||||
const float* src1_dd, float* dst_dd,
|
||||
@@ -376,3 +462,25 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
|
||||
(void)dst;
|
||||
(void)src1_dd;
|
||||
}
|
||||
|
||||
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||
const ggml_tensor* src1, ggml_tensor* dst,
|
||||
const float* src0_dd, const float* src1_dd,
|
||||
float* dst_dd,
|
||||
const queue_ptr& main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
||||
|
||||
(void)src1;
|
||||
(void)dst;
|
||||
(void)src1_dd;
|
||||
}
|
||||
|
||||
@@ -32,4 +32,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
|
||||
float* dst_dd,
|
||||
const queue_ptr& main_stream);
|
||||
|
||||
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||
const ggml_tensor* src1, ggml_tensor* dst,
|
||||
const float* src0_dd, const float* src1_dd,
|
||||
float* dst_dd,
|
||||
const queue_ptr& main_stream);
|
||||
|
||||
#endif // GGML_SYCL_NORM_HPP
|
||||
|
||||
@@ -132,7 +132,7 @@ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst,
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
|
||||
nrows_y, scale, max_bias, m0,
|
||||
m1, n_head_log2, item_ct1,
|
||||
|
||||
@@ -0,0 +1,305 @@
|
||||
#include <sycl/sycl.hpp>
|
||||
#include "wkv.hpp"
|
||||
|
||||
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
|
||||
|
||||
// Helper function for the main kernel
|
||||
template <int block_size>
|
||||
static void rwkv_wkv6_f32_kernel(
|
||||
const int B, const int T, const int C, const int H,
|
||||
const float* k, const float* v, const float* r,
|
||||
const float* tf, const float* td, const float* s,
|
||||
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int bid = item_ct1.get_group(2);
|
||||
|
||||
const int head_size = block_size;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
// Set up shared memory pointers
|
||||
float* _k = shared_mem;
|
||||
float* _r = _k + head_size;
|
||||
float* _tf = _r + head_size;
|
||||
float* _td = _tf + head_size;
|
||||
|
||||
// Local state array
|
||||
float state[block_size];
|
||||
|
||||
// Load initial state
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||
}
|
||||
|
||||
// Sync threads before shared memory operations
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Load time-mixing parameters
|
||||
_tf[tid] = tf[head_i * head_size + tid];
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Main sequence processing loop
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
||||
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
||||
t += C) {
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Load current timestep data to shared memory
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0;
|
||||
|
||||
// Process in chunks of 4 for better vectorization
|
||||
sycl::float4 k4, r4, tf4, td4, s4;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
// Load data in vec4 chunks
|
||||
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
||||
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
||||
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
// Compute key-value product
|
||||
sycl::float4 kv4 = k4 * _v;
|
||||
|
||||
// Accumulate weighted sum
|
||||
y += sycl::dot(r4, tf4 * kv4 + s4);
|
||||
|
||||
// Update state
|
||||
s4 = s4 * td4 + kv4;
|
||||
|
||||
// Store updated state
|
||||
state[j] = s4.x();
|
||||
state[j+1] = s4.y();
|
||||
state[j+2] = s4.z();
|
||||
state[j+3] = s4.w();
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
// Save final state
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
static void rwkv_wkv7_f32_kernel(
|
||||
const int B, const int T, const int C, const int H,
|
||||
const float* r, const float* w, const float* k, const float* v,
|
||||
const float* a, const float* b, const float* s,
|
||||
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int bid = item_ct1.get_group(2);
|
||||
|
||||
const int head_size = block_size;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
float* _r = shared_mem;
|
||||
float* _w = _r + head_size;
|
||||
float* _k = _w + head_size;
|
||||
float* _a = _k + head_size;
|
||||
float* _b = _a + head_size;
|
||||
|
||||
float state[block_size];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
|
||||
}
|
||||
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
||||
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
||||
t += C) {
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
_r[tid] = r[t];
|
||||
_w[tid] = w[t];
|
||||
_k[tid] = k[t];
|
||||
_a[tid] = a[t];
|
||||
_b[tid] = b[t];
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0, sa = 0;
|
||||
sycl::float4 a4, s4;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
a4 = sycl::float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
||||
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
sa += sycl::dot(a4, s4);
|
||||
}
|
||||
|
||||
sycl::float4 r4, w4, k4, b4;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
w4 = sycl::float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
||||
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
b4 = sycl::float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
||||
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
sycl::float4 kv4 = k4 * _v;
|
||||
|
||||
s4 = s4 * w4 + kv4 + sa * b4;
|
||||
y += sycl::dot(r4, s4);
|
||||
|
||||
state[j] = s4.x();
|
||||
state[j+1] = s4.y();
|
||||
state[j+2] = s4.z();
|
||||
state[j+3] = s4.w();
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
const ggml_tensor *src0 = dst->src[0];
|
||||
const ggml_tensor *src1 = dst->src[1];
|
||||
|
||||
const float* k_d = (const float*)dst->src[0]->data;
|
||||
const float* v_d = (const float*)dst->src[1]->data;
|
||||
const float* r_d = (const float*)dst->src[2]->data;
|
||||
const float* tf_d = (const float*)dst->src[3]->data;
|
||||
const float* td_d = (const float*)dst->src[4]->data;
|
||||
const float* s_d = (const float*)dst->src[5]->data;
|
||||
float* dst_d = (float*)dst->data;
|
||||
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
// Calculate execution configuration
|
||||
const size_t shared_mem_size = C / H * 4 * sizeof(float); // For k, r, tf, td
|
||||
sycl::range<3> block_dims(1, 1, C / H);
|
||||
sycl::range<3> grid_dims(1, 1, B * H);
|
||||
|
||||
// Submit kernel
|
||||
if (C / H == WKV_BLOCK_SIZE) {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
|
||||
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
);
|
||||
});
|
||||
});
|
||||
} else {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
|
||||
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
GGML_UNUSED(src0);
|
||||
GGML_UNUSED(src1);
|
||||
}
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
const ggml_tensor *src0 = dst->src[0];
|
||||
const ggml_tensor *src1 = dst->src[1];
|
||||
|
||||
const float* r_d = (const float*)dst->src[0]->data;
|
||||
const float* w_d = (const float*)dst->src[1]->data;
|
||||
const float* k_d = (const float*)dst->src[2]->data;
|
||||
const float* v_d = (const float*)dst->src[3]->data;
|
||||
const float* a_d = (const float*)dst->src[4]->data;
|
||||
const float* b_d = (const float*)dst->src[5]->data;
|
||||
const float* s_d = (const float*)dst->src[6]->data;
|
||||
float* dst_d = (float*)dst->data;
|
||||
|
||||
const int64_t B = dst->src[6]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2);
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
// Calculate execution configuration
|
||||
const size_t shared_mem_size = C / H * 5 * sizeof(float); // For r, w, k, a, b
|
||||
sycl::range<3> block_dims(1, 1, C / H);
|
||||
sycl::range<3> grid_dims(1, 1, B * H);
|
||||
|
||||
// Submit kernel
|
||||
if (C / H == WKV_BLOCK_SIZE) {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
|
||||
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
);
|
||||
});
|
||||
});
|
||||
} else {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
|
||||
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
GGML_UNUSED(src0);
|
||||
GGML_UNUSED(src1);
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
#ifndef GGML_SYCL_WKV_HPP
|
||||
#define GGML_SYCL_WKV_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_WKV_HPP
|
||||
@@ -1,143 +0,0 @@
|
||||
#include <sycl/sycl.hpp>
|
||||
#include "wkv6.hpp"
|
||||
|
||||
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
|
||||
|
||||
// Helper function for the main kernel
|
||||
static void rwkv_wkv_f32_kernel(
|
||||
const int B, const int T, const int C, const int H,
|
||||
const float* k, const float* v, const float* r,
|
||||
const float* tf, const float* td, const float* s,
|
||||
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int bid = item_ct1.get_group(2);
|
||||
|
||||
const int head_size = WKV_BLOCK_SIZE;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
// Set up shared memory pointers
|
||||
float* _k = shared_mem;
|
||||
float* _r = _k + head_size;
|
||||
float* _tf = _r + head_size;
|
||||
float* _td = _tf + head_size;
|
||||
|
||||
// Local state array
|
||||
float state[WKV_BLOCK_SIZE];
|
||||
|
||||
// Load initial state
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||
}
|
||||
|
||||
// Sync threads before shared memory operations
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Load time-mixing parameters
|
||||
_tf[tid] = tf[head_i * head_size + tid];
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Main sequence processing loop
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
||||
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
||||
t += C) {
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Load current timestep data to shared memory
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0;
|
||||
|
||||
// Process in chunks of 4 for better vectorization
|
||||
sycl::float4 k4, r4, tf4, td4, s4;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
// Load data in vec4 chunks
|
||||
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
||||
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
||||
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
// Compute key-value product
|
||||
sycl::float4 kv4 = k4 * _v;
|
||||
|
||||
// Accumulate weighted sum
|
||||
y += sycl::dot(r4, tf4 * kv4 + s4);
|
||||
|
||||
// Update state
|
||||
s4 = s4 * td4 + kv4;
|
||||
|
||||
// Store updated state
|
||||
state[j] = s4.x();
|
||||
state[j+1] = s4.y();
|
||||
state[j+2] = s4.z();
|
||||
state[j+3] = s4.w();
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
// Save final state
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
const ggml_tensor *src0 = dst->src[0];
|
||||
const ggml_tensor *src1 = dst->src[1];
|
||||
|
||||
const float* k_d = (const float*)dst->src[0]->data;
|
||||
const float* v_d = (const float*)dst->src[1]->data;
|
||||
const float* r_d = (const float*)dst->src[2]->data;
|
||||
const float* tf_d = (const float*)dst->src[3]->data;
|
||||
const float* td_d = (const float*)dst->src[4]->data;
|
||||
const float* s_d = (const float*)dst->src[5]->data;
|
||||
float* dst_d = (float*)dst->data;
|
||||
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
// Calculate execution configuration
|
||||
const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
|
||||
sycl::range<3> block_dims(1, 1, C / H);
|
||||
sycl::range<3> grid_dims(1, 1, B * H);
|
||||
|
||||
// Submit kernel
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv_f32_kernel(
|
||||
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
GGML_UNUSED(src0);
|
||||
GGML_UNUSED(src1);
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
#ifndef GGML_SYCL_WKV6_HPP
|
||||
#define GGML_SYCL_WKV6_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
|
||||
#endif // GGML_SYCL_WKV6_HPP
|
||||
@@ -304,6 +304,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_group_norm_f32;
|
||||
vk_pipeline pipeline_rms_norm_f32;
|
||||
vk_pipeline pipeline_rms_norm_back_f32;
|
||||
vk_pipeline pipeline_l2_norm_f32;
|
||||
vk_pipeline pipeline_gelu_f32;
|
||||
vk_pipeline pipeline_gelu_quick_f32;
|
||||
vk_pipeline pipeline_silu_f32;
|
||||
@@ -328,6 +329,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_timestep_embedding_f32;
|
||||
vk_pipeline pipeline_pool2d_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv6_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv7_f32;
|
||||
vk_pipeline pipeline_opt_step_adamw_f32;
|
||||
|
||||
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
||||
@@ -629,6 +631,13 @@ struct vk_op_rwkv_wkv6_push_constants {
|
||||
uint32_t H;
|
||||
};
|
||||
|
||||
struct vk_op_rwkv_wkv7_push_constants {
|
||||
uint32_t B;
|
||||
uint32_t T;
|
||||
uint32_t C;
|
||||
uint32_t H;
|
||||
};
|
||||
|
||||
// Allow pre-recording command buffers
|
||||
struct vk_staging_memcpy {
|
||||
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
||||
@@ -2263,6 +2272,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
@@ -2374,6 +2384,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
for (auto &c : compiles) {
|
||||
@@ -2512,13 +2524,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
|
||||
if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) {
|
||||
device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE);
|
||||
#if defined(_WIN32)
|
||||
} else if (device->vendor_id == VK_VENDOR_ID_NVIDIA) {
|
||||
} else {
|
||||
// Limit batching of allocations to 1GB by default to avoid fragmentation issues
|
||||
device->suballocation_block_size = 1024*1024*1024;
|
||||
#endif
|
||||
} else {
|
||||
device->suballocation_block_size = device->max_memory_allocation_size;
|
||||
}
|
||||
device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
|
||||
|
||||
@@ -5473,6 +5481,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_rms_norm_back_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_L2_NORM:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_l2_norm_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(dst)) {
|
||||
case GGML_UNARY_OP_SILU:
|
||||
@@ -5612,6 +5625,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_rwkv_wkv6_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_rwkv_wkv7_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_opt_step_adamw_f32;
|
||||
@@ -5859,6 +5877,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
@@ -6108,23 +6127,17 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
||||
}, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) {
|
||||
const ggml_tensor * k = dst->src[0];
|
||||
const ggml_tensor * v = dst->src[1];
|
||||
const ggml_tensor * r = dst->src[2];
|
||||
const ggml_tensor * tf = dst->src[3];
|
||||
const ggml_tensor * td = dst->src[4];
|
||||
const ggml_tensor * state = dst->src[5];
|
||||
static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
|
||||
GGML_ASSERT(version == 6 || version == 7);
|
||||
int num_srcs = version == 6 ? 6 : 7;
|
||||
|
||||
for (int i = 0; i < num_srcs; i++) {
|
||||
GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
|
||||
}
|
||||
|
||||
GGML_ASSERT(!ggml_is_quantized(k->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(v->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(r->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(tf->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(td->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(state->type));
|
||||
GGML_ASSERT(dst->buffer != nullptr);
|
||||
|
||||
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
|
||||
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
|
||||
GGML_ASSERT(pipeline != nullptr);
|
||||
|
||||
if (dryrun) {
|
||||
@@ -6133,89 +6146,73 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
|
||||
}
|
||||
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
|
||||
ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
|
||||
ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context;
|
||||
ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context;
|
||||
ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
|
||||
ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
|
||||
for (int i = 0; i < num_srcs; i++) {
|
||||
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
|
||||
}
|
||||
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
|
||||
vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr;
|
||||
size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0;
|
||||
bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
|
||||
vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
|
||||
size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
|
||||
bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
|
||||
|
||||
if (ctx->device->uma) {
|
||||
ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
|
||||
ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
|
||||
ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
|
||||
ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
|
||||
ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
|
||||
ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
|
||||
for (int i = 0; i < num_srcs; i++) {
|
||||
ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
|
||||
srcs_uma[i] = d_srcs[i] != nullptr;
|
||||
}
|
||||
|
||||
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
|
||||
|
||||
K_uma = d_K != nullptr;
|
||||
V_uma = d_V != nullptr;
|
||||
R_uma = d_R != nullptr;
|
||||
TF_uma = d_TF != nullptr;
|
||||
TD_uma = d_TD != nullptr;
|
||||
STATE_uma = d_State != nullptr;
|
||||
DST_uma = d_D != nullptr;
|
||||
dst_uma = d_D != nullptr;
|
||||
}
|
||||
|
||||
if (!K_uma) {
|
||||
d_K = k_buf_ctx->dev_buffer;
|
||||
k_offset = vk_tensor_offset(k) + k->view_offs;
|
||||
uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 };
|
||||
for (int i = 0; i < num_srcs; i++) {
|
||||
src_sizes[i] = ggml_nbytes(dst->src[i]);
|
||||
if (!srcs_uma[i]) {
|
||||
d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
|
||||
src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
|
||||
}
|
||||
}
|
||||
if (!V_uma) {
|
||||
d_V = v_buf_ctx->dev_buffer;
|
||||
v_offset = vk_tensor_offset(v) + v->view_offs;
|
||||
}
|
||||
if (!R_uma) {
|
||||
d_R = r_buf_ctx->dev_buffer;
|
||||
r_offset = vk_tensor_offset(r) + r->view_offs;
|
||||
}
|
||||
if (!TF_uma) {
|
||||
d_TF = tf_buf_ctx->dev_buffer;
|
||||
tf_offset = vk_tensor_offset(tf) + tf->view_offs;
|
||||
}
|
||||
if (!TD_uma) {
|
||||
d_TD = td_buf_ctx->dev_buffer;
|
||||
td_offset = vk_tensor_offset(td) + td->view_offs;
|
||||
}
|
||||
if (!STATE_uma) {
|
||||
d_State = state_buf_ctx->dev_buffer;
|
||||
state_offset = vk_tensor_offset(state) + state->view_offs;
|
||||
}
|
||||
if (!DST_uma) {
|
||||
|
||||
const uint64_t dst_size = ggml_nbytes(dst);
|
||||
if (!dst_uma) {
|
||||
d_D = dst_buf_ctx->dev_buffer;
|
||||
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||
}
|
||||
|
||||
const uint64_t k_size = ggml_nbytes(k);
|
||||
const uint64_t v_size = ggml_nbytes(v);
|
||||
const uint64_t r_size = ggml_nbytes(r);
|
||||
const uint64_t tf_size = ggml_nbytes(tf);
|
||||
const uint64_t td_size = ggml_nbytes(td);
|
||||
const uint64_t state_size = ggml_nbytes(state);
|
||||
const uint64_t dst_size = ggml_nbytes(dst);
|
||||
|
||||
std::array<uint32_t, 3> elements = {
|
||||
(uint32_t)(pc.B * pc.H),
|
||||
1,
|
||||
1
|
||||
};
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||
vk_subbuffer{ d_K, k_offset, k_size },
|
||||
vk_subbuffer{ d_V, v_offset, v_size },
|
||||
vk_subbuffer{ d_R, r_offset, r_size },
|
||||
vk_subbuffer{ d_TF, tf_offset, tf_size },
|
||||
vk_subbuffer{ d_TD, td_offset, td_size },
|
||||
vk_subbuffer{ d_State, state_offset, state_size },
|
||||
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
|
||||
if (version == 6) {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
|
||||
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
|
||||
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
|
||||
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
|
||||
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
|
||||
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
||||
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
|
||||
} else if (version == 7) {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
|
||||
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
|
||||
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
|
||||
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
|
||||
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
|
||||
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
||||
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
|
||||
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||
}, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements);
|
||||
} else {
|
||||
// shouldn't happen
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
||||
@@ -6224,7 +6221,7 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
const size_t n_heads = dst->src[0]->ne[1];
|
||||
const size_t n_seqs = dst->src[5]->ne[1];
|
||||
|
||||
ggml_vk_op_f32_rwkv6(
|
||||
ggml_vk_op_f32_wkv(
|
||||
ctx, subctx, dst,
|
||||
{
|
||||
(uint32_t)n_seqs,
|
||||
@@ -6232,6 +6229,26 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
(uint32_t)n_embed,
|
||||
(uint32_t)n_heads,
|
||||
},
|
||||
6,
|
||||
dryrun
|
||||
);
|
||||
}
|
||||
|
||||
static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
||||
const size_t seq_length = dst->src[0]->ne[2];
|
||||
const size_t n_embed = dst->ne[0];
|
||||
const size_t n_heads = dst->src[0]->ne[1];
|
||||
const size_t n_seqs = dst->src[6]->ne[1];
|
||||
|
||||
ggml_vk_op_f32_wkv(
|
||||
ctx, subctx, dst,
|
||||
{
|
||||
(uint32_t)n_seqs,
|
||||
(uint32_t)seq_length,
|
||||
(uint32_t)n_embed,
|
||||
(uint32_t)n_heads,
|
||||
},
|
||||
7,
|
||||
dryrun
|
||||
);
|
||||
}
|
||||
@@ -6533,6 +6550,11 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
float * op_params = (float *)dst->op_params;
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
||||
}
|
||||
@@ -7528,6 +7550,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
@@ -7544,6 +7567,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
@@ -7590,6 +7614,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_UNARY:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
@@ -7707,6 +7732,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_L2_NORM:
|
||||
ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(node)) {
|
||||
@@ -7797,6 +7826,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
|
||||
break;
|
||||
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun);
|
||||
|
||||
break;
|
||||
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
||||
|
||||
@@ -7870,6 +7904,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
@@ -7889,6 +7924,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
@@ -8806,6 +8842,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
@@ -8835,6 +8872,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
return true;
|
||||
@@ -9219,6 +9257,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
|
||||
} else if (tensor->op == GGML_OP_SILU_BACK) {
|
||||
tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
|
||||
} else if (tensor->op == GGML_OP_L2_NORM) {
|
||||
const float eps = ((float *) tensor->op_params)[0];
|
||||
tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
|
||||
} else if (tensor->op == GGML_OP_SOFT_MAX) {
|
||||
if (src1 != nullptr) {
|
||||
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
||||
@@ -9338,6 +9379,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
|
||||
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
|
||||
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
|
||||
} else if (tensor->op == GGML_OP_RWKV_WKV7) {
|
||||
tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
|
||||
src_clone[4], src_clone[5], src_clone[6]);
|
||||
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
|
||||
src_clone[0]->flags = src0->flags;
|
||||
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
find_package (Threads REQUIRED)
|
||||
find_program(GLSLC_EXECUTABLE glslc)
|
||||
if(NOT GLSLC_EXECUTABLE)
|
||||
message(FATAL_ERROR "glslc not found.")
|
||||
endif()
|
||||
|
||||
set(TARGET vulkan-shaders-gen)
|
||||
add_executable(${TARGET} vulkan-shaders-gen.cpp)
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#define BLOCK_SIZE 512
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
shared FLOAT_TYPE sum[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
|
||||
sum[tid] += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
sum[tid] += sum[tid + s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
|
||||
}
|
||||
}
|
||||
@@ -434,6 +434,7 @@ void process_shaders() {
|
||||
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
||||
@@ -528,6 +529,8 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
for (auto &c : compiles) {
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
#define BLOCK_SIZE 64
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout(push_constant) uniform Parameters {
|
||||
uint B;
|
||||
uint T;
|
||||
uint C;
|
||||
uint H;
|
||||
};
|
||||
|
||||
layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; };
|
||||
layout(binding = 1) readonly buffer WBuf { A_TYPE w[]; };
|
||||
layout(binding = 2) readonly buffer KBuf { A_TYPE k[]; };
|
||||
layout(binding = 3) readonly buffer VBuf { A_TYPE v[]; };
|
||||
layout(binding = 4) readonly buffer ABuf { A_TYPE a[]; };
|
||||
layout(binding = 5) readonly buffer BBuf { A_TYPE b[]; };
|
||||
layout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; };
|
||||
layout(binding = 7) buffer DstBuf { A_TYPE dst[]; };
|
||||
|
||||
shared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint head_size = BLOCK_SIZE;
|
||||
const uint batch_id = gl_WorkGroupID.x / H;
|
||||
const uint head_id = gl_WorkGroupID.x % H;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
const uint state_size = C * head_size;
|
||||
const uint n_seq_tokens = T / B;
|
||||
|
||||
if (batch_id >= B || head_id >= H) {
|
||||
return;
|
||||
}
|
||||
|
||||
A_TYPE state[BLOCK_SIZE];
|
||||
[[unroll]] for (uint i = 0; i < head_size; i++) {
|
||||
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
||||
+ tid * head_size + i];
|
||||
}
|
||||
|
||||
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
||||
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
||||
|
||||
for (uint t = start_t; t < end_t; t += C) {
|
||||
barrier();
|
||||
_r[tid] = r[t];
|
||||
_w[tid] = w[t];
|
||||
_k[tid] = k[t];
|
||||
_a[tid] = a[t];
|
||||
_b[tid] = b[t];
|
||||
barrier();
|
||||
|
||||
A_TYPE sa = 0.0;
|
||||
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
|
||||
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
vec4 a_vec = vec4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
||||
sa += dot(s_vec, a_vec);
|
||||
}
|
||||
|
||||
const A_TYPE v_val = v[t];
|
||||
A_TYPE y = 0.0;
|
||||
|
||||
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
|
||||
vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
||||
vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
||||
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
vec4 kv = k_vec * v_val;
|
||||
s_vec = s_vec * w_vec + kv + sa * b_vec;
|
||||
y += dot(r_vec, s_vec);
|
||||
|
||||
state[j] = s_vec.x;
|
||||
state[j+1] = s_vec.y;
|
||||
state[j+2] = s_vec.z;
|
||||
state[j+3] = s_vec.w;
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
[[unroll]] for (uint i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
||||
+ tid * head_size + i] = state[i];
|
||||
}
|
||||
}
|
||||
+85
-2
@@ -929,6 +929,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"RMS_NORM",
|
||||
"RMS_NORM_BACK",
|
||||
"GROUP_NORM",
|
||||
"L2_NORM",
|
||||
|
||||
"MUL_MAT",
|
||||
"MUL_MAT_ID",
|
||||
@@ -977,6 +978,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"ADD_REL_POS",
|
||||
"RWKV_WKV6",
|
||||
"GATED_LINEAR_ATTN",
|
||||
"RWKV_WKV7",
|
||||
|
||||
"UNARY",
|
||||
|
||||
@@ -996,7 +998,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"OPT_STEP_ADAMW",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
||||
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@@ -1026,6 +1028,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"rms_norm(x)",
|
||||
"rms_norm_back(x)",
|
||||
"group_norm(x)",
|
||||
"l2_norm(x)",
|
||||
|
||||
"X*Y",
|
||||
"X[i]*Y",
|
||||
@@ -1074,6 +1077,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"add_rel_pos(x)",
|
||||
"rwkv_wkv6(k, v, r, tf, td, s)",
|
||||
"gated_linear_attn(k, v, q, gate, s)",
|
||||
"rwkv_wkv7(r, w, k, v, a, b, s)",
|
||||
|
||||
"unary(x)",
|
||||
|
||||
@@ -1093,7 +1097,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"adamw(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
||||
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -2686,6 +2690,37 @@ struct ggml_tensor * ggml_group_norm_inplace(
|
||||
return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
|
||||
}
|
||||
|
||||
// ggml_l2_norm
|
||||
|
||||
static struct ggml_tensor * ggml_l2_norm_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float eps,
|
||||
bool inplace) {
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params_f32(result, 0, eps);
|
||||
|
||||
result->op = GGML_OP_L2_NORM;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_l2_norm(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float eps) {
|
||||
return ggml_l2_norm_impl(ctx, a, eps, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_l2_norm_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float eps) {
|
||||
return ggml_l2_norm_impl(ctx, a, eps, true);
|
||||
}
|
||||
|
||||
// ggml_mul_mat
|
||||
|
||||
static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
||||
@@ -4720,6 +4755,54 @@ struct ggml_tensor * ggml_gated_linear_attn(
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_rwkv_wkv7
|
||||
|
||||
struct ggml_tensor * ggml_rwkv_wkv7(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * r,
|
||||
struct ggml_tensor * w,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * state) {
|
||||
GGML_ASSERT(ggml_is_contiguous(r));
|
||||
GGML_ASSERT(ggml_is_contiguous(w));
|
||||
GGML_ASSERT(ggml_is_contiguous(k));
|
||||
GGML_ASSERT(ggml_is_contiguous(v));
|
||||
GGML_ASSERT(ggml_is_contiguous(a));
|
||||
GGML_ASSERT(ggml_is_contiguous(b));
|
||||
GGML_ASSERT(ggml_is_contiguous(state));
|
||||
|
||||
const int64_t S = k->ne[0];
|
||||
const int64_t H = k->ne[1];
|
||||
const int64_t n_tokens = k->ne[2];
|
||||
const int64_t n_seqs = state->ne[1];
|
||||
{
|
||||
GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens);
|
||||
GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens);
|
||||
GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
|
||||
GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens);
|
||||
GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens);
|
||||
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
|
||||
}
|
||||
|
||||
// concat output and new_state
|
||||
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
result->op = GGML_OP_RWKV_WKV7;
|
||||
result->src[0] = r;
|
||||
result->src[1] = w;
|
||||
result->src[2] = k;
|
||||
result->src[3] = v;
|
||||
result->src[4] = a;
|
||||
result->src[5] = b;
|
||||
result->src[6] = state;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_unary
|
||||
|
||||
static struct ggml_tensor * ggml_unary_impl(
|
||||
|
||||
+110
-16
@@ -118,22 +118,26 @@ class Keys:
|
||||
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "{arch}.attention.head_count"
|
||||
HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
|
||||
MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
|
||||
CLAMP_KQV = "{arch}.attention.clamp_kqv"
|
||||
KEY_LENGTH = "{arch}.attention.key_length"
|
||||
VALUE_LENGTH = "{arch}.attention.value_length"
|
||||
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
|
||||
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
|
||||
GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon"
|
||||
GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups"
|
||||
CAUSAL = "{arch}.attention.causal"
|
||||
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
|
||||
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
|
||||
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
|
||||
SLIDING_WINDOW = "{arch}.attention.sliding_window"
|
||||
SCALE = "{arch}.attention.scale"
|
||||
HEAD_COUNT = "{arch}.attention.head_count"
|
||||
HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
|
||||
MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
|
||||
CLAMP_KQV = "{arch}.attention.clamp_kqv"
|
||||
KEY_LENGTH = "{arch}.attention.key_length"
|
||||
VALUE_LENGTH = "{arch}.attention.value_length"
|
||||
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
|
||||
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
|
||||
GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon"
|
||||
GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups"
|
||||
CAUSAL = "{arch}.attention.causal"
|
||||
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
|
||||
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
|
||||
DECAY_LORA_RANK = "{arch}.attention.decay_lora_rank"
|
||||
ICLR_LORA_RANK = "{arch}.attention.iclr_lora_rank"
|
||||
VALUE_RESIDUAL_MIX_LORA_RANK = "{arch}.attention.value_residual_mix_lora_rank"
|
||||
GATE_LORA_RANK = "{arch}.attention.gate_lora_rank"
|
||||
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
|
||||
SLIDING_WINDOW = "{arch}.attention.sliding_window"
|
||||
SCALE = "{arch}.attention.scale"
|
||||
|
||||
class Rope:
|
||||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
@@ -257,6 +261,8 @@ class MODEL_ARCH(IntEnum):
|
||||
STARCODER2 = auto()
|
||||
RWKV6 = auto()
|
||||
RWKV6QWEN2 = auto()
|
||||
RWKV7 = auto()
|
||||
ARWKV7 = auto()
|
||||
MAMBA = auto()
|
||||
XVERSE = auto()
|
||||
COMMAND_R = auto()
|
||||
@@ -329,8 +335,20 @@ class MODEL_TENSOR(IntEnum):
|
||||
SSM_A = auto()
|
||||
SSM_D = auto()
|
||||
SSM_OUT = auto()
|
||||
TIME_MIX_W0 = auto()
|
||||
TIME_MIX_W1 = auto()
|
||||
TIME_MIX_W2 = auto()
|
||||
TIME_MIX_A0 = auto()
|
||||
TIME_MIX_A1 = auto()
|
||||
TIME_MIX_A2 = auto()
|
||||
TIME_MIX_V0 = auto()
|
||||
TIME_MIX_V1 = auto()
|
||||
TIME_MIX_V2 = auto()
|
||||
TIME_MIX_G1 = auto()
|
||||
TIME_MIX_G2 = auto()
|
||||
TIME_MIX_K_K = auto()
|
||||
TIME_MIX_K_A = auto()
|
||||
TIME_MIX_R_K = auto()
|
||||
TIME_MIX_LERP_X = auto()
|
||||
TIME_MIX_LERP_K = auto()
|
||||
TIME_MIX_LERP_V = auto()
|
||||
@@ -445,6 +463,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||
MODEL_ARCH.RWKV6: "rwkv6",
|
||||
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
|
||||
MODEL_ARCH.RWKV7: "rwkv7",
|
||||
MODEL_ARCH.ARWKV7: "arwkv7",
|
||||
MODEL_ARCH.MAMBA: "mamba",
|
||||
MODEL_ARCH.XVERSE: "xverse",
|
||||
MODEL_ARCH.COMMAND_R: "command-r",
|
||||
@@ -517,8 +537,20 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
|
||||
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
||||
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
||||
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
|
||||
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
|
||||
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
|
||||
MODEL_TENSOR.TIME_MIX_A0: "blk.{bid}.time_mix_a0",
|
||||
MODEL_TENSOR.TIME_MIX_A1: "blk.{bid}.time_mix_a1",
|
||||
MODEL_TENSOR.TIME_MIX_A2: "blk.{bid}.time_mix_a2",
|
||||
MODEL_TENSOR.TIME_MIX_V0: "blk.{bid}.time_mix_v0",
|
||||
MODEL_TENSOR.TIME_MIX_V1: "blk.{bid}.time_mix_v1",
|
||||
MODEL_TENSOR.TIME_MIX_V2: "blk.{bid}.time_mix_v2",
|
||||
MODEL_TENSOR.TIME_MIX_G1: "blk.{bid}.time_mix_g1",
|
||||
MODEL_TENSOR.TIME_MIX_G2: "blk.{bid}.time_mix_g2",
|
||||
MODEL_TENSOR.TIME_MIX_K_K: "blk.{bid}.time_mix_k_k",
|
||||
MODEL_TENSOR.TIME_MIX_K_A: "blk.{bid}.time_mix_k_a",
|
||||
MODEL_TENSOR.TIME_MIX_R_K: "blk.{bid}.time_mix_r_k",
|
||||
MODEL_TENSOR.TIME_MIX_LERP_X: "blk.{bid}.time_mix_lerp_x",
|
||||
MODEL_TENSOR.TIME_MIX_LERP_K: "blk.{bid}.time_mix_lerp_k",
|
||||
MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v",
|
||||
@@ -1172,6 +1204,68 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.RWKV7: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_NORM_2,
|
||||
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
|
||||
MODEL_TENSOR.TIME_MIX_W0,
|
||||
MODEL_TENSOR.TIME_MIX_W1,
|
||||
MODEL_TENSOR.TIME_MIX_W2,
|
||||
MODEL_TENSOR.TIME_MIX_A0,
|
||||
MODEL_TENSOR.TIME_MIX_A1,
|
||||
MODEL_TENSOR.TIME_MIX_A2,
|
||||
MODEL_TENSOR.TIME_MIX_V0,
|
||||
MODEL_TENSOR.TIME_MIX_V1,
|
||||
MODEL_TENSOR.TIME_MIX_V2,
|
||||
MODEL_TENSOR.TIME_MIX_G1,
|
||||
MODEL_TENSOR.TIME_MIX_G2,
|
||||
MODEL_TENSOR.TIME_MIX_K_K,
|
||||
MODEL_TENSOR.TIME_MIX_K_A,
|
||||
MODEL_TENSOR.TIME_MIX_R_K,
|
||||
MODEL_TENSOR.TIME_MIX_KEY,
|
||||
MODEL_TENSOR.TIME_MIX_VALUE,
|
||||
MODEL_TENSOR.TIME_MIX_RECEPTANCE,
|
||||
MODEL_TENSOR.TIME_MIX_LN,
|
||||
MODEL_TENSOR.TIME_MIX_OUTPUT,
|
||||
MODEL_TENSOR.CHANNEL_MIX_LERP_K,
|
||||
MODEL_TENSOR.CHANNEL_MIX_KEY,
|
||||
MODEL_TENSOR.CHANNEL_MIX_VALUE,
|
||||
],
|
||||
MODEL_ARCH.ARWKV7: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
|
||||
MODEL_TENSOR.TIME_MIX_W0,
|
||||
MODEL_TENSOR.TIME_MIX_W1,
|
||||
MODEL_TENSOR.TIME_MIX_W2,
|
||||
MODEL_TENSOR.TIME_MIX_A0,
|
||||
MODEL_TENSOR.TIME_MIX_A1,
|
||||
MODEL_TENSOR.TIME_MIX_A2,
|
||||
MODEL_TENSOR.TIME_MIX_V0,
|
||||
MODEL_TENSOR.TIME_MIX_V1,
|
||||
MODEL_TENSOR.TIME_MIX_V2,
|
||||
MODEL_TENSOR.TIME_MIX_G1,
|
||||
MODEL_TENSOR.TIME_MIX_G2,
|
||||
MODEL_TENSOR.TIME_MIX_K_K,
|
||||
MODEL_TENSOR.TIME_MIX_K_A,
|
||||
MODEL_TENSOR.TIME_MIX_R_K,
|
||||
MODEL_TENSOR.TIME_MIX_KEY,
|
||||
MODEL_TENSOR.TIME_MIX_VALUE,
|
||||
MODEL_TENSOR.TIME_MIX_RECEPTANCE,
|
||||
MODEL_TENSOR.TIME_MIX_LN,
|
||||
MODEL_TENSOR.TIME_MIX_OUTPUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.MAMBA: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
||||
@@ -767,6 +767,18 @@ class GGUFWriter:
|
||||
def add_kv_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_decay_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.DECAY_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_iclr_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.ICLR_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_value_residual_mix_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_gate_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_relative_attn_buckets_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)
|
||||
|
||||
|
||||
+100
-31
@@ -27,7 +27,8 @@ class TensorNameMap:
|
||||
"embedding.word_embeddings", # chatglm
|
||||
"transformer.token_embeddings", # openelm
|
||||
"shared", # t5
|
||||
"rwkv.embeddings", # rwkv
|
||||
"rwkv.embeddings", # rwkv6
|
||||
"model.embeddings", # rwkv7
|
||||
),
|
||||
|
||||
# Token type embeddings
|
||||
@@ -42,6 +43,9 @@ class TensorNameMap:
|
||||
"emb_ln", # nomic-bert
|
||||
"transformer.norm", # openelm
|
||||
"rwkv.blocks.0.pre_ln", # rwkv
|
||||
"rwkv.blocks.0.pre_ln", # rwkv6
|
||||
"model.pre_ln", # rwkv7
|
||||
"model.layers.0.pre_norm", # rwkv7
|
||||
"backbone.norm", # wavtokenizer
|
||||
),
|
||||
|
||||
@@ -81,7 +85,8 @@ class TensorNameMap:
|
||||
"encoder.final_layernorm", # chatglm
|
||||
"transformer.norm", # openelm
|
||||
"model.norm", # nemotron
|
||||
"rwkv.ln_out", # rwkv
|
||||
"rwkv.ln_out", # rwkv6
|
||||
"model.ln_out", # rwkv7
|
||||
"backbone.final_layer_norm", # wavtokenizer
|
||||
),
|
||||
|
||||
@@ -122,14 +127,16 @@ class TensorNameMap:
|
||||
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
|
||||
"encoder.layers.{bid}.input_layernorm", # chatglm
|
||||
"transformer.layers.{bid}.attn_norm", # openelm
|
||||
"rwkv.blocks.{bid}.ln1", # rwkv
|
||||
"rwkv.blocks.{bid}.ln1", # rwkv6
|
||||
"model.layers.{bid}.ln1", # rwkv7
|
||||
),
|
||||
|
||||
# Attention norm 2
|
||||
MODEL_TENSOR.ATTN_NORM_2: (
|
||||
"transformer.h.{bid}.ln_attn", # falcon40b
|
||||
"encoder.layer.{bid}.layer_norm_1", # jina-v2-code
|
||||
"rwkv.blocks.{bid}.ln2", # rwkv
|
||||
"rwkv.blocks.{bid}.ln2", # rwkv6
|
||||
"model.layers.{bid}.ln2", # rwkv7
|
||||
),
|
||||
|
||||
# Attention query-key-value
|
||||
@@ -462,112 +469,174 @@ class TensorNameMap:
|
||||
"backbone.layers.{bid}.mixer.out_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_W0: (
|
||||
"model.layers.{bid}.attention.w0", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_W1: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2
|
||||
"model.layers.{bid}.attention.w1", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_W2: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2
|
||||
"model.layers.{bid}.attention.w2", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_A0: (
|
||||
"model.layers.{bid}.attention.a0", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_A1: (
|
||||
"model.layers.{bid}.attention.a1", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_A2: (
|
||||
"model.layers.{bid}.attention.a2", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_V0: (
|
||||
"model.layers.{bid}.attention.v0", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_V1: (
|
||||
"model.layers.{bid}.attention.v1", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_V2: (
|
||||
"model.layers.{bid}.attention.v2", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_G1: (
|
||||
"model.layers.{bid}.attention.g1", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_G2: (
|
||||
"model.layers.{bid}.attention.g2", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_K_K: (
|
||||
"model.layers.{bid}.attention.k_k", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_K_A: (
|
||||
"model.layers.{bid}.attention.k_a", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_R_K: (
|
||||
"model.layers.{bid}.attention.r_k", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_X: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_x", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_maa_x", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_x", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_K: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_k", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_maa_k", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_k", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_V: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_v", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_maa_v", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_v", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_R: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_r", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_maa_r", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_r", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_G: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_g", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_maa_g", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_g", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_W: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_maa_w", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_FIRST: (
|
||||
"rwkv.blocks.{bid}.attention.time_faaaa", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_faaaa", # rwkv6
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_DECAY: (
|
||||
"rwkv.blocks.{bid}.attention.time_decay", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_decay", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_decay", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_DECAY_W1: (
|
||||
"rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_decay_w1", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_DECAY_W2: (
|
||||
"rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv v6
|
||||
"rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv6
|
||||
"model.layers.{bid}.self_attn.time_decay_w2", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_KEY: (
|
||||
"rwkv.blocks.{bid}.attention.key", # rwkv
|
||||
"rwkv.blocks.{bid}.attention.key", # rwkv6
|
||||
"model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2
|
||||
"model.layers.{bid}.attention.key", # rwkv7
|
||||
"model.layers.{bid}.attention.k_proj", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_VALUE: (
|
||||
"rwkv.blocks.{bid}.attention.value", # rwkv
|
||||
"rwkv.blocks.{bid}.attention.value", # rwkv6
|
||||
"model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2
|
||||
"model.layers.{bid}.attention.value", # rwkv7
|
||||
"model.layers.{bid}.attention.v_proj", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_RECEPTANCE: (
|
||||
"rwkv.blocks.{bid}.attention.receptance", # rwkv
|
||||
"model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2
|
||||
"rwkv.blocks.{bid}.attention.receptance", # rwkv6
|
||||
"model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2
|
||||
"model.layers.{bid}.attention.receptance", # rwkv7
|
||||
"model.layers.{bid}.attention.r_proj", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_GATE: (
|
||||
"rwkv.blocks.{bid}.attention.gate", # rwkv
|
||||
"model.layers.{bid}.self_attn.gate", # rwkv6qwen2
|
||||
"rwkv.blocks.{bid}.attention.gate", # rwkv6
|
||||
"model.layers.{bid}.self_attn.gate", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LN: (
|
||||
"rwkv.blocks.{bid}.attention.ln_x", # rwkv
|
||||
"rwkv.blocks.{bid}.attention.ln_x", # rwkv6
|
||||
"model.layers.{bid}.attention.ln_x" # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_OUTPUT: (
|
||||
"rwkv.blocks.{bid}.attention.output", # rwkv
|
||||
"rwkv.blocks.{bid}.attention.output", # rwkv6
|
||||
"model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2
|
||||
"model.layers.{bid}.attention.output", # rwkv7
|
||||
"model.layers.{bid}.attention.o_proj", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CHANNEL_MIX_LERP_K: (
|
||||
"rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv v6
|
||||
"rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv6
|
||||
"model.layers.{bid}.feed_forward.x_k", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CHANNEL_MIX_LERP_R: (
|
||||
"rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv v6
|
||||
"rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv6
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CHANNEL_MIX_KEY: (
|
||||
"rwkv.blocks.{bid}.feed_forward.key", # rwkv
|
||||
"rwkv.blocks.{bid}.feed_forward.key", # rwkv6
|
||||
"model.layers.{bid}.feed_forward.key", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: (
|
||||
"rwkv.blocks.{bid}.feed_forward.receptance", # rwkv
|
||||
"rwkv.blocks.{bid}.feed_forward.receptance", # rwkv6
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CHANNEL_MIX_VALUE: (
|
||||
"rwkv.blocks.{bid}.feed_forward.value", # rwkv
|
||||
"rwkv.blocks.{bid}.feed_forward.value", # rwkv6
|
||||
"model.layers.{bid}.feed_forward.value", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_A: (
|
||||
|
||||
@@ -24,7 +24,34 @@ struct llama_adapter_lora_deleter {
|
||||
void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); }
|
||||
};
|
||||
|
||||
struct llama_batch_ext_deleter {
|
||||
void operator()(llama_batch_ext * batch) { llama_batch_ext_free(batch); }
|
||||
};
|
||||
|
||||
typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr;
|
||||
typedef std::unique_ptr<llama_context, llama_context_deleter> llama_context_ptr;
|
||||
typedef std::unique_ptr<llama_sampler, llama_sampler_deleter> llama_sampler_ptr;
|
||||
typedef std::unique_ptr<llama_adapter_lora, llama_adapter_lora_deleter> llama_adapter_lora_ptr;
|
||||
|
||||
struct llama_batch_ext_ptr : std::unique_ptr<llama_batch_ext, llama_batch_ext_deleter> {
|
||||
llama_batch_ext_ptr() : std::unique_ptr<llama_batch_ext, llama_batch_ext_deleter>() {}
|
||||
llama_batch_ext_ptr(llama_batch_ext * batch) : std::unique_ptr<llama_batch_ext, llama_batch_ext_deleter>(batch) {}
|
||||
|
||||
// convenience function to create a batch from text tokens, without worrying about manually freeing it
|
||||
static llama_batch_ext_ptr init_from_text(llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
int32_t pos0,
|
||||
int32_t seq_id,
|
||||
bool output_last) {
|
||||
return llama_batch_ext_ptr(llama_batch_ext_init_from_text(tokens, n_tokens, pos0, seq_id, output_last));
|
||||
}
|
||||
|
||||
// convenience function to create a batch from text embeddings, without worrying about manually freeing it
|
||||
static llama_batch_ext_ptr init_from_embd(float * embd,
|
||||
size_t n_tokens,
|
||||
size_t n_embd,
|
||||
int32_t pos0,
|
||||
int32_t seq_id) {
|
||||
return llama_batch_ext_ptr(llama_batch_ext_init_from_embd(embd, n_tokens, n_embd, pos0, seq_id));
|
||||
}
|
||||
};
|
||||
|
||||
+108
-8
@@ -234,6 +234,9 @@ extern "C" {
|
||||
typedef bool (*llama_progress_callback)(float progress, void * user_data);
|
||||
|
||||
// Input data for llama_decode
|
||||
//
|
||||
// WARN: This struct is DEPRECATED and will be removed in the future, use llama_batch_ext instead
|
||||
//
|
||||
// A llama_batch object can contain input about one or many sequences
|
||||
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
|
||||
//
|
||||
@@ -257,6 +260,10 @@ extern "C" {
|
||||
int8_t * logits; // TODO: rename this to "output"
|
||||
} llama_batch;
|
||||
|
||||
// Input data for llama_decode / llama_encode
|
||||
// It can contain text tokens and embeddings for one or many sequences
|
||||
struct llama_batch_ext;
|
||||
|
||||
enum llama_model_kv_override_type {
|
||||
LLAMA_KV_OVERRIDE_TYPE_INT,
|
||||
LLAMA_KV_OVERRIDE_TYPE_FLOAT,
|
||||
@@ -891,9 +898,9 @@ extern "C" {
|
||||
//
|
||||
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
|
||||
//
|
||||
LLAMA_API struct llama_batch llama_batch_get_one(
|
||||
DEPRECATED(LLAMA_API struct llama_batch llama_batch_get_one(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens);
|
||||
int32_t n_tokens), "use llama_batch_ext_init_from_text instead");
|
||||
|
||||
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
|
||||
// Each token can be assigned up to n_seq_max sequence ids
|
||||
@@ -902,13 +909,98 @@ extern "C" {
|
||||
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
|
||||
// The rest of the llama_batch members are allocated with size n_tokens
|
||||
// All members are left uninitialized
|
||||
LLAMA_API struct llama_batch llama_batch_init(
|
||||
int32_t n_tokens,
|
||||
int32_t embd,
|
||||
int32_t n_seq_max);
|
||||
DEPRECATED(LLAMA_API struct llama_batch llama_batch_init(
|
||||
int32_t n_tokens,
|
||||
int32_t embd,
|
||||
int32_t n_seq_max), "use llama_batch_ext_init instead");
|
||||
|
||||
// Frees a batch of tokens allocated with llama_batch_init()
|
||||
LLAMA_API void llama_batch_free(struct llama_batch batch);
|
||||
DEPRECATED(LLAMA_API void llama_batch_free(struct llama_batch batch),
|
||||
"use llama_batch_ext API instead");
|
||||
|
||||
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
|
||||
// Each token can be assigned up to n_seq_max sequence ids
|
||||
// The batch has to be freed with llama_batch_ext_free()
|
||||
LLAMA_API struct llama_batch_ext * llama_batch_ext_init(
|
||||
int32_t n_tokens,
|
||||
int32_t n_seq_max);
|
||||
|
||||
// Same with llama_batch_init, but initializes the batch with the provided text tokens
|
||||
// First token will be at position pos0
|
||||
// The sequence ID will be fixed to seq_id
|
||||
// If output_last is true, the last token will have output set
|
||||
// The batch has to be freed with llama_batch_ext_free()
|
||||
LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_text(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
int32_t pos0,
|
||||
int32_t seq_id,
|
||||
bool output_last);
|
||||
|
||||
// Same with llama_batch_init, but initializes the batch with the provided raw embeddings
|
||||
// Size of embd should be n_tokens * n_embd
|
||||
// n_embd is the number of embeddings per token, can be obtained from llama_model_n_embd()
|
||||
// First token will be at position pos0
|
||||
// The sequence ID will be fixed to seq_id
|
||||
// The batch has to be freed with llama_batch_ext_free()
|
||||
LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd(
|
||||
float * embd,
|
||||
size_t n_tokens,
|
||||
size_t n_embd,
|
||||
int32_t pos0,
|
||||
int32_t seq_id);
|
||||
|
||||
// Set arbitrary token to the embeddings batch
|
||||
// Note: this is only to be used in conjunction with llama_batch_ext_init_from_embd()
|
||||
// n_pos must match the n_tokens of the batch
|
||||
// Returns -1 if n_pos does not match the n_tokens of the batch
|
||||
LLAMA_API int32_t llama_batch_ext_set_pos(struct llama_batch_ext * batch, llama_pos * pos, size_t n_pos);
|
||||
|
||||
// Get the number of tokens in the batch
|
||||
LLAMA_API int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch);
|
||||
|
||||
// Add text tokens to the batch
|
||||
// Return values:
|
||||
// -1 : not enough space in the batch
|
||||
// -2 : embd is already set, cannot add text tokens
|
||||
// otherwise, returns the output ID
|
||||
LLAMA_API int32_t llama_batch_ext_add_text(
|
||||
struct llama_batch_ext * batch,
|
||||
llama_token token,
|
||||
llama_pos pos,
|
||||
const llama_seq_id * seq_ids,
|
||||
size_t n_seq_ids,
|
||||
bool output);
|
||||
|
||||
// Set output (logits/embeddings) for the token in the ith sequence
|
||||
// If pos == -1, output will be set for the all tokens
|
||||
// Return values:
|
||||
// -1 : the token is not in the batch
|
||||
// otherwise, returns the output ID
|
||||
LLAMA_API int32_t llama_batch_ext_set_output(
|
||||
struct llama_batch_ext * batch,
|
||||
llama_pos pos,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Set output (logits/embeddings) for the last added token
|
||||
// Return values:
|
||||
// -1 : the batch is empty
|
||||
// otherwise, returns the output ID
|
||||
LLAMA_API int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch);
|
||||
|
||||
// Get a "view" from a number of tokens offset
|
||||
// Return returned batch must be freed with llama_batch_ext_free()
|
||||
LLAMA_API struct llama_batch_ext * llama_batch_ext_get_view(
|
||||
struct llama_batch_ext * batch,
|
||||
int32_t offset,
|
||||
int32_t n_tokens);
|
||||
|
||||
// Remove everything from the batch
|
||||
LLAMA_API void llama_batch_ext_clear(struct llama_batch_ext * batch);
|
||||
|
||||
// Frees a batch of tokens allocated with llama_batch_ext_init()
|
||||
// If this is a view, the original batch is not freed
|
||||
LLAMA_API void llama_batch_ext_free(struct llama_batch_ext * batch);
|
||||
|
||||
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
|
||||
// Stores the encoder output internally for later use by the decoder cross-attention layers.
|
||||
@@ -918,13 +1010,21 @@ extern "C" {
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch);
|
||||
|
||||
LLAMA_API int32_t llama_encode_ext(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch_ext * batch);
|
||||
|
||||
// Positive return values does not mean a fatal error, but rather a warning.
|
||||
// 0 - success
|
||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||
// < 0 - error. the KV cache state is restored to the state before this call
|
||||
LLAMA_API int32_t llama_decode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch);
|
||||
struct llama_batch batch);
|
||||
|
||||
LLAMA_API int32_t llama_decode_ext(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch_ext * batch);
|
||||
|
||||
// Set the number of threads used for decoding
|
||||
// n_threads is the number of threads used for generation (single token)
|
||||
|
||||
+102
-16
@@ -59,6 +59,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_EXAONE, "exaone" },
|
||||
{ LLM_ARCH_RWKV6, "rwkv6" },
|
||||
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
|
||||
{ LLM_ARCH_RWKV7, "rwkv7" },
|
||||
{ LLM_ARCH_ARWKV7, "arwkv7" },
|
||||
{ LLM_ARCH_GRANITE, "granite" },
|
||||
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
||||
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
||||
@@ -110,22 +112,26 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
|
||||
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
|
||||
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
||||
{ LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
|
||||
{ LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
|
||||
{ LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
|
||||
{ LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
|
||||
{ LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
|
||||
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
|
||||
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
||||
{ LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
|
||||
{ LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
|
||||
{ LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
|
||||
{ LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
|
||||
{ LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
|
||||
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
|
||||
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_DECAY_LORA_RANK, "%s.attention.decay_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_ICLR_LORA_RANK, "%s.attention.iclr_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, "%s.attention.value_residual_mix_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
||||
@@ -1238,6 +1244,74 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_RWKV7,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
|
||||
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
|
||||
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
|
||||
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
|
||||
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
|
||||
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
|
||||
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
|
||||
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
|
||||
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
|
||||
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
|
||||
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
|
||||
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
|
||||
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
|
||||
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
|
||||
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
|
||||
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
|
||||
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
|
||||
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
|
||||
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
|
||||
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
|
||||
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
|
||||
{ LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" },
|
||||
{ LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" },
|
||||
{ LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_ARWKV7,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
|
||||
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
|
||||
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
|
||||
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
|
||||
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
|
||||
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
|
||||
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
|
||||
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
|
||||
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
|
||||
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
|
||||
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
|
||||
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
|
||||
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
|
||||
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
|
||||
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
|
||||
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
|
||||
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
|
||||
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
|
||||
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
|
||||
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GRANITE,
|
||||
{
|
||||
@@ -1397,6 +1471,12 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_A2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_V1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_V2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_G1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_G2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
@@ -1415,6 +1495,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_TIME_MIX_K_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_TIME_MIX_K_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_TIME_MIX_R_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
@@ -1422,6 +1505,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_W0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_A0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_V0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
|
||||
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
|
||||
@@ -63,6 +63,8 @@ enum llm_arch {
|
||||
LLM_ARCH_EXAONE,
|
||||
LLM_ARCH_RWKV6,
|
||||
LLM_ARCH_RWKV6QWEN2,
|
||||
LLM_ARCH_RWKV7,
|
||||
LLM_ARCH_ARWKV7,
|
||||
LLM_ARCH_GRANITE,
|
||||
LLM_ARCH_GRANITE_MOE,
|
||||
LLM_ARCH_CHAMELEON,
|
||||
@@ -127,6 +129,10 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_CAUSAL,
|
||||
LLM_KV_ATTENTION_Q_LORA_RANK,
|
||||
LLM_KV_ATTENTION_KV_LORA_RANK,
|
||||
LLM_KV_ATTENTION_DECAY_LORA_RANK,
|
||||
LLM_KV_ATTENTION_ICLR_LORA_RANK,
|
||||
LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK,
|
||||
LLM_KV_ATTENTION_GATE_LORA_RANK,
|
||||
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
||||
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
||||
LLM_KV_ATTENTION_SCALE,
|
||||
@@ -250,8 +256,20 @@ enum llm_tensor {
|
||||
LLM_TENSOR_SSM_A,
|
||||
LLM_TENSOR_SSM_D,
|
||||
LLM_TENSOR_SSM_OUT,
|
||||
LLM_TENSOR_TIME_MIX_W0,
|
||||
LLM_TENSOR_TIME_MIX_W1,
|
||||
LLM_TENSOR_TIME_MIX_W2,
|
||||
LLM_TENSOR_TIME_MIX_A0,
|
||||
LLM_TENSOR_TIME_MIX_A1,
|
||||
LLM_TENSOR_TIME_MIX_A2,
|
||||
LLM_TENSOR_TIME_MIX_V0,
|
||||
LLM_TENSOR_TIME_MIX_V1,
|
||||
LLM_TENSOR_TIME_MIX_V2,
|
||||
LLM_TENSOR_TIME_MIX_G1,
|
||||
LLM_TENSOR_TIME_MIX_G2,
|
||||
LLM_TENSOR_TIME_MIX_K_K,
|
||||
LLM_TENSOR_TIME_MIX_K_A,
|
||||
LLM_TENSOR_TIME_MIX_R_K,
|
||||
LLM_TENSOR_TIME_MIX_LERP_X,
|
||||
LLM_TENSOR_TIME_MIX_LERP_W,
|
||||
LLM_TENSOR_TIME_MIX_LERP_K,
|
||||
|
||||
+215
-23
@@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
|
||||
void llama_sbatch::from_batch(const llama_batch_ext & batch, size_t n_embd, bool simple_split, bool logits_all) {
|
||||
GGML_ASSERT(batch.n_tokens >= 0);
|
||||
this->batch = &batch;
|
||||
this->n_embd = n_embd;
|
||||
@@ -273,46 +273,60 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
|
||||
);
|
||||
}
|
||||
|
||||
llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
|
||||
batch = in_batch;
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
if (!batch.pos) {
|
||||
pos.resize(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
llama_batch_allocr::llama_batch_allocr(struct llama_batch & in_batch, llama_pos p0) {
|
||||
batch = new llama_batch_ext{
|
||||
/*n_tokens =*/ in_batch.n_tokens,
|
||||
/*max_tokens =*/ in_batch.n_tokens,
|
||||
/*is_view =*/ false,
|
||||
/*tokens =*/ in_batch.token,
|
||||
/*embd =*/ in_batch.embd,
|
||||
/*pos =*/ in_batch.pos,
|
||||
/*n_seq_id =*/ in_batch.n_seq_id,
|
||||
/*seq_id =*/ in_batch.seq_id,
|
||||
/*logits =*/ in_batch.logits,
|
||||
};
|
||||
GGML_ASSERT(batch->n_tokens > 0);
|
||||
if (!in_batch.pos) {
|
||||
pos.resize(batch->n_tokens);
|
||||
for (int32_t i = 0; i < batch->n_tokens; i++) {
|
||||
pos[i] = i + p0;
|
||||
}
|
||||
batch.pos = pos.data();
|
||||
batch->pos = pos.data();
|
||||
}
|
||||
if (!batch.n_seq_id) {
|
||||
n_seq_id.resize(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
if (!batch->n_seq_id) {
|
||||
n_seq_id.resize(batch->n_tokens);
|
||||
for (int32_t i = 0; i < batch->n_tokens; i++) {
|
||||
n_seq_id[i] = seq_id_0.size();
|
||||
}
|
||||
batch.n_seq_id = n_seq_id.data();
|
||||
batch->n_seq_id = n_seq_id.data();
|
||||
}
|
||||
if (!batch.seq_id) {
|
||||
seq_id.resize(batch.n_tokens + 1);
|
||||
seq_id[batch.n_tokens] = NULL;
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
if (!batch->seq_id) {
|
||||
seq_id.resize(batch->n_tokens + 1);
|
||||
seq_id[batch->n_tokens] = NULL;
|
||||
for (int32_t i = 0; i < batch->n_tokens; i++) {
|
||||
seq_id[i] = seq_id_0.data();
|
||||
}
|
||||
batch.seq_id = seq_id.data();
|
||||
batch->seq_id = seq_id.data();
|
||||
}
|
||||
if (!batch.logits) {
|
||||
logits.resize(batch.n_tokens);
|
||||
if (!batch->logits) {
|
||||
logits.resize(batch->n_tokens);
|
||||
logits[logits.size() - 1] = true;
|
||||
batch.logits = logits.data();
|
||||
batch->logits = logits.data();
|
||||
}
|
||||
}
|
||||
|
||||
llama_batch_allocr::~llama_batch_allocr() {
|
||||
delete batch;
|
||||
}
|
||||
|
||||
//
|
||||
// interface implementation
|
||||
//
|
||||
|
||||
struct llama_batch llama_batch_get_one(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens) {
|
||||
return {
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens) {
|
||||
return llama_batch{
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*tokens =*/ tokens,
|
||||
/*embd =*/ nullptr,
|
||||
@@ -323,6 +337,183 @@ struct llama_batch llama_batch_get_one(
|
||||
};
|
||||
}
|
||||
|
||||
struct llama_batch_ext * llama_batch_ext_init_from_text(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
int32_t pos0,
|
||||
int32_t seq_id,
|
||||
bool output_last) {
|
||||
llama_batch_ext * batch = llama_batch_ext_init(n_tokens, 1);
|
||||
for (int32_t i = 0; i < n_tokens; i++) {
|
||||
llama_batch_ext_add_text(batch, tokens[i], pos0 + i, &seq_id, 1, false);
|
||||
}
|
||||
if (output_last) {
|
||||
llama_batch_ext_set_output_last(batch);
|
||||
}
|
||||
return batch;
|
||||
}
|
||||
|
||||
static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t n_embd, int32_t n_seq_max) {
|
||||
llama_batch_ext * batch = new llama_batch_ext{
|
||||
/*n_tokens =*/ 0,
|
||||
/*max_tokens =*/ n_tokens_alloc,
|
||||
/*is_view =*/ false,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ nullptr,
|
||||
/*pos =*/ nullptr,
|
||||
/*n_seq_id =*/ nullptr,
|
||||
/*seq_id =*/ nullptr,
|
||||
/*logits =*/ nullptr,
|
||||
};
|
||||
|
||||
if (n_embd) {
|
||||
batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * n_embd);
|
||||
} else {
|
||||
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
|
||||
}
|
||||
|
||||
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
|
||||
batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
|
||||
batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
|
||||
for (int i = 0; i < n_tokens_alloc; ++i) {
|
||||
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
|
||||
}
|
||||
batch->seq_id[n_tokens_alloc] = nullptr;
|
||||
|
||||
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
|
||||
|
||||
return batch;
|
||||
}
|
||||
|
||||
struct llama_batch_ext * llama_batch_ext_init(int32_t n_tokens_alloc, int32_t n_seq_max) {
|
||||
return llama_batch_ext_init_impl(n_tokens_alloc, 0, n_seq_max);
|
||||
}
|
||||
|
||||
struct llama_batch_ext * llama_batch_ext_init_from_embd(
|
||||
float * embd,
|
||||
size_t n_tokens,
|
||||
size_t n_embd,
|
||||
int32_t pos0,
|
||||
int32_t seq_id) {
|
||||
struct llama_batch_ext * batch = llama_batch_ext_init_impl(n_tokens, n_embd, 1);
|
||||
memcpy(batch->embd, embd, n_tokens * n_embd * sizeof(float));
|
||||
for (size_t i = 0; i < n_tokens; i++) {
|
||||
batch->pos [i] = pos0 + i;
|
||||
batch->n_seq_id[i] = 1;
|
||||
batch->seq_id [i][0] = seq_id;
|
||||
}
|
||||
return batch;
|
||||
}
|
||||
|
||||
int32_t llama_batch_ext_set_pos(struct llama_batch_ext * batch, llama_pos * pos, size_t n_pos) {
|
||||
if ((size_t) batch->n_tokens != n_pos) {
|
||||
return -1;
|
||||
}
|
||||
memcpy(batch->pos, pos, n_pos * sizeof(llama_pos));
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch) {
|
||||
return batch->n_tokens;
|
||||
}
|
||||
|
||||
int32_t llama_batch_ext_add_text(
|
||||
struct llama_batch_ext * batch,
|
||||
llama_token token,
|
||||
llama_pos pos,
|
||||
const llama_seq_id * seq_ids,
|
||||
size_t n_seq_ids,
|
||||
bool output) {
|
||||
if (batch->n_tokens + 1 > batch->max_tokens) {
|
||||
return -1; // llama_batch size exceeded
|
||||
}
|
||||
if (batch->embd) {
|
||||
return -2; // embd is already set, cannot add text tokens
|
||||
}
|
||||
const int32_t output_id = batch->n_tokens;
|
||||
batch->token [output_id] = token;
|
||||
batch->pos [output_id] = pos;
|
||||
batch->n_seq_id[output_id] = n_seq_ids;
|
||||
for (size_t j = 0; j < n_seq_ids; j++) {
|
||||
batch->seq_id[batch->n_tokens][j] = seq_ids[j];
|
||||
}
|
||||
batch->logits [output_id] = output;
|
||||
batch->n_tokens++;
|
||||
return output_id;
|
||||
}
|
||||
|
||||
int32_t llama_batch_ext_set_output(
|
||||
struct llama_batch_ext * batch,
|
||||
llama_pos pos,
|
||||
llama_seq_id seq_id) {
|
||||
for (int32_t i = 0; i < batch->n_tokens; i++) {
|
||||
// find the token having seq_id
|
||||
for (int32_t j = 0; j < batch->n_seq_id[i]; j++) {
|
||||
if (batch->seq_id[i][j] == seq_id) {
|
||||
// found the sequence
|
||||
if (pos == -1 || pos == batch->pos[i]) {
|
||||
batch->logits[i] = true;
|
||||
return i;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1; // not found
|
||||
}
|
||||
|
||||
int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch) {
|
||||
if (batch->n_tokens == 0) {
|
||||
return -1;
|
||||
}
|
||||
const int32_t output_id = batch->n_tokens - 1;
|
||||
batch->logits[output_id] = true;
|
||||
return output_id;
|
||||
}
|
||||
|
||||
void llama_batch_ext_clear(struct llama_batch_ext * batch) {
|
||||
batch->n_tokens = 0;
|
||||
}
|
||||
|
||||
struct llama_batch_ext * llama_batch_ext_get_view(
|
||||
struct llama_batch_ext * batch,
|
||||
int32_t offset,
|
||||
int32_t n_tokens) {
|
||||
if (batch->embd) {
|
||||
return nullptr; // not yet supported
|
||||
}
|
||||
llama_batch_ext * batch_view = new llama_batch_ext{
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*max_tokens =*/ n_tokens,
|
||||
/*is_view =*/ true,
|
||||
/*tokens =*/ batch->token + offset,
|
||||
/*embd =*/ nullptr,
|
||||
/*pos =*/ batch->pos + offset,
|
||||
/*n_seq_id =*/ batch->n_seq_id + offset,
|
||||
/*seq_id =*/ batch->seq_id + offset,
|
||||
/*logits =*/ batch->logits + offset,
|
||||
};
|
||||
return batch_view;
|
||||
}
|
||||
|
||||
void llama_batch_ext_free(struct llama_batch_ext * batch) {
|
||||
// do not free the members if it's a view
|
||||
if (!batch->is_view) {
|
||||
if (batch->token) free(batch->token);
|
||||
if (batch->embd) free(batch->embd);
|
||||
if (batch->pos) free(batch->pos);
|
||||
if (batch->n_seq_id) free(batch->n_seq_id);
|
||||
if (batch->seq_id) {
|
||||
for (int i = 0; batch->seq_id[i] != nullptr; ++i) {
|
||||
free(batch->seq_id[i]);
|
||||
}
|
||||
free(batch->seq_id);
|
||||
}
|
||||
if (batch->logits) free(batch->logits);
|
||||
}
|
||||
delete batch;
|
||||
}
|
||||
|
||||
// deprecated
|
||||
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
|
||||
llama_batch batch = {
|
||||
/*n_tokens =*/ 0,
|
||||
@@ -353,6 +544,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
||||
return batch;
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_batch_free(struct llama_batch batch) {
|
||||
if (batch.token) free(batch.token);
|
||||
if (batch.embd) free(batch.embd);
|
||||
|
||||
+32
-4
@@ -5,6 +5,32 @@
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
// Input data for llama_decode / llama_encode
|
||||
// A llama_batch_ext object can contain input about one or many sequences
|
||||
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
|
||||
//
|
||||
// - token : the token ids of the input (used when embd is NULL)
|
||||
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
||||
// - pos : the positions of the respective token in the sequence
|
||||
// (if set to NULL, the token position will be tracked automatically by llama_decode)
|
||||
// - seq_id : the sequence to which the respective token belongs
|
||||
// (if set to NULL, the sequence ID will be assumed to be 0)
|
||||
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
||||
// (if set to NULL, only the logits for last token will be returned)
|
||||
//
|
||||
struct llama_batch_ext {
|
||||
int32_t n_tokens;
|
||||
int32_t max_tokens;
|
||||
bool is_view;
|
||||
|
||||
llama_token * token;
|
||||
float * embd;
|
||||
llama_pos * pos;
|
||||
int32_t * n_seq_id;
|
||||
llama_seq_id ** seq_id;
|
||||
int8_t * logits; // TODO: rename this to "output"
|
||||
};
|
||||
|
||||
// very similar to llama_batch,
|
||||
// but has more metadata about sequences
|
||||
struct llama_ubatch {
|
||||
@@ -47,7 +73,7 @@ struct llama_sbatch {
|
||||
std::vector<int64_t> out_ids;
|
||||
std::vector<llama_sbatch_seq> seq;
|
||||
|
||||
const llama_batch * batch = nullptr;
|
||||
const llama_batch_ext * batch = nullptr;
|
||||
|
||||
// buffers for the ubatch
|
||||
std::vector<llama_token> ubatch_token;
|
||||
@@ -70,12 +96,12 @@ struct llama_sbatch {
|
||||
// sequence-wise split
|
||||
llama_ubatch split_seq(size_t n_ubatch);
|
||||
|
||||
void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
|
||||
void from_batch(const llama_batch_ext & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
|
||||
};
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
struct llama_batch_allocr {
|
||||
struct llama_batch batch;
|
||||
struct llama_batch_ext * batch;
|
||||
|
||||
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
||||
std::vector<llama_pos> pos;
|
||||
@@ -84,5 +110,7 @@ struct llama_batch_allocr {
|
||||
std::vector<int8_t> logits;
|
||||
|
||||
// optionally fulfill the batch returned by llama_batch_get_one
|
||||
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
|
||||
llama_batch_allocr(struct llama_batch & in_batch, llama_pos p0);
|
||||
|
||||
~llama_batch_allocr();
|
||||
};
|
||||
|
||||
+46
-26
@@ -4,6 +4,7 @@
|
||||
#include "llama-io.h"
|
||||
#include "llama-mmap.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-batch.h"
|
||||
#include "llama-kv-cache.h"
|
||||
|
||||
#include <cassert>
|
||||
@@ -1000,16 +1001,26 @@ bool llama_context::apply_adapter_cvec(
|
||||
}
|
||||
|
||||
int llama_context::encode(llama_batch & inp_batch) {
|
||||
// temporary allocate memory and convert llama_batch to llama_batch_ext
|
||||
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
|
||||
return encode(*batch_allocr.batch);
|
||||
}
|
||||
|
||||
int llama_context::decode(llama_batch & inp_batch) {
|
||||
// temporary allocate memory and convert llama_batch to llama_batch_ext
|
||||
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
|
||||
return decode(*batch_allocr.batch);
|
||||
}
|
||||
|
||||
int llama_context::encode(llama_batch_ext & inp_batch) {
|
||||
if (inp_batch.n_tokens == 0) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
|
||||
|
||||
const llama_batch & batch = batch_allocr.batch;
|
||||
llama_batch_ext & batch = inp_batch;
|
||||
const int32_t n_tokens = batch.n_tokens;
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
@@ -1057,6 +1068,13 @@ int llama_context::encode(llama_batch & inp_batch) {
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
|
||||
const auto causal_attn_org = cparams.causal_attn;
|
||||
|
||||
// always use non-causal attention for encoder graphs
|
||||
// TODO: this is a tmp solution until we have a proper way to support enc-dec models
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
|
||||
cparams.causal_attn = false;
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
|
||||
|
||||
@@ -1064,6 +1082,8 @@ int llama_context::encode(llama_batch & inp_batch) {
|
||||
|
||||
res->set_inputs(&ubatch);
|
||||
|
||||
cparams.causal_attn = causal_attn_org;
|
||||
|
||||
const auto compute_status = graph_compute(gf, n_tokens > 1);
|
||||
switch (compute_status) {
|
||||
case GGML_STATUS_SUCCESS:
|
||||
@@ -1152,17 +1172,13 @@ int llama_context::encode(llama_batch & inp_batch) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int llama_context::decode(llama_batch & inp_batch) {
|
||||
int llama_context::decode(llama_batch_ext & inp_batch) {
|
||||
if (inp_batch.n_tokens == 0) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
|
||||
|
||||
const llama_batch & batch = batch_allocr.batch;
|
||||
llama_batch_ext & batch = inp_batch;
|
||||
|
||||
const auto & vocab = model.vocab;
|
||||
const auto & hparams = model.hparams;
|
||||
@@ -2738,26 +2754,30 @@ size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, lla
|
||||
|
||||
///
|
||||
|
||||
// deprecated
|
||||
int32_t llama_encode(
|
||||
llama_context * ctx,
|
||||
llama_batch batch) {
|
||||
const int ret = ctx->encode(batch);
|
||||
if (ret != 0) {
|
||||
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
|
||||
}
|
||||
|
||||
return ret;
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch inp_batch) {
|
||||
return ctx->encode(inp_batch);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
int32_t llama_decode(
|
||||
llama_context * ctx,
|
||||
llama_batch batch) {
|
||||
const int ret = ctx->decode(batch);
|
||||
if (ret != 0) {
|
||||
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
||||
}
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch inp_batch) {
|
||||
return ctx->decode(inp_batch);
|
||||
}
|
||||
|
||||
return ret;
|
||||
int32_t llama_encode_ext(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch_ext * inp_batch) {
|
||||
return ctx->encode(*inp_batch);
|
||||
}
|
||||
|
||||
int32_t llama_decode_ext(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch_ext * inp_batch) {
|
||||
return ctx->decode(*inp_batch);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@@ -82,9 +82,13 @@ struct llama_context {
|
||||
int32_t il_start,
|
||||
int32_t il_end);
|
||||
|
||||
// deprecated
|
||||
int encode(llama_batch & inp_batch);
|
||||
int decode(llama_batch & inp_batch);
|
||||
|
||||
int encode(llama_batch_ext & inp_batch);
|
||||
int decode(llama_batch_ext & inp_batch);
|
||||
|
||||
//
|
||||
// state save/load
|
||||
//
|
||||
|
||||
@@ -76,6 +76,10 @@ struct llama_hparams {
|
||||
uint32_t time_decay_extra_dim = 0;
|
||||
uint32_t wkv_head_size = 0;
|
||||
uint32_t token_shift_count = 2;
|
||||
uint32_t n_lora_decay = 0;
|
||||
uint32_t n_lora_iclr = 0;
|
||||
uint32_t n_lora_value_res_mix = 0;
|
||||
uint32_t n_lora_gate = 0;
|
||||
|
||||
float rope_attn_factor = 1.0f;
|
||||
float rope_freq_base_train;
|
||||
|
||||
+565
-16
@@ -32,6 +32,7 @@ const char * llm_type_name(llm_type type) {
|
||||
case LLM_TYPE_109M: return "109M";
|
||||
case LLM_TYPE_137M: return "137M";
|
||||
case LLM_TYPE_160M: return "160M";
|
||||
case LLM_TYPE_190M: return "190M";
|
||||
case LLM_TYPE_220M: return "220M";
|
||||
case LLM_TYPE_250M: return "250M";
|
||||
case LLM_TYPE_270M: return "270M";
|
||||
@@ -48,6 +49,7 @@ const char * llm_type_name(llm_type type) {
|
||||
case LLM_TYPE_1_6B: return "1.6B";
|
||||
case LLM_TYPE_2B: return "2B";
|
||||
case LLM_TYPE_2_8B: return "2.8B";
|
||||
case LLM_TYPE_2_9B: return "2.9B";
|
||||
case LLM_TYPE_3B: return "3B";
|
||||
case LLM_TYPE_4B: return "4B";
|
||||
case LLM_TYPE_6B: return "6B";
|
||||
@@ -1250,6 +1252,36 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_RWKV7:
|
||||
case LLM_ARCH_ARWKV7:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false);
|
||||
ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
|
||||
ml.get_key(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay);
|
||||
ml.get_key(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr);
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix);
|
||||
ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false);
|
||||
ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 12: type = LLM_TYPE_190M; break;
|
||||
case 24:
|
||||
switch (hparams.n_embd) {
|
||||
case 1024: type = LLM_TYPE_450M; break;
|
||||
case 2048: type = LLM_TYPE_1_5B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
} break;
|
||||
case 28:
|
||||
switch (hparams.n_embd) {
|
||||
case 1536: type = LLM_TYPE_1_5B; break;
|
||||
case 3584: type = LLM_TYPE_7B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
} break;
|
||||
case 32: type = LLM_TYPE_2_9B; break; // RWKV-7-World
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GRANITE:
|
||||
case LLM_ARCH_GRANITE_MOE:
|
||||
{
|
||||
@@ -3366,6 +3398,146 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_RWKV7:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// Block 0, LN0
|
||||
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
|
||||
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
const int n_lora_decay = hparams.n_lora_decay;
|
||||
const int n_lora_iclr = hparams.n_lora_iclr;
|
||||
const int n_lora_value_res_mix = hparams.n_lora_value_res_mix;
|
||||
const int n_lora_gate = hparams.n_lora_gate;
|
||||
const int attn_hidden_size = n_embd;
|
||||
const int ffn_size = hparams.n_ff_arr[0];
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
|
||||
|
||||
layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0);
|
||||
layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, 0);
|
||||
|
||||
layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0);
|
||||
layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0);
|
||||
|
||||
layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0);
|
||||
layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0);
|
||||
|
||||
if (i == 0) {
|
||||
// actually not used
|
||||
layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0);
|
||||
layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0);
|
||||
} else {
|
||||
layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0);
|
||||
layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0);
|
||||
}
|
||||
|
||||
layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, 0);
|
||||
layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, 0);
|
||||
|
||||
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0);
|
||||
|
||||
layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0);
|
||||
layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0);
|
||||
layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0);
|
||||
|
||||
layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0);
|
||||
layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
||||
layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
||||
|
||||
layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0);
|
||||
layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
|
||||
|
||||
layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
|
||||
|
||||
layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0);
|
||||
layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0);
|
||||
}
|
||||
|
||||
} break;
|
||||
case LLM_ARCH_ARWKV7:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
const int n_lora_decay = hparams.n_lora_decay;
|
||||
const int n_lora_iclr = hparams.n_lora_iclr;
|
||||
const int n_lora_value_res_mix = hparams.n_lora_value_res_mix;
|
||||
const int n_lora_gate = hparams.n_lora_gate;
|
||||
const int attn_hidden_size = n_embd;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0);
|
||||
layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0);
|
||||
|
||||
layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0);
|
||||
layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0);
|
||||
|
||||
if (i == 0) {
|
||||
// actually not used
|
||||
layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0);
|
||||
layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0);
|
||||
} else {
|
||||
layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0);
|
||||
layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0);
|
||||
}
|
||||
|
||||
layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
try {
|
||||
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0);
|
||||
} catch(std::runtime_error & e) {
|
||||
// ARWKV models may not have gate tensors
|
||||
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0);
|
||||
}
|
||||
|
||||
layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0);
|
||||
layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0);
|
||||
layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0);
|
||||
|
||||
layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0);
|
||||
layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
||||
layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
||||
|
||||
layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
|
||||
} break;
|
||||
case LLM_ARCH_CHAMELEON:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
@@ -10212,6 +10384,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto head_size = hparams.wkv_head_size;
|
||||
const auto n_head = n_embd / head_size;
|
||||
@@ -10224,6 +10397,10 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||
bool is_qrwkv = layer.time_mix_first == nullptr;
|
||||
|
||||
ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
|
||||
|
||||
sx = ggml_reshape_2d(ctx0, sx, n_embd, n_tokens);
|
||||
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
|
||||
|
||||
ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_x), cur);
|
||||
|
||||
xxx = ggml_reshape_4d(
|
||||
@@ -10366,7 +10543,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||
cur = ggml_mul(ctx0, cur, g);
|
||||
cur = build_lora_mm(layer.time_mix_output, cur);
|
||||
|
||||
return cur;
|
||||
return ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -10389,6 +10566,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const llama_layer * layer = &model.layers[il];
|
||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
||||
gf, state_copy, state_mask, ubatch, il
|
||||
@@ -10422,9 +10600,6 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
||||
1
|
||||
);
|
||||
|
||||
cur = build_rwkv6_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV6);
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
token_shift = ggml_concat(ctx0,
|
||||
ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)),
|
||||
ggml_view_3d(ctx0, ffn_norm, n_embd, 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(ffn_norm)),
|
||||
@@ -10432,6 +10607,18 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
||||
);
|
||||
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
|
||||
ffn_norm = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens), inp_out_ids);
|
||||
x_prev = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens), inp_out_ids);
|
||||
cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
|
||||
}
|
||||
|
||||
cur = build_rwkv6_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV6);
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) {
|
||||
cur = ggml_scale(ctx0, cur, 0.5F);
|
||||
}
|
||||
@@ -10444,12 +10631,6 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
|
||||
cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
@@ -10481,10 +10662,9 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const llama_layer * layer = &model.layers[il];
|
||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
||||
gf, state_copy, state_mask, ubatch, il
|
||||
@@ -10508,6 +10688,13 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
|
||||
ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
|
||||
}
|
||||
|
||||
// feed-forward network
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
@@ -10532,10 +10719,358 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_rwkv7_base : public llm_graph_context {
|
||||
const llama_model & model;
|
||||
|
||||
llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {
|
||||
}
|
||||
|
||||
ggml_tensor * build_rwkv7_channel_mix(
|
||||
const llama_layer * layer,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * x_prev,
|
||||
llm_arch arch) const {
|
||||
ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
|
||||
switch (arch) {
|
||||
case LLM_ARCH_RWKV7:
|
||||
{
|
||||
ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur);
|
||||
|
||||
ggml_tensor * k = ggml_sqr(
|
||||
ctx0,
|
||||
ggml_relu(
|
||||
ctx0,
|
||||
build_lora_mm(layer->channel_mix_key, xk)
|
||||
)
|
||||
);
|
||||
|
||||
cur = build_lora_mm(layer->channel_mix_value, k);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * build_rwkv7_time_mix(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * x_prev,
|
||||
ggml_tensor * state_copy,
|
||||
ggml_tensor * state_mask,
|
||||
ggml_tensor *& first_layer_value,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto head_size = hparams.wkv_head_size;
|
||||
const auto head_count = n_embd / head_size;
|
||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
||||
const auto kv_head = kv_self->head;
|
||||
|
||||
const auto & layer = model.layers[il];
|
||||
|
||||
bool has_gating = layer.time_mix_g1 && layer.time_mix_g2;
|
||||
|
||||
ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
|
||||
ggml_tensor * dummy = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_embd, n_seq_tokens, n_seqs, has_gating ? 6 : 5);
|
||||
sx = ggml_repeat(ctx0, sx, dummy);
|
||||
|
||||
ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_fused), cur);
|
||||
|
||||
ggml_tensor * xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
|
||||
ggml_tensor * xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
|
||||
ggml_tensor * xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
|
||||
ggml_tensor * xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
|
||||
ggml_tensor * xa = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
|
||||
ggml_tensor * xg = has_gating ? ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 5 * sizeof(float)) : nullptr;
|
||||
|
||||
ggml_tensor * r = build_lora_mm(layer.time_mix_receptance, xr);
|
||||
ggml_tensor * w = ggml_add(
|
||||
ctx0,
|
||||
ggml_mul_mat(ctx0, layer.time_mix_w2, ggml_tanh(ctx0, ggml_mul_mat(ctx0, layer.time_mix_w1, xw))),
|
||||
layer.time_mix_w0
|
||||
);
|
||||
w = ggml_exp(ctx0, ggml_scale(ctx0, ggml_sigmoid(ctx0, w), -0.606531));
|
||||
|
||||
ggml_tensor * k = build_lora_mm(layer.time_mix_key, xk);
|
||||
ggml_tensor * v = build_lora_mm(layer.time_mix_value, xv);
|
||||
if (first_layer_value == nullptr) {
|
||||
first_layer_value = v;
|
||||
} else {
|
||||
// Add the first layer value as a residual connection.
|
||||
v = ggml_add(ctx0, v,
|
||||
ggml_mul(ctx0,
|
||||
ggml_sub(ctx0, first_layer_value, v),
|
||||
ggml_sigmoid(ctx0, ggml_add(ctx0,
|
||||
ggml_mul_mat(ctx0, layer.time_mix_v2, ggml_mul_mat(ctx0, layer.time_mix_v1, xv)),
|
||||
layer.time_mix_v0
|
||||
)
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
ggml_tensor * g = nullptr;
|
||||
if (layer.time_mix_g1 && layer.time_mix_g2) {
|
||||
g = ggml_mul_mat(ctx0, layer.time_mix_g2, ggml_sigmoid(ctx0, ggml_mul_mat(ctx0, layer.time_mix_g1, xg)));
|
||||
}
|
||||
|
||||
ggml_tensor * a = ggml_sigmoid(ctx0,
|
||||
ggml_add(
|
||||
ctx0,
|
||||
ggml_mul_mat(ctx0, layer.time_mix_a2, ggml_mul_mat(ctx0, layer.time_mix_a1, xa)),
|
||||
layer.time_mix_a0
|
||||
)
|
||||
);
|
||||
|
||||
ggml_tensor * kk = ggml_reshape_3d(ctx0, ggml_mul(ctx0, k, layer.time_mix_k_k), head_size, head_count, n_tokens);
|
||||
kk = ggml_l2_norm(ctx0, kk, 1e-12);
|
||||
|
||||
ggml_tensor * ka = ggml_mul(ctx0, k, layer.time_mix_k_a);
|
||||
k = ggml_add(ctx0, k, ggml_sub(ctx0, ggml_mul(ctx0, a, ka), ka));
|
||||
|
||||
r = ggml_reshape_3d(ctx0, r, head_size, head_count, n_tokens);
|
||||
w = ggml_reshape_3d(ctx0, w, head_size, head_count, n_tokens);
|
||||
k = ggml_reshape_3d(ctx0, k, head_size, head_count, n_tokens);
|
||||
v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
|
||||
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
||||
|
||||
ggml_tensor * wkv_state = build_copy_mask_state(
|
||||
gf, kv_self->v_l[il], state_copy, state_mask,
|
||||
hparams.n_embd_v_s(), n_seqs);
|
||||
|
||||
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
||||
cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
|
||||
wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
|
||||
|
||||
ggml_build_forward_expand(
|
||||
gf,
|
||||
ggml_cpy(
|
||||
ctx0,
|
||||
wkv_state,
|
||||
ggml_view_1d(
|
||||
ctx0,
|
||||
kv_self->v_l[il],
|
||||
hparams.n_embd_v_s() * n_seqs,
|
||||
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il])
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
if (layer.time_mix_ln && layer.time_mix_ln_b) {
|
||||
// group norm with head_count groups
|
||||
cur = ggml_reshape_3d(ctx0, cur, n_embd / head_count, head_count, n_tokens);
|
||||
cur = ggml_norm(ctx0, cur, 64e-5f);
|
||||
|
||||
// Convert back to regular vectors.
|
||||
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
|
||||
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.time_mix_ln), layer.time_mix_ln_b);
|
||||
} else {
|
||||
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
|
||||
}
|
||||
|
||||
ggml_tensor * rk = ggml_sum_rows(ctx0,
|
||||
ggml_mul(ctx0, ggml_mul(ctx0, k, r), ggml_reshape_2d(ctx0, layer.time_mix_r_k, head_size, head_count)));
|
||||
cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, ggml_mul(ctx0, v, rk), n_embd, n_tokens));
|
||||
|
||||
if (has_gating) {
|
||||
cur = ggml_mul(ctx0, cur, g);
|
||||
}
|
||||
cur = build_lora_mm(layer.time_mix_output, cur);
|
||||
|
||||
return ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs);
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_rwkv7 : public llm_build_rwkv7_base {
|
||||
llm_build_rwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
|
||||
GGML_ASSERT(hparams.token_shift_count == 2);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
ggml_tensor * v_first = nullptr;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
||||
|
||||
ggml_tensor * state_copy = build_inp_s_copy();
|
||||
ggml_tensor * state_mask = build_inp_s_mask();
|
||||
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const llama_layer * layer = &model.layers[il];
|
||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
||||
gf, state_copy, state_mask, ubatch, il
|
||||
);
|
||||
|
||||
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
|
||||
ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
|
||||
|
||||
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM, il);
|
||||
cb(att_norm, "attn_norm", il);
|
||||
|
||||
ggml_tensor * x_prev = ggml_concat(
|
||||
ctx0,
|
||||
att_shift,
|
||||
ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0),
|
||||
1
|
||||
);
|
||||
|
||||
cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il);
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
ggml_tensor * ffn_norm = build_norm(ffn_inp, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, il);
|
||||
cb(ffn_norm, "ffn_norm", il);
|
||||
|
||||
x_prev = ggml_concat(
|
||||
ctx0,
|
||||
ffn_shift,
|
||||
ggml_view_3d(ctx0, ffn_norm, n_embd, n_seq_tokens - 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], 0),
|
||||
1
|
||||
);
|
||||
|
||||
token_shift = ggml_concat(ctx0,
|
||||
ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)),
|
||||
ggml_view_3d(ctx0, ffn_norm, n_embd, 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(ffn_norm)),
|
||||
1
|
||||
);
|
||||
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
|
||||
ffn_norm = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens), inp_out_ids);
|
||||
x_prev = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens), inp_out_ids);
|
||||
}
|
||||
|
||||
cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7);
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
||||
llm_build_arwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
|
||||
GGML_ASSERT(n_embd == hparams.n_embd_k_s());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
ggml_tensor * v_first = nullptr;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
ggml_tensor * state_copy = build_inp_s_copy();
|
||||
ggml_tensor * state_mask = build_inp_s_mask();
|
||||
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const llama_layer * layer = &model.layers[il];
|
||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
||||
gf, state_copy, state_mask, ubatch, il
|
||||
);
|
||||
|
||||
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
|
||||
cb(att_norm, "attn_norm", il);
|
||||
|
||||
ggml_tensor * x_prev = ggml_concat(
|
||||
ctx0,
|
||||
token_shift,
|
||||
ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0),
|
||||
1
|
||||
);
|
||||
|
||||
cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il);
|
||||
|
||||
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
|
||||
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
|
||||
ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
|
||||
}
|
||||
|
||||
// feed-forward network
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
@@ -10883,9 +11418,11 @@ llama_memory_i * llama_model::create_memory() const {
|
||||
llama_memory_i * res;
|
||||
|
||||
switch (arch) {
|
||||
case LLM_ARCH_MAMBA:
|
||||
case LLM_ARCH_RWKV6:
|
||||
case LLM_ARCH_RWKV6QWEN2:
|
||||
case LLM_ARCH_MAMBA:
|
||||
case LLM_ARCH_RWKV7:
|
||||
case LLM_ARCH_ARWKV7:
|
||||
{
|
||||
res = new llama_kv_cache_unified(hparams, {
|
||||
/*.get_rope_factors =*/ nullptr
|
||||
@@ -11132,6 +11669,14 @@ llm_graph_result_ptr llama_model::build_graph(
|
||||
{
|
||||
llm = std::make_unique<llm_build_rwkv6qwen2>(*this, params, gf);
|
||||
} break;
|
||||
case LLM_ARCH_RWKV7:
|
||||
{
|
||||
llm = std::make_unique<llm_build_rwkv7>(*this, params, gf);
|
||||
} break;
|
||||
case LLM_ARCH_ARWKV7:
|
||||
{
|
||||
llm = std::make_unique<llm_build_arwkv7>(*this, params, gf);
|
||||
} break;
|
||||
case LLM_ARCH_CHAMELEON:
|
||||
{
|
||||
llm = std::make_unique<llm_build_chameleon>(*this, params, gf);
|
||||
@@ -11245,6 +11790,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_JAIS:
|
||||
case LLM_ARCH_RWKV6:
|
||||
case LLM_ARCH_RWKV6QWEN2:
|
||||
case LLM_ARCH_RWKV7:
|
||||
case LLM_ARCH_ARWKV7:
|
||||
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||
return LLAMA_ROPE_TYPE_NONE;
|
||||
|
||||
@@ -11399,6 +11946,8 @@ bool llama_model_is_recurrent(const llama_model * model) {
|
||||
case LLM_ARCH_MAMBA: return true;
|
||||
case LLM_ARCH_RWKV6: return true;
|
||||
case LLM_ARCH_RWKV6QWEN2: return true;
|
||||
case LLM_ARCH_RWKV7: return true;
|
||||
case LLM_ARCH_ARWKV7: return true;
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ enum llm_type {
|
||||
LLM_TYPE_109M,
|
||||
LLM_TYPE_137M,
|
||||
LLM_TYPE_160M,
|
||||
LLM_TYPE_190M,
|
||||
LLM_TYPE_220M,
|
||||
LLM_TYPE_250M,
|
||||
LLM_TYPE_270M,
|
||||
@@ -45,6 +46,7 @@ enum llm_type {
|
||||
LLM_TYPE_1_6B,
|
||||
LLM_TYPE_2B,
|
||||
LLM_TYPE_2_8B,
|
||||
LLM_TYPE_2_9B,
|
||||
LLM_TYPE_3B,
|
||||
LLM_TYPE_4B,
|
||||
LLM_TYPE_6B,
|
||||
@@ -260,6 +262,20 @@ struct llama_layer {
|
||||
struct ggml_tensor * time_mix_receptance_b = nullptr;
|
||||
struct ggml_tensor * time_mix_gate = nullptr;
|
||||
|
||||
// rwkv7
|
||||
struct ggml_tensor * time_mix_w0 = nullptr;
|
||||
struct ggml_tensor * time_mix_a0 = nullptr;
|
||||
struct ggml_tensor * time_mix_a1 = nullptr;
|
||||
struct ggml_tensor * time_mix_a2 = nullptr;
|
||||
struct ggml_tensor * time_mix_v0 = nullptr;
|
||||
struct ggml_tensor * time_mix_v1 = nullptr;
|
||||
struct ggml_tensor * time_mix_v2 = nullptr;
|
||||
struct ggml_tensor * time_mix_g1 = nullptr;
|
||||
struct ggml_tensor * time_mix_g2 = nullptr;
|
||||
struct ggml_tensor * time_mix_k_k = nullptr;
|
||||
struct ggml_tensor * time_mix_k_a = nullptr;
|
||||
struct ggml_tensor * time_mix_r_k = nullptr;
|
||||
|
||||
struct ggml_tensor * time_mix_ln = nullptr;
|
||||
struct ggml_tensor * time_mix_ln_b = nullptr;
|
||||
struct ggml_tensor * time_mix_output = nullptr;
|
||||
|
||||
+10
-1
@@ -756,10 +756,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
// NOTE: can't use LLM_TN here because the layer number is not known
|
||||
quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
|
||||
|
||||
// do not quantize RWKV's time_mix_first tensors
|
||||
// do not quantize RWKV's small yet 2D weights
|
||||
quantize &= name.find("time_mix_first.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_w0.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_w1.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_w2.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_v0.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_v1.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_v2.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_a0.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_a1.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_a2.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_g1.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_g2.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
|
||||
|
||||
@@ -1916,6 +1916,40 @@ struct test_gla : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_RWKV_WKV7
|
||||
struct test_rwkv_wkv7 : public test_case {
|
||||
const ggml_type type;
|
||||
|
||||
const int64_t head_count;
|
||||
const int64_t head_size;
|
||||
const int64_t n_seq_tokens;
|
||||
const int64_t n_seqs;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
|
||||
}
|
||||
|
||||
test_rwkv_wkv7(ggml_type type = GGML_TYPE_F32,
|
||||
int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
|
||||
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
const int64_t n_tokens = n_seq_tokens * n_seqs;
|
||||
ggml_tensor * r = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * w = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * b = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
// Outputs may become NaN with long seqlen without these normalization
|
||||
a = ggml_l2_norm(ctx, a, 1e-7F);
|
||||
b = ggml_l2_norm(ctx, b, 1e-7F);
|
||||
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
|
||||
ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_MUL_MAT
|
||||
struct test_mul_mat : public test_case {
|
||||
const ggml_type type_a;
|
||||
@@ -2972,6 +3006,32 @@ struct test_group_norm : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_L2_NORM
|
||||
struct test_l2_norm : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const float eps;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR2(type, ne);
|
||||
}
|
||||
|
||||
test_l2_norm(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {64, 64, 320, 1},
|
||||
float eps = 1e-12f)
|
||||
: type(type), ne(ne), eps(eps) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * out = ggml_l2_norm(ctx, a, eps);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_ACC
|
||||
struct test_acc : public test_case {
|
||||
const ggml_type type;
|
||||
@@ -4006,8 +4066,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
|
||||
}
|
||||
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
||||
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
|
||||
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1}));
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
|
||||
@@ -4019,6 +4082,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
|
||||
|
||||
test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 1, 1));
|
||||
test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 1));
|
||||
test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||
test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 128, 4));
|
||||
|
||||
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 1, 1));
|
||||
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 1));
|
||||
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||
|
||||
Reference in New Issue
Block a user