mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-09 07:16:44 +02:00
model : add support for talkie-1930-13b (#22596)
* initial talkie support, coherent * reorder to follow convention * absorb inverse rope * stop folding scalars to improve quantization * use broadcasting instead of duplication * style cleanup * add scaling support to LoraTorchTensor; use that path in conversion * use layer_out_scale instead of embd_skip_scale
This commit is contained in:
@@ -215,6 +215,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"T5EncoderModel": "t5",
|
||||
"T5ForConditionalGeneration": "t5",
|
||||
"T5WithLMHeadModel": "t5",
|
||||
"TalkieForCausalLM": "talkie",
|
||||
"UMT5ForConditionalGeneration": "t5",
|
||||
"UMT5Model": "t5",
|
||||
"UltravoxModel": "ultravox",
|
||||
|
||||
@@ -1622,6 +1622,9 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "62f6fb0a6fd5098caeabb19b07a5c1099cafc8b9c40eab6ea89ece4ec02fbc57":
|
||||
# ref: https://huggingface.co/sarvamai/sarvam-30b
|
||||
res = "sarvam-moe"
|
||||
if chkhsh == "f728162c1315c26e40249849799b4ba3fe584c32084b4795b03eb295e63cb5af":
|
||||
# ref: https://huggingface.co/lewtun/talkie-1930-13b-it-hf
|
||||
res = "talkie"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch import Tensor
|
||||
|
||||
from .base import LazyTorchTensor, ModelBase, TextModel, gguf
|
||||
|
||||
|
||||
@ModelBase.register("TalkieForCausalLM")
|
||||
class TalkieModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.TALKIE
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
# Talkie used F.rms_norm without an explicit eps
|
||||
self.gguf_writer.add_layer_norm_rms_eps(torch.finfo(torch.float32).eps)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
prefix = f"model.blocks.{bid}." if bid is not None else ""
|
||||
suffix = name.removeprefix(prefix)
|
||||
|
||||
if suffix == "attn_gain.a_g":
|
||||
yield self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid, ".scale"), data_torch
|
||||
return
|
||||
elif suffix == "mlp_gain.a_g":
|
||||
yield self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid, ".scale"), data_torch
|
||||
return
|
||||
elif suffix == "lm_head_gain.w_g":
|
||||
self.gguf_writer.add_logit_scale(LazyTorchTensor.to_eager(data_torch).item())
|
||||
return
|
||||
elif suffix in ("attn.attn_query.weight", "attn.attn_key.weight"):
|
||||
# absorb inverse rope
|
||||
head_dim = self.hparams["head_dim"]
|
||||
shape = data_torch.shape
|
||||
data_torch = torch.reshape(data_torch, (-1, head_dim, shape[-1]))
|
||||
signs = torch.ones((1, head_dim, 1), dtype=data_torch.dtype)
|
||||
signs[:, head_dim // 2 :, :] = -1
|
||||
if self.lazy:
|
||||
signs = LazyTorchTensor.from_eager(signs)
|
||||
# (n_head, head_dim, n_in) -> (n_out, n_in)
|
||||
data_torch = torch.reshape(data_torch * signs, shape)
|
||||
elif suffix == "attn.head_gain.head_g":
|
||||
# allow head gain to broadcast
|
||||
data_torch = data_torch.unsqueeze(-1)
|
||||
|
||||
if not name.endswith(".weight"):
|
||||
name += ".weight"
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
@@ -156,6 +156,7 @@ models = [
|
||||
{"name": "kanana2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/kakaocorp/kanana-2-30b-a3b-instruct-2601", },
|
||||
{"name": "f2llmv2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/codefuse-ai/F2LLM-v2-4B", },
|
||||
{"name": "sarvam-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sarvamai/sarvam-30b", },
|
||||
{"name": "talkie", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/lewtun/talkie-1930-13b-it-hf", },
|
||||
]
|
||||
|
||||
# some models are known to be broken upstream, so we will skip them as exceptions
|
||||
|
||||
@@ -208,6 +208,16 @@ class LoraTorchTensor:
|
||||
def to(self, *args, **kwargs):
|
||||
return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs))
|
||||
|
||||
def __mul__(self, other) -> LoraTorchTensor:
|
||||
# Only output-side multiplication for now
|
||||
# W = B @ A, so M_out * W == (M_out * B) @ A
|
||||
if not isinstance(other, (int, float)) and other.shape and other.shape[-1] != 1:
|
||||
raise NotImplementedError
|
||||
return LoraTorchTensor(self._lora_A, self._lora_B * other)
|
||||
|
||||
def __rmul__(self, other) -> LoraTorchTensor:
|
||||
return self * other
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func: Callable, types, args=(), kwargs=None):
|
||||
del types # unused
|
||||
|
||||
@@ -505,6 +505,7 @@ class MODEL_ARCH(IntEnum):
|
||||
LLAMA_EMBED = auto()
|
||||
MAINCODER = auto()
|
||||
KIMI_LINEAR = auto()
|
||||
TALKIE = auto()
|
||||
|
||||
|
||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
@@ -1021,6 +1022,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.LLAMA_EMBED: "llama-embed",
|
||||
MODEL_ARCH.MAINCODER: "maincoder",
|
||||
MODEL_ARCH.KIMI_LINEAR: "kimi-linear",
|
||||
MODEL_ARCH.TALKIE: "talkie",
|
||||
}
|
||||
|
||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||
@@ -4013,6 +4015,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
MODEL_ARCH.TALKIE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.LAYER_OUT_SCALE,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ class TensorNameMap:
|
||||
"encoder", # neobert
|
||||
"model.transformer.wte", # llada
|
||||
"embed_tokens", # qwen3-embedding
|
||||
"model.embed", # talkie
|
||||
),
|
||||
|
||||
# Token type embeddings
|
||||
@@ -259,6 +260,7 @@ class TensorNameMap:
|
||||
"model.transformer.blocks.{bid}.q_proj", # llada
|
||||
"layers.{bid}.self_attn.q_proj", # qwen3-embedding
|
||||
"backbone.layers.{bid}.mixer.q_proj", # nemotron-h
|
||||
"model.blocks.{bid}.attn.attn_query", # talkie
|
||||
),
|
||||
|
||||
# Attention key
|
||||
@@ -279,6 +281,7 @@ class TensorNameMap:
|
||||
"model.transformer.blocks.{bid}.k_proj", # llada
|
||||
"layers.{bid}.self_attn.k_proj", # qwen3-embedding
|
||||
"backbone.layers.{bid}.mixer.k_proj", # nemotron-h
|
||||
"model.blocks.{bid}.attn.attn_key", # talkie
|
||||
),
|
||||
|
||||
# Attention value
|
||||
@@ -298,6 +301,7 @@ class TensorNameMap:
|
||||
"model.transformer.blocks.{bid}.v_proj", # llada
|
||||
"layers.{bid}.self_attn.v_proj", # qwen3-embedding
|
||||
"backbone.layers.{bid}.mixer.v_proj", # nemotron-h
|
||||
"model.blocks.{bid}.attn.attn_value", # talkie
|
||||
),
|
||||
|
||||
# Attention output
|
||||
@@ -336,6 +340,7 @@ class TensorNameMap:
|
||||
"layers.{bid}.self_attn.o_proj", # qwen3-embedding
|
||||
"backbone.layers.{bid}.mixer.o_proj", # nemotron-h
|
||||
"model.layers.{bid}.self_attn.language_expert_dense", # cogvlm
|
||||
"model.blocks.{bid}.attn.attn_resid", # talkie
|
||||
),
|
||||
|
||||
# Attention output norm
|
||||
@@ -508,6 +513,7 @@ class TensorNameMap:
|
||||
"layers.{bid}.mlp.up_proj", # qwen3-embedding
|
||||
"backbone.layers.{bid}.mixer.up_proj", # nemotron-h
|
||||
"model.layers.{bid}.mlp.language_mlp.up_proj", # cogvlm
|
||||
"model.blocks.{bid}.mlp.mlp_linear", # talkie
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_EXP: (
|
||||
@@ -561,6 +567,7 @@ class TensorNameMap:
|
||||
"model.transformer.blocks.{bid}.ff_proj", # llada
|
||||
"layers.{bid}.mlp.gate_proj", # qwen3-embedding
|
||||
"model.layers.{bid}.mlp.language_mlp.gate_proj", # cogvlm
|
||||
"model.blocks.{bid}.mlp.mlp_gate", # talkie
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||
@@ -636,6 +643,7 @@ class TensorNameMap:
|
||||
"layers.{bid}.mlp.down_proj", # qwen3-embedding
|
||||
"backbone.layers.{bid}.mixer.down_proj", # nemotron-h
|
||||
"model.layers.{bid}.mlp.language_mlp.down_proj", # cogvlm
|
||||
"model.blocks.{bid}.mlp.mlp_resid", # talkie
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||
@@ -682,6 +690,7 @@ class TensorNameMap:
|
||||
"model.layers.layers.{bid}.mixer.q_norm", # plamo3
|
||||
"layers.{bid}.self_attn.q_norm", # qwen3-embedding
|
||||
"model.layers.{bid}.attention.query_layernorm", # apertus
|
||||
"model.blocks.{bid}.attn.head_gain.head_g", # talkie
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_K_NORM: (
|
||||
@@ -716,6 +725,7 @@ class TensorNameMap:
|
||||
|
||||
MODEL_TENSOR.LAYER_OUT_SCALE: (
|
||||
"model.layers.{bid}.layer_scalar", # gemma4
|
||||
"model.blocks.{bid}.embed_skip.a_g", # talkie
|
||||
),
|
||||
|
||||
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (
|
||||
|
||||
@@ -133,6 +133,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_LLAMA_EMBED, "llama-embed" },
|
||||
{ LLM_ARCH_MAINCODER, "maincoder" },
|
||||
{ LLM_ARCH_KIMI_LINEAR, "kimi-linear" },
|
||||
{ LLM_ARCH_TALKIE, "talkie" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
|
||||
@@ -137,6 +137,7 @@ enum llm_arch {
|
||||
LLM_ARCH_LLAMA_EMBED,
|
||||
LLM_ARCH_MAINCODER,
|
||||
LLM_ARCH_KIMI_LINEAR,
|
||||
LLM_ARCH_TALKIE,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
||||
@@ -44,6 +44,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params
|
||||
return new llama_model_llama_embed(params);
|
||||
case LLM_ARCH_MAINCODER:
|
||||
return new llama_model_maincoder(params);
|
||||
case LLM_ARCH_TALKIE:
|
||||
return new llama_model_talkie(params);
|
||||
case LLM_ARCH_DECI:
|
||||
return new llama_model_deci(params);
|
||||
case LLM_ARCH_BAICHUAN:
|
||||
@@ -2353,6 +2355,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_QWEN3NEXT:
|
||||
case LLM_ARCH_MIMO2:
|
||||
case LLM_ARCH_STEP35:
|
||||
case LLM_ARCH_TALKIE:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
|
||||
+1
-1
@@ -488,7 +488,7 @@ struct llama_layer {
|
||||
struct ggml_tensor * indexer_attn_k = nullptr;
|
||||
struct ggml_tensor * indexer_attn_q_b = nullptr; // note: for lora a/b, not bias
|
||||
|
||||
// gemma4 layer output scale
|
||||
// gemma4 layer output scale, reused for talkie embedding skip scale
|
||||
struct ggml_tensor * out_scale = nullptr;
|
||||
|
||||
struct llama_layer_posnet posnet;
|
||||
|
||||
+2
-1
@@ -2196,7 +2196,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
} else if (
|
||||
tokenizer_pre == "gpt-4o" ||
|
||||
tokenizer_pre == "llama4" ||
|
||||
tokenizer_pre == "kanana2") {
|
||||
tokenizer_pre == "kanana2" ||
|
||||
tokenizer_pre == "talkie") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
|
||||
@@ -186,6 +186,19 @@ struct llama_model_maincoder : public llama_model_base {
|
||||
};
|
||||
|
||||
|
||||
struct llama_model_talkie : public llama_model_base {
|
||||
llama_model_talkie(const struct llama_model_params & params) : llama_model_base(params) {}
|
||||
void load_arch_hparams(llama_model_loader & ml) override;
|
||||
void load_arch_tensors(llama_model_loader & ml) override;
|
||||
|
||||
struct graph : public llm_graph_context {
|
||||
graph(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
|
||||
};
|
||||
|
||||
|
||||
struct llama_model_deci : public llama_model_base {
|
||||
llama_model_deci(const struct llama_model_params & params) : llama_model_base(params) {}
|
||||
void load_arch_hparams(llama_model_loader & ml) override;
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
#include "models.h"
|
||||
|
||||
void llama_model_talkie::load_arch_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 40: type = LLM_TYPE_13B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_model_talkie::load_arch_tensors(llama_model_loader &) {
|
||||
LLAMA_LOAD_LOCALS;
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||
|
||||
// no k gain
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {1, n_head}, 0);
|
||||
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
||||
|
||||
layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1}, 0);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<llm_graph_context> llama_model_talkie::build_arch_graph(const llm_graph_params & params) const {
|
||||
return std::make_unique<graph>(*this, params);
|
||||
}
|
||||
|
||||
llama_model_talkie::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
inpL = build_norm(inpL, nullptr, nullptr, LLM_NORM_RMS, -1);
|
||||
cb(inpL, "inp_norm", -1);
|
||||
|
||||
ggml_tensor * embd_skip = inpL;
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
ggml_tensor * inp_skip = embd_skip;
|
||||
|
||||
cur = build_norm(inpL, nullptr, nullptr, LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur,
|
||||
n_embd_head, n_head, n_head_kv, il);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
// reference applies qknorm after rope
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_norm", il);
|
||||
|
||||
Kcur = build_norm(Kcur, nullptr, nullptr, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "Kcur_norm", il);
|
||||
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, nullptr, model.layers[il].wo_s,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
inp_skip = ggml_get_rows(ctx0, inp_skip, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
cur = build_norm(ffn_inp, nullptr, nullptr, LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, nullptr, nullptr,
|
||||
model.layers[il].ffn_gate, nullptr, nullptr,
|
||||
model.layers[il].ffn_down, nullptr, model.layers[il].ffn_down_s,
|
||||
nullptr,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
ggml_tensor * skip = ggml_mul(ctx0, inp_skip, model.layers[il].out_scale);
|
||||
cb(skip, "embd_skip", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, skip);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur, nullptr, nullptr, LLM_NORM_RMS, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
res->t_embd = cur;
|
||||
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
cur = ggml_scale(ctx0, cur, hparams.f_logit_scale);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
Reference in New Issue
Block a user