mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-07-01 18:17:42 +02:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 02e409a5be | |||
| 6b82eb7883 | |||
| 86a3f0fad8 | |||
| 63908b631a | |||
| 42b12b5608 | |||
| 4e842d5120 | |||
| ca709e427b | |||
| 0cdce38a97 | |||
| e39502e74b |
+583
-30
@@ -1,6 +1,11 @@
|
||||
#include "console.h"
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <cctype>
|
||||
#include <cwctype>
|
||||
#include <cstdint>
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
@@ -35,9 +40,26 @@
|
||||
|
||||
namespace console {
|
||||
|
||||
#if defined (_WIN32)
|
||||
namespace {
|
||||
// Use private-use unicode values to represent special keys that are not reported
|
||||
// as characters (e.g. arrows on Windows). These values should never clash with
|
||||
// real input and let the rest of the code handle navigation uniformly.
|
||||
static constexpr char32_t KEY_ARROW_LEFT = 0xE000;
|
||||
static constexpr char32_t KEY_ARROW_RIGHT = 0xE001;
|
||||
static constexpr char32_t KEY_ARROW_UP = 0xE002;
|
||||
static constexpr char32_t KEY_ARROW_DOWN = 0xE003;
|
||||
static constexpr char32_t KEY_HOME = 0xE004;
|
||||
static constexpr char32_t KEY_END = 0xE005;
|
||||
static constexpr char32_t KEY_CTRL_ARROW_LEFT = 0xE006;
|
||||
static constexpr char32_t KEY_CTRL_ARROW_RIGHT = 0xE007;
|
||||
static constexpr char32_t KEY_DELETE = 0xE008;
|
||||
}
|
||||
|
||||
//
|
||||
// Console state
|
||||
//
|
||||
#endif
|
||||
|
||||
static bool advanced_display = false;
|
||||
static bool simple_io = true;
|
||||
@@ -176,7 +198,18 @@ namespace console {
|
||||
if (record.EventType == KEY_EVENT && record.Event.KeyEvent.bKeyDown) {
|
||||
wchar_t wc = record.Event.KeyEvent.uChar.UnicodeChar;
|
||||
if (wc == 0) {
|
||||
continue;
|
||||
const DWORD ctrl_mask = LEFT_CTRL_PRESSED | RIGHT_CTRL_PRESSED;
|
||||
const bool ctrl_pressed = (record.Event.KeyEvent.dwControlKeyState & ctrl_mask) != 0;
|
||||
switch (record.Event.KeyEvent.wVirtualKeyCode) {
|
||||
case VK_LEFT: return ctrl_pressed ? KEY_CTRL_ARROW_LEFT : KEY_ARROW_LEFT;
|
||||
case VK_RIGHT: return ctrl_pressed ? KEY_CTRL_ARROW_RIGHT : KEY_ARROW_RIGHT;
|
||||
case VK_UP: return KEY_ARROW_UP;
|
||||
case VK_DOWN: return KEY_ARROW_DOWN;
|
||||
case VK_HOME: return KEY_HOME;
|
||||
case VK_END: return KEY_END;
|
||||
case VK_DELETE: return KEY_DELETE;
|
||||
default: continue;
|
||||
}
|
||||
}
|
||||
|
||||
if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate
|
||||
@@ -315,6 +348,52 @@ namespace console {
|
||||
#endif
|
||||
}
|
||||
|
||||
static char32_t decode_utf8(const std::string & input, size_t pos, size_t & advance) {
|
||||
unsigned char c = static_cast<unsigned char>(input[pos]);
|
||||
if ((c & 0x80u) == 0u) {
|
||||
advance = 1;
|
||||
return c;
|
||||
}
|
||||
if ((c & 0xE0u) == 0xC0u && pos + 1 < input.size()) {
|
||||
unsigned char c1 = static_cast<unsigned char>(input[pos + 1]);
|
||||
if ((c1 & 0xC0u) != 0x80u) {
|
||||
advance = 1;
|
||||
return 0xFFFD;
|
||||
}
|
||||
advance = 2;
|
||||
return ((c & 0x1Fu) << 6) | (static_cast<unsigned char>(input[pos + 1]) & 0x3Fu);
|
||||
}
|
||||
if ((c & 0xF0u) == 0xE0u && pos + 2 < input.size()) {
|
||||
unsigned char c1 = static_cast<unsigned char>(input[pos + 1]);
|
||||
unsigned char c2 = static_cast<unsigned char>(input[pos + 2]);
|
||||
if ((c1 & 0xC0u) != 0x80u || (c2 & 0xC0u) != 0x80u) {
|
||||
advance = 1;
|
||||
return 0xFFFD;
|
||||
}
|
||||
advance = 3;
|
||||
return ((c & 0x0Fu) << 12) |
|
||||
((static_cast<unsigned char>(input[pos + 1]) & 0x3Fu) << 6) |
|
||||
(static_cast<unsigned char>(input[pos + 2]) & 0x3Fu);
|
||||
}
|
||||
if ((c & 0xF8u) == 0xF0u && pos + 3 < input.size()) {
|
||||
unsigned char c1 = static_cast<unsigned char>(input[pos + 1]);
|
||||
unsigned char c2 = static_cast<unsigned char>(input[pos + 2]);
|
||||
unsigned char c3 = static_cast<unsigned char>(input[pos + 3]);
|
||||
if ((c1 & 0xC0u) != 0x80u || (c2 & 0xC0u) != 0x80u || (c3 & 0xC0u) != 0x80u) {
|
||||
advance = 1;
|
||||
return 0xFFFD;
|
||||
}
|
||||
advance = 4;
|
||||
return ((c & 0x07u) << 18) |
|
||||
((static_cast<unsigned char>(input[pos + 1]) & 0x3Fu) << 12) |
|
||||
((static_cast<unsigned char>(input[pos + 2]) & 0x3Fu) << 6) |
|
||||
(static_cast<unsigned char>(input[pos + 3]) & 0x3Fu);
|
||||
}
|
||||
|
||||
advance = 1;
|
||||
return 0xFFFD; // replacement character for invalid input
|
||||
}
|
||||
|
||||
static void append_utf8(char32_t ch, std::string & out) {
|
||||
if (ch <= 0x7F) {
|
||||
out.push_back(static_cast<unsigned char>(ch));
|
||||
@@ -336,22 +415,319 @@ namespace console {
|
||||
}
|
||||
|
||||
// Helper function to remove the last UTF-8 character from a string
|
||||
static void pop_back_utf8_char(std::string & line) {
|
||||
if (line.empty()) {
|
||||
static size_t prev_utf8_char_pos(const std::string & line, size_t pos) {
|
||||
if (pos == 0) return 0;
|
||||
pos--;
|
||||
while (pos > 0 && (line[pos] & 0xC0) == 0x80) {
|
||||
pos--;
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
static size_t next_utf8_char_pos(const std::string & line, size_t pos) {
|
||||
if (pos >= line.length()) return line.length();
|
||||
pos++;
|
||||
while (pos < line.length() && (line[pos] & 0xC0) == 0x80) {
|
||||
pos++;
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
static void move_cursor(int delta);
|
||||
static void move_word_left(size_t & char_pos, size_t & byte_pos, const std::vector<int> & widths, const std::string & line);
|
||||
static void move_word_right(size_t & char_pos, size_t & byte_pos, const std::vector<int> & widths, const std::string & line);
|
||||
static void move_to_line_start(size_t & char_pos, size_t & byte_pos, const std::vector<int> & widths);
|
||||
static void move_to_line_end(size_t & char_pos, size_t & byte_pos, const std::vector<int> & widths, const std::string & line);
|
||||
|
||||
static void delete_at_cursor(std::string & line, std::vector<int> & widths, size_t & char_pos, size_t & byte_pos) {
|
||||
if (char_pos >= widths.size()) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t pos = line.length() - 1;
|
||||
size_t next_pos = next_utf8_char_pos(line, byte_pos);
|
||||
int w = widths[char_pos];
|
||||
size_t char_len = next_pos - byte_pos;
|
||||
|
||||
// Find the start of the last UTF-8 character (checking up to 4 bytes back)
|
||||
for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) {
|
||||
if ((line[pos] & 0xC0) != 0x80) {
|
||||
break; // Found the start of the character
|
||||
}
|
||||
line.erase(byte_pos, char_len);
|
||||
widths.erase(widths.begin() + char_pos);
|
||||
|
||||
size_t p = byte_pos;
|
||||
int tail_width = 0;
|
||||
for (size_t i = char_pos; i < widths.size(); ++i) {
|
||||
size_t following = next_utf8_char_pos(line, p);
|
||||
put_codepoint(line.c_str() + p, following - p, widths[i]);
|
||||
tail_width += widths[i];
|
||||
p = following;
|
||||
}
|
||||
line.erase(pos);
|
||||
|
||||
for (int i = 0; i < w; ++i) {
|
||||
fputc(' ', out);
|
||||
}
|
||||
|
||||
move_cursor(-(tail_width + w));
|
||||
}
|
||||
|
||||
static void clear_current_line(const std::vector<int> & widths) {
|
||||
int total_width = 0;
|
||||
for (int w : widths) {
|
||||
total_width += (w > 0 ? w : 1);
|
||||
}
|
||||
|
||||
if (total_width > 0) {
|
||||
std::string spaces(total_width, ' ');
|
||||
fwrite(spaces.c_str(), 1, total_width, out);
|
||||
move_cursor(-total_width);
|
||||
}
|
||||
}
|
||||
|
||||
static void set_line_contents(std::string new_line, std::string & line, std::vector<int> & widths, size_t & char_pos,
|
||||
size_t & byte_pos) {
|
||||
move_to_line_start(char_pos, byte_pos, widths);
|
||||
clear_current_line(widths);
|
||||
|
||||
line = std::move(new_line);
|
||||
widths.clear();
|
||||
byte_pos = 0;
|
||||
char_pos = 0;
|
||||
|
||||
size_t idx = 0;
|
||||
while (idx < line.size()) {
|
||||
size_t advance = 0;
|
||||
char32_t cp = decode_utf8(line, idx, advance);
|
||||
int expected_width = estimateWidth(cp);
|
||||
int real_width = put_codepoint(line.c_str() + idx, advance, expected_width);
|
||||
if (real_width < 0) real_width = 0;
|
||||
widths.push_back(real_width);
|
||||
idx += advance;
|
||||
++char_pos;
|
||||
byte_pos = idx;
|
||||
}
|
||||
}
|
||||
|
||||
static void move_to_line_start(size_t & char_pos, size_t & byte_pos, const std::vector<int> & widths) {
|
||||
int back_width = 0;
|
||||
for (size_t i = 0; i < char_pos; ++i) {
|
||||
back_width += widths[i];
|
||||
}
|
||||
move_cursor(-back_width);
|
||||
char_pos = 0;
|
||||
byte_pos = 0;
|
||||
}
|
||||
|
||||
static void move_to_line_end(size_t & char_pos, size_t & byte_pos, const std::vector<int> & widths, const std::string & line) {
|
||||
int forward_width = 0;
|
||||
for (size_t i = char_pos; i < widths.size(); ++i) {
|
||||
forward_width += widths[i];
|
||||
}
|
||||
move_cursor(forward_width);
|
||||
char_pos = widths.size();
|
||||
byte_pos = line.length();
|
||||
}
|
||||
|
||||
static bool has_ctrl_modifier(const std::string & params) {
|
||||
size_t start = 0;
|
||||
while (start < params.size()) {
|
||||
size_t end = params.find(';', start);
|
||||
size_t len = (end == std::string::npos) ? params.size() - start : end - start;
|
||||
if (len > 0) {
|
||||
int value = 0;
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
char ch = params[start + i];
|
||||
if (!std::isdigit(static_cast<unsigned char>(ch))) {
|
||||
value = -1;
|
||||
break;
|
||||
}
|
||||
value = value * 10 + (ch - '0');
|
||||
}
|
||||
if (value == 5) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (end == std::string::npos) {
|
||||
break;
|
||||
}
|
||||
start = end + 1;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool is_space_codepoint(char32_t cp) {
|
||||
return std::iswspace(static_cast<wint_t>(cp)) != 0;
|
||||
}
|
||||
|
||||
static void move_word_left(size_t & char_pos, size_t & byte_pos, const std::vector<int> & widths, const std::string & line) {
|
||||
if (char_pos == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t new_char_pos = char_pos;
|
||||
size_t new_byte_pos = byte_pos;
|
||||
int move_width = 0;
|
||||
|
||||
while (new_char_pos > 0) {
|
||||
size_t prev_byte = prev_utf8_char_pos(line, new_byte_pos);
|
||||
size_t advance = 0;
|
||||
char32_t cp = decode_utf8(line, prev_byte, advance);
|
||||
if (!is_space_codepoint(cp)) {
|
||||
break;
|
||||
}
|
||||
move_width += widths[new_char_pos - 1];
|
||||
new_char_pos--;
|
||||
new_byte_pos = prev_byte;
|
||||
}
|
||||
|
||||
while (new_char_pos > 0) {
|
||||
size_t prev_byte = prev_utf8_char_pos(line, new_byte_pos);
|
||||
size_t advance = 0;
|
||||
char32_t cp = decode_utf8(line, prev_byte, advance);
|
||||
if (is_space_codepoint(cp)) {
|
||||
break;
|
||||
}
|
||||
move_width += widths[new_char_pos - 1];
|
||||
new_char_pos--;
|
||||
new_byte_pos = prev_byte;
|
||||
}
|
||||
|
||||
move_cursor(-move_width);
|
||||
char_pos = new_char_pos;
|
||||
byte_pos = new_byte_pos;
|
||||
}
|
||||
|
||||
static void move_word_right(size_t & char_pos, size_t & byte_pos, const std::vector<int> & widths, const std::string & line) {
|
||||
if (char_pos >= widths.size()) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t new_char_pos = char_pos;
|
||||
size_t new_byte_pos = byte_pos;
|
||||
int move_width = 0;
|
||||
|
||||
while (new_char_pos < widths.size()) {
|
||||
size_t advance = 0;
|
||||
char32_t cp = decode_utf8(line, new_byte_pos, advance);
|
||||
if (!is_space_codepoint(cp)) {
|
||||
break;
|
||||
}
|
||||
move_width += widths[new_char_pos];
|
||||
new_char_pos++;
|
||||
new_byte_pos += advance;
|
||||
}
|
||||
|
||||
while (new_char_pos < widths.size()) {
|
||||
size_t advance = 0;
|
||||
char32_t cp = decode_utf8(line, new_byte_pos, advance);
|
||||
if (is_space_codepoint(cp)) {
|
||||
break;
|
||||
}
|
||||
move_width += widths[new_char_pos];
|
||||
new_char_pos++;
|
||||
new_byte_pos += advance;
|
||||
}
|
||||
|
||||
while (new_char_pos < widths.size()) {
|
||||
size_t advance = 0;
|
||||
char32_t cp = decode_utf8(line, new_byte_pos, advance);
|
||||
if (!is_space_codepoint(cp)) {
|
||||
break;
|
||||
}
|
||||
move_width += widths[new_char_pos];
|
||||
new_char_pos++;
|
||||
new_byte_pos += advance;
|
||||
}
|
||||
|
||||
move_cursor(move_width);
|
||||
char_pos = new_char_pos;
|
||||
byte_pos = new_byte_pos;
|
||||
}
|
||||
|
||||
static void move_cursor(int delta) {
|
||||
if (delta == 0) return;
|
||||
#if defined(_WIN32)
|
||||
if (hConsole != NULL) {
|
||||
CONSOLE_SCREEN_BUFFER_INFO bufferInfo;
|
||||
GetConsoleScreenBufferInfo(hConsole, &bufferInfo);
|
||||
COORD newCursorPosition = bufferInfo.dwCursorPosition;
|
||||
int width = bufferInfo.dwSize.X;
|
||||
int newX = newCursorPosition.X + delta;
|
||||
int newY = newCursorPosition.Y;
|
||||
|
||||
while (newX >= width) {
|
||||
newX -= width;
|
||||
newY++;
|
||||
}
|
||||
while (newX < 0) {
|
||||
newX += width;
|
||||
newY--;
|
||||
}
|
||||
|
||||
newCursorPosition.X = newX;
|
||||
newCursorPosition.Y = newY;
|
||||
SetConsoleCursorPosition(hConsole, newCursorPosition);
|
||||
}
|
||||
#else
|
||||
if (delta < 0) {
|
||||
for (int i = 0; i < -delta; i++) fprintf(out, "\b");
|
||||
} else {
|
||||
for (int i = 0; i < delta; i++) fprintf(out, "\033[C");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
struct history_t {
|
||||
std::vector<std::string> entries;
|
||||
size_t viewing_idx = SIZE_MAX;
|
||||
std::string backup_line; // current line before viewing history
|
||||
void add(const std::string & line) {
|
||||
if (line.empty()) {
|
||||
return;
|
||||
}
|
||||
// avoid duplicates with the last entry
|
||||
if (entries.empty() || entries.back() != line) {
|
||||
entries.push_back(line);
|
||||
}
|
||||
// also clear viewing state
|
||||
end_viewing();
|
||||
}
|
||||
bool prev(std::string & cur_line) {
|
||||
if (entries.empty()) {
|
||||
return false;
|
||||
}
|
||||
if (viewing_idx == SIZE_MAX) {
|
||||
return false;
|
||||
}
|
||||
if (viewing_idx > 0) {
|
||||
viewing_idx--;
|
||||
}
|
||||
cur_line = entries[viewing_idx];
|
||||
return true;
|
||||
}
|
||||
bool next(std::string & cur_line) {
|
||||
if (entries.empty() || viewing_idx == SIZE_MAX) {
|
||||
return false;
|
||||
}
|
||||
viewing_idx++;
|
||||
if (viewing_idx >= entries.size()) {
|
||||
cur_line = backup_line;
|
||||
end_viewing();
|
||||
} else {
|
||||
cur_line = entries[viewing_idx];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
void begin_viewing(const std::string & line) {
|
||||
backup_line = line;
|
||||
viewing_idx = entries.size();
|
||||
}
|
||||
void end_viewing() {
|
||||
viewing_idx = SIZE_MAX;
|
||||
backup_line.clear();
|
||||
}
|
||||
bool is_viewing() const {
|
||||
return viewing_idx != SIZE_MAX;
|
||||
}
|
||||
} history;
|
||||
|
||||
static bool readline_advanced(std::string & line, bool multiline_input) {
|
||||
if (out != stdout) {
|
||||
fflush(stdout);
|
||||
@@ -362,8 +738,33 @@ namespace console {
|
||||
bool is_special_char = false;
|
||||
bool end_of_stream = false;
|
||||
|
||||
size_t byte_pos = 0; // current byte index
|
||||
size_t char_pos = 0; // current character index (one char can be multiple bytes)
|
||||
|
||||
char32_t input_char;
|
||||
while (true) {
|
||||
assert(char_pos <= byte_pos);
|
||||
assert(char_pos <= widths.size());
|
||||
auto history_prev = [&]() {
|
||||
if (!history.is_viewing()) {
|
||||
history.begin_viewing(line);
|
||||
}
|
||||
std::string new_line;
|
||||
if (!history.prev(new_line)) {
|
||||
return;
|
||||
}
|
||||
set_line_contents(new_line, line, widths, char_pos, byte_pos);
|
||||
};
|
||||
auto history_next = [&]() {
|
||||
if (history.is_viewing()) {
|
||||
std::string new_line;
|
||||
if (!history.next(new_line)) {
|
||||
return;
|
||||
}
|
||||
set_line_contents(new_line, line, widths, char_pos, byte_pos);
|
||||
}
|
||||
};
|
||||
|
||||
fflush(out); // Ensure all output is displayed before waiting for input
|
||||
input_char = getchar32();
|
||||
|
||||
@@ -371,7 +772,7 @@ namespace console {
|
||||
break;
|
||||
}
|
||||
|
||||
if (input_char == (char32_t) WEOF || input_char == 0x04 /* Ctrl+D*/) {
|
||||
if (input_char == (char32_t) WEOF || input_char == 0x04 /* Ctrl+D */) {
|
||||
end_of_stream = true;
|
||||
break;
|
||||
}
|
||||
@@ -384,7 +785,71 @@ namespace console {
|
||||
|
||||
if (input_char == '\033') { // Escape sequence
|
||||
char32_t code = getchar32();
|
||||
if (code == '[' || code == 0x1B) {
|
||||
if (code == '[') {
|
||||
std::string params;
|
||||
while (true) {
|
||||
code = getchar32();
|
||||
if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~' || code == (char32_t) WEOF) {
|
||||
break;
|
||||
}
|
||||
params.push_back(static_cast<char>(code));
|
||||
}
|
||||
|
||||
const bool ctrl_modifier = has_ctrl_modifier(params);
|
||||
|
||||
if (code == 'D') { // left
|
||||
if (ctrl_modifier) {
|
||||
move_word_left(char_pos, byte_pos, widths, line);
|
||||
} else if (char_pos > 0) {
|
||||
int w = widths[char_pos - 1];
|
||||
move_cursor(-w);
|
||||
char_pos--;
|
||||
byte_pos = prev_utf8_char_pos(line, byte_pos);
|
||||
}
|
||||
} else if (code == 'C') { // right
|
||||
if (ctrl_modifier) {
|
||||
move_word_right(char_pos, byte_pos, widths, line);
|
||||
} else if (char_pos < widths.size()) {
|
||||
int w = widths[char_pos];
|
||||
move_cursor(w);
|
||||
char_pos++;
|
||||
byte_pos = next_utf8_char_pos(line, byte_pos);
|
||||
}
|
||||
} else if (code == 'H') { // home
|
||||
move_to_line_start(char_pos, byte_pos, widths);
|
||||
} else if (code == 'F') { // end
|
||||
move_to_line_end(char_pos, byte_pos, widths, line);
|
||||
} else if (code == 'A' || code == 'B') {
|
||||
// up/down
|
||||
if (code == 'A') {
|
||||
history_prev();
|
||||
is_special_char = false;
|
||||
} else if (code == 'B') {
|
||||
history_next();
|
||||
is_special_char = false;
|
||||
}
|
||||
} else if ((code == '~' || (code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z')) && !params.empty()) {
|
||||
std::string digits;
|
||||
for (char ch : params) {
|
||||
if (ch == ';') {
|
||||
break;
|
||||
}
|
||||
if (std::isdigit(static_cast<unsigned char>(ch))) {
|
||||
digits.push_back(ch);
|
||||
}
|
||||
}
|
||||
|
||||
if (code == '~') {
|
||||
if (digits == "1" || digits == "7") { // home
|
||||
move_to_line_start(char_pos, byte_pos, widths);
|
||||
} else if (digits == "4" || digits == "8") { // end
|
||||
move_to_line_end(char_pos, byte_pos, widths, line);
|
||||
} else if (digits == "3") { // delete
|
||||
delete_at_cursor(line, widths, char_pos, byte_pos);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (code == 0x1B) {
|
||||
// Discard the rest of the escape sequence
|
||||
while ((code = getchar32()) != (char32_t) WEOF) {
|
||||
if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') {
|
||||
@@ -392,28 +857,107 @@ namespace console {
|
||||
}
|
||||
}
|
||||
}
|
||||
#if defined(_WIN32)
|
||||
} else if (input_char == KEY_ARROW_LEFT) {
|
||||
if (char_pos > 0) {
|
||||
int w = widths[char_pos - 1];
|
||||
move_cursor(-w);
|
||||
char_pos--;
|
||||
byte_pos = prev_utf8_char_pos(line, byte_pos);
|
||||
}
|
||||
} else if (input_char == KEY_ARROW_RIGHT) {
|
||||
if (char_pos < widths.size()) {
|
||||
int w = widths[char_pos];
|
||||
move_cursor(w);
|
||||
char_pos++;
|
||||
byte_pos = next_utf8_char_pos(line, byte_pos);
|
||||
}
|
||||
} else if (input_char == KEY_CTRL_ARROW_LEFT) {
|
||||
move_word_left(char_pos, byte_pos, widths, line);
|
||||
} else if (input_char == KEY_CTRL_ARROW_RIGHT) {
|
||||
move_word_right(char_pos, byte_pos, widths, line);
|
||||
} else if (input_char == KEY_HOME) {
|
||||
move_to_line_start(char_pos, byte_pos, widths);
|
||||
} else if (input_char == KEY_END) {
|
||||
move_to_line_end(char_pos, byte_pos, widths, line);
|
||||
} else if (input_char == KEY_DELETE) {
|
||||
delete_at_cursor(line, widths, char_pos, byte_pos);
|
||||
} else if (input_char == KEY_ARROW_UP || input_char == KEY_ARROW_DOWN) {
|
||||
if (input_char == KEY_ARROW_UP) {
|
||||
history_prev();
|
||||
is_special_char = false;
|
||||
} else if (input_char == KEY_ARROW_DOWN) {
|
||||
history_next();
|
||||
is_special_char = false;
|
||||
}
|
||||
#endif
|
||||
} else if (input_char == 0x08 || input_char == 0x7F) { // Backspace
|
||||
if (!widths.empty()) {
|
||||
int count;
|
||||
do {
|
||||
count = widths.back();
|
||||
widths.pop_back();
|
||||
// Move cursor back, print space, and move cursor back again
|
||||
for (int i = 0; i < count; i++) {
|
||||
replace_last(' ');
|
||||
pop_cursor();
|
||||
}
|
||||
pop_back_utf8_char(line);
|
||||
} while (count == 0 && !widths.empty());
|
||||
if (char_pos > 0) {
|
||||
int w = widths[char_pos - 1];
|
||||
move_cursor(-w);
|
||||
char_pos--;
|
||||
size_t prev_pos = prev_utf8_char_pos(line, byte_pos);
|
||||
size_t char_len = byte_pos - prev_pos;
|
||||
byte_pos = prev_pos;
|
||||
|
||||
// remove the character
|
||||
line.erase(byte_pos, char_len);
|
||||
widths.erase(widths.begin() + char_pos);
|
||||
|
||||
// redraw tail
|
||||
size_t p = byte_pos;
|
||||
int tail_width = 0;
|
||||
for (size_t i = char_pos; i < widths.size(); ++i) {
|
||||
size_t next_p = next_utf8_char_pos(line, p);
|
||||
put_codepoint(line.c_str() + p, next_p - p, widths[i]);
|
||||
tail_width += widths[i];
|
||||
p = next_p;
|
||||
}
|
||||
|
||||
// clear display
|
||||
for (int i = 0; i < w; ++i) {
|
||||
fputc(' ', out);
|
||||
}
|
||||
move_cursor(-(tail_width + w));
|
||||
}
|
||||
} else {
|
||||
int offset = line.length();
|
||||
append_utf8(input_char, line);
|
||||
int width = put_codepoint(line.c_str() + offset, line.length() - offset, estimateWidth(input_char));
|
||||
if (width < 0) {
|
||||
width = 0;
|
||||
// insert character
|
||||
std::string new_char_str;
|
||||
append_utf8(input_char, new_char_str);
|
||||
int w = estimateWidth(input_char);
|
||||
|
||||
if (char_pos == widths.size()) {
|
||||
// insert at the end
|
||||
line += new_char_str;
|
||||
int real_w = put_codepoint(new_char_str.c_str(), new_char_str.length(), w);
|
||||
if (real_w < 0) real_w = 0;
|
||||
widths.push_back(real_w);
|
||||
byte_pos += new_char_str.length();
|
||||
char_pos++;
|
||||
} else {
|
||||
// insert in middle
|
||||
line.insert(byte_pos, new_char_str);
|
||||
|
||||
int real_w = put_codepoint(new_char_str.c_str(), new_char_str.length(), w);
|
||||
if (real_w < 0) real_w = 0;
|
||||
|
||||
widths.insert(widths.begin() + char_pos, real_w);
|
||||
|
||||
// print the tail
|
||||
size_t p = byte_pos + new_char_str.length();
|
||||
int tail_width = 0;
|
||||
for (size_t i = char_pos + 1; i < widths.size(); ++i) {
|
||||
size_t next_p = next_utf8_char_pos(line, p);
|
||||
put_codepoint(line.c_str() + p, next_p - p, widths[i]);
|
||||
tail_width += widths[i];
|
||||
p = next_p;
|
||||
}
|
||||
|
||||
move_cursor(-tail_width);
|
||||
|
||||
byte_pos += new_char_str.length();
|
||||
char_pos++;
|
||||
}
|
||||
widths.push_back(width);
|
||||
}
|
||||
|
||||
if (!line.empty() && (line.back() == '\\' || line.back() == '/')) {
|
||||
@@ -451,6 +995,15 @@ namespace console {
|
||||
}
|
||||
}
|
||||
|
||||
if (!end_of_stream && !line.empty()) {
|
||||
// remove the trailing newline for history storage
|
||||
if (!line.empty() && line.back() == '\n') {
|
||||
line.pop_back();
|
||||
}
|
||||
// TODO: maybe support multiline history entries?
|
||||
history.add(line);
|
||||
}
|
||||
|
||||
fflush(out);
|
||||
return has_more;
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ static bool ggml_is_view(const struct ggml_tensor * t) {
|
||||
// ops that return true for this function must not use restrict pointers for their backend implementations
|
||||
bool ggml_op_can_inplace(enum ggml_op op) {
|
||||
switch (op) {
|
||||
case GGML_OP_FILL:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_DIAG_MASK_ZERO:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
|
||||
@@ -2251,12 +2251,12 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
|
||||
int sections[4],
|
||||
bool mrope_used,
|
||||
bool is_imrope,
|
||||
bool indep_sects) {
|
||||
ggml_tensor * src0 = dst->src[0]; // input
|
||||
bool indep_sects,
|
||||
int64_t rope_dims) {
|
||||
ggml_tensor * src1 = dst->src[1]; // position
|
||||
ggml_tensor * src2 = dst->src[2]; // freq_factors
|
||||
|
||||
int64_t theta_scale_length = src0->ne[0] / 2;
|
||||
int64_t theta_scale_length = rope_dims / 2;
|
||||
int64_t position_length = dst->ne[2];
|
||||
|
||||
// TODO: check theta_scale_length and position_length.
|
||||
@@ -2331,18 +2331,17 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
|
||||
ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float),
|
||||
ctx.rope_cache.theta_scale_exp_host, theta_scale_length * sizeof(float),
|
||||
ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream()));
|
||||
|
||||
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
|
||||
theta_scale_ne, theta_scale_nb, 1);
|
||||
}
|
||||
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
|
||||
theta_scale_ne, theta_scale_nb, 1);
|
||||
|
||||
// Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
|
||||
// TODO: acl_yarn_ramp_tensor use rope cache.
|
||||
bool yarn_ramp_tensor_updated = false;
|
||||
ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
|
||||
acl_tensor_ptr acl_yarn_ramp_tensor;
|
||||
if (ext_factor != 0 &&
|
||||
// TODO: check more parameter.
|
||||
(ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.freq_scale != freq_scale)) {
|
||||
if (ext_factor != 0 && (theta_scale_updated || ctx.rope_cache.theta_scale_length != theta_scale_length ||
|
||||
ctx.rope_cache.freq_scale != freq_scale)) {
|
||||
yarn_ramp_tensor_updated = true;
|
||||
|
||||
// -rope_yarn_ramp
|
||||
@@ -2590,7 +2589,7 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
|
||||
aclnn_muls(ctx, acl_cos_tensor.get(), attn_factor, nullptr, true);
|
||||
}
|
||||
|
||||
int64_t sin_reshape_ne[4] = { src0->ne[0], 1, dst->ne[2], 1 };
|
||||
int64_t sin_reshape_ne[4] = { rope_dims, 1, dst->ne[2], 1 };
|
||||
size_t sin_reshape_nb[GGML_MAX_DIMS];
|
||||
sin_reshape_nb[0] = sizeof(float);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
@@ -2645,7 +2644,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
// param
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
int sections[4];
|
||||
int sections[4];
|
||||
// const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
@@ -2654,44 +2653,60 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int) * 4);
|
||||
|
||||
// TODO: n_dims <= ne0
|
||||
GGML_ASSERT(n_dims == ne0);
|
||||
GGML_ASSERT(n_dims % 2 == 0);
|
||||
GGML_ASSERT(n_dims <= ne00);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||
|
||||
float corr_dims[2];
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
|
||||
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
|
||||
// mrope_used means the GGML_ROPE_TYPE_MROPE bit is set.
|
||||
// Note: this bit is also set for imrope and some vision modes,
|
||||
// so mrope_used does NOT exclusively indicate pure mrope.
|
||||
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (mrope_used) {
|
||||
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
||||
}
|
||||
|
||||
if (is_vision) {
|
||||
GGML_ASSERT(n_dims == ne0/2);
|
||||
GGML_ASSERT(n_dims == ne0 / 2);
|
||||
}
|
||||
|
||||
if (is_imrope || mrope_used) {
|
||||
is_neox = true;
|
||||
}
|
||||
|
||||
// init ctx.rope_cos/rope_sin cache
|
||||
aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, mrope_used, is_imrope, is_vision);
|
||||
int64_t rope_dims = n_dims;
|
||||
|
||||
int64_t sin_reshape_ne[4] = { ne00, 1, ne02, 1 };
|
||||
//Our current RotaryPositionEmbedding does not support the VISION mode,
|
||||
//but essentially it only modifies theta_base in mrope,
|
||||
//then repeats it at the end in the same way as is_neox.
|
||||
//In fact, RoPE is still applied across all dimensions.
|
||||
if (is_vision) {
|
||||
rope_dims = src0->ne[0];
|
||||
}
|
||||
int64_t tail_dims = ne00 - rope_dims;
|
||||
bool has_tail = tail_dims > 0;
|
||||
|
||||
// init ctx.rope_cos/rope_sin cache
|
||||
aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections,
|
||||
mrope_used, is_imrope, is_vision, rope_dims);
|
||||
|
||||
// Cache is generated with ne00 dimensions, so we use ne00 for reshape
|
||||
int64_t sin_reshape_ne[4] = { rope_dims, 1, ne02, 1 };
|
||||
size_t sin_reshape_nb[GGML_MAX_DIMS];
|
||||
sin_reshape_nb[0] = sizeof(float);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
@@ -2704,7 +2719,6 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0);
|
||||
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
|
||||
|
||||
#ifdef ASCEND_310P
|
||||
// Special ROPE operation for 310P
|
||||
|
||||
@@ -2844,46 +2858,124 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
}
|
||||
return;
|
||||
#endif
|
||||
|
||||
int64_t acl_mode = is_neox ? 0 : 1;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src.get(), acl_cos_reshape_tensor.get(),
|
||||
acl_sin_reshape_tensor.get(), acl_mode, acl_dst.get());
|
||||
break;
|
||||
}
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
ggml_cann_pool_alloc src_trans_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(float));
|
||||
void * src_trans_buffer = src_trans_allocator.get();
|
||||
ggml_cann_pool_alloc dst_trans_allocator(ctx.pool(), ggml_nelements(dst) * sizeof(float));
|
||||
void * dst_trans_buffer = dst_trans_allocator.get();
|
||||
// Pre-define head and tail dimensions for reuse
|
||||
int64_t head_ne[GGML_MAX_DIMS] = { rope_dims, ne01, ne02, ne03 };
|
||||
int64_t tail_ne[GGML_MAX_DIMS] = { tail_dims, ne01, ne02, ne03 };
|
||||
|
||||
size_t src_trans_nb[GGML_MAX_DIMS];
|
||||
src_trans_nb[0] = sizeof(float);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
|
||||
}
|
||||
// Step 1: Prepare trans tensors for F16 type conversion to F32 if needed
|
||||
bool src_dst_need_trans = false;
|
||||
ggml_cann_pool_alloc src_trans_allocator(ctx.pool());
|
||||
ggml_cann_pool_alloc dst_trans_allocator(ctx.pool());
|
||||
acl_tensor_ptr acl_src_trans_tensor;
|
||||
acl_tensor_ptr acl_dst_trans_tensor;
|
||||
void * src_trans_buffer = nullptr;
|
||||
void * dst_trans_buffer = nullptr;
|
||||
size_t src_dst_trans_nb[GGML_MAX_DIMS];
|
||||
if (src0->type == GGML_TYPE_F16) {
|
||||
src_dst_need_trans = true;
|
||||
src_trans_buffer = src_trans_allocator.alloc(ggml_nelements(src0) * sizeof(float));
|
||||
dst_trans_buffer = dst_trans_allocator.alloc(ggml_nelements(dst) * sizeof(float));
|
||||
|
||||
acl_tensor_ptr acl_src_trans_tensor = ggml_cann_create_tensor(
|
||||
src_trans_buffer, ACL_FLOAT, sizeof(float), src0->ne, src_trans_nb, GGML_MAX_DIMS);
|
||||
acl_tensor_ptr acl_dst_trans_tensor = ggml_cann_create_tensor(
|
||||
dst_trans_buffer, ACL_FLOAT, sizeof(float), dst->ne, src_trans_nb, GGML_MAX_DIMS);
|
||||
src_dst_trans_nb[0] = sizeof(float);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
src_dst_trans_nb[i] = src_dst_trans_nb[i - 1] * src0->ne[i - 1];
|
||||
}
|
||||
acl_src_trans_tensor = ggml_cann_create_tensor(src_trans_buffer, ACL_FLOAT, sizeof(float), src0->ne,
|
||||
src_dst_trans_nb, GGML_MAX_DIMS);
|
||||
acl_dst_trans_tensor = ggml_cann_create_tensor(dst_trans_buffer, ACL_FLOAT, sizeof(float), dst->ne,
|
||||
src_dst_trans_nb, GGML_MAX_DIMS);
|
||||
aclnn_cast(ctx, acl_src.get(), acl_src_trans_tensor.get(), ACL_FLOAT);
|
||||
}
|
||||
|
||||
aclnn_cast(ctx, acl_src.get(), acl_src_trans_tensor.get(), ACL_FLOAT);
|
||||
// Step 2: Prepare head tensors for tail splitting if needed
|
||||
acl_tensor_ptr acl_src_head;
|
||||
acl_tensor_ptr acl_dst_head;
|
||||
if (has_tail) {
|
||||
// Create head views for RotaryPositionEmbedding (only first rope_dims dimensions)
|
||||
// RotaryPositionEmbedding requires contiguous dst tensor, so we use a temporary buffer
|
||||
if (src_dst_need_trans) {
|
||||
// Use F32 trans tensor strides
|
||||
acl_src_head = ggml_cann_create_tensor((char *) src_trans_buffer, ACL_FLOAT, sizeof(float), head_ne,
|
||||
src_dst_trans_nb, GGML_MAX_DIMS);
|
||||
} else {
|
||||
// Use original F32 tensor strides
|
||||
acl_src_head = ggml_cann_create_tensor((char *) src0->data, ACL_FLOAT, sizeof(float), head_ne, src0->nb,
|
||||
GGML_MAX_DIMS);
|
||||
}
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get(),
|
||||
acl_cos_reshape_tensor.get(), acl_sin_reshape_tensor.get(), acl_mode,
|
||||
acl_dst_trans_tensor.get());
|
||||
int64_t head_elements = rope_dims * ne01 * ne02 * ne03;
|
||||
ggml_cann_pool_alloc dst_head_contiguous_allocator(ctx.pool(), head_elements * sizeof(float));
|
||||
void * dst_head_contiguous_buffer = dst_head_contiguous_allocator.get();
|
||||
|
||||
aclnn_cast(ctx, acl_dst_trans_tensor.get(), acl_dst.get(), ACL_FLOAT16);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("Unsupported tensor type for GGML_OP_ROPE");
|
||||
break;
|
||||
size_t head_contiguous_nb[GGML_MAX_DIMS];
|
||||
head_contiguous_nb[0] = sizeof(float);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
head_contiguous_nb[i] = head_contiguous_nb[i - 1] * head_ne[i - 1];
|
||||
}
|
||||
acl_dst_head = ggml_cann_create_tensor(dst_head_contiguous_buffer, ACL_FLOAT, sizeof(float), head_ne,
|
||||
head_contiguous_nb, GGML_MAX_DIMS);
|
||||
}
|
||||
|
||||
// Step 3: Execute RotaryPositionEmbedding
|
||||
if (has_tail) {
|
||||
// Rotate only the head portion (first rope_dims dimensions)
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_head.get(), acl_cos_reshape_tensor.get(),
|
||||
acl_sin_reshape_tensor.get(), acl_mode, acl_dst_head.get());
|
||||
|
||||
// Copy head result from contiguous buffer back to destination tensor
|
||||
if (src_dst_need_trans) {
|
||||
acl_tensor_ptr acl_dst_head_target = ggml_cann_create_tensor(
|
||||
(char *) dst_trans_buffer, ACL_FLOAT, sizeof(float), head_ne, src_dst_trans_nb, GGML_MAX_DIMS);
|
||||
cann_copy(ctx, acl_dst_head.get(), acl_dst_head_target.get());
|
||||
} else {
|
||||
acl_tensor_ptr acl_dst_head_target =
|
||||
ggml_cann_create_tensor((char *) dst->data, ACL_FLOAT, sizeof(float), head_ne, dst->nb, GGML_MAX_DIMS);
|
||||
cann_copy(ctx, acl_dst_head.get(), acl_dst_head_target.get());
|
||||
}
|
||||
} else if (src_dst_need_trans) {
|
||||
// Rotate full tensor (no tail), using trans tensors
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get(), acl_cos_reshape_tensor.get(),
|
||||
acl_sin_reshape_tensor.get(), acl_mode, acl_dst_trans_tensor.get());
|
||||
} else {
|
||||
// Rotate full tensor (no tail), using original tensors
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src.get(), acl_cos_reshape_tensor.get(),
|
||||
acl_sin_reshape_tensor.get(), acl_mode, acl_dst.get());
|
||||
}
|
||||
|
||||
// Step 4: Copy unrotated tail portion from source to destination
|
||||
if (has_tail) {
|
||||
size_t src_tail_offset;
|
||||
size_t dst_tail_offset;
|
||||
|
||||
auto copy_tail_device = [&](void * src_ptr, void * dst_ptr, aclDataType dtype, size_t elem_size,
|
||||
size_t * nb_src_arr, size_t * nb_dst_arr) {
|
||||
acl_tensor_ptr acl_src_tail =
|
||||
ggml_cann_create_tensor(src_ptr, dtype, elem_size, tail_ne, nb_src_arr, GGML_MAX_DIMS);
|
||||
acl_tensor_ptr acl_dst_tail =
|
||||
ggml_cann_create_tensor(dst_ptr, dtype, elem_size, tail_ne, nb_dst_arr, GGML_MAX_DIMS);
|
||||
cann_copy(ctx, acl_src_tail.get(), acl_dst_tail.get());
|
||||
};
|
||||
|
||||
if (src_dst_need_trans) {
|
||||
// Use F32 trans tensor strides and offsets
|
||||
src_tail_offset = rope_dims * src_dst_trans_nb[0];
|
||||
dst_tail_offset = rope_dims * src_dst_trans_nb[0];
|
||||
copy_tail_device((char *) src_trans_buffer + src_tail_offset, (char *) dst_trans_buffer + dst_tail_offset,
|
||||
ACL_FLOAT, sizeof(float), src_dst_trans_nb, src_dst_trans_nb);
|
||||
} else {
|
||||
// Use original tensor strides and offsets
|
||||
src_tail_offset = rope_dims * nb00;
|
||||
dst_tail_offset = rope_dims * nb0;
|
||||
copy_tail_device((char *) src0->data + src_tail_offset, (char *) dst->data + dst_tail_offset,
|
||||
ggml_cann_type_mapping(dst->type), ggml_element_size(dst), src0->nb, dst->nb);
|
||||
}
|
||||
}
|
||||
|
||||
// Step 5: Cast back to F16 if needed
|
||||
if (src_dst_need_trans) {
|
||||
aclnn_cast(ctx, acl_dst_trans_tensor.get(), acl_dst.get(), ACL_FLOAT16);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -315,7 +315,7 @@ struct ggml_cann_rope_cache {
|
||||
if (theta_scale_exp_host) {
|
||||
free(theta_scale_exp_host);
|
||||
}
|
||||
if(position_select_index_host) {
|
||||
if (position_select_index_host) {
|
||||
free(position_select_index_host);
|
||||
}
|
||||
}
|
||||
@@ -340,7 +340,7 @@ struct ggml_cann_rope_cache {
|
||||
|
||||
void set(int64_t theta_scale_length,
|
||||
int64_t position_length,
|
||||
float ext_factor,
|
||||
float ext_factor,
|
||||
float theta_scale,
|
||||
float freq_scale,
|
||||
float attn_factor,
|
||||
|
||||
@@ -2308,7 +2308,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
|
||||
|
||||
bool cann_graph_update_required = false;
|
||||
#ifdef USE_ACL_GRAPH
|
||||
bool use_cann_graph = true;
|
||||
bool use_cann_graph = true;
|
||||
|
||||
static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
|
||||
if (!prefill_use_graph) {
|
||||
@@ -2338,7 +2338,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
|
||||
}
|
||||
}
|
||||
#else
|
||||
bool use_cann_graph = false;
|
||||
bool use_cann_graph = false;
|
||||
#endif // USE_ACL_GRAPH
|
||||
evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, cann_graph_update_required);
|
||||
|
||||
@@ -2474,16 +2474,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
}
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
// TODO: with ops-test v == 1
|
||||
// TODO: n_dims <= ne0
|
||||
if (op->src[0]->ne[0] != op->op_params[1]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (op->src[0]->ne[0] > 896) {
|
||||
return false;
|
||||
}
|
||||
#ifdef ASCEND_310P
|
||||
// TODO: Support rope_dim < ne00(dim)
|
||||
if (op->src[0]->ne[0] != op->op_params[1]) {
|
||||
return false;
|
||||
}
|
||||
if (!ggml_is_contiguous(op->src[0])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -564,6 +564,12 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
|
||||
const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x;
|
||||
|
||||
#if defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
|
||||
// Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.
|
||||
// Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.
|
||||
KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0] *= 4.0f;
|
||||
#endif // defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
|
||||
|
||||
if (use_logit_softcap) {
|
||||
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
|
||||
}
|
||||
@@ -858,6 +864,11 @@ static __global__ void flash_attn_tile(
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
|
||||
tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
|
||||
#if defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
|
||||
// Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.
|
||||
// Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.
|
||||
tmp_h2[i1/2] *= make_half2(0.25f, 0.25f);
|
||||
#endif // defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
||||
&Q_tmp[jc*(DKQ/2) + i0/2 + (threadIdx.y % np)*(warp_size*cpy_ne_D/2) + threadIdx.x*(cpy_ne_D/2)],
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#define CUDA_FILL_BLOCK_SIZE 256
|
||||
|
||||
template <typename T>
|
||||
static __global__ void fill_kernel(T * __restrict__ dst, const int64_t k, const T value) {
|
||||
static __global__ void fill_kernel(T * dst, const int64_t k, const T value) {
|
||||
const int64_t i = (int64_t)blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (i >= k) {
|
||||
return;
|
||||
|
||||
@@ -221,7 +221,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
if (ctx->debug_graph > 0) {
|
||||
GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), is_concurrent ? "(concurrent)" : "");
|
||||
GGML_LOG_DEBUG("%s: node[%5d] - %-12s %-12s %s\n", __func__, idx, ggml_op_name(node->op), ggml_get_name(node), is_concurrent ? "(concurrent)" : "");
|
||||
}
|
||||
if (ctx->debug_graph > 1) {
|
||||
GGML_TENSOR_LOCALS( int64_t, ne0, node->src[0], ne);
|
||||
|
||||
@@ -124,6 +124,13 @@ static void ggml_print_backtrace_symbols(void) {
|
||||
int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0]));
|
||||
backtrace_symbols_fd(trace, nptrs, STDERR_FILENO);
|
||||
}
|
||||
#elif defined(__APPLE__)
|
||||
#include <execinfo.h>
|
||||
static void ggml_print_backtrace_symbols(void) {
|
||||
void * trace[100];
|
||||
int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0]));
|
||||
backtrace_symbols_fd(trace, nptrs, STDERR_FILENO);
|
||||
}
|
||||
#else
|
||||
static void ggml_print_backtrace_symbols(void) {
|
||||
// platform not supported
|
||||
@@ -135,6 +142,20 @@ void ggml_print_backtrace(void) {
|
||||
if (GGML_NO_BACKTRACE) {
|
||||
return;
|
||||
}
|
||||
#if defined(__APPLE__)
|
||||
// On macOS, fork+debugger attachment is problematic due to:
|
||||
// 1. libdispatch "poisons" forked child processes
|
||||
// 2. lldb has issues attaching to parent from forked child
|
||||
// Use simple backtrace() instead to avoid Terminal.app crashes
|
||||
const char * GGML_BACKTRACE_LLDB = getenv("GGML_BACKTRACE_LLDB");
|
||||
if (!GGML_BACKTRACE_LLDB) {
|
||||
fprintf(stderr, "WARNING: Using native backtrace. Set GGML_BACKTRACE_LLDB for more info.\n");
|
||||
fprintf(stderr, "WARNING: GGML_BACKTRACE_LLDB may cause native MacOS Terminal.app to crash.\n");
|
||||
fprintf(stderr, "See: https://github.com/ggml-org/llama.cpp/pull/17869\n");
|
||||
ggml_print_backtrace_symbols();
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
#if defined(__linux__)
|
||||
FILE * f = fopen("/proc/self/status", "r");
|
||||
size_t size = 0;
|
||||
|
||||
@@ -67,6 +67,30 @@ Parentheses `()` can be used to group sequences, which allows for embedding alte
|
||||
- `{m,n}` repeats the precedent symbol or sequence at between `m` and `n` times (included)
|
||||
- `{0,n}` repeats the precedent symbol or sequence at most `n` times (included)
|
||||
|
||||
## Tokens
|
||||
|
||||
Tokens allow grammars to match specific tokenizer tokens rather than character sequences. This is useful for constraining outputs based on special tokens (like `<think>` or `</think>`).
|
||||
|
||||
Tokens can be specified in two ways:
|
||||
|
||||
1. **Token ID**: Use angle brackets with the token ID in square brackets: `<[token-id]>`. For example, `<[1000]>` matches the token with ID 1000.
|
||||
|
||||
2. **Token string**: Use angle brackets with the token text directly: `<token>`. For example, `<think>` will match the token whose text is exactly `<think>`. This only works if the string tokenizes to exactly one token in the vocabulary, otherwise the grammar will fail to parse.
|
||||
|
||||
You can negate token matches using the `!` prefix: `!<[1000]>` or `!<think>` matches any token *except* the specified one.
|
||||
|
||||
```
|
||||
# Match a thinking block: <think>...</think>
|
||||
# Using token strings (requires these to be single tokens in the vocab)
|
||||
root ::= <think> thinking </think> .*
|
||||
thinking ::= !</think>*
|
||||
|
||||
# Equivalent grammar using explicit token IDs
|
||||
# Assumes token 1000 = <think>, token 1001 = </think>
|
||||
root ::= <[1000]> thinking <[1001]> .*
|
||||
thinking ::= !<[1001]>*
|
||||
```
|
||||
|
||||
## Comments and newlines
|
||||
|
||||
Comments can be specified with `#`:
|
||||
|
||||
@@ -139,6 +139,7 @@ add_library(llama
|
||||
set_target_properties(llama PROPERTIES
|
||||
VERSION ${LLAMA_INSTALL_VERSION}
|
||||
SOVERSION 0
|
||||
MACHO_CURRENT_VERSION 0 # keep macOS linker from seeing oversized version number
|
||||
)
|
||||
|
||||
target_include_directories(llama PRIVATE .)
|
||||
|
||||
+233
-33
@@ -181,6 +181,52 @@ static std::pair<uint32_t, const char *> parse_char(const char * src) {
|
||||
throw std::runtime_error("unexpected end of input");
|
||||
}
|
||||
|
||||
static std::pair<uint32_t, const char *> parse_token(const llama_vocab * vocab, const char * src) {
|
||||
const char * pos = src;
|
||||
if (*pos != '<') {
|
||||
throw std::runtime_error(std::string("expecting '<' at ") + pos);
|
||||
}
|
||||
pos++;
|
||||
|
||||
// Parse <[id]>
|
||||
if (*pos == '[') {
|
||||
pos++;
|
||||
const char * int_end = parse_int(pos);
|
||||
uint32_t token_id = std::stoul(std::string(pos, int_end - pos));
|
||||
pos = int_end;
|
||||
if (*pos != ']') {
|
||||
throw std::runtime_error(std::string("expecting ']' at ") + pos);
|
||||
}
|
||||
pos++;
|
||||
if (*pos != '>') {
|
||||
throw std::runtime_error(std::string("expecting '>' at ") + pos);
|
||||
}
|
||||
pos++;
|
||||
return std::make_pair(token_id, pos);
|
||||
}
|
||||
|
||||
if (vocab == nullptr) {
|
||||
throw std::runtime_error(std::string("no vocab to parse token at ") + src);
|
||||
}
|
||||
|
||||
// Parse <token> and tokenize to obtain the token id
|
||||
while (*pos != 0 && *pos != '>') {
|
||||
pos++;
|
||||
}
|
||||
if (*pos != '>') {
|
||||
throw std::runtime_error(std::string("expecting '>' at ") + pos);
|
||||
}
|
||||
pos++;
|
||||
|
||||
llama_token tokens[2];
|
||||
int32_t n_tokens = vocab->tokenize(src, static_cast<int32_t>(pos - src), tokens, 2, false, true);
|
||||
if (n_tokens != 1) {
|
||||
// must tokenize to exactly 1 token
|
||||
throw std::runtime_error("invalid token '" + std::string(src, pos - src) + "'");
|
||||
}
|
||||
return std::make_pair(tokens[0], pos);
|
||||
}
|
||||
|
||||
static void print_grammar_char(FILE * file, uint32_t c) {
|
||||
if (0x20 <= c && c <= 0x7f) {
|
||||
fprintf(file, "%c", static_cast<char>(c));
|
||||
@@ -212,6 +258,8 @@ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
|
||||
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
||||
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
|
||||
case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
|
||||
case LLAMA_GRETYPE_TOKEN: fprintf(file, "TOKEN"); break;
|
||||
case LLAMA_GRETYPE_TOKEN_NOT: fprintf(file, "TOKEN_NOT"); break;
|
||||
}
|
||||
switch (elem.type) {
|
||||
case LLAMA_GRETYPE_END:
|
||||
@@ -228,6 +276,17 @@ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
|
||||
print_grammar_char(file, elem.value);
|
||||
fprintf(file, "\") ");
|
||||
break;
|
||||
case LLAMA_GRETYPE_TOKEN:
|
||||
fprintf(file, "<[");
|
||||
fprintf(file, "%u", elem.value);
|
||||
fprintf(file, "]> ");
|
||||
break;
|
||||
case LLAMA_GRETYPE_TOKEN_NOT:
|
||||
fprintf(file, "!");
|
||||
fprintf(file, "<[");
|
||||
fprintf(file, "%u", elem.value);
|
||||
fprintf(file, "]> ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
fprintf(file, "\n");
|
||||
@@ -284,6 +343,17 @@ static void print_rule(
|
||||
case LLAMA_GRETYPE_CHAR_ANY:
|
||||
fprintf(file, ".");
|
||||
break;
|
||||
case LLAMA_GRETYPE_TOKEN:
|
||||
fprintf(file, "<[");
|
||||
fprintf(file, "%u", elem.value);
|
||||
fprintf(file, "]> ");
|
||||
break;
|
||||
case LLAMA_GRETYPE_TOKEN_NOT:
|
||||
fprintf(file, "!");
|
||||
fprintf(file, "<[");
|
||||
fprintf(file, "%u", elem.value);
|
||||
fprintf(file, "]> ");
|
||||
break;
|
||||
}
|
||||
if (is_char_element(elem)) {
|
||||
switch (rule[i + 1].type) {
|
||||
@@ -444,6 +514,17 @@ const char * llama_grammar_parser::parse_sequence(
|
||||
}
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (*pos == '<' || *pos == '!') { // token
|
||||
auto type = LLAMA_GRETYPE_TOKEN;
|
||||
if (*pos == '!') { // token inverse
|
||||
type = LLAMA_GRETYPE_TOKEN_NOT;
|
||||
pos++;
|
||||
}
|
||||
auto token_pair = parse_token(vocab, pos);
|
||||
const char * token_end = token_pair.second;
|
||||
last_sym_start = rule.size();
|
||||
rule.push_back({type, token_pair.first});
|
||||
pos = parse_space(token_end, is_nested);
|
||||
} else if (is_word_char(*pos)) { // rule reference
|
||||
const char * name_end = parse_name(pos);
|
||||
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
|
||||
@@ -691,6 +772,21 @@ static bool llama_grammar_match_partial_char(
|
||||
return !is_positive_char;
|
||||
}
|
||||
|
||||
// returns true iff token matches the rule at pos (regular or inverse)
|
||||
// asserts that pos is pointing to a token element
|
||||
static bool llama_grammar_match_token(
|
||||
const llama_grammar_element * pos,
|
||||
const llama_token token) {
|
||||
GGML_ASSERT(pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT);
|
||||
if (pos->type == LLAMA_GRETYPE_TOKEN) {
|
||||
return pos->value == static_cast<uint32_t>(token);
|
||||
}
|
||||
if (pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
|
||||
return pos->value != static_cast<uint32_t>(token);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// transforms a grammar pushdown stack into N possible stacks, all ending
|
||||
// at a character range (terminal element)
|
||||
static void llama_grammar_advance_stack(
|
||||
@@ -738,6 +834,8 @@ static void llama_grammar_advance_stack(
|
||||
case LLAMA_GRETYPE_CHAR:
|
||||
case LLAMA_GRETYPE_CHAR_NOT:
|
||||
case LLAMA_GRETYPE_CHAR_ANY:
|
||||
case LLAMA_GRETYPE_TOKEN:
|
||||
case LLAMA_GRETYPE_TOKEN_NOT:
|
||||
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
||||
// only add the stack if it's not a duplicate of one we already have
|
||||
new_stacks.emplace_back(stack);
|
||||
@@ -831,26 +929,38 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
|
||||
return grammar->stacks;
|
||||
}
|
||||
|
||||
static void llama_grammar_accept_chr(
|
||||
struct llama_grammar & grammar,
|
||||
const llama_grammar_stack & stack,
|
||||
uint32_t chr,
|
||||
llama_grammar_stacks & new_stacks) {
|
||||
if (stack.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const llama_grammar_element * pos = stack.back();
|
||||
|
||||
// ignore if this turns into a token
|
||||
if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto match = llama_grammar_match_char(pos, chr);
|
||||
if (match.first) {
|
||||
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
||||
if (!llama_grammar_is_end_of_sequence(match.second)) {
|
||||
new_stack.push_back(match.second);
|
||||
}
|
||||
llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
|
||||
llama_grammar_stacks stacks_new;
|
||||
stacks_new.reserve(grammar->stacks.size());
|
||||
|
||||
for (const auto & stack : grammar->stacks) {
|
||||
if (stack.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto match = llama_grammar_match_char(stack.back(), chr);
|
||||
if (match.first) {
|
||||
const llama_grammar_element * pos = match.second;
|
||||
|
||||
// update top of stack to next element, if any
|
||||
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
||||
if (!llama_grammar_is_end_of_sequence(pos)) {
|
||||
new_stack.push_back(pos);
|
||||
}
|
||||
llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new);
|
||||
}
|
||||
llama_grammar_accept_chr(*grammar, stack, chr, stacks_new);
|
||||
}
|
||||
|
||||
grammar->stacks = std::move(stacks_new);
|
||||
@@ -875,6 +985,22 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
||||
|
||||
const llama_grammar_element * stack_pos = stack.back();
|
||||
|
||||
// if the top of the stack is a token rule, then we only need to check the token id
|
||||
if (stack_pos->type == LLAMA_GRETYPE_TOKEN || stack_pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
|
||||
for (const auto & tok : candidates) {
|
||||
if (*tok.code_points == 0) {
|
||||
// reached the end of a token consumed by char rules, reject iff it ended
|
||||
// in a partial response
|
||||
if (tok.partial_utf8.n_remain != 0) {
|
||||
rejects.push_back(tok);
|
||||
}
|
||||
} else if (!llama_grammar_match_token(stack_pos, tok.id)) {
|
||||
rejects.push_back(tok);
|
||||
}
|
||||
}
|
||||
return rejects;
|
||||
}
|
||||
|
||||
llama_grammar_candidates next_candidates;
|
||||
next_candidates.reserve(candidates.size());
|
||||
|
||||
@@ -887,7 +1013,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
||||
rejects.push_back(tok);
|
||||
}
|
||||
} else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
|
||||
next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
|
||||
next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8, tok.id });
|
||||
} else {
|
||||
rejects.push_back(tok);
|
||||
}
|
||||
@@ -905,7 +1031,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
||||
|
||||
auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
|
||||
for (const auto & tok : next_rejects) {
|
||||
rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
|
||||
rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8, tok.id });
|
||||
}
|
||||
|
||||
return rejects;
|
||||
@@ -972,12 +1098,13 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||
vocab,
|
||||
std::move(vec_rules),
|
||||
std::move(stacks),
|
||||
/* .partial_utf8 = */ {},
|
||||
/* .lazy =*/ false,
|
||||
/* .awaiting_trigger = */ false,
|
||||
/* .trigger_buffer = */ "",
|
||||
/* .trigger_tokens = */ {},
|
||||
/* .trigger_patterns = */ {},
|
||||
/* .partial_utf8 = */ {},
|
||||
/* .lazy = */ false,
|
||||
/* .awaiting_trigger = */ false,
|
||||
/* .trigger_buffer = */ "",
|
||||
/* .trigger_buffer_positions = */ {},
|
||||
/* .trigger_tokens = */ {},
|
||||
/* .trigger_patterns = */ {},
|
||||
};
|
||||
}
|
||||
|
||||
@@ -990,7 +1117,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||
size_t num_trigger_patterns,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens) {
|
||||
llama_grammar_parser parser;
|
||||
llama_grammar_parser parser(vocab);
|
||||
|
||||
// if there is a grammar, parse it
|
||||
// rules will be empty (default) if there are parse errors
|
||||
@@ -1077,10 +1204,11 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||
vocab,
|
||||
std::move(vec_rules),
|
||||
std::move(stacks),
|
||||
/* .partial_utf8 = */ {},
|
||||
/* .lazy = */ lazy,
|
||||
/* .awaiting_trigger = */ lazy,
|
||||
/* .trigger_buffer = */ "",
|
||||
/* .partial_utf8 = */ {},
|
||||
/* .lazy = */ lazy,
|
||||
/* .awaiting_trigger = */ lazy,
|
||||
/* .trigger_buffer = */ "",
|
||||
/* .trigger_buffer_positions = */ {},
|
||||
std::move(vec_trigger_tokens),
|
||||
std::move(vec_trigger_patterns),
|
||||
};
|
||||
@@ -1103,6 +1231,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
|
||||
grammar.lazy,
|
||||
grammar.awaiting_trigger,
|
||||
grammar.trigger_buffer,
|
||||
grammar.trigger_buffer_positions,
|
||||
grammar.trigger_tokens,
|
||||
grammar.trigger_patterns,
|
||||
};
|
||||
@@ -1156,7 +1285,7 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
} else {
|
||||
candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
|
||||
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
|
||||
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second, id });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1175,10 +1304,12 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
|
||||
grammar.awaiting_trigger = false;
|
||||
grammar.trigger_buffer.clear();
|
||||
llama_grammar_accept_str(grammar, piece);
|
||||
llama_grammar_accept_token(grammar, token, piece);
|
||||
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
|
||||
return;
|
||||
} else {
|
||||
auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size());
|
||||
grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
|
||||
grammar.trigger_buffer += piece;
|
||||
|
||||
std::smatch match;
|
||||
@@ -1196,10 +1327,23 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||
if (start == std::string::npos) {
|
||||
start = match.position(0);
|
||||
}
|
||||
|
||||
// replay tokens that overlap with [start, end)
|
||||
for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
|
||||
auto [tok_start, tok_end] = tok_pos;
|
||||
if (tok_end <= start) {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t piece_start = (tok_start < start) ? start : tok_start; // allow for partial token pieces
|
||||
size_t piece_len = tok_end - piece_start;
|
||||
auto tok_piece = grammar.trigger_buffer.substr(piece_start, piece_len);
|
||||
llama_grammar_accept_token(grammar, tok, tok_piece);
|
||||
}
|
||||
|
||||
auto constrained_str = grammar.trigger_buffer.substr(start);
|
||||
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
|
||||
grammar.trigger_buffer.clear();
|
||||
llama_grammar_accept_str(grammar, constrained_str);
|
||||
grammar.trigger_buffer_positions.clear();
|
||||
LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
|
||||
return;
|
||||
}
|
||||
@@ -1218,7 +1362,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
llama_grammar_accept_str(grammar, piece);
|
||||
llama_grammar_accept_token(grammar, token, piece);
|
||||
}
|
||||
|
||||
void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) {
|
||||
@@ -1235,3 +1379,59 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
|
||||
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) {
|
||||
// Note terminating 0 in decoded string
|
||||
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
|
||||
const auto & code_points = decoded.first;
|
||||
|
||||
llama_grammar_stacks stacks_new;
|
||||
stacks_new.reserve(grammar.stacks.size());
|
||||
|
||||
for (const auto & stack : grammar.stacks) {
|
||||
if (stack.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const llama_grammar_element * pos = stack.back();
|
||||
|
||||
if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
|
||||
if (llama_grammar_match_token(pos, token)) {
|
||||
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
||||
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
|
||||
new_stack.push_back(pos + 1);
|
||||
}
|
||||
llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new);
|
||||
}
|
||||
} else {
|
||||
llama_grammar_stacks current_stacks = {stack};
|
||||
|
||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||
llama_grammar_stacks next_stacks;
|
||||
|
||||
for (const auto & cur_stack : current_stacks) {
|
||||
llama_grammar_accept_chr(grammar, cur_stack, *it, next_stacks);
|
||||
}
|
||||
|
||||
current_stacks = std::move(next_stacks);
|
||||
if (current_stacks.empty()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto & surviving_stack : current_stacks) {
|
||||
if (std::find(stacks_new.begin(), stacks_new.end(), surviving_stack) == stacks_new.end()) {
|
||||
stacks_new.emplace_back(surviving_stack);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
grammar.stacks = std::move(stacks_new);
|
||||
grammar.partial_utf8 = decoded.second;
|
||||
|
||||
if (grammar.stacks.empty()) {
|
||||
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+20
-1
@@ -36,11 +36,17 @@ enum llama_gretype {
|
||||
|
||||
// any character (.)
|
||||
LLAMA_GRETYPE_CHAR_ANY = 7,
|
||||
|
||||
// terminal element: token (<[token-id]>)
|
||||
LLAMA_GRETYPE_TOKEN = 8,
|
||||
|
||||
// inverse token (!<[token-id]>)
|
||||
LLAMA_GRETYPE_TOKEN_NOT = 9,
|
||||
};
|
||||
|
||||
typedef struct llama_grammar_element {
|
||||
enum llama_gretype type;
|
||||
uint32_t value; // Unicode code point or rule ID
|
||||
uint32_t value; // Unicode code point, rule ID, or token ID
|
||||
} llama_grammar_element;
|
||||
|
||||
struct llama_partial_utf8 {
|
||||
@@ -52,6 +58,7 @@ struct llama_grammar_candidate {
|
||||
size_t index;
|
||||
const uint32_t * code_points;
|
||||
llama_partial_utf8 partial_utf8;
|
||||
llama_token id;
|
||||
};
|
||||
|
||||
using llama_grammar_rule = std::vector< llama_grammar_element>;
|
||||
@@ -77,10 +84,13 @@ std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
||||
const llama_grammar_candidates & candidates);
|
||||
|
||||
struct llama_grammar_parser {
|
||||
const llama_vocab * vocab;
|
||||
std::map<std::string, uint32_t> symbol_ids;
|
||||
|
||||
llama_grammar_rules rules;
|
||||
|
||||
llama_grammar_parser(const struct llama_vocab * vocab = nullptr) : vocab(vocab) {}
|
||||
|
||||
llama_grammar_stack c_rules() const;
|
||||
|
||||
uint32_t get_symbol_id(const char * src, size_t len);
|
||||
@@ -112,6 +122,9 @@ struct llama_grammar_trigger_pattern {
|
||||
};
|
||||
|
||||
struct llama_grammar {
|
||||
// maintain a list of llama_tokens and their positions in the trigger_buffer
|
||||
using token_pos = std::pair<llama_token, std::pair<size_t, size_t>>;
|
||||
|
||||
// note: allow null vocab for testing (not great)
|
||||
const llama_vocab * vocab;
|
||||
|
||||
@@ -127,6 +140,7 @@ struct llama_grammar {
|
||||
bool lazy = false;
|
||||
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
|
||||
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
|
||||
std::vector<token_pos> trigger_buffer_positions; // Tokens buffered by lazy grammar. Used to replay when a trigger is found.
|
||||
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
|
||||
std::vector<llama_grammar_trigger_pattern>
|
||||
trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated
|
||||
@@ -171,3 +185,8 @@ void llama_grammar_accept_impl(
|
||||
void llama_grammar_accept_str(
|
||||
struct llama_grammar & grammar,
|
||||
const std::string & piece);
|
||||
|
||||
void llama_grammar_accept_token(
|
||||
struct llama_grammar & grammar,
|
||||
llama_token token,
|
||||
const std::string & piece);
|
||||
|
||||
+3
-2
@@ -1606,8 +1606,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 28: type = LLM_TYPE_20B; break;
|
||||
switch (hparams.n_ff_exp) {
|
||||
case 1408: type = LLM_TYPE_16B; break;
|
||||
case 1792: type = LLM_TYPE_20B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
|
||||
@@ -32,13 +32,66 @@ static bool test_build_grammar_fails(const std::string & grammar_str) {
|
||||
return grammar_fails;
|
||||
}
|
||||
|
||||
struct token_and_piece {
|
||||
llama_token token;
|
||||
std::string piece;
|
||||
};
|
||||
|
||||
// token() encodes a 32-bit ID as 5 bytes: a 0xff marker followed by the ID in big-endian order.
|
||||
static std::string token(llama_token id) {
|
||||
return std::string{
|
||||
static_cast<char>(0xff),
|
||||
static_cast<char>((id >> 24) & 0xff),
|
||||
static_cast<char>((id >> 16) & 0xff),
|
||||
static_cast<char>((id >> 8) & 0xff),
|
||||
static_cast<char>(id & 0xff)
|
||||
};
|
||||
}
|
||||
|
||||
// parse_tokens() parses the token encodes above and UTF-8 text.
|
||||
static std::vector<token_and_piece> parse_tokens(const std::string & input) {
|
||||
std::vector<token_and_piece> result;
|
||||
result.reserve(input.size());
|
||||
size_t offset = 0;
|
||||
while (offset < input.size()) {
|
||||
try {
|
||||
if (static_cast<unsigned char>(input[offset]) == 0xff) {
|
||||
if (offset + 5 > input.size()) {
|
||||
throw std::runtime_error("not enough bytes for token id");
|
||||
}
|
||||
uint32_t val =
|
||||
(static_cast<unsigned char>(input[offset + 1]) << 24) |
|
||||
(static_cast<unsigned char>(input[offset + 2]) << 16) |
|
||||
(static_cast<unsigned char>(input[offset + 3]) << 8) |
|
||||
(static_cast<unsigned char>(input[offset + 4]));
|
||||
auto piece = "<[" + std::to_string(val) + "]>";
|
||||
result.push_back({static_cast<llama_token>(val), piece});
|
||||
offset += 5;
|
||||
} else {
|
||||
uint32_t cpt = unicode_cpt_from_utf8(input, offset);
|
||||
result.push_back({0, unicode_cpt_to_utf8(cpt)});
|
||||
}
|
||||
} catch (const std::invalid_argument & /*ex*/) {
|
||||
// Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize
|
||||
++offset;
|
||||
result.push_back({0, unicode_cpt_to_utf8(0xFFFD)}); // replacement character
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static bool match_string(const std::string & input, llama_grammar * grammar) {
|
||||
const auto cpts = unicode_cpts_from_utf8(input);
|
||||
const auto parsed = parse_tokens(input);
|
||||
|
||||
auto & stacks_cur = llama_grammar_get_stacks(grammar);
|
||||
|
||||
for (const auto & cpt : cpts) {
|
||||
llama_grammar_accept(grammar, cpt);
|
||||
for (const auto & in : parsed) {
|
||||
try {
|
||||
llama_grammar_accept_token(*grammar, in.token, in.piece);
|
||||
} catch (const std::runtime_error & /*e*/) {
|
||||
// normally this shouldn't get hit because of llama_grammar_apply
|
||||
return false;
|
||||
}
|
||||
|
||||
if (stacks_cur.empty()) {
|
||||
// no stacks means that the grammar failed to match at this point
|
||||
@@ -426,6 +479,30 @@ static void test_simple_grammar() {
|
||||
"12a45",
|
||||
}
|
||||
);
|
||||
|
||||
// Test case for a simple grammar with tokens
|
||||
test_grammar(
|
||||
"simple grammar with tokens",
|
||||
R"""(
|
||||
root ::= <[10]> content <[11]>
|
||||
content ::= (!<[11]>)*)""",
|
||||
// Passing strings
|
||||
{
|
||||
token(10) + "hello world" + token(11),
|
||||
token(10) + "text with " + token(12) + " other tokens " + token(13) + " mixed in" + token(11),
|
||||
token(10) + token(11),
|
||||
token(10) + token(12) + token(13) + token(14) + token(15) + token(11),
|
||||
token(10) + "a" + token(11),
|
||||
},
|
||||
// Failing strings
|
||||
{
|
||||
token(10) + "missing end token",
|
||||
token(10),
|
||||
"missing start token" + token(11),
|
||||
token(10) + token(11) + token(11), // double end token
|
||||
token(11) + "wrong order" + token(10),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
static void test_complex_grammar() {
|
||||
@@ -487,6 +564,34 @@ static void test_complex_grammar() {
|
||||
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
|
||||
}
|
||||
);
|
||||
|
||||
// Test case for a more complex grammar with tokens
|
||||
test_grammar(
|
||||
"complex grammar with tokens",
|
||||
R"""(
|
||||
root ::= reasoning+ content tool-call*
|
||||
reasoning ::= <[10]> (!<[11]>)* <[11]>
|
||||
content ::= <[20]> (!<[21]>)* <[21]>
|
||||
tool-call ::= <[12]> name <[13]> args <[14]>
|
||||
name ::= (!<[13]>)+
|
||||
args ::= (!<[14]>)*)""",
|
||||
// Passing strings
|
||||
{
|
||||
token(10) + "I am thinking" + token(11) + token(20) + "hello world!" + token(21) + token(12) + "search" + token(13) + "query=test" + token(14),
|
||||
token(10) + "reasoning 1" + token(11) + token(10) + "reasoning 2" + token(11) + token(20) + token(21) + token(12) + "tool" + token(13) + token(14),
|
||||
token(10) + token(11) + token(20) + "content" + token(21),
|
||||
token(10) + "think" + token(12) + " nested" + token(11) + token(20) + token(10) + "more content" + token(21) + token(12) + "fn" + token(13) + "x=1,y=2" + token(14) + token(12) + "fn2" + token(13) + token(14),
|
||||
token(10) + "reasoning" + token(11) + token(10) + "more" + token(11) + token(10) + "even more" + token(11) + token(20) + "text" + token(21) + token(12) + "a" + token(13) + "b" + token(14) + token(12) + "c" + token(13) + "d" + token(14),
|
||||
},
|
||||
// Failing strings
|
||||
{
|
||||
token(20) + "content only" + token(21),
|
||||
token(10) + "no closing reasoning",
|
||||
token(10) + token(11) + token(20) + "no closing content",
|
||||
token(10) + token(11) + token(20) + token(21) + token(12) + "incomplete tool",
|
||||
token(10) + token(11) + token(11) + token(20) + token(21),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
static void test_special_chars() {
|
||||
|
||||
@@ -515,5 +515,19 @@ int main()
|
||||
{LLAMA_GRETYPE_END, 0},
|
||||
});
|
||||
|
||||
// <[1000]> = "<think>"
|
||||
// <[1001]> = "</think>"
|
||||
verify_parsing(R"""(
|
||||
root ::= <[1000]> !<[1001]> <[1001]>
|
||||
)""", {
|
||||
{"root", 0}
|
||||
}, {
|
||||
// root (index 0)
|
||||
{LLAMA_GRETYPE_TOKEN, 1000},
|
||||
{LLAMA_GRETYPE_TOKEN_NOT, 1001},
|
||||
{LLAMA_GRETYPE_TOKEN, 1001},
|
||||
{LLAMA_GRETYPE_END, 0},
|
||||
});
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -202,7 +202,7 @@ int main()
|
||||
uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
|
||||
cp[0] = 37 + i;
|
||||
cp[1] = 0;
|
||||
next_candidates[i] = {i, cp, {}};
|
||||
next_candidates[i] = {i, cp, {}, 0};
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {
|
||||
|
||||
@@ -16,6 +16,7 @@ add_library(mtmd
|
||||
set_target_properties(mtmd PROPERTIES
|
||||
VERSION ${LLAMA_INSTALL_VERSION}
|
||||
SOVERSION 0
|
||||
MACHO_CURRENT_VERSION 0 # keep macOS linker from seeing oversized version number
|
||||
)
|
||||
|
||||
target_link_libraries (mtmd PUBLIC ggml llama)
|
||||
|
||||
Reference in New Issue
Block a user