mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-07-01 01:57:43 +02:00
Compare commits
46 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ba7654380a | |||
| 6ab2e4765a | |||
| 96e1280839 | |||
| 2c9f833d17 | |||
| 251364549f | |||
| 8acdacb3ea | |||
| 89b2b56e86 | |||
| e128a1bf5b | |||
| 6ef79a67ca | |||
| 4e39a3c332 | |||
| be421fc429 | |||
| 87c2630546 | |||
| 2b3a25c212 | |||
| 8352cdc87b | |||
| 1e2f78a004 | |||
| 0fd7ca7a21 | |||
| 6fefc05a7a | |||
| 7ab364390f | |||
| 7c7f3b7f43 | |||
| 102ac1891d | |||
| d6ae2fa061 | |||
| 68d0027f3d | |||
| ea002810a2 | |||
| 8fad3c7a7c | |||
| 7cf64f6bee | |||
| 5e2d57b2b2 | |||
| f1648e91cf | |||
| d6c95b0740 | |||
| d76a86d967 | |||
| 776f9e59cc | |||
| 3d652bfddf | |||
| 5220a16d18 | |||
| 3ffbbd5ce1 | |||
| 42994048a3 | |||
| e9b2f84f14 | |||
| e721c05c93 | |||
| 57b6abf85a | |||
| 94bb63e4f0 | |||
| f79243992c | |||
| ed4ce0dda2 | |||
| 07d1572347 | |||
| 5e43f104cc | |||
| 16e4b22c5e | |||
| 074c4fd39d | |||
| 669912d9a5 | |||
| fa31c438e0 |
@@ -467,6 +467,7 @@ jobs:
|
||||
run: |
|
||||
cmake -B build -S . \
|
||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||
-DGGML_HIP_ROCWMMA_FATTN=ON \
|
||||
-DGGML_HIP=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
@@ -476,6 +477,7 @@ jobs:
|
||||
cmake -B build2 -S . \
|
||||
-DCMAKE_C_COMPILER=hipcc \
|
||||
-DCMAKE_CXX_COMPILER=hipcc \
|
||||
-DGGML_HIP_ROCWMMA_FATTN=ON \
|
||||
-DGGML_HIP=ON
|
||||
cmake --build build2 --config Release -j $(nproc)
|
||||
|
||||
@@ -1202,6 +1204,11 @@ jobs:
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Clone rocWMMA repository
|
||||
id: clone_rocwmma
|
||||
run: |
|
||||
git clone https://github.com/rocm/rocwmma --branch rocm-6.2.4 --depth 1
|
||||
|
||||
- name: Install
|
||||
id: depends
|
||||
run: |
|
||||
@@ -1231,8 +1238,10 @@ jobs:
|
||||
cmake -G "Unix Makefiles" -B build -S . `
|
||||
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
|
||||
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/rocwmma/library/include/" `
|
||||
-DCMAKE_BUILD_TYPE=Release `
|
||||
-DGGML_HIP=ON `
|
||||
-DGGML_HIP_ROCWMMA_FATTN=ON `
|
||||
-DGGML_RPC=ON
|
||||
cmake --build build -j ${env:NUMBER_OF_PROCESSORS}
|
||||
|
||||
@@ -1251,6 +1260,11 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Clone rocWMMA repository
|
||||
id: clone_rocwmma
|
||||
run: |
|
||||
git clone https://github.com/rocm/rocwmma --branch rocm-6.2.4 --depth 1
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
@@ -1280,8 +1294,10 @@ jobs:
|
||||
cmake -G "Unix Makefiles" -B build -S . `
|
||||
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
|
||||
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/rocwmma/library/include/" `
|
||||
-DCMAKE_BUILD_TYPE=Release `
|
||||
-DAMDGPU_TARGETS=${{ matrix.gpu_target }} `
|
||||
-DGGML_HIP_ROCWMMA_FATTN=ON `
|
||||
-DGGML_HIP=ON `
|
||||
-DGGML_RPC=ON
|
||||
cmake --build build -j ${env:NUMBER_OF_PROCESSORS}
|
||||
@@ -1320,6 +1336,8 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
@@ -1368,6 +1386,7 @@ jobs:
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-xcframework.zip
|
||||
name: llama-${{ steps.tag.outputs.name }}-xcframework
|
||||
|
||||
android-build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# date: Tue Feb 4 13:04:05 EET 2025
|
||||
# date: Sat Mar 8 18:23:52 EET 2025
|
||||
# this file is auto-generated by scripts/gen-authors.sh
|
||||
|
||||
0cc4m <picard12@live.de>
|
||||
@@ -8,10 +8,12 @@
|
||||
3ooabkhxtn <31479382+3ooabkhxtn@users.noreply.github.com>
|
||||
44670 <44670@users.noreply.github.com>
|
||||
65a <10104049+65a@users.noreply.github.com>
|
||||
708-145 <40387547+708-145@users.noreply.github.com>
|
||||
AN Long <aisk@users.noreply.github.com>
|
||||
AT <manyoso@users.noreply.github.com>
|
||||
Aarni Koskela <akx@iki.fi>
|
||||
Aaron Miller <apage43@ninjawhale.com>
|
||||
Aaron Teo <57927438+taronaeo@users.noreply.github.com>
|
||||
Aaryaman Vasishta <aaryaman.vasishta@amd.com>
|
||||
Abheek Gulati <abheekg@hotmail.com>
|
||||
Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com>
|
||||
@@ -20,6 +22,7 @@ Adithya Balaji <adithya.b94@gmail.com>
|
||||
AdithyanI <adithyan.i4internet@gmail.com>
|
||||
Adrian <smith.adriane@gmail.com>
|
||||
Adrian Hesketh <a-h@users.noreply.github.com>
|
||||
Adrian Kretz <me@akretz.com>
|
||||
Adrien Gallouët <adrien@gallouet.fr>
|
||||
Adrien Gallouët <angt@huggingface.co>
|
||||
Ahmad Tameem <113388789+Tameem-10xE@users.noreply.github.com>
|
||||
@@ -28,15 +31,18 @@ AidanBeltonS <87009434+AidanBeltonS@users.noreply.github.com>
|
||||
AidanBeltonS <aidan.belton@codeplay.com>
|
||||
Aisuko <urakiny@gmail.com>
|
||||
Akarshan Biswas <akarshan.biswas@gmail.com>
|
||||
Akarshan Biswas <akarshan@menlo.ai>
|
||||
Akarshan Biswas <akarshanbiswas@fedoraproject.org>
|
||||
Al Mochkin <14274697+amochkin@users.noreply.github.com>
|
||||
Albert Jin <albert.jin@gmail.com>
|
||||
Alberto <57916483+albbus-stack@users.noreply.github.com>
|
||||
Alberto Cabrera Pérez <alberto.cabrera@codeplay.com>
|
||||
Alberto Cabrera Pérez <alberto.cabrera@intel.com>
|
||||
Aleksei Nikiforov <103434461+AlekseiNikiforovIBM@users.noreply.github.com>
|
||||
Alex <awhill19@icloud.com>
|
||||
Alex Azarov <alex@azarov.by>
|
||||
Alex Azarov <alexander.azarov@mapbox.com>
|
||||
Alex Brooks <alex.brooks@ibm.com>
|
||||
Alex Klinkhamer <from.github.com.917@grencez.dev>
|
||||
Alex Klinkhamer <git@grencez.dev>
|
||||
Alex Nguyen <tiendung@users.noreply.github.com>
|
||||
@@ -67,6 +73,7 @@ Andrew Minh Nguyen <40281306+amqdn@users.noreply.github.com>
|
||||
Andy Salerno <andysalerno@gmail.com>
|
||||
Andy Tai <andy-tai@users.noreply.github.com>
|
||||
Anthony Van de Gejuchte <anthonyvdgent@gmail.com>
|
||||
Antoine Viallon <antoine@lesviallon.fr>
|
||||
Antonis Makropoulos <benuix@gmail.com>
|
||||
Arik Poznanski <arikpoz@users.noreply.github.com>
|
||||
Armen Kaleshian <kriation@users.noreply.github.com>
|
||||
@@ -83,6 +90,7 @@ Atsushi Tatsuma <yoshoku@outlook.com>
|
||||
Austin <77757836+teleprint-me@users.noreply.github.com>
|
||||
AustinMroz <austinmroz@utexas.edu>
|
||||
BADR <contact@pythops.com>
|
||||
BB-fat <45072480+BB-fat@users.noreply.github.com>
|
||||
Bach Le <bach@bullno1.com>
|
||||
Bailey Chittle <39804642+bachittle@users.noreply.github.com>
|
||||
BarfingLemurs <128182951+BarfingLemurs@users.noreply.github.com>
|
||||
@@ -101,6 +109,7 @@ Bert Wagner <github@bertwagner.com>
|
||||
Billel Mokeddem <billel.mokeddem.ml@gmail.com>
|
||||
Bingan <70050083+binganao@users.noreply.github.com>
|
||||
Bjarke Viksøe <164612031+bviksoe@users.noreply.github.com>
|
||||
Bodhi <3882561+BodhiHu@users.noreply.github.com>
|
||||
Bodo Graumann <mail@bodograumann.de>
|
||||
Bono Lv <lvscar@users.noreply.github.com>
|
||||
Borislav Stanimirov <b.stanimirov@abv.bg>
|
||||
@@ -128,6 +137,7 @@ CentricStorm <CentricStorm@users.noreply.github.com>
|
||||
Chad Brewbaker <crb002@gmail.com>
|
||||
Changyeon Kim <cyzero.kim@samsung.com>
|
||||
Chao Jiang <jc19chaoj@zoho.com>
|
||||
Charles Duffy <charles@dyfis.net>
|
||||
Charles Xu <63788048+chaxu01@users.noreply.github.com>
|
||||
Charles Xu <charles.xu@arm.com>
|
||||
Chen Xi <xi2.chen@intel.com>
|
||||
@@ -139,12 +149,14 @@ Chris Kuehl <ckuehl@ckuehl.me>
|
||||
Christian Demsar <christian@github.email.demsar.us>
|
||||
Christian Demsar <crasm@git.vczf.us>
|
||||
Christian Falch <875252+chrfalch@users.noreply.github.com>
|
||||
Christian Fillion <cfillion@users.noreply.github.com>
|
||||
Christian Kastner <ckk@kvr.at>
|
||||
Christian Kögler <ck3d@gmx.de>
|
||||
Christian Köhnenkamp <cvk5@me.com>
|
||||
Christian Zhou-Zheng <59622928+christianazinn@users.noreply.github.com>
|
||||
Christopher Nielsen <62156882+mascguy@users.noreply.github.com>
|
||||
Clark Saben <76020733+csaben@users.noreply.github.com>
|
||||
Clauszy <zhangyub@uniontech.com>
|
||||
Clint Herron <hanclinto@gmail.com>
|
||||
Conrad Kramer <conrad@conradkramer.com>
|
||||
Corentin REGAL <corentin.regal@gmail.com>
|
||||
@@ -163,6 +175,7 @@ Daniel Hiltgen <dhiltgen@users.noreply.github.com>
|
||||
Daniel Illescas Romero <illescas.daniel@protonmail.com>
|
||||
Daniel Kleine <53251018+d-kleine@users.noreply.github.com>
|
||||
Daniele <57776841+daniandtheweb@users.noreply.github.com>
|
||||
Danny Milosavljevic <dannym@friendly-machines.com>
|
||||
DannyDaemonic <DannyDaemonic@gmail.com>
|
||||
Dat Quoc Nguyen <2412555+datquocnguyen@users.noreply.github.com>
|
||||
Dave <dave-fl@users.noreply.github.com>
|
||||
@@ -170,6 +183,7 @@ Dave Airlie <airlied@gmail.com>
|
||||
Dave Airlie <airlied@redhat.com>
|
||||
Dave Della Costa <ddellacosta+github@gmail.com>
|
||||
David Friehs <david@friehs.info>
|
||||
David Huang <1969802+hjc4869@users.noreply.github.com>
|
||||
David Kennedy <dakennedyd@gmail.com>
|
||||
David Pflug <david@pflug.email>
|
||||
David Renshaw <dwrenshaw@gmail.com>
|
||||
@@ -236,6 +250,7 @@ Felix <stenbackfelix@gmail.com>
|
||||
Finn Voorhees <finnvoorhees@gmail.com>
|
||||
Firat <firatkiral@gmail.com>
|
||||
FirstTimeEZ <179362031+FirstTimeEZ@users.noreply.github.com>
|
||||
Florent BENOIT <fbenoit@redhat.com>
|
||||
Folko-Ven <71110216+Folko-Ven@users.noreply.github.com>
|
||||
Foul-Tarnished <107711110+Foul-Tarnished@users.noreply.github.com>
|
||||
Francisco Melo <43780565+francis2tm@users.noreply.github.com>
|
||||
@@ -254,6 +269,7 @@ Gary Mulder <gjmulder@gmail.com>
|
||||
Gavin Zhao <gavinzhaojw@protonmail.com>
|
||||
Genkagaku.GPT <hlhr202@163.com>
|
||||
Georgi Gerganov <ggerganov@gmail.com>
|
||||
Gian-Carlo Pascutto <gcp@sjeng.org>
|
||||
Gilad S <giladgd@users.noreply.github.com>
|
||||
Gilad S. <7817232+giladgd@users.noreply.github.com>
|
||||
Giuseppe Scrivano <giuseppe@scrivano.org>
|
||||
@@ -267,7 +283,9 @@ Guspan Tanadi <36249910+guspan-tanadi@users.noreply.github.com>
|
||||
Gustavo Rocha Dias <91472747+gustrd@users.noreply.github.com>
|
||||
Haggai Nuchi <h.nuchi@gmail.com>
|
||||
Halalaluyafail3 <55773281+Halalaluyafail3@users.noreply.github.com>
|
||||
Hale Chan <halechan@qq.com>
|
||||
Hamdoud Hakem <90524568+hamdoudhakem@users.noreply.github.com>
|
||||
Han Yin <han.yin@arm.com>
|
||||
HanishKVC <hanishkvc@gmail.com>
|
||||
Haohui Mai <ricetons@gmail.com>
|
||||
Haoxiang Fei <tonyfettes@tonyfettes.com>
|
||||
@@ -278,6 +296,7 @@ Haus1 <haus.xda@gmail.com>
|
||||
Henk Poley <HenkPoley@gmail.com>
|
||||
Henri Vasserman <henv@hot.ee>
|
||||
Henrik Forstén <henrik.forsten@gmail.com>
|
||||
Henry Linjamäki <henry.linjamaki@gmail.com>
|
||||
Herman Semenov <GermanAizek@yandex.ru>
|
||||
Hesen Peng <hesen.peng@gmail.com>
|
||||
HimariO <dsfhe49854@gmail.com>
|
||||
@@ -307,6 +326,7 @@ Ivan <nekotekina@gmail.com>
|
||||
Ivan Filipov <159561759+vanaka11@users.noreply.github.com>
|
||||
Ivan Komarov <Ivan.Komarov@dfyz.info>
|
||||
Ivan Stepanov <ivanstepanovftw@gmail.com>
|
||||
JC <43374599+MrSMlT@users.noreply.github.com>
|
||||
JFLFY2255 <JFLFY2255@163.com>
|
||||
JH23X <165871467+JH23X@users.noreply.github.com>
|
||||
Jack Mousseau <jack@software.inc>
|
||||
@@ -325,6 +345,7 @@ Jan Ploski <jpl@plosquare.com>
|
||||
Jannis Schönleber <joennlae@gmail.com>
|
||||
Jared Van Bortel <cebtenzzre@gmail.com>
|
||||
Jared Van Bortel <jared@nomic.ai>
|
||||
Jason C.H <ctrysbita@outlook.com>
|
||||
Jason McCartney <jmac@theroot.org>
|
||||
Jason Stillerman <jason.t.stillerman@gmail.com>
|
||||
Jean-Christophe Hoelt <hoelt@fovea.cc>
|
||||
@@ -342,6 +363,7 @@ Jiahao Li <liplus17@163.com>
|
||||
Jian Liao <jianliao@users.noreply.github.com>
|
||||
JidongZhang-THU <1119708529@qq.com>
|
||||
Jinwoo Jeong <33892306+williamjeong2@users.noreply.github.com>
|
||||
Jinyang He <hejinyang@loongson.cn>
|
||||
Jiří Podivín <66251151+jpodivin@users.noreply.github.com>
|
||||
Jiří Sejkora <Sejseloid@gmail.com>
|
||||
Joan Fontanals <jfontanalsmartinez@gmail.com>
|
||||
@@ -379,6 +401,7 @@ Justine Tunney <jtunney@mozilla.com>
|
||||
Juuso Alasuutari <juuso.alasuutari@gmail.com>
|
||||
KASR <karim.asrih@gmail.com>
|
||||
Kamil Tomšík <info@tomsik.cz>
|
||||
Kante Yin <kerthcet@gmail.com>
|
||||
Karol Kontny <82021046+kkontny@users.noreply.github.com>
|
||||
Karsten Weiss <knweiss@gmail.com>
|
||||
Karthick <j.karthic2004@gmail.com>
|
||||
@@ -419,6 +442,7 @@ LoganDark <github@logandark.mozmail.com>
|
||||
Loïc Carrère <loic.carrere@gmail.com>
|
||||
LostRuins <39025047+LostRuins@users.noreply.github.com>
|
||||
LostRuins Concedo <39025047+LostRuins@users.noreply.github.com>
|
||||
Lucas Moura Belo <lucas.belo@live.com>
|
||||
Luciano <lucianostrika44@gmail.com>
|
||||
Luo Tian <lt@basecity.com>
|
||||
Lyle Dean <dean@lyle.dev>
|
||||
@@ -463,6 +487,7 @@ Matthew Tejo <matthew.tejo@gmail.com>
|
||||
Matvey Soloviev <blackhole89@gmail.com>
|
||||
Max Krasnyansky <max.krasnyansky@gmail.com>
|
||||
Max Krasnyansky <quic_maxk@quicinc.com>
|
||||
Maxim Evtush <154841002+maximevtush@users.noreply.github.com>
|
||||
Maxime <672982+maximegmd@users.noreply.github.com>
|
||||
Maximilian Winter <maximilian.winter.91@gmail.com>
|
||||
Meng Zhang <meng@tabbyml.com>
|
||||
@@ -494,6 +519,7 @@ Miwa / Ensan <63481257+ensan-hcl@users.noreply.github.com>
|
||||
Mohammadreza Hendiani <hendiani.mohammadreza@gmail.com>
|
||||
Mohammadreza Hendiani <mohammad.r.hendiani@gmail.com>
|
||||
Molly Sophia <mollysophia379@gmail.com>
|
||||
MoonRide303 <130458190+MoonRide303@users.noreply.github.com>
|
||||
MorganRO8 <47795945+MorganRO8@users.noreply.github.com>
|
||||
Murilo Santana <mvrilo@gmail.com>
|
||||
Musab Gultekin <musabgultekin@users.noreply.github.com>
|
||||
@@ -524,6 +550,7 @@ Nikolas <127742645+nneubacher@users.noreply.github.com>
|
||||
Nindaleth <Nindaleth@users.noreply.github.com>
|
||||
Nuno <rare-magma@posteo.eu>
|
||||
OSecret <135510162+OLSecret@users.noreply.github.com>
|
||||
Oleksandr Kuvshynov <661042+okuvshynov@users.noreply.github.com>
|
||||
Oleksandr Nikitin <oleksandr@tvori.info>
|
||||
Oleksii Maryshchenko <oleksii.maryshchenko@gmail.com>
|
||||
Olivier Chafik <ochafik@users.noreply.github.com>
|
||||
@@ -533,6 +560,7 @@ PAB <pierreantoine.bannier@gmail.com>
|
||||
Pablo Duboue <pablo.duboue@gmail.com>
|
||||
Pascal Patry <ppatry@mtacitlabs.com>
|
||||
Patrice Ferlet <metal3d@gmail.com>
|
||||
Patrick Peng <retr0@retr0.blog>
|
||||
Paul Tsochantaris <ptsochantaris@icloud.com>
|
||||
Pavel Zloi <github.com@drteam.rocks>
|
||||
Pavol Rusnak <pavol@rusnak.io>
|
||||
@@ -549,6 +577,7 @@ Pieter Ouwerkerk <pieter.ouwerkerk@gmail.com>
|
||||
Plamen Minev <pacominev@gmail.com>
|
||||
Prashant Vithule <119530321+Vithulep@users.noreply.github.com>
|
||||
Przemysław Pawełczyk <przemoc@gmail.com>
|
||||
PureJourney <edward.pong@qq.com>
|
||||
Qin Yue Chen <71813199+chenqiny@users.noreply.github.com>
|
||||
Qingyou Meng <meng.qingyou@gmail.com>
|
||||
Qu Zongfu <43257352+yancaoweidaode@users.noreply.github.com>
|
||||
@@ -564,14 +593,17 @@ Rand Xie <randxiexyy29@gmail.com>
|
||||
Randall Fitzgerald <randall@dasaku.net>
|
||||
Random Fly <renfei8@live.cn>
|
||||
Reinforce-II <fate@eastal.com>
|
||||
Rémy O <remyoudompheng@gmail.com>
|
||||
Rémy Oudompheng <oudomphe@phare.normalesup.org>
|
||||
Ren Xuancheng <jklj077@users.noreply.github.com>
|
||||
Rene Leonhardt <65483435+reneleonhardt@users.noreply.github.com>
|
||||
Reza Kakhki <rezakakhki.de@gmail.com>
|
||||
Reza Rahemtola <49811529+RezaRahemtola@users.noreply.github.com>
|
||||
RhinoDevel <RhinoDevel@users.noreply.github.com>
|
||||
Riccardo Orlando <Riccorl@users.noreply.github.com>
|
||||
Riceball LEE <snowyu.lee@gmail.com>
|
||||
Rich Dougherty <rich@rd.nz>
|
||||
Richard <r-burton@hotmail.co.uk>
|
||||
Richard Kiss <him@richardkiss.com>
|
||||
Richard Roberson <richardr1126@gmail.com>
|
||||
Rick G <26732651+TheFlipbook@users.noreply.github.com>
|
||||
@@ -588,6 +620,7 @@ Robert Sung-wook Shin <edp1096@users.noreply.github.com>
|
||||
Robey Holderith <robey@flaminglunchbox.net>
|
||||
Robyn <robyngraf@users.noreply.github.com>
|
||||
Roger Meier <r.meier@siemens.com>
|
||||
Rohanjames1997 <rohan.james4@gmail.com>
|
||||
Roland <14355895+rbur0425@users.noreply.github.com>
|
||||
Romain Biessy <romain.biessy@codeplay.com>
|
||||
Romain D <90720+Artefact2@users.noreply.github.com>
|
||||
@@ -610,6 +643,7 @@ Ryan Landay <rlanday@gmail.com>
|
||||
Ryder Wishart <ryderwishart@gmail.com>
|
||||
Ryuei <louixs@users.noreply.github.com>
|
||||
Rőczey Barnabás <31726601+An0nie@users.noreply.github.com>
|
||||
SAMI <samuel.koesnadi@stud.uni-due.de>
|
||||
SRHMorris <69468379+SRHMorris@users.noreply.github.com>
|
||||
SXX <sxx1136965276@gmail.com>
|
||||
SakuraUmi <yukinon244@gmail.com>
|
||||
@@ -634,6 +668,8 @@ Shane A <shanea@allenai.org>
|
||||
Shangning Xu <32517059+xushangning@users.noreply.github.com>
|
||||
Shankar <gshankar.87@gmail.com>
|
||||
Shanshan Shen <467638484@qq.com>
|
||||
Shelby Jenkins <47464908+ShelbyJenkins@users.noreply.github.com>
|
||||
Sheldon Robinson <sheldon.robinson@live.com>
|
||||
Shijie <821898965@qq.com>
|
||||
Shintarou Okada <kokuzen@gmail.com>
|
||||
Shouzheng Liu <61452103+lshzh-ww@users.noreply.github.com>
|
||||
@@ -713,18 +749,24 @@ Victor Nogueira <felladrin@gmail.com>
|
||||
Victor Z. Peng <ziliangdotme@gmail.com>
|
||||
Viet-Anh NGUYEN (Andrew) <vietanh.dev@gmail.com>
|
||||
Vinesh Janarthanan <36610342+VJHack@users.noreply.github.com>
|
||||
Vitali Lovich <vlovich+github@gmail.com>
|
||||
Vivian <vynride@gmail.com>
|
||||
Vlad <spitfireage@gmail.com>
|
||||
Vladimir <bogdad@gmail.com>
|
||||
Vladimir Malyutin <first-leon@yandex.ru>
|
||||
Vladimir Vuksanovic <109677816+vvuksanovic@users.noreply.github.com>
|
||||
Vladimir Zorin <vladimir@deviant.guru>
|
||||
VoidIsVoid <343750470@qq.com>
|
||||
Volodymyr Vitvitskyi <72226+signalpillar@users.noreply.github.com>
|
||||
Wagner Bruna <wbruna@users.noreply.github.com>
|
||||
Wang Qin <37098874+wangqin0@users.noreply.github.com>
|
||||
Wang Ran (汪然) <wangr@smail.nju.edu.cn>
|
||||
WangHaoranRobin <56047610+WangHaoranRobin@users.noreply.github.com>
|
||||
Weird Constructor <weirdconstructor@gmail.com>
|
||||
Weizhao Ouyang <o451686892@gmail.com>
|
||||
Welby Seely <welbyseely@gmail.com>
|
||||
Wentai Zhang <rchardx@gmail.com>
|
||||
Wilken Gottwalt <12194808+wgottwalt@users.noreply.github.com>
|
||||
WillCorticesAI <150854901+WillCorticesAI@users.noreply.github.com>
|
||||
William Tambellini <william.tambellini@gmail.com>
|
||||
William Tambellini <wtambellini@sdl.com>
|
||||
@@ -816,6 +858,8 @@ chaihahaha <chai836275709@gmail.com>
|
||||
chiranko <96988916+chiranko@users.noreply.github.com>
|
||||
clibdev <52199778+clibdev@users.noreply.github.com>
|
||||
clyang <clyang@clyang.net>
|
||||
cmdr2 <secondary.cmdr2@gmail.com>
|
||||
cmdr2 <shashank.shekhar.global@gmail.com>
|
||||
cocktailpeanut <121128867+cocktailpeanut@users.noreply.github.com>
|
||||
codezjx <code.zjx@gmail.com>
|
||||
coezbek <c.oezbek@gmail.com>
|
||||
@@ -835,6 +879,7 @@ deepdiffuser <112834445+deepdiffuser@users.noreply.github.com>
|
||||
devojony <61173062+devojony@users.noreply.github.com>
|
||||
ditsuke <ditsuke@protonmail.com>
|
||||
divinity76 <divinity76@gmail.com>
|
||||
dm4 <dm4@secondstate.io>
|
||||
dm4 <sunrisedm4@gmail.com>
|
||||
dotpy314 <33351922+dotpy314@users.noreply.github.com>
|
||||
drbh <david.richard.holtz@gmail.com>
|
||||
@@ -849,6 +894,7 @@ fairydreaming <166155368+fairydreaming@users.noreply.github.com>
|
||||
fengerhu1 <2748250768@qq.com>
|
||||
fj-y-saito <85871716+fj-y-saito@users.noreply.github.com>
|
||||
fraxy-v <65565042+fraxy-v@users.noreply.github.com>
|
||||
fxzjshm <11426482+fxzjshm@users.noreply.github.com>
|
||||
github-actions[bot] <github-actions[bot]@users.noreply.github.com>
|
||||
gliptic <gliptic@users.noreply.github.com>
|
||||
gn64 <yukikaze.jp@gmail.com>
|
||||
@@ -873,6 +919,7 @@ hydai <z54981220@gmail.com>
|
||||
iSma <ismail.senhaji@gmail.com>
|
||||
iacore <74560659+iacore@users.noreply.github.com>
|
||||
icppWorld <124377669+icppWorld@users.noreply.github.com>
|
||||
igardev <49397134+igardev@users.noreply.github.com>
|
||||
igarnier <igarnier@protonmail.com>
|
||||
intelmatt <61025942+intelmatt@users.noreply.github.com>
|
||||
iohub <rickyang.pro@gmail.com>
|
||||
@@ -880,6 +927,7 @@ issixx <46835150+issixx@users.noreply.github.com>
|
||||
jacobi petrucciani <8117202+jpetrucciani@users.noreply.github.com>
|
||||
jaime-m-p <167997752+jaime-m-p@users.noreply.github.com>
|
||||
jameswu2014 <545426914@qq.com>
|
||||
jason_w <jason.wang@126.com>
|
||||
jdomke <28772296+jdomke@users.noreply.github.com>
|
||||
jiahao su <damow890@gmail.com>
|
||||
jiez <373447296@qq.com>
|
||||
@@ -891,6 +939,7 @@ jon-chuang <9093549+jon-chuang@users.noreply.github.com>
|
||||
jp-x-g <jpxg-dev@protonmail.com>
|
||||
jukofyork <69222624+jukofyork@users.noreply.github.com>
|
||||
junchao-loongson <68935141+junchao-loongson@users.noreply.github.com>
|
||||
junchao-zhao <68935141+junchao-loongson@users.noreply.github.com>
|
||||
jwj7140 <32943891+jwj7140@users.noreply.github.com>
|
||||
k.h.lai <adrian.k.h.lai@outlook.com>
|
||||
kaizau <kaizau@users.noreply.github.com>
|
||||
@@ -925,6 +974,7 @@ ltoniazzi <61414566+ltoniazzi@users.noreply.github.com>
|
||||
luoyu-intel <yu.luo@intel.com>
|
||||
m3ndax <adrian.goessl@outlook.com>
|
||||
maddes8cht <55592906+maddes8cht@users.noreply.github.com>
|
||||
magicse <magicse@users.noreply.github.com>
|
||||
mahorozte <41834471+mahorozte@users.noreply.github.com>
|
||||
makomk <makosoft@googlemail.com>
|
||||
manikbhandari <mbbhandarimanik2@gmail.com>
|
||||
@@ -935,6 +985,7 @@ matt23654 <matthew.webber@protonmail.com>
|
||||
matteo <matteogeniaccio@yahoo.it>
|
||||
mdrokz <mohammadmunshi@gmail.com>
|
||||
mgroeber9110 <45620825+mgroeber9110@users.noreply.github.com>
|
||||
midnight <midnightmagic@users.noreply.github.com>
|
||||
minarchist <minarchist@users.noreply.github.com>
|
||||
mj-shifu <77107165+mj-shifu@users.noreply.github.com>
|
||||
mmyjona <jonathan.gonse@gmail.com>
|
||||
@@ -958,10 +1009,12 @@ omahs <73983677+omahs@users.noreply.github.com>
|
||||
oobabooga <112222186+oobabooga@users.noreply.github.com>
|
||||
opparco <parco.opaai@gmail.com>
|
||||
ostix360 <55257054+ostix360@users.noreply.github.com>
|
||||
pascal-lc <49066376+pascal-lc@users.noreply.github.com>
|
||||
pculliton <phillipculliton@gmail.com>
|
||||
peidaqi <peidaqi@gmail.com>
|
||||
pengxin99 <pengxin.yuan@intel.com>
|
||||
perserk <perserk@gmail.com>
|
||||
petterreinholdtsen <pere-github@hungry.com>
|
||||
piDack <104877312+piDack@users.noreply.github.com>
|
||||
pmysl <piotr.myslinski@outlook.com>
|
||||
postmasters <namnguyen@google.com>
|
||||
@@ -983,6 +1036,7 @@ semidark <me@semidark.net>
|
||||
serhii-nakon <57632032+serhii-nakon@users.noreply.github.com>
|
||||
sharpHL <132747147+sharpHL@users.noreply.github.com>
|
||||
shibe2 <shibe@tuta.io>
|
||||
simon886212 <37953122+simon886212@users.noreply.github.com>
|
||||
singularity <12184989+singularity-s0@users.noreply.github.com>
|
||||
sjinzh <sjinzh@gmail.com>
|
||||
sjxx <63994076+ylsdamxssjxxdd@users.noreply.github.com>
|
||||
@@ -1000,10 +1054,12 @@ tarcey <cey.tarik@gmail.com>
|
||||
tc-mb <157115220+tc-mb@users.noreply.github.com>
|
||||
texmex76 <40733439+texmex76@users.noreply.github.com>
|
||||
thement <40525767+thement@users.noreply.github.com>
|
||||
theraininsky <76763719+theraininsky@users.noreply.github.com>
|
||||
thewh1teagle <61390950+thewh1teagle@users.noreply.github.com>
|
||||
tjohnman <tjohnman@users.noreply.github.com>
|
||||
toyer <2042519524@qq.com>
|
||||
tslmy <tslmy@users.noreply.github.com>
|
||||
tv1wnd <55383215+tv1wnd@users.noreply.github.com>
|
||||
ubik2 <ubik2@users.noreply.github.com>
|
||||
uint256_t <konndennsa@gmail.com>
|
||||
uint256_t <maekawatoshiki1017@gmail.com>
|
||||
@@ -1014,6 +1070,7 @@ valiray <133289098+valiray@users.noreply.github.com>
|
||||
vb <vaibhavs10@gmail.com>
|
||||
vik <vikhyatk@gmail.com>
|
||||
viric <viric@viric.name>
|
||||
vmobilis <75476228+vmobilis@users.noreply.github.com>
|
||||
vodkaslime <646329483@qq.com>
|
||||
vvhg1 <94630311+vvhg1@users.noreply.github.com>
|
||||
vxiiduu <73044267+vxiiduu@users.noreply.github.com>
|
||||
@@ -1028,6 +1085,8 @@ wzy <32936898+Freed-Wu@users.noreply.github.com>
|
||||
xaedes <xaedes@gmail.com>
|
||||
xaedes <xaedes@googlemail.com>
|
||||
xctan <axunlei@gmail.com>
|
||||
xiaobing318 <71554036+xiaobing318@users.noreply.github.com>
|
||||
xiaofei <hbuxiaofei@gmail.com>
|
||||
xloem <0xloem@gmail.com>
|
||||
yangli2 <yangli2@gmail.com>
|
||||
ymcki <84055651+ymcki@users.noreply.github.com>
|
||||
|
||||
@@ -836,7 +836,7 @@ ifdef GGML_MUSA
|
||||
else
|
||||
MUSA_PATH ?= /opt/musa
|
||||
endif
|
||||
MUSA_ARCHITECTURES ?= 21;22
|
||||
MUSA_ARCHITECTURES ?= 21;22;31
|
||||
|
||||
MK_CPPFLAGS += -DGGML_USE_MUSA -DGGML_USE_CUDA
|
||||
MK_LDFLAGS += -L$(MUSA_PATH)/lib -Wl,-rpath=$(MUSA_PATH)/lib
|
||||
|
||||
@@ -25,7 +25,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
|
||||
|
||||
- **How to use [MTLResidencySet](https://developer.apple.com/documentation/metal/mtlresidencyset?language=objc) to keep the GPU memory active?** https://github.com/ggml-org/llama.cpp/pull/11427
|
||||
- **VS Code extension for FIM completions:** https://github.com/ggml-org/llama.vscode
|
||||
- Universal tool call support in `llama-server`: https://github.com/ggml-org/llama.cpp/pull/9639
|
||||
- Universal [tool call support](./docs/function-calling.md) in `llama-server` https://github.com/ggml-org/llama.cpp/pull/9639
|
||||
- Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim
|
||||
- Introducing GGUF-my-LoRA https://github.com/ggml-org/llama.cpp/discussions/10123
|
||||
- Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggml-org/llama.cpp/discussions/9669
|
||||
@@ -157,6 +157,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- Guile Scheme: [guile_llama_cpp](https://savannah.nongnu.org/projects/guile-llama-cpp)
|
||||
- Swift [srgtuszy/llama-cpp-swift](https://github.com/srgtuszy/llama-cpp-swift)
|
||||
- Swift [ShenghaiWang/SwiftLlama](https://github.com/ShenghaiWang/SwiftLlama)
|
||||
- Delphi [Embarcadero/llama-cpp-delphi](https://github.com/Embarcadero/llama-cpp-delphi)
|
||||
|
||||
</details>
|
||||
|
||||
@@ -171,6 +172,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [eva](https://github.com/ylsdamxssjxxdd/eva) (MIT)
|
||||
- [iohub/collama](https://github.com/iohub/coLLaMA) (Apache-2.0)
|
||||
- [janhq/jan](https://github.com/janhq/jan) (AGPL)
|
||||
- [johnbean393/Sidekick](https://github.com/johnbean393/Sidekick) (MIT)
|
||||
- [KanTV](https://github.com/zhouwg/kantv?tab=readme-ov-file) (Apache-2.0)
|
||||
- [KodiBot](https://github.com/firatkiral/kodibot) (GPL)
|
||||
- [llama.vim](https://github.com/ggml-org/llama.vim) (MIT)
|
||||
|
||||
@@ -352,10 +352,10 @@ function gg_run_open_llama_7b_v2 {
|
||||
|
||||
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
|
||||
|
||||
(time ./bin/llama-save-load-state--model ${model_q4_0} -ngl 10 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state--model ${model_q4_0} -ngl 10 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state--model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state--model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
|
||||
function check_ppl {
|
||||
qnt="$1"
|
||||
|
||||
+39
-8
@@ -1867,16 +1867,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_PASSKEY}));
|
||||
add_opt(common_arg(
|
||||
{"-o", "--output", "--output-file"}, "FNAME",
|
||||
string_format("output file (default: '%s')",
|
||||
ex == LLAMA_EXAMPLE_EXPORT_LORA
|
||||
? params.lora_outfile.c_str()
|
||||
: ex == LLAMA_EXAMPLE_CVECTOR_GENERATOR
|
||||
? params.cvector_outfile.c_str()
|
||||
: params.out_file.c_str()),
|
||||
string_format("output file (default: '%s')", params.out_file.c_str()),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.out_file = value;
|
||||
params.cvector_outfile = value;
|
||||
params.lora_outfile = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA}));
|
||||
add_opt(common_arg(
|
||||
@@ -2571,5 +2564,43 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--fim-qwen-7b-spec"},
|
||||
string_format("use Qwen 2.5 Coder 7B + 0.5B draft for speculative decoding (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
|
||||
params.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
|
||||
params.speculative.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
|
||||
params.speculative.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
|
||||
params.speculative.n_gpu_layers = 99;
|
||||
params.port = 8012;
|
||||
params.n_gpu_layers = 99;
|
||||
params.flash_attn = true;
|
||||
params.n_ubatch = 1024;
|
||||
params.n_batch = 1024;
|
||||
params.n_ctx = 0;
|
||||
params.n_cache_reuse = 256;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--fim-qwen-14b-spec"},
|
||||
string_format("use Qwen 2.5 Coder 14B + 0.5B draft for speculative decoding (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF";
|
||||
params.hf_file = "qwen2.5-coder-14b-q8_0.gguf";
|
||||
params.speculative.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
|
||||
params.speculative.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
|
||||
params.speculative.n_gpu_layers = 99;
|
||||
params.port = 8012;
|
||||
params.n_gpu_layers = 99;
|
||||
params.flash_attn = true;
|
||||
params.n_ubatch = 1024;
|
||||
params.n_batch = 1024;
|
||||
params.n_ctx = 0;
|
||||
params.n_cache_reuse = 256;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
return ctx_arg;
|
||||
}
|
||||
|
||||
+372
-194
@@ -60,7 +60,9 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
|
||||
}
|
||||
msg.role = message.at("role");
|
||||
|
||||
if (message.contains("content")) {
|
||||
auto has_content = message.contains("content");
|
||||
auto has_tool_calls = message.contains("tool_calls");
|
||||
if (has_content) {
|
||||
const auto & content = message.at("content");
|
||||
if (content.is_string()) {
|
||||
msg.content = content;
|
||||
@@ -81,19 +83,8 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
|
||||
} else if (!content.is_null()) {
|
||||
throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Expected 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
|
||||
}
|
||||
if (message.contains("reasoning_content")) {
|
||||
msg.reasoning_content = message.at("reasoning_content");
|
||||
}
|
||||
if (message.contains("name")) {
|
||||
msg.tool_name = message.at("name");
|
||||
}
|
||||
if (message.contains("tool_call_id")) {
|
||||
msg.tool_call_id = message.at("tool_call_id");
|
||||
}
|
||||
if (message.contains("tool_calls")) {
|
||||
if (has_tool_calls) {
|
||||
for (const auto & tool_call : message.at("tool_calls")) {
|
||||
common_chat_tool_call tc;
|
||||
if (!tool_call.contains("type")) {
|
||||
@@ -118,6 +109,18 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
|
||||
msg.tool_calls.push_back(tc);
|
||||
}
|
||||
}
|
||||
if (!has_content && !has_tool_calls) {
|
||||
throw std::runtime_error("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)");
|
||||
}
|
||||
if (message.contains("reasoning_content")) {
|
||||
msg.reasoning_content = message.at("reasoning_content");
|
||||
}
|
||||
if (message.contains("name")) {
|
||||
msg.tool_name = message.at("name");
|
||||
}
|
||||
if (message.contains("tool_call_id")) {
|
||||
msg.tool_call_id = message.at("tool_call_id");
|
||||
}
|
||||
|
||||
msgs.push_back(msg);
|
||||
}
|
||||
@@ -442,6 +445,7 @@ std::string common_chat_format_name(common_chat_format format) {
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
|
||||
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
|
||||
case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING: return "Hermes 2 Pro (extract reasoning)";
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)";
|
||||
default:
|
||||
@@ -449,12 +453,6 @@ std::string common_chat_format_name(common_chat_format format) {
|
||||
}
|
||||
}
|
||||
|
||||
const common_grammar_options grammar_options {
|
||||
/* .dotall = */ false,
|
||||
/* .compact_spaces = */ false,
|
||||
// /* .compact_spaces = */ true,
|
||||
};
|
||||
|
||||
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
|
||||
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
||||
struct json_error_locator : public nlohmann::json_sax<json> {
|
||||
@@ -500,6 +498,34 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
|
||||
}
|
||||
}
|
||||
|
||||
static bool parse_literal(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
|
||||
auto expected_it = expected.begin();
|
||||
auto tmp_it = it;
|
||||
while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
|
||||
++tmp_it;
|
||||
++expected_it;
|
||||
}
|
||||
if (expected_it == expected.end()) {
|
||||
it = tmp_it;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static std::optional<std::smatch> parse_pattern(std::string::const_iterator & it, const std::string::const_iterator & end, const std::regex & expected) {
|
||||
std::smatch match;
|
||||
if (std::regex_match(it, end, match, expected)) {
|
||||
it = match.suffix().first;
|
||||
return match;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static void consume_spaces(std::string::const_iterator & it, const std::string::const_iterator & end) {
|
||||
while (it != end && std::isspace(*it)) {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
||||
@@ -509,7 +535,8 @@ static common_chat_msg parse_json_tool_calls(
|
||||
const std::string& input,
|
||||
const std::optional<std::regex> & trigger_opt,
|
||||
const std::regex & function_regex,
|
||||
const std::regex & close_regex) {
|
||||
const std::regex & close_regex,
|
||||
bool allow_raw_python = false) {
|
||||
std::smatch match;
|
||||
|
||||
common_chat_msg result;
|
||||
@@ -540,14 +567,19 @@ static common_chat_msg parse_json_tool_calls(
|
||||
it = rit->suffix().first;
|
||||
|
||||
json arguments;
|
||||
if (!parse_json(it, end, arguments)) {
|
||||
if (parse_json(it, end, arguments)) {
|
||||
if (!std::regex_search(it, end, match, close_regex)) {
|
||||
throw std::runtime_error("Malformed input, missing closing pattern: " + input);
|
||||
}
|
||||
it = match.suffix().first;
|
||||
result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
|
||||
} else {
|
||||
if (allow_raw_python && name == "python") {
|
||||
result.tool_calls.push_back({name, json({{"code", std::string(it, end)}}).dump(), /* id= */ ""});
|
||||
break;
|
||||
}
|
||||
throw std::runtime_error("Failed to parse json tool call arguments: " + input);
|
||||
}
|
||||
if (!std::regex_search(it, end, match, close_regex)) {
|
||||
throw std::runtime_error("Malformed input, missing closing pattern: " + input);
|
||||
}
|
||||
it = match.suffix().first;
|
||||
result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
|
||||
}
|
||||
|
||||
if (!result.tool_calls.empty()) {
|
||||
@@ -559,29 +591,29 @@ static common_chat_msg parse_json_tool_calls(
|
||||
return result;
|
||||
}
|
||||
|
||||
static common_chat_tool_call process_tool_call(const json & tool_call) {
|
||||
const auto & arguments = tool_call.at("arguments");
|
||||
return {
|
||||
/* .name = */ tool_call.at("name"),
|
||||
/* .arguments = */ arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||
/* .id = */ tool_call.contains("id") ? tool_call.at("id") : "",
|
||||
};
|
||||
}
|
||||
static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
|
||||
auto content_end = input.find(prefix);
|
||||
size_t tc_start = std::string::npos;
|
||||
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
const auto process_tool_calls = [&](const json & tool_calls) {
|
||||
for (const auto & tool_call : tool_calls) {
|
||||
const auto & arguments = tool_call.at("arguments");
|
||||
result.tool_calls.push_back({
|
||||
tool_call.at("name"),
|
||||
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||
tool_call.contains("id") ? tool_call.at("id") : "",
|
||||
});
|
||||
}
|
||||
};
|
||||
if (content_end == std::string::npos) {
|
||||
result.content = input;
|
||||
} else {
|
||||
tc_start = content_end + prefix.size() - rstrip_prefix;
|
||||
result.content = input.substr(0, content_end);
|
||||
auto tool_calls = json::parse(input.substr(tc_start));
|
||||
process_tool_calls(tool_calls);
|
||||
for (const auto & tool_call : tool_calls) {
|
||||
result.tool_calls.emplace_back(process_tool_call(tool_call));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -700,7 +732,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
|
||||
data.grammar_lazy = false;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
builder.add_schema("root", schema);
|
||||
}, grammar_options);
|
||||
});
|
||||
|
||||
auto tweaked_messages = common_chat_template::add_system(
|
||||
inputs.messages,
|
||||
@@ -770,8 +802,11 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
|
||||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
|
||||
}, grammar_options);
|
||||
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
|
||||
});
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"});
|
||||
data.preserved_tokens = {
|
||||
"[TOOL_CALLS]",
|
||||
};
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
|
||||
return data;
|
||||
@@ -813,14 +848,18 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
|
||||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\"");
|
||||
}, grammar_options);
|
||||
data.grammar_triggers.push_back({"<|START_ACTION|>", /* .at_start = */ false});
|
||||
});
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
||||
"<|START_ACTION|>",
|
||||
});
|
||||
data.preserved_tokens = {
|
||||
"<|START_ACTION|>",
|
||||
"<|END_ACTION|>",
|
||||
"<|START_RESPONSE|>",
|
||||
"<|END_RESPONSE|>",
|
||||
"<|START_THINKING|>",
|
||||
"<|END_THINKING|>",
|
||||
"<|END_ACTION|>",
|
||||
};
|
||||
auto adjusted_messages = json::array();
|
||||
for (const auto & msg : inputs.messages) {
|
||||
@@ -840,9 +879,9 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) {
|
||||
static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|>)([\\s\\S\\n\\r]*)");
|
||||
static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
|
||||
static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>");
|
||||
static const std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S]*?)<\\|END_THINKING\\|>)([\\s\\S]*)");
|
||||
static const std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S]*?)<\\|END_ACTION\\|>");
|
||||
static const std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S]*?)<\\|END_RESPONSE\\|>");
|
||||
|
||||
std::smatch match;
|
||||
|
||||
@@ -945,23 +984,23 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"{\" space "
|
||||
"( \"\\\"type\\\":\" space \"\\\"function\\\",\" space )? "
|
||||
"\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
||||
builder.add_schema(name + "-args", parameters) +
|
||||
" \"}\""));
|
||||
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
|
||||
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
|
||||
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
|
||||
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
|
||||
"\"}\" space"));
|
||||
});
|
||||
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
||||
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
|
||||
});
|
||||
data.grammar_triggers.push_back({"{\"name\":", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"{\n \"name\":", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"{\n \"name\":", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"{\"type\": \"function\"", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"{\n \"type\": \"function\"", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"{\n \"type\": \"function\"", /* .at_start = */ true});
|
||||
if (!builtin_tools.empty()) {
|
||||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
||||
data.preserved_tokens.push_back("<|python_tag|>");
|
||||
}
|
||||
// Allow a few empty lines on top of the usual constrained json schema space rule.
|
||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||
}, grammar_options);
|
||||
});
|
||||
data.additional_stops.push_back("<|eom_id|>");
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
|
||||
{"tools_in_user_message", false},
|
||||
@@ -974,33 +1013,33 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
||||
}
|
||||
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
|
||||
// TODO: tighten & simplify the parser, don't accept leading text context.
|
||||
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
|
||||
static std::regex close_regex("\\}");
|
||||
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
|
||||
static const std::regex function_regex(
|
||||
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
|
||||
static const std::regex close_regex("\\}\\s*");
|
||||
static const std::regex builtin_call_regex("<\\|python_tag\\|>\\s*([^.(]+)\\s*\\.\\s*call\\s*\\(\\s*([\\w]+)\\s*=\\s*([\\s\\S]*?)\\)");
|
||||
|
||||
if (with_builtin_tools) {
|
||||
std::smatch match;
|
||||
if (std::regex_match(input, match, builtin_call_regex)) {
|
||||
auto name = match[1].str();
|
||||
auto raw_args = match[2].str();
|
||||
try {
|
||||
auto name = match[1].str();
|
||||
auto arg_name = match[2].str();
|
||||
auto arg_value_str = match[3].str();
|
||||
auto arg_value = json::parse(arg_value_str);
|
||||
|
||||
// TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
|
||||
auto it_eq = raw_args.find('=');
|
||||
auto arg_name = raw_args.substr(0, it_eq);
|
||||
auto arg_value_str = raw_args.substr(it_eq + 1);
|
||||
auto arg_value = json::parse(arg_value_str);
|
||||
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
msg.content = match.prefix().str();
|
||||
msg.tool_calls.push_back({
|
||||
/* .name = */ name,
|
||||
/* .arguments = */ (json {
|
||||
{arg_name, arg_value},
|
||||
}).dump(),
|
||||
/* .id = */ "",
|
||||
});
|
||||
return msg;
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
msg.tool_calls.push_back({
|
||||
/* .name = */ name,
|
||||
/* .arguments = */ (json {
|
||||
{arg_name, arg_value},
|
||||
}).dump(),
|
||||
/* .id = */ "",
|
||||
});
|
||||
return msg;
|
||||
} catch (const std::exception & e) {
|
||||
LOG_WRN("Failed to parse builtin tool call arguments (%s): %s", e.what(), input.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
||||
@@ -1017,10 +1056,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||
tool_rules.push_back(builder.add_rule(name + "-call",
|
||||
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
|
||||
"```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
|
||||
"```json\\n\" " + builder.add_schema(name + "-args", parameters) + " "
|
||||
"\"```<|tool▁call▁end|>\""));
|
||||
});
|
||||
// Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
|
||||
// so we accept common variants (then it's all constrained)
|
||||
@@ -1029,18 +1068,20 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
|
||||
"(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
|
||||
"\"<|tool▁calls▁end|>\""
|
||||
" space");
|
||||
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({"<|tool_calls_begin|>", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({"<|tool calls begin|>", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({"<|tool\\_calls\\_begin|>", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool▁calls▁begin|>"});
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls_begin|>"});
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool calls begin|>"});
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool\\_calls\\_begin|>"});
|
||||
data.preserved_tokens = {
|
||||
"<think>",
|
||||
"</think>",
|
||||
"<|tool▁calls▁begin|>",
|
||||
"<|tool▁call▁begin|>",
|
||||
"<|tool▁sep|>",
|
||||
"<|tool▁calls▁end|",
|
||||
"<|tool▁call▁end|>",
|
||||
"<|tool▁calls▁end|",
|
||||
};
|
||||
}, grammar_options);
|
||||
});
|
||||
}
|
||||
auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
|
||||
@@ -1065,34 +1106,42 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
|
||||
data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING : COMMON_CHAT_FORMAT_DEEPSEEK_R1;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) {
|
||||
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
||||
static std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
|
||||
static std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)");
|
||||
static std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>");
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
static common_chat_msg handle_think_tag_prelude(const std::string & input, bool extract_reasoning, const std::function<common_chat_msg(const std::string &)> & rest_parser) {
|
||||
std::smatch match;
|
||||
static const std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)");
|
||||
if (std::regex_match(input, match, reasoning_content_regex)) {
|
||||
std::string rest;
|
||||
auto rest = match[3].str();
|
||||
auto msg = rest_parser(rest);
|
||||
auto reasoning_content = string_strip(match[2].str());
|
||||
if (extract_reasoning) {
|
||||
msg.reasoning_content = string_strip(match[2].str());
|
||||
} else {
|
||||
msg.content = match[1].str();
|
||||
msg.reasoning_content = reasoning_content;
|
||||
} else if (!reasoning_content.empty()) {
|
||||
std::ostringstream content;
|
||||
content << "<think>" << reasoning_content << "</think>" << msg.content;
|
||||
msg.content = content.str();
|
||||
}
|
||||
rest = match[3].str();
|
||||
return msg;
|
||||
}
|
||||
return rest_parser(input);
|
||||
}
|
||||
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) {
|
||||
return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) {
|
||||
static const std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
||||
static const std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
|
||||
static const std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>");
|
||||
|
||||
if (std::regex_search(rest, match, tool_calls_regex)) {
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
std::smatch match;
|
||||
if (std::regex_search(input, match, tool_calls_regex)) {
|
||||
auto tool_calls = match[1].str();
|
||||
auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex);
|
||||
msg.tool_calls = std::move(msg2.tool_calls);
|
||||
} else {
|
||||
msg.content += std::string(rest.begin() + rest.find_first_not_of(" \r\n"), rest.end());
|
||||
msg.content = input;
|
||||
}
|
||||
} else {
|
||||
msg.content = input;
|
||||
}
|
||||
return msg;
|
||||
return msg;
|
||||
});
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
@@ -1129,8 +1178,11 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
|
||||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
|
||||
}, grammar_options);
|
||||
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
|
||||
});
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, " functools["});
|
||||
data.preserved_tokens = {
|
||||
" functools[",
|
||||
};
|
||||
data.format = COMMON_CHAT_FORMAT_FIREFUNCTION_V2;
|
||||
} else {
|
||||
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
@@ -1158,11 +1210,28 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
||||
first_tool_rules.push_back(builder.add_rule(name + "-call", "( \"assistant<|end_header_id|>\\n\" )? \"" + name + "\\n\" " + args_rule));
|
||||
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
|
||||
data.grammar_triggers.push_back({name, /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
||||
regex_escape(name + "\n"),
|
||||
});
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
||||
regex_escape("assistant<|end_header_id|>\n" + name + "\n"),
|
||||
});
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
||||
regex_escape(">>>" + name + "\n"),
|
||||
});
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
||||
">>>assistant<|end_header_id|>\n" + name,
|
||||
});
|
||||
});
|
||||
data.preserved_tokens = {
|
||||
"<|end_header_id|>",
|
||||
};
|
||||
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
||||
if (inputs.parallel_tool_calls) {
|
||||
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
|
||||
@@ -1171,34 +1240,20 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
||||
builder.add_rule("root", first_rule);
|
||||
}
|
||||
|
||||
}, grammar_options);
|
||||
});
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
static bool consume(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
|
||||
auto expected_it = expected.begin();
|
||||
auto tmp_it = it;
|
||||
while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
|
||||
++tmp_it;
|
||||
++expected_it;
|
||||
}
|
||||
if (expected_it == expected.end()) {
|
||||
it = tmp_it;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) {
|
||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||
static std::regex close_regex(R"($|(?=>>>))");
|
||||
static const std::regex function_regex(R"((?:>>>)?(?:assistant<|end_header_id|>\n)?(\w+)\n)");
|
||||
static const std::regex close_regex(R"($|(?=>>>))");
|
||||
|
||||
std::string content;
|
||||
auto it = input.begin();
|
||||
const auto end = input.end();
|
||||
|
||||
if (consume(it, end, "all\n")) {
|
||||
if (parse_literal(it, end, "all\n")) {
|
||||
std::smatch match;
|
||||
if (std::regex_search(it, end, match, function_regex)) {
|
||||
auto fun_it = match.prefix().second;
|
||||
@@ -1213,7 +1268,7 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
|
||||
}
|
||||
// TODO: tighten & simplify.
|
||||
try {
|
||||
auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex);
|
||||
auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex, /* allow_raw_python= */ true);
|
||||
res.content = content + res.content;
|
||||
return res;
|
||||
} catch (const std::exception & e) {
|
||||
@@ -1266,12 +1321,13 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
|
||||
});
|
||||
if (has_raw_python) {
|
||||
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
||||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
||||
data.preserved_tokens.push_back("<|python_tag|>");
|
||||
}
|
||||
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
||||
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
|
||||
}, grammar_options);
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
|
||||
});
|
||||
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
// TODO: if (has_raw_python)
|
||||
@@ -1280,7 +1336,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
|
||||
}
|
||||
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
|
||||
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||
static const std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||
std::smatch match;
|
||||
if (std::regex_search(input, match, python_tag_regex)) {
|
||||
auto code = match[1].str();
|
||||
@@ -1294,8 +1350,8 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
|
||||
});
|
||||
return msg;
|
||||
}
|
||||
static std::regex function_regex(R"(<function=(\w+)>)");
|
||||
static std::regex close_regex(R"(</function>)");
|
||||
static const std::regex function_regex(R"(<function=(\w+)>)");
|
||||
static const std::regex close_regex(R"(</function>)");
|
||||
// TODO: tighten & simplify.
|
||||
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
||||
}
|
||||
@@ -1306,6 +1362,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
std::vector<std::string> tool_call_alts;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
@@ -1319,68 +1376,187 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
|
||||
}},
|
||||
{"required", json::array({"name", "arguments"})},
|
||||
}));
|
||||
tool_call_alts.push_back(builder.add_rule(
|
||||
name + "-function-tag",
|
||||
"\"<function\" ( \"=" + name + "\" | \" name=\\\"" + name + "\\\"\" ) \">\" space " +
|
||||
builder.add_schema(name + "-args", parameters) + " "
|
||||
"\"</function>\" space"));
|
||||
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
||||
"<function=" + name + ">",
|
||||
});
|
||||
auto escaped_name = regex_escape(name);
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
||||
"<function\\s+name\\s*=\\s*\"" + escaped_name + "\"",
|
||||
});
|
||||
});
|
||||
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
|
||||
auto any_tool_call = builder.add_rule("any_tool_call", "( " + string_join(tool_rules, " | ") + " ) space");
|
||||
std::vector<std::string> alt_tags {
|
||||
any_tool_call,
|
||||
"\"<tool_call>\" space " + any_tool_call + " \"</tool_call>\"",
|
||||
// The rest is just to accommodate common "good bad" outputs.
|
||||
"\"<function_call>\" space " + any_tool_call + " \"</function_call>\"",
|
||||
"\"<response>\" space " + any_tool_call + " \"</response>\"",
|
||||
"\"<tools>\" space " + any_tool_call + " \"</tools>\"",
|
||||
"\"<json>\" space " + any_tool_call + " \"</json>\"",
|
||||
"\"<xml>\" space " + any_tool_call + " \"</xml>\"",
|
||||
"\"<JSON>\" space " + any_tool_call + " \"</JSON>\"",
|
||||
};
|
||||
auto wrappable_tool_call = builder.add_rule("wrappable_tool_call", "( " + string_join(alt_tags, " | ") + " ) space");
|
||||
tool_call_alts.push_back(wrappable_tool_call);
|
||||
tool_call_alts.push_back(
|
||||
"( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space ");
|
||||
auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | "));
|
||||
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
|
||||
data.preserved_tokens = { "</tool_call>" };
|
||||
}, grammar_options);
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<tool_call>"});
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function"});
|
||||
// Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
||||
"(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?\\s*\\{\\s*\"", //name\"\\s*:\\s*\"" + escaped_name + "\"",
|
||||
});
|
||||
data.preserved_tokens = {
|
||||
"<think>",
|
||||
"</think>",
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
"<function",
|
||||
"<tools>",
|
||||
"</tools>",
|
||||
"<response>",
|
||||
"</response>",
|
||||
"<function_call>",
|
||||
"</function_call>",
|
||||
"<json>",
|
||||
"</json>",
|
||||
"<JSON>",
|
||||
"</JSON>",
|
||||
"```",
|
||||
"```json",
|
||||
"```xml",
|
||||
};
|
||||
});
|
||||
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
|
||||
data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING : COMMON_CHAT_FORMAT_HERMES_2_PRO;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) {
|
||||
try {
|
||||
std::regex start_pattern(R"([\n\s]*<tool_call>)");
|
||||
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
|
||||
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
|
||||
static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input, bool extract_reasoning) {
|
||||
return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) {
|
||||
static const std::regex open_regex(
|
||||
"(?:"
|
||||
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
|
||||
"(<tool_call>" // match 2 (open_tag)
|
||||
"|<function_call>"
|
||||
"|<tool>"
|
||||
"|<tools>"
|
||||
"|<response>"
|
||||
"|<json>"
|
||||
"|<xml>"
|
||||
"|<JSON>"
|
||||
")?"
|
||||
"(\\s*\\{\\s*\"name\"\\s*:[\\s\\S]*)" // match 3 (named tool call + rest)
|
||||
")"
|
||||
"|"
|
||||
"(?:<function=([^>]+)>" // match 4 (function name)
|
||||
"|<function name=\"([^\"]+)\">)" // match 5 (function name again)
|
||||
"([\\s\\S]*)" // match 6 (function arguments + rest)})"
|
||||
);
|
||||
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
try {
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
|
||||
auto end = input.end();
|
||||
std::sregex_iterator rend;
|
||||
std::sregex_iterator rit(input.begin(), end, start_pattern);
|
||||
if (rit == rend) {
|
||||
std::string::const_iterator it = input.begin();
|
||||
const std::string::const_iterator end = input.end();
|
||||
std::smatch match;
|
||||
|
||||
while (it != end) {
|
||||
if (std::regex_search(it, end, match, open_regex)) {
|
||||
// Add content before the match
|
||||
msg.content += std::string(it, match[0].first);
|
||||
|
||||
auto block_start = match[1].str();
|
||||
std::string block_end = block_start.empty() ? "" : "```";
|
||||
|
||||
auto open_tag = match[2].str();
|
||||
std::string close_tag;
|
||||
|
||||
if (match[3].matched) {
|
||||
close_tag = open_tag.empty() ? "" : "</" + open_tag.substr(1);
|
||||
auto json_it = match[3].first;
|
||||
json tool_call;
|
||||
if (parse_json(json_it, end, tool_call) && tool_call.contains("name") && tool_call.contains("arguments")) {
|
||||
|
||||
msg.tool_calls.emplace_back(process_tool_call(tool_call));
|
||||
it = json_it; // Move iterator past parsed JSON
|
||||
|
||||
// Handle close tags
|
||||
consume_spaces(it, end);
|
||||
if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
|
||||
throw std::runtime_error("Failed to parse closing tag");
|
||||
}
|
||||
consume_spaces(it, end);
|
||||
if (!block_end.empty() && !parse_literal(it, end, block_end)) {
|
||||
throw std::runtime_error("Failed to parse block end");
|
||||
}
|
||||
consume_spaces(it, end);
|
||||
} else {
|
||||
// Not a valid tool call, treat as content
|
||||
msg.content += std::string(match[0].first, match[0].second);
|
||||
it = match[0].second;
|
||||
}
|
||||
} else {
|
||||
auto function_name = match[4].str();
|
||||
if (function_name.empty()) {
|
||||
function_name = match[5].str();
|
||||
}
|
||||
GGML_ASSERT(!function_name.empty());
|
||||
|
||||
close_tag = "</function>";
|
||||
// Start parsing from after the opening tags
|
||||
auto json_it = match[6].first;
|
||||
json arguments;
|
||||
if (parse_json(json_it, end, arguments)) {
|
||||
msg.tool_calls.emplace_back(process_tool_call({
|
||||
{"name", function_name},
|
||||
{"arguments", arguments},
|
||||
}));
|
||||
it = json_it; // Move iterator past parsed JSON
|
||||
|
||||
// Handle close tags
|
||||
consume_spaces(it, end);
|
||||
if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
|
||||
throw std::runtime_error("Failed to parse closing tag");
|
||||
}
|
||||
consume_spaces(it, end);
|
||||
if (!block_end.empty() && !parse_literal(it, end, block_end)) {
|
||||
throw std::runtime_error("Failed to parse block end");
|
||||
}
|
||||
consume_spaces(it, end);
|
||||
} else {
|
||||
// Not a valid tool call, treat as content
|
||||
msg.content += std::string(match[0].first, match[0].second);
|
||||
it = match[0].second;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Add remaining content
|
||||
msg.content += std::string(it, end);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return msg;
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
msg.content = input;
|
||||
return msg;
|
||||
}
|
||||
|
||||
msg.content = rit->prefix();
|
||||
|
||||
auto it = rit->suffix().first;
|
||||
while (it != end) {
|
||||
json call;
|
||||
if (!parse_json(it, end, call)) {
|
||||
throw std::runtime_error("Failed to parse json tool call");
|
||||
}
|
||||
const auto & arguments = call.at("arguments");
|
||||
msg.tool_calls.push_back({
|
||||
call.at("name"),
|
||||
arguments.dump(),
|
||||
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||
/* id= */ "",
|
||||
});
|
||||
rit = {it, end, middle_pattern};
|
||||
if (rit != rend) {
|
||||
it = rit->suffix().first;
|
||||
} else {
|
||||
rit = {it, end, end_pattern};
|
||||
if (rit == rend) {
|
||||
throw std::runtime_error("Malformed input, missing </tool_call>");
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
return msg;
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
msg.content = input;
|
||||
return msg;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
@@ -1445,6 +1621,11 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
return common_chat_params_init_command_r7b(tmpl, params);
|
||||
}
|
||||
|
||||
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
|
||||
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
|
||||
return common_chat_params_init_hermes_2_pro(tmpl, params);
|
||||
}
|
||||
|
||||
// Use generic handler when mixing tools + JSON schema.
|
||||
// TODO: support that mix in handlers below.
|
||||
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
||||
@@ -1466,11 +1647,6 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
return common_chat_params_init_without_tools(tmpl, params);
|
||||
}
|
||||
|
||||
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
|
||||
if (src.find("<tool_call>") != std::string::npos) {
|
||||
return common_chat_params_init_hermes_2_pro(tmpl, params);
|
||||
}
|
||||
|
||||
// Functionary v3.1 (w/ tools)
|
||||
if (src.find("<|start_header_id|>") != std::string::npos
|
||||
&& src.find("<function=") != std::string::npos) {
|
||||
@@ -1588,7 +1764,9 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
|
||||
return common_chat_parse_functionary_v3_1_llama_3_1(input);
|
||||
case COMMON_CHAT_FORMAT_HERMES_2_PRO:
|
||||
return common_chat_parse_hermes_2_pro(input);
|
||||
return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ false);
|
||||
case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING:
|
||||
return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ true);
|
||||
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
|
||||
return common_chat_parse_firefunction_v2(input);
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B:
|
||||
|
||||
@@ -53,6 +53,7 @@ enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING,
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
|
||||
|
||||
|
||||
+27
-1
@@ -10,7 +10,6 @@
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
@@ -483,6 +482,11 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
||||
s = std::move(builder);
|
||||
}
|
||||
|
||||
std::string regex_escape(const std::string & s) {
|
||||
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
||||
return std::regex_replace(s, special_chars, "\\$0");
|
||||
}
|
||||
|
||||
std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
|
||||
std::ostringstream result;
|
||||
for (size_t i = 0; i < values.size(); ++i) {
|
||||
@@ -2026,3 +2030,25 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
json common_grammar_trigger::to_json() const {
|
||||
json out {
|
||||
{"type", (int) type},
|
||||
{"value", value},
|
||||
};
|
||||
if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
|
||||
out["token"] = (int) token;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <>
|
||||
common_grammar_trigger common_grammar_trigger::from_json(const json & in) {
|
||||
common_grammar_trigger out;
|
||||
out.type = (common_grammar_trigger_type) in.at("type").get<int>();
|
||||
out.value = in.at("value").get<std::string>();
|
||||
if (out.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
|
||||
out.token = (llama_token) in.at("token").get<int>();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
+20
-9
@@ -110,9 +110,21 @@ enum common_conversation_mode {
|
||||
COMMON_CONVERSATION_MODE_AUTO = 2,
|
||||
};
|
||||
|
||||
enum common_grammar_trigger_type {
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
||||
};
|
||||
|
||||
struct common_grammar_trigger {
|
||||
std::string word;
|
||||
bool at_start;
|
||||
common_grammar_trigger_type type;
|
||||
std::string value;
|
||||
llama_token token = LLAMA_TOKEN_NULL;
|
||||
|
||||
// T can only be nlohmann::ordered_json
|
||||
template <class T> T to_json() const;
|
||||
template <class T> static common_grammar_trigger from_json(const T & in);
|
||||
};
|
||||
|
||||
// sampling parameters
|
||||
@@ -163,8 +175,7 @@ struct common_params_sampling {
|
||||
|
||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||
bool grammar_lazy = false;
|
||||
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to trigger lazy grammar
|
||||
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens.
|
||||
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
|
||||
std::set<llama_token> preserved_tokens;
|
||||
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
@@ -396,8 +407,6 @@ struct common_params {
|
||||
int32_t i_pos = -1; // position of the passkey in the junk text
|
||||
|
||||
// imatrix params
|
||||
std::string out_file = "imatrix.dat"; // save the resulting imatrix to this file
|
||||
|
||||
int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations
|
||||
int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations
|
||||
int32_t i_chunk = 0; // start processing from this chunk
|
||||
@@ -409,16 +418,16 @@ struct common_params {
|
||||
int n_pca_batch = 100;
|
||||
int n_pca_iterations = 1000;
|
||||
dimre_method cvector_dimre_method = DIMRE_METHOD_PCA;
|
||||
std::string cvector_outfile = "control_vector.gguf";
|
||||
std::string cvector_positive_file = "examples/cvector-generator/positive.txt";
|
||||
std::string cvector_negative_file = "examples/cvector-generator/negative.txt";
|
||||
|
||||
bool spm_infill = false; // suffix/prefix/middle pattern for infill
|
||||
|
||||
std::string lora_outfile = "ggml-lora-merged-f16.gguf";
|
||||
|
||||
// batched-bench params
|
||||
bool batched_bench_output_jsonl = false;
|
||||
|
||||
// common params
|
||||
std::string out_file; // output filename for all example programs
|
||||
};
|
||||
|
||||
// call once at the start of a program if it uses libcommon
|
||||
@@ -458,6 +467,8 @@ std::string string_repeat(const std::string & str, size_t n);
|
||||
|
||||
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
|
||||
|
||||
std::string regex_escape(const std::string & s);
|
||||
|
||||
template<class T>
|
||||
static std::vector<T> string_split(const std::string & str, char delim) {
|
||||
static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");
|
||||
|
||||
@@ -264,7 +264,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
||||
throw std::runtime_error("At least one of min_value or max_value must be set");
|
||||
}
|
||||
|
||||
const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}";
|
||||
const std::string SPACE_RULE = "| \" \" | \"\\n\"{1,2} [ \\t]{0,20}";
|
||||
|
||||
struct BuiltinRule {
|
||||
std::string content;
|
||||
@@ -764,11 +764,10 @@ private:
|
||||
public:
|
||||
SchemaConverter(
|
||||
const std::function<json(const std::string &)> & fetch_json,
|
||||
bool dotall,
|
||||
bool compact_spaces)
|
||||
bool dotall)
|
||||
: _fetch_json(fetch_json), _dotall(dotall)
|
||||
{
|
||||
_rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE;
|
||||
_rules["space"] = SPACE_RULE;
|
||||
}
|
||||
|
||||
void resolve_refs(json & schema, const std::string & url) {
|
||||
@@ -1007,7 +1006,7 @@ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
|
||||
}
|
||||
|
||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
|
||||
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces);
|
||||
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
|
||||
common_grammar_builder builder {
|
||||
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
|
||||
return converter._add_rule(name, rule);
|
||||
|
||||
@@ -16,7 +16,6 @@ struct common_grammar_builder {
|
||||
|
||||
struct common_grammar_options {
|
||||
bool dotall = false;
|
||||
bool compact_spaces = false;
|
||||
};
|
||||
|
||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
|
||||
|
||||
+37
-5
@@ -1378,13 +1378,27 @@ struct ArgumentsExpression {
|
||||
}
|
||||
};
|
||||
|
||||
static std::string strip(const std::string & s) {
|
||||
auto start = s.find_first_not_of(" \t\n\r");
|
||||
static std::string strip(const std::string & s, const std::string & chars = "", bool left = true, bool right = true) {
|
||||
auto charset = chars.empty() ? " \t\n\r" : chars;
|
||||
auto start = left ? s.find_first_not_of(charset) : 0;
|
||||
if (start == std::string::npos) return "";
|
||||
auto end = s.find_last_not_of(" \t\n\r");
|
||||
auto end = right ? s.find_last_not_of(charset) : s.size() - 1;
|
||||
return s.substr(start, end - start + 1);
|
||||
}
|
||||
|
||||
static std::vector<std::string> split(const std::string & s, const std::string & sep) {
|
||||
std::vector<std::string> result;
|
||||
size_t start = 0;
|
||||
size_t end = s.find(sep);
|
||||
while (end != std::string::npos) {
|
||||
result.push_back(s.substr(start, end - start));
|
||||
start = end + sep.length();
|
||||
end = s.find(sep, start);
|
||||
}
|
||||
result.push_back(s.substr(start));
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::string capitalize(const std::string & s) {
|
||||
if (s.empty()) return s;
|
||||
auto result = s;
|
||||
@@ -1467,8 +1481,26 @@ public:
|
||||
} else if (obj.is_string()) {
|
||||
auto str = obj.get<std::string>();
|
||||
if (method->get_name() == "strip") {
|
||||
vargs.expectArgs("strip method", {0, 0}, {0, 0});
|
||||
return Value(strip(str));
|
||||
vargs.expectArgs("strip method", {0, 1}, {0, 0});
|
||||
auto chars = vargs.args.empty() ? "" : vargs.args[0].get<std::string>();
|
||||
return Value(strip(str, chars));
|
||||
} else if (method->get_name() == "lstrip") {
|
||||
vargs.expectArgs("lstrip method", {0, 1}, {0, 0});
|
||||
auto chars = vargs.args.empty() ? "" : vargs.args[0].get<std::string>();
|
||||
return Value(strip(str, chars, /* left= */ true, /* right= */ false));
|
||||
} else if (method->get_name() == "rstrip") {
|
||||
vargs.expectArgs("rstrip method", {0, 1}, {0, 0});
|
||||
auto chars = vargs.args.empty() ? "" : vargs.args[0].get<std::string>();
|
||||
return Value(strip(str, chars, /* left= */ false, /* right= */ true));
|
||||
} else if (method->get_name() == "split") {
|
||||
vargs.expectArgs("split method", {1, 1}, {0, 0});
|
||||
auto sep = vargs.args[0].get<std::string>();
|
||||
auto parts = split(str, sep);
|
||||
Value result = Value::array();
|
||||
for (const auto& part : parts) {
|
||||
result.push_back(Value(part));
|
||||
}
|
||||
return result;
|
||||
} else if (method->get_name() == "capitalize") {
|
||||
vargs.expectArgs("capitalize method", {0, 0}, {0, 0});
|
||||
return Value(capitalize(str));
|
||||
|
||||
+44
-7
@@ -160,16 +160,53 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
} else {
|
||||
std::vector<const char *> trigger_words;
|
||||
trigger_words.reserve(params.grammar_trigger_words.size());
|
||||
for (const auto & str : params.grammar_trigger_words) {
|
||||
trigger_words.push_back(str.word.c_str());
|
||||
std::vector<std::string> patterns_at_start;
|
||||
std::vector<std::string> patterns_anywhere;
|
||||
std::vector<llama_token> trigger_tokens;
|
||||
for (const auto & trigger : params.grammar_triggers) {
|
||||
switch (trigger.type) {
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
||||
{
|
||||
const auto & word = trigger.value;
|
||||
patterns_anywhere.push_back(regex_escape(word));
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
|
||||
{
|
||||
const auto & pattern = trigger.value;
|
||||
(trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
||||
{
|
||||
const auto token = trigger.token;
|
||||
trigger_tokens.push_back(token);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown trigger type");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> trigger_patterns;
|
||||
if (!patterns_at_start.empty()) {
|
||||
trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
|
||||
}
|
||||
if (!patterns_anywhere.empty()) {
|
||||
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
|
||||
}
|
||||
|
||||
std::vector<const char *> trigger_patterns_c;
|
||||
trigger_patterns_c.reserve(trigger_patterns.size());
|
||||
for (const auto & regex : trigger_patterns) {
|
||||
trigger_patterns_c.push_back(regex.c_str());
|
||||
}
|
||||
|
||||
grmr = params.grammar_lazy
|
||||
? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root",
|
||||
trigger_words.data(), trigger_words.size(),
|
||||
params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size())
|
||||
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
||||
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||
trigger_tokens.data(), trigger_tokens.size())
|
||||
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||
}
|
||||
|
||||
|
||||
+42
-12
@@ -197,29 +197,53 @@ The following compilation options are also available to tweak performance:
|
||||
|
||||
## MUSA
|
||||
|
||||
This provides GPU acceleration using the MUSA cores of your Moore Threads MTT GPU. Make sure to have the MUSA SDK installed. You can download it from here: [MUSA SDK](https://developer.mthreads.com/sdk/download/musa).
|
||||
This provides GPU acceleration using a Moore Threads GPU. Make sure to have the [MUSA SDK](https://developer.mthreads.com/musa/musa-sdk) installed.
|
||||
|
||||
- Using `CMake`:
|
||||
#### Download directly from Moore Threads
|
||||
|
||||
```bash
|
||||
cmake -B build -DGGML_MUSA=ON
|
||||
cmake --build build --config Release
|
||||
You may find the official downloads here: [Moore Threads developer site](https://developer.mthreads.com/sdk/download/musa).
|
||||
|
||||
### Compilation
|
||||
|
||||
```bash
|
||||
cmake -B build -DGGML_MUSA=ON
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
#### Override Compute Capability Specifications
|
||||
|
||||
By default, all supported compute capabilities are enabled. To customize this behavior, you can specify the `MUSA_ARCHITECTURES` option in the CMake command:
|
||||
|
||||
```bash
|
||||
cmake -B build -DGGML_MUSA=ON -DMUSA_ARCHITECTURES="21"
|
||||
```
|
||||
|
||||
This configuration enables only compute capability `2.1` (MTT S80) during compilation, which can help reduce compilation time.
|
||||
|
||||
#### Compilation options
|
||||
|
||||
Most of the compilation options available for CUDA should also be available for MUSA, though they haven't been thoroughly tested yet.
|
||||
|
||||
- For static builds, add `-DBUILD_SHARED_LIBS=OFF` and `-DCMAKE_POSITION_INDEPENDENT_CODE=ON`:
|
||||
```
|
||||
|
||||
For static build:
|
||||
|
||||
```bash
|
||||
cmake -B build -DGGML_MUSA=ON \
|
||||
-DBUILD_SHARED_LIBS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
The environment variable [`MUSA_VISIBLE_DEVICES`](https://docs.mthreads.com/musa-sdk/musa-sdk-doc-online/programming_guide/Z%E9%99%84%E5%BD%95/) can be used to specify which GPU(s) will be used.
|
||||
### Runtime MUSA environmental variables
|
||||
|
||||
You may set the [musa environmental variables](https://docs.mthreads.com/musa-sdk/musa-sdk-doc-online/programming_guide/Z%E9%99%84%E5%BD%95/) at runtime.
|
||||
|
||||
```bash
|
||||
# Use `MUSA_VISIBLE_DEVICES` to hide the first compute device.
|
||||
MUSA_VISIBLE_DEVICES="-0" ./build/bin/llama-server --model /srv/models/llama.gguf
|
||||
```
|
||||
|
||||
### Unified Memory
|
||||
|
||||
The environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` can be used to enable unified memory in Linux. This allows swapping to system RAM instead of crashing when the GPU VRAM is exhausted.
|
||||
|
||||
Most of the compilation options available for CUDA should also be available for MUSA, though they haven't been thoroughly tested yet.
|
||||
|
||||
## HIP
|
||||
|
||||
This provides GPU acceleration on HIP-supported AMD GPUs.
|
||||
@@ -235,6 +259,12 @@ You can download it from your Linux distro's package manager or from here: [ROCm
|
||||
On Linux it is also possible to use unified memory architecture (UMA) to share main memory between the CPU and integrated GPU by setting `-DGGML_HIP_UMA=ON`.
|
||||
However, this hurts performance for non-integrated GPUs (but enables working with integrated GPUs).
|
||||
|
||||
To enhance flash attention performance on RDNA3+ or CDNA architectures, you can utilize the rocWMMA library by enabling the `-DGGML_HIP_ROCWMMA_FATTN=ON` option. This requires rocWMMA headers to be installed on the build system.
|
||||
|
||||
The rocWMMA library is included by default when installing the ROCm SDK using the `rocm` meta package provided by AMD. Alternatively, if you are not using the meta package, you can install the library using the `rocwmma-dev` or `rocwmma-devel` package, depending on your system's package manager.
|
||||
|
||||
As an alternative, you can manually install the library by cloning it from the official [GitHub repository](https://github.com/ROCm/rocWMMA), checkout the corresponding version tag (e.g. `rocm-6.2.4`) and set `-DCMAKE_CXX_FLAGS="-I<path/to/rocwmma>/library/include/"` in CMake. This also works under Windows despite not officially supported by AMD.
|
||||
|
||||
Note that if you get the following error:
|
||||
```
|
||||
clang: error: cannot find ROCm device library; provide its path via '--rocm-path' or '--rocm-device-lib-path', or pass '-nogpulib' to build without ROCm device library
|
||||
|
||||
@@ -287,30 +287,32 @@ Here are some models known to work (w/ chat template override when needed):
|
||||
|
||||
llama-server --jinja -fa -hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M
|
||||
llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q6_K_L
|
||||
llama-server --jinja -fa -hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M
|
||||
llama-server --jinja -fa -hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M
|
||||
|
||||
# Native support for DeepSeek R1 works best w/ our own template (official template buggy)
|
||||
# Native support for DeepSeek R1 works best w/ our template override (official template is buggy, although we do work around it)
|
||||
|
||||
llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q6_K_L \
|
||||
--chat-template-file models/templates/llama-cpp-deepseek-r1.jinja
|
||||
--chat-template-file models/templates/llama-cpp-deepseek-r1.jinja
|
||||
|
||||
llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-32B-GGUF:Q4_K_M \
|
||||
--chat-template-file models/templates/llama-cpp-deepseek-r1.jinja
|
||||
--chat-template-file models/templates/llama-cpp-deepseek-r1.jinja
|
||||
|
||||
# Native support requires the right template for these GGUFs:
|
||||
|
||||
llama-server --jinja -fa -hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M
|
||||
--chat-template-file models/templates/meetkai-functionary-medium-v3.2.jinja
|
||||
|
||||
llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \
|
||||
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use )
|
||||
--chat-template-file models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja
|
||||
|
||||
llama-server --jinja -fa -hf bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M \
|
||||
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use )
|
||||
--chat-template-file models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja
|
||||
|
||||
llama-server --jinja -fa -hf bartowski/firefunction-v2-GGUF -hff firefunction-v2-IQ1_M.gguf \
|
||||
--chat-template-file <( python scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 tool_use )
|
||||
--chat-template-file models/templates/fireworks-ai-llama-3-firefunction-v2.jinja
|
||||
|
||||
llama-server --jinja -fa -hf bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L \
|
||||
--chat-template-file <( python scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use )
|
||||
--chat-template-file models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja
|
||||
|
||||
# Generic format support
|
||||
llama-server --jinja -fa -hf bartowski/phi-4-GGUF:Q4_0
|
||||
@@ -318,6 +320,8 @@ llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q8_0
|
||||
llama-server --jinja -fa -hf bartowski/c4ai-command-r-v01-GGUF:Q2_K
|
||||
```
|
||||
|
||||
To get the official template from original HuggingFace repos, you can use [scripts/get_chat_template.py](../scripts/get_chat_template.py) (see examples invocations in [models/templates/README.md](../models/templates/README.md))
|
||||
|
||||
> [!TIP]
|
||||
> If there is no official `tool_use` Jinja template, you may want to set `--chat-template chatml` to use a default that works with many models (YMMV!), or write your own (e.g. we provide a custom [llama-cpp-deepseek-r1.jinja](../models/templates/llama-cpp-deepseek-r1.jinja) for DeepSeek R1 distills)
|
||||
|
||||
|
||||
@@ -394,6 +394,8 @@ static int prepare_entries(common_params & params, train_context & ctx_train) {
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
params.out_file = "control_vector.gguf";
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_CVECTOR_GENERATOR, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
@@ -498,7 +500,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// write output vectors to gguf
|
||||
export_gguf(ctx_train.v_final, params.cvector_outfile, model_hint);
|
||||
export_gguf(ctx_train.v_final, params.out_file, model_hint);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
|
||||
@@ -413,20 +413,22 @@ static void print_usage(int, char ** argv) {
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
params.out_file = "ggml-lora-merged-f16.gguf";
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_EXPORT_LORA, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
g_verbose = (params.verbosity > 1);
|
||||
try {
|
||||
lora_merge_ctx ctx(params.model, params.lora_adapters, params.lora_outfile, params.cpuparams.n_threads);
|
||||
lora_merge_ctx ctx(params.model, params.lora_adapters, params.out_file, params.cpuparams.n_threads);
|
||||
ctx.run_merge();
|
||||
} catch (const std::exception & err) {
|
||||
fprintf(stderr, "%s\n", err.what());
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
printf("done, output file is %s\n", params.lora_outfile.c_str());
|
||||
printf("done, output file is %s\n", params.out_file.c_str());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -206,9 +206,6 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
|
||||
|
||||
void IMatrixCollector::save_imatrix(int ncall) const {
|
||||
auto fname = m_params.out_file;
|
||||
if (fname.empty()) {
|
||||
fname = "imatrix.dat";
|
||||
}
|
||||
|
||||
if (ncall > 0) {
|
||||
fname += ".at_";
|
||||
@@ -583,6 +580,8 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
params.out_file = "imatrix.dat" ;
|
||||
|
||||
params.n_ctx = 512;
|
||||
params.logits_all = true;
|
||||
params.escape = false;
|
||||
|
||||
@@ -195,7 +195,7 @@ class BuiltinRule:
|
||||
self.deps = deps or []
|
||||
|
||||
# Constraining spaces to prevent model "running away".
|
||||
SPACE_RULE = '| " " | "\\n" [ \\t]{0,20}'
|
||||
SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}'
|
||||
|
||||
PRIMITIVE_RULES = {
|
||||
'boolean' : BuiltinRule('("true" | "false") space', []),
|
||||
|
||||
@@ -361,7 +361,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
|
||||
const auto tokens_list = common_tokenize(context, text, true, parse_special);
|
||||
|
||||
auto n_ctx = llama_n_ctx(context);
|
||||
auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
|
||||
auto n_kv_req = tokens_list.size() + n_len;
|
||||
|
||||
LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req);
|
||||
|
||||
|
||||
@@ -5,13 +5,25 @@ Currently, this readme only supports minicpm-omni's image capabilities, and we w
|
||||
|
||||
Download [MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6) PyTorch model from huggingface to "MiniCPM-o-2_6" folder.
|
||||
|
||||
|
||||
### Build llama.cpp
|
||||
Readme modification time: 20250206
|
||||
|
||||
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md)
|
||||
|
||||
Clone llama.cpp:
|
||||
```bash
|
||||
git clone git@github.com:OpenBMB/llama.cpp.git
|
||||
git clone https://github.com/ggerganov/llama.cpp
|
||||
cd llama.cpp
|
||||
git checkout minicpm-omni
|
||||
```
|
||||
|
||||
Build llama.cpp using `CMake`:
|
||||
```bash
|
||||
cmake -B build
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
|
||||
### Usage of MiniCPM-o 2.6
|
||||
|
||||
Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-o-2_6-gguf) by us)
|
||||
@@ -22,25 +34,15 @@ python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-
|
||||
python ./convert_hf_to_gguf.py ../MiniCPM-o-2_6/model
|
||||
|
||||
# quantize int4 version
|
||||
./llama-quantize ../MiniCPM-o-2_6/model/ggml-model-f16.gguf ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M
|
||||
./build/bin/llama-quantize ../MiniCPM-o-2_6/model/ggml-model-f16.gguf ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M
|
||||
```
|
||||
|
||||
Build llama.cpp using `CMake`:
|
||||
https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md
|
||||
|
||||
```bash
|
||||
cmake -B build
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
Inference on Linux or Mac
|
||||
```
|
||||
```bash
|
||||
# run f16 version
|
||||
./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
./build/bin/llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
|
||||
# run quantized int4 version
|
||||
./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
|
||||
# or run in interactive mode
|
||||
./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i
|
||||
./build/bin/llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
```
|
||||
|
||||
@@ -4,13 +4,26 @@
|
||||
|
||||
Download [MiniCPM-Llama3-V-2_5](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5) PyTorch model from huggingface to "MiniCPM-Llama3-V-2_5" folder.
|
||||
|
||||
|
||||
### Build llama.cpp
|
||||
Readme modification time: 20250206
|
||||
|
||||
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md)
|
||||
|
||||
Clone llama.cpp:
|
||||
```bash
|
||||
git clone https://github.com/ggml-org/llama.cpp
|
||||
cd llama.cpp
|
||||
```
|
||||
|
||||
### Usage
|
||||
Build llama.cpp using `CMake`:
|
||||
```bash
|
||||
cmake -B build
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
|
||||
### Usage of MiniCPM-Llama3-V 2.5
|
||||
|
||||
Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf) by us)
|
||||
|
||||
@@ -20,80 +33,15 @@ python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-
|
||||
python ./convert_hf_to_gguf.py ../MiniCPM-Llama3-V-2_5/model
|
||||
|
||||
# quantize int4 version
|
||||
./llama-quantize ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf Q4_K_M
|
||||
./build/bin/llama-quantize ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf Q4_K_M
|
||||
```
|
||||
|
||||
Build for Linux or Mac
|
||||
|
||||
```bash
|
||||
make
|
||||
make llama-minicpmv-cli
|
||||
```
|
||||
|
||||
Inference on Linux or Mac
|
||||
```
|
||||
```bash
|
||||
# run f16 version
|
||||
./llama-minicpmv-cli -m ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
./build/bin/llama-minicpmv-cli -m ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
|
||||
# run quantized int4 version
|
||||
./llama-minicpmv-cli -m ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
|
||||
# or run in interactive mode
|
||||
./llama-minicpmv-cli -m ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i
|
||||
```
|
||||
|
||||
### Android
|
||||
|
||||
#### Build on Android device using Termux
|
||||
We found that build on Android device would bring better runtime performance, so we recommend to build on device.
|
||||
|
||||
[Termux](https://github.com/termux/termux-app#installation) is a terminal app on Android device (no root required).
|
||||
|
||||
Install tools in Termux:
|
||||
```
|
||||
apt update && apt upgrade -y
|
||||
apt install git make cmake
|
||||
```
|
||||
|
||||
It's recommended to move your model inside the `~/` directory for best performance:
|
||||
```
|
||||
cd storage/downloads
|
||||
mv model.gguf ~/
|
||||
```
|
||||
|
||||
#### Building the Project using Android NDK
|
||||
Obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake.
|
||||
|
||||
Execute the following commands on your computer to avoid downloading the NDK to your mobile. Alternatively, you can also do this in Termux:
|
||||
|
||||
```bash
|
||||
mkdir build-android
|
||||
cd build-android
|
||||
export NDK=/your_ndk_path
|
||||
cmake -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_C_FLAGS=-march=armv8.4a+dotprod ..
|
||||
make
|
||||
```
|
||||
|
||||
Install [termux](https://github.com/termux/termux-app#installation) on your device and run `termux-setup-storage` to get access to your SD card (if Android 11+ then run the command twice).
|
||||
|
||||
Finally, copy these built `llama` binaries and the model file to your device storage. Because the file permissions in the Android sdcard cannot be changed, you can copy the executable files to the `/data/data/com.termux/files/home/bin` path, and then execute the following commands in Termux to add executable permission:
|
||||
|
||||
(Assumed that you have pushed the built executable files to the /sdcard/llama.cpp/bin path using `adb push`)
|
||||
```
|
||||
$cp -r /sdcard/llama.cpp/bin /data/data/com.termux/files/home/
|
||||
$cd /data/data/com.termux/files/home/bin
|
||||
$chmod +x ./*
|
||||
```
|
||||
|
||||
Download models and push them to `/sdcard/llama.cpp/`, then move it to `/data/data/com.termux/files/home/model/`
|
||||
|
||||
```
|
||||
$mv /sdcard/llama.cpp/ggml-model-Q4_K_M.gguf /data/data/com.termux/files/home/model/
|
||||
$mv /sdcard/llama.cpp/mmproj-model-f16.gguf /data/data/com.termux/files/home/model/
|
||||
```
|
||||
|
||||
Now, you can start chatting:
|
||||
```
|
||||
$cd /data/data/com.termux/files/home/bin
|
||||
$./llama-minicpmv-cli -m ../model/ggml-model-Q4_K_M.gguf --mmproj ../model/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
./build/bin/llama-minicpmv-cli -m ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
```
|
||||
|
||||
@@ -4,13 +4,25 @@
|
||||
|
||||
Download [MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6) PyTorch model from huggingface to "MiniCPM-V-2_6" folder.
|
||||
|
||||
|
||||
### Build llama.cpp
|
||||
Readme modification time: 20250206
|
||||
|
||||
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md)
|
||||
|
||||
Clone llama.cpp:
|
||||
```bash
|
||||
git clone git@github.com:OpenBMB/llama.cpp.git
|
||||
git clone https://github.com/ggerganov/llama.cpp
|
||||
cd llama.cpp
|
||||
git checkout minicpmv-main
|
||||
```
|
||||
|
||||
Build llama.cpp using `CMake`:
|
||||
```bash
|
||||
cmake -B build
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
|
||||
### Usage of MiniCPM-V 2.6
|
||||
|
||||
Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf) by us)
|
||||
@@ -21,87 +33,15 @@ python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-
|
||||
python ./convert_hf_to_gguf.py ../MiniCPM-V-2_6/model
|
||||
|
||||
# quantize int4 version
|
||||
./llama-quantize ../MiniCPM-V-2_6/model/ggml-model-f16.gguf ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M
|
||||
./build/bin/llama-quantize ../MiniCPM-V-2_6/model/ggml-model-f16.gguf ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M
|
||||
```
|
||||
|
||||
Build for Linux or Mac
|
||||
|
||||
```bash
|
||||
make
|
||||
make llama-minicpmv-cli
|
||||
```
|
||||
|
||||
Inference on Linux or Mac
|
||||
```
|
||||
```bash
|
||||
# run f16 version
|
||||
./llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
./build/bin/llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
|
||||
# run quantized int4 version
|
||||
./llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
|
||||
# or run in interactive mode
|
||||
./llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i
|
||||
```
|
||||
|
||||
### Video
|
||||
Install FFmpeg
|
||||
```
|
||||
brew install ffmpeg
|
||||
brew install pkg-config
|
||||
```
|
||||
|
||||
### Android
|
||||
|
||||
#### Build on Android device using Termux
|
||||
We found that build on Android device would bring better runtime performance, so we recommend to build on device.
|
||||
|
||||
[Termux](https://github.com/termux/termux-app#installation) is a terminal app on Android device (no root required).
|
||||
|
||||
Install tools in Termux:
|
||||
```
|
||||
apt update && apt upgrade -y
|
||||
apt install git make cmake
|
||||
```
|
||||
|
||||
It's recommended to move your model inside the `~/` directory for best performance:
|
||||
```
|
||||
cd storage/downloads
|
||||
mv model.gguf ~/
|
||||
```
|
||||
|
||||
#### Building the Project using Android NDK
|
||||
Obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake.
|
||||
|
||||
Execute the following commands on your computer to avoid downloading the NDK to your mobile. Alternatively, you can also do this in Termux:
|
||||
|
||||
```bash
|
||||
mkdir build-android
|
||||
cd build-android
|
||||
export NDK=/your_ndk_path
|
||||
cmake -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_C_FLAGS=-march=armv8.4a+dotprod ..
|
||||
make
|
||||
```
|
||||
|
||||
Install [termux](https://github.com/termux/termux-app#installation) on your device and run `termux-setup-storage` to get access to your SD card (if Android 11+ then run the command twice).
|
||||
|
||||
Finally, copy these built `llama` binaries and the model file to your device storage. Because the file permissions in the Android sdcard cannot be changed, you can copy the executable files to the `/data/data/com.termux/files/home/bin` path, and then execute the following commands in Termux to add executable permission:
|
||||
|
||||
(Assumed that you have pushed the built executable files to the /sdcard/llama.cpp/bin path using `adb push`)
|
||||
```
|
||||
$cp -r /sdcard/llama.cpp/bin /data/data/com.termux/files/home/
|
||||
$cd /data/data/com.termux/files/home/bin
|
||||
$chmod +x ./*
|
||||
```
|
||||
|
||||
Download models and push them to `/sdcard/llama.cpp/`, then move it to `/data/data/com.termux/files/home/model/`
|
||||
|
||||
```
|
||||
$mv /sdcard/llama.cpp/ggml-model-Q4_K_M.gguf /data/data/com.termux/files/home/model/
|
||||
$mv /sdcard/llama.cpp/mmproj-model-f16.gguf /data/data/com.termux/files/home/model/
|
||||
```
|
||||
|
||||
Now, you can start chatting:
|
||||
```
|
||||
$cd /data/data/com.termux/files/home/bin
|
||||
$./llama-minicpmv-cli -m ../model/ggml-model-Q4_K_M.gguf --mmproj ../model/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
./build/bin/llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
```
|
||||
|
||||
+76
-74
@@ -4,31 +4,12 @@
|
||||
// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
|
||||
#include "clip.h"
|
||||
#include "ggml.h"
|
||||
#include "ggml-cpp.h"
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-alloc.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "gguf.h"
|
||||
|
||||
//#ifdef GGML_USE_CUDA
|
||||
//#include "ggml-cuda.h"
|
||||
//#endif
|
||||
//
|
||||
//#ifdef GGML_USE_SYCL
|
||||
//#include "ggml-sycl.h"
|
||||
//#endif
|
||||
//
|
||||
//#ifdef GGML_USE_METAL
|
||||
//#include "ggml-metal.h"
|
||||
//#endif
|
||||
//
|
||||
//#ifdef GGML_USE_CANN
|
||||
//#include "ggml-cann.h"
|
||||
//#endif
|
||||
//
|
||||
//#ifdef GGML_USE_VULKAN
|
||||
//#include "ggml-vulkan.h"
|
||||
//#endif
|
||||
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb_image.h"
|
||||
|
||||
@@ -600,18 +581,54 @@ struct clip_ctx {
|
||||
bool has_post_norm = false;
|
||||
bool has_patch_bias = false;
|
||||
|
||||
struct gguf_context * ctx_gguf;
|
||||
struct ggml_context * ctx_data;
|
||||
struct gguf_context * ctx_gguf = nullptr;
|
||||
struct ggml_context * ctx_data = nullptr;
|
||||
|
||||
std::vector<uint8_t> buf_compute_meta;
|
||||
|
||||
// memory buffers to evaluate the model
|
||||
ggml_backend_buffer_t params_buffer = NULL;
|
||||
std::vector<ggml_backend_t> backend_ptrs;
|
||||
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
||||
|
||||
ggml_backend_t backend = NULL;
|
||||
ggml_gallocr_t compute_alloc = NULL;
|
||||
ggml_backend_t backend = nullptr;
|
||||
ggml_backend_t backend_cpu = nullptr;
|
||||
ggml_backend_buffer_t buf = nullptr;
|
||||
|
||||
ggml_backend_sched_ptr sched;
|
||||
|
||||
struct clip_image_size * load_image_size;
|
||||
|
||||
clip_ctx(clip_context_params & ctx_params) {
|
||||
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
||||
backend = ctx_params.use_gpu
|
||||
? ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr)
|
||||
: nullptr;
|
||||
|
||||
if (backend) {
|
||||
LOG_INF("%s: CLIP using %s backend\n", __func__, ggml_backend_name(backend));
|
||||
backend_ptrs.push_back(backend);
|
||||
backend_buft.push_back(ggml_backend_get_default_buffer_type(backend));
|
||||
} else {
|
||||
backend = backend_cpu;
|
||||
LOG_INF("%s: CLIP using CPU backend\n", __func__);
|
||||
}
|
||||
|
||||
backend_ptrs.push_back(backend_cpu);
|
||||
backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu));
|
||||
|
||||
sched.reset(
|
||||
ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), 8192, false)
|
||||
);
|
||||
}
|
||||
|
||||
~clip_ctx() {
|
||||
ggml_free(ctx_data);
|
||||
gguf_free(ctx_gguf);
|
||||
ggml_backend_buffer_free(buf);
|
||||
ggml_backend_free(backend);
|
||||
if (backend_cpu != backend) {
|
||||
ggml_backend_free(backend_cpu);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
|
||||
@@ -1184,6 +1201,14 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
|
||||
// read and create ggml_context containing the tensors and their data
|
||||
struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
return clip_init(fname, clip_context_params{
|
||||
/* use_gpu */ true,
|
||||
/* verbosity */ verbosity,
|
||||
});
|
||||
}
|
||||
|
||||
struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params) {
|
||||
int verbosity = ctx_params.verbosity;
|
||||
struct ggml_context * meta = NULL;
|
||||
|
||||
struct gguf_init_params params = {
|
||||
@@ -1277,7 +1302,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
}
|
||||
}
|
||||
|
||||
clip_ctx * new_clip = new clip_ctx{};
|
||||
clip_ctx * new_clip = new clip_ctx(ctx_params);
|
||||
|
||||
// update projector type
|
||||
{
|
||||
@@ -1296,36 +1321,6 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
}
|
||||
}
|
||||
|
||||
//#ifdef GGML_USE_CUDA
|
||||
// new_clip->backend = ggml_backend_cuda_init(0);
|
||||
// LOG_INF("%s: CLIP using CUDA backend\n", __func__);
|
||||
//#endif
|
||||
//
|
||||
//#ifdef GGML_USE_METAL
|
||||
// new_clip->backend = ggml_backend_metal_init();
|
||||
// LOG_INF("%s: CLIP using Metal backend\n", __func__);
|
||||
//#endif
|
||||
//
|
||||
//#ifdef GGML_USE_CANN
|
||||
// new_clip->backend = ggml_backend_cann_init(0);
|
||||
// LOG_INF("%s: CLIP using CANN backend\n", __func__);
|
||||
//#endif
|
||||
//
|
||||
//#ifdef GGML_USE_VULKAN
|
||||
// new_clip->backend = ggml_backend_vk_init(0);
|
||||
// LOG_INF("%s: CLIP using Vulkan backend\n", __func__);
|
||||
//#endif
|
||||
//
|
||||
//#ifdef GGML_USE_SYCL
|
||||
// new_clip->backend = ggml_backend_sycl_init(0);
|
||||
// LOG_INF("%s: CLIP using SYCL backend\n", __func__);
|
||||
//#endif
|
||||
|
||||
if (!new_clip->backend) {
|
||||
new_clip->backend = ggml_backend_cpu_init();
|
||||
LOG_INF("%s: CLIP using CPU backend\n", __func__);
|
||||
}
|
||||
|
||||
// model size and capabilities
|
||||
{
|
||||
int idx = get_key_idx(ctx, KEY_HAS_TEXT_ENC);
|
||||
@@ -1378,6 +1373,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
LOG_INF("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
|
||||
LOG_INF("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector);
|
||||
LOG_INF("%s: minicpmv_projector: %d\n", __func__, new_clip->has_minicpmv_projector);
|
||||
LOG_INF("%s: minicpmv_version: %d\n", __func__, new_clip->minicpmv_version);
|
||||
LOG_INF("%s: glm_projector: %d\n", __func__, new_clip->has_glm_projector);
|
||||
LOG_INF("%s: model size: %.2f MB\n", __func__, model_size / 1024.0 / 1024.0);
|
||||
LOG_INF("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0);
|
||||
@@ -1420,7 +1416,9 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
}
|
||||
|
||||
// alloc memory and offload data
|
||||
new_clip->params_buffer = ggml_backend_alloc_ctx_tensors(new_clip->ctx_data, new_clip->backend);
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(new_clip->backend);
|
||||
new_clip->buf = ggml_backend_alloc_ctx_tensors_from_buft(new_clip->ctx_data, buft);
|
||||
ggml_backend_buffer_set_usage(new_clip->buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
||||
for (int i = 0; i < n_tensors; ++i) {
|
||||
const char * name = gguf_get_tensor_name(ctx, i);
|
||||
struct ggml_tensor * cur = ggml_get_tensor(new_clip->ctx_data, name);
|
||||
@@ -1433,7 +1431,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
return nullptr;
|
||||
}
|
||||
int num_bytes = ggml_nbytes(cur);
|
||||
if (ggml_backend_buffer_is_host(new_clip->params_buffer)) {
|
||||
if (ggml_backend_buft_is_host(buft)) {
|
||||
// for the CPU and Metal backend, we can read directly into the tensor
|
||||
fin.read(reinterpret_cast<char *>(cur->data), num_bytes);
|
||||
} else {
|
||||
@@ -1719,14 +1717,21 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
// measure mem requirement and allocate
|
||||
{
|
||||
new_clip->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
|
||||
new_clip->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_clip->backend));
|
||||
clip_image_f32_batch batch;
|
||||
batch.size = 1;
|
||||
batch.data = nullptr;
|
||||
ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch, nullptr, false);
|
||||
ggml_gallocr_reserve(new_clip->compute_alloc, gf);
|
||||
size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0);
|
||||
LOG_INF("%s: compute allocated memory: %.2f MB\n", __func__, compute_memory_buffer_size /1024.0/1024.0);
|
||||
ggml_backend_sched_reserve(new_clip->sched.get(), gf);
|
||||
for (size_t i = 0; i < new_clip->backend_ptrs.size(); ++i) {
|
||||
ggml_backend_t backend = new_clip->backend_ptrs[i];
|
||||
ggml_backend_buffer_type_t buft = new_clip->backend_buft[i];
|
||||
size_t size = ggml_backend_sched_get_buffer_size(new_clip->sched.get(), backend);
|
||||
if (size > 1) {
|
||||
LOG_INF("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
|
||||
ggml_backend_buft_name(buft),
|
||||
size / 1024.0 / 1024.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return new_clip;
|
||||
@@ -2407,12 +2412,6 @@ ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) {
|
||||
}
|
||||
|
||||
void clip_free(clip_ctx * ctx) {
|
||||
ggml_free(ctx->ctx_data);
|
||||
gguf_free(ctx->ctx_gguf);
|
||||
|
||||
ggml_backend_buffer_free(ctx->params_buffer);
|
||||
ggml_backend_free(ctx->backend);
|
||||
ggml_gallocr_free(ctx->compute_alloc);
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
@@ -2608,8 +2607,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
}
|
||||
|
||||
// build the inference graph
|
||||
ggml_backend_sched_reset(ctx->sched.get());
|
||||
ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true);
|
||||
ggml_gallocr_alloc_graph(ctx->compute_alloc, gf);
|
||||
ggml_backend_sched_alloc_graph(ctx->sched.get(), gf);
|
||||
|
||||
// set inputs
|
||||
const auto & model = ctx->vision_model;
|
||||
@@ -2774,11 +2774,13 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
}
|
||||
}
|
||||
|
||||
if (ggml_backend_is_cpu(ctx->backend)) {
|
||||
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
|
||||
}
|
||||
ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads);
|
||||
|
||||
ggml_backend_graph_compute(ctx->backend, gf);
|
||||
auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf);
|
||||
if (status != GGML_STATUS_SUCCESS) {
|
||||
LOG_ERR("%s: ggml_backend_sched_graph_compute failed with error %d\n", __func__, status);
|
||||
return false;
|
||||
}
|
||||
|
||||
// the last node is the embedding tensor
|
||||
struct ggml_tensor * embeddings = ggml_graph_node(gf, -1);
|
||||
|
||||
@@ -39,8 +39,15 @@ struct clip_image_f32_batch {
|
||||
size_t size;
|
||||
};
|
||||
|
||||
CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
|
||||
CLIP_API struct clip_ctx * clip_model_load_cpu(const char * fname, int verbosity);
|
||||
struct clip_context_params {
|
||||
bool use_gpu;
|
||||
int verbosity;
|
||||
};
|
||||
|
||||
// deprecated, use clip_init
|
||||
CLIP_API struct clip_ctx * clip_model_load(const char * fname, int verbosity);
|
||||
|
||||
CLIP_API struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params);
|
||||
|
||||
CLIP_API void clip_free(struct clip_ctx * ctx);
|
||||
|
||||
|
||||
@@ -89,6 +89,7 @@ def bytes_to_unicode():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True)
|
||||
ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16")
|
||||
ap.add_argument('--bigendian', action="store_true", default=False, help="Model is executed on big-endian machine")
|
||||
ap.add_argument("--text-only", action="store_true", required=False,
|
||||
help="Save a text-only model. It can't be used to encode images")
|
||||
ap.add_argument("--vision-only", action="store_true", required=False,
|
||||
@@ -191,7 +192,7 @@ output_dir = args.output_dir if args.output_dir is not None else dir_model
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_prefix = os.path.basename(output_dir).replace("ggml_", "")
|
||||
fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf")
|
||||
fout = GGUFWriter(path=fname_out, arch="clip")
|
||||
fout = GGUFWriter(path=fname_out, arch="clip", endianess=GGUFEndian.LITTLE if not args.bigendian else GGUFEndian.BIG)
|
||||
|
||||
fout.add_bool("clip.has_text_encoder", has_text_encoder)
|
||||
fout.add_bool("clip.has_vision_encoder", has_vision_encoder)
|
||||
|
||||
@@ -86,7 +86,11 @@ static struct clip_ctx * clip_init_context(common_params * params) {
|
||||
if (prompt.empty()) {
|
||||
prompt = "describe the image in detail.";
|
||||
}
|
||||
auto * ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
|
||||
struct clip_context_params clip_params = {
|
||||
/* use_gpu */ params->n_gpu_layers != 0,
|
||||
/* verbosity */ params->verbosity,
|
||||
};
|
||||
auto * ctx_clip = clip_init(clip_path, clip_params);
|
||||
return ctx_clip;
|
||||
}
|
||||
|
||||
@@ -148,19 +152,34 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
|
||||
process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
|
||||
eval_string(ctx_llava->ctx_llama, std::string("</image>").c_str(), params->n_batch, &n_past, false);
|
||||
if (num_image_embeds > 1) {
|
||||
size_t num_image_embeds_col = clip_uhd_num_image_embeds_col(ctx_llava->ctx_clip);
|
||||
eval_string(ctx_llava->ctx_llama, std::string("<slice>").c_str(), params->n_batch, &n_past, false);
|
||||
for (size_t i = 0; i < (num_image_embeds-1)/num_image_embeds_col; ++i) {
|
||||
for (size_t j = 0; j < num_image_embeds_col; ++j) {
|
||||
eval_string(ctx_llava->ctx_llama, std::string("<image>").c_str(), params->n_batch, &n_past, false);
|
||||
process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
|
||||
eval_string(ctx_llava->ctx_llama, std::string("</image>").c_str(), params->n_batch, &n_past, false);
|
||||
if (j == num_image_embeds_col - 1) {
|
||||
eval_string(ctx_llava->ctx_llama, std::string("\n").c_str(), params->n_batch, &n_past, false);
|
||||
if (has_minicpmv_projector == 2) {
|
||||
size_t num_image_embeds_col = clip_uhd_num_image_embeds_col(ctx_llava->ctx_clip);
|
||||
eval_string(ctx_llava->ctx_llama, std::string("<slice>").c_str(), params->n_batch, &n_past, false);
|
||||
for (size_t i = 0; i < (num_image_embeds-1)/num_image_embeds_col; ++i) {
|
||||
for (size_t j = 0; j < num_image_embeds_col; ++j) {
|
||||
eval_string(ctx_llava->ctx_llama, std::string("<image>").c_str(), params->n_batch, &n_past, false);
|
||||
process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
|
||||
eval_string(ctx_llava->ctx_llama, std::string("</image>").c_str(), params->n_batch, &n_past, false);
|
||||
if (j == num_image_embeds_col - 1) {
|
||||
eval_string(ctx_llava->ctx_llama, std::string("\n").c_str(), params->n_batch, &n_past, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
eval_string(ctx_llava->ctx_llama, std::string("</slice>").c_str(), params->n_batch, &n_past, false);
|
||||
}
|
||||
else if (has_minicpmv_projector == 3 || has_minicpmv_projector == 4) {
|
||||
size_t num_image_embeds_col = clip_uhd_num_image_embeds_col(ctx_llava->ctx_clip);
|
||||
for (size_t i = 0; i < (num_image_embeds-1)/num_image_embeds_col; ++i) {
|
||||
for (size_t j = 0; j < num_image_embeds_col; ++j) {
|
||||
eval_string(ctx_llava->ctx_llama, std::string("<slice>").c_str(), params->n_batch, &n_past, false);
|
||||
process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
|
||||
eval_string(ctx_llava->ctx_llama, std::string("</slice>").c_str(), params->n_batch, &n_past, false);
|
||||
if (j == num_image_embeds_col - 1) {
|
||||
eval_string(ctx_llava->ctx_llama, std::string("\n").c_str(), params->n_batch, &n_past, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
eval_string(ctx_llava->ctx_llama, std::string("</slice>").c_str(), params->n_batch, &n_past, false);
|
||||
}
|
||||
LOG_INF("%s: image token past: %d\n", __func__, n_past);
|
||||
}
|
||||
|
||||
@@ -597,7 +597,6 @@ elif args.minicpmv_projector is not None:
|
||||
fname_middle = "mmproj-"
|
||||
has_text_encoder = False
|
||||
has_minicpmv_projector = True
|
||||
minicpmv_version = 4
|
||||
elif args.vision_only:
|
||||
fname_middle = "vision-"
|
||||
has_text_encoder = False
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// WARNING: This file was ported from json_schema_to_grammar.py, please fix bugs / add features there first.
|
||||
const SPACE_RULE = '| " " | "\\n" [ \\t]{0,20}';
|
||||
const SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}';
|
||||
|
||||
function _buildRepetition(itemRule, minItems, maxItems, opts={}) {
|
||||
if (minItems === 0 && maxItems === 1) {
|
||||
|
||||
+50
-37
@@ -131,9 +131,9 @@ struct slot_params {
|
||||
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
|
||||
}
|
||||
|
||||
std::vector<std::string> grammar_trigger_words;
|
||||
for (const auto & trigger : sampling.grammar_trigger_words) {
|
||||
grammar_trigger_words.push_back(trigger.word);
|
||||
auto grammar_triggers = json::array();
|
||||
for (const auto & trigger : sampling.grammar_triggers) {
|
||||
grammar_triggers.push_back(trigger.to_json<json>());
|
||||
}
|
||||
|
||||
return json {
|
||||
@@ -170,8 +170,8 @@ struct slot_params {
|
||||
{"n_probs", sampling.n_probs},
|
||||
{"min_keep", sampling.min_keep},
|
||||
{"grammar", sampling.grammar},
|
||||
{"grammar_trigger_words", grammar_trigger_words},
|
||||
{"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
|
||||
{"grammar_lazy", sampling.grammar_lazy},
|
||||
{"grammar_triggers", grammar_triggers},
|
||||
{"preserved_tokens", sampling.preserved_tokens},
|
||||
{"chat_format", common_chat_format_name(oaicompat_chat_format)},
|
||||
{"samplers", samplers},
|
||||
@@ -356,24 +356,6 @@ struct server_task {
|
||||
}
|
||||
|
||||
{
|
||||
const auto grammar_triggers = data.find("grammar_triggers");
|
||||
if (grammar_triggers != data.end()) {
|
||||
for (const auto & t : *grammar_triggers) {
|
||||
common_grammar_trigger trigger;
|
||||
trigger.word = t.at("word");
|
||||
trigger.at_start = t.at("at_start");
|
||||
|
||||
auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
SRV_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
|
||||
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
|
||||
params.sampling.preserved_tokens.insert(ids[0]);
|
||||
continue;
|
||||
}
|
||||
SRV_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
|
||||
params.sampling.grammar_trigger_words.push_back(trigger);
|
||||
}
|
||||
}
|
||||
const auto preserved_tokens = data.find("preserved_tokens");
|
||||
if (preserved_tokens != data.end()) {
|
||||
for (const auto & t : *preserved_tokens) {
|
||||
@@ -383,12 +365,39 @@ struct server_task {
|
||||
params.sampling.preserved_tokens.insert(ids[0]);
|
||||
} else {
|
||||
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
|
||||
SRV_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
|
||||
SRV_DBG("Not preserved because more than 1 token: %s\n", t.get<std::string>().c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (params.sampling.grammar_lazy) {
|
||||
GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0);
|
||||
const auto grammar_triggers = data.find("grammar_triggers");
|
||||
if (grammar_triggers != data.end()) {
|
||||
for (const auto & t : *grammar_triggers) {
|
||||
auto ct = common_grammar_trigger::from_json(t);
|
||||
if (ct.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
|
||||
const auto & word = ct.value;
|
||||
auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
auto token = ids[0];
|
||||
if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) {
|
||||
throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
|
||||
}
|
||||
SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
|
||||
common_grammar_trigger trigger;
|
||||
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
|
||||
trigger.value = word;
|
||||
trigger.token = token;
|
||||
params.sampling.grammar_triggers.push_back(std::move(trigger));
|
||||
} else {
|
||||
SRV_DBG("Grammar trigger word: `%s`\n", word.c_str());
|
||||
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
|
||||
}
|
||||
} else {
|
||||
params.sampling.grammar_triggers.push_back(ct);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) {
|
||||
throw std::runtime_error("Error: no triggers set for lazy grammar!");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -742,7 +751,10 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
{"name", tc.name},
|
||||
{"arguments", tc.arguments},
|
||||
}},
|
||||
{"id", tc.id},
|
||||
// Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
|
||||
// We only generate a random id for the ones that don't generate one by themselves
|
||||
// (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
|
||||
{"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
|
||||
});
|
||||
}
|
||||
message["tool_calls"] = tool_calls;
|
||||
@@ -1304,7 +1316,7 @@ struct server_slot {
|
||||
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
|
||||
}
|
||||
|
||||
bool can_batch_with(server_slot & other_slot) {
|
||||
bool can_batch_with(server_slot & other_slot) const {
|
||||
return is_non_causal() == other_slot.is_non_causal()
|
||||
&& are_lora_equal(lora, other_slot.lora);
|
||||
}
|
||||
@@ -1892,6 +1904,7 @@ struct server_context {
|
||||
try {
|
||||
common_chat_format_example(chat_templates.get(), params.use_jinja);
|
||||
} catch (const std::exception & e) {
|
||||
SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what());
|
||||
SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
||||
chat_templates = common_chat_templates_init(model, "chatml");
|
||||
}
|
||||
@@ -2045,7 +2058,7 @@ struct server_context {
|
||||
|
||||
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
|
||||
// Might be better to reject the request with a 400 ?
|
||||
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.params.n_predict, slot.n_predict);
|
||||
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, slot.n_predict);
|
||||
slot.params.n_predict = slot.n_predict;
|
||||
}
|
||||
|
||||
@@ -2148,14 +2161,6 @@ struct server_context {
|
||||
}
|
||||
|
||||
if (slot.has_new_line) {
|
||||
// if we have already seen a new line, we stop after a certain time limit
|
||||
if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
|
||||
slot.stop = STOP_TYPE_LIMIT;
|
||||
slot.has_next_token = false;
|
||||
|
||||
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
|
||||
}
|
||||
|
||||
// require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
|
||||
if (slot.params.n_indent > 0) {
|
||||
// check the current indentation
|
||||
@@ -2194,6 +2199,14 @@ struct server_context {
|
||||
// check if there is a new line in the generated text
|
||||
if (result.text_to_send.find('\n') != std::string::npos) {
|
||||
slot.has_new_line = true;
|
||||
|
||||
// if we have seen a new line, we stop after a certain time limit, but only upon another new line
|
||||
if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
|
||||
slot.stop = STOP_TYPE_LIMIT;
|
||||
slot.has_next_token = false;
|
||||
|
||||
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
|
||||
}
|
||||
}
|
||||
|
||||
// if context shift is disabled, we stop when it reaches the context limit
|
||||
|
||||
Regular → Executable
+139
-100
@@ -1,4 +1,12 @@
|
||||
#!/usr/bin/env python
|
||||
import pytest
|
||||
|
||||
# ensure grandparent path is in sys.path
|
||||
from pathlib import Path
|
||||
import sys
|
||||
path = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(path))
|
||||
|
||||
from utils import *
|
||||
|
||||
server: ServerProcess
|
||||
@@ -66,15 +74,8 @@ WEATHER_TOOL = {
|
||||
}
|
||||
|
||||
|
||||
def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None):
|
||||
global server
|
||||
n_predict = 512
|
||||
# server = ServerPreset.stories15m_moe()
|
||||
server.jinja = True
|
||||
server.n_predict = n_predict
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs):
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
@@ -83,16 +84,15 @@ def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, a
|
||||
"tool_choice": "required",
|
||||
"tools": [tool],
|
||||
"parallel_tool_calls": False,
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"top_p": 1.0,
|
||||
**kwargs,
|
||||
})
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}'
|
||||
assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
|
||||
assert expected_function_name == tool_call["function"]["name"]
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
@@ -108,7 +108,14 @@ def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, a
|
||||
("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
|
||||
])
|
||||
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
|
||||
do_test_completion_with_required_tool_tiny(template_name, tool, argument_key)
|
||||
global server
|
||||
n_predict = 512
|
||||
# server = ServerPreset.stories15m_moe()
|
||||
server.jinja = True
|
||||
server.n_predict = n_predict
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, temperature=0.0, top_k=1, top_p=1.0)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@@ -130,10 +137,17 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict,
|
||||
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
|
||||
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
|
||||
("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"),
|
||||
("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"),
|
||||
# ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"),
|
||||
])
|
||||
def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None):
|
||||
do_test_completion_with_required_tool_tiny(template_name, tool, argument_key)
|
||||
global server
|
||||
n_predict = 512
|
||||
# server = ServerPreset.stories15m_moe()
|
||||
server.jinja = True
|
||||
server.n_predict = n_predict
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@@ -142,25 +156,33 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
|
||||
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
|
||||
(TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
# (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
# (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
@@ -176,10 +198,10 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
# (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
# TODO: fix these
|
||||
# (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
# (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
global server
|
||||
@@ -197,7 +219,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
||||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
@@ -215,7 +237,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}'
|
||||
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
|
||||
assert expected_function_name == tool_call["function"]["name"]
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
@@ -225,13 +247,8 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
||||
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
|
||||
|
||||
|
||||
def do_test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
|
||||
global server
|
||||
server.jinja = True
|
||||
server.n_predict = n_predict
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs):
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
@@ -239,9 +256,7 @@ def do_test_completion_without_tool_call(template_name: str, n_predict: int, too
|
||||
],
|
||||
"tools": tools if tools else None,
|
||||
"tool_choice": tool_choice,
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"top_p": 1.0,
|
||||
**kwargs,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
@@ -254,7 +269,12 @@ def do_test_completion_without_tool_call(template_name: str, n_predict: int, too
|
||||
("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'),
|
||||
])
|
||||
def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
|
||||
do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice)
|
||||
global server
|
||||
server.jinja = True
|
||||
server.n_predict = n_predict
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@@ -270,7 +290,12 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t
|
||||
("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'),
|
||||
])
|
||||
def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
|
||||
do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice)
|
||||
global server
|
||||
server.jinja = True
|
||||
server.n_predict = n_predict
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@@ -281,6 +306,12 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
|
||||
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
@@ -324,48 +355,53 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
|
||||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
do_test_weather(server, max_tokens=n_predict)
|
||||
|
||||
|
||||
def do_test_weather(server: ServerProcess, **kwargs):
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
|
||||
{"role": "user", "content": "What is the weather in Istanbul?"},
|
||||
],
|
||||
"tools": [WEATHER_TOOL],
|
||||
**kwargs,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}'
|
||||
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"]
|
||||
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"], f'Expected weather tool call, got {tool_call["function"]["name"]}'
|
||||
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
|
||||
location = actual_arguments["location"]
|
||||
assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}"
|
||||
assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
|
||||
assert re.match('^Istanbul(( |, ?)(TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [
|
||||
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
||||
(None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
|
||||
(None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
|
||||
(None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
(None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
(None, 128, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(None, 128, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
|
||||
(None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
|
||||
# TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
|
||||
("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
# ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
# (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
# ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
global server
|
||||
# n_predict = 512
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192 * 2
|
||||
@@ -379,10 +415,14 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
|
||||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
do_test_calc_result(server, result_override, n_predict)
|
||||
|
||||
|
||||
def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs):
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things, and provide very concise answers. Do not explain your reasoning to the user. Provide any numerical values back to the user with at most two decimals."},
|
||||
{"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."},
|
||||
{"role": "user", "content": "What's the y coordinate of a point on the unit sphere at angle 30 degrees?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
@@ -423,7 +463,8 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
],
|
||||
**kwargs,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
@@ -434,19 +475,19 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
|
||||
if result_override is not None:
|
||||
assert re.match(result_override, content), f'Expected {result_override}, got {content}'
|
||||
else:
|
||||
assert re.match('^[\\s\\S]*?The (y[ -])?coordinate [\\s\\S]*?is (approximately )?0\\.56\\b|^0\\.56$', content), \
|
||||
assert re.match('^[\\s\\S]*?((That\'s|\\bis) (approximately )?)?\\b0\\.(5\\b|56\\b|556)', content), \
|
||||
f'Expected something like "The y coordinate is 0.56.", got {content}'
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("n_predict,reasoning_format,expect_content,expect_reasoning_content,hf_repo,template_override", [
|
||||
(128, 'deepseek', "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(128, None, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(128, 'deepseek', "^The sum of 102 and 7 is 109[\\s\\S]*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(128, None, "^The sum of 102 and 7 is 109[\\s\\S]*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
|
||||
(1024, 'deepseek', "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(1024, 'none', "^I need[\\s\\S]*?</think>\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(1024, 'deepseek', "To find the sum of[\\s\\S]*", "I need to calculate the sum of 102 and 7[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(1024, 'none', "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
|
||||
(1024, 'deepseek', "To find the sum of.*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
(1024, 'deepseek', "To find the sum of[\\s\\S]*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
])
|
||||
def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
global server
|
||||
@@ -464,7 +505,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
|
||||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the sum of 102 and 7?"},
|
||||
@@ -476,7 +517,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
|
||||
|
||||
content = choice["message"].get("content")
|
||||
if expect_content is None:
|
||||
assert content is None, f'Expected no content in {choice["message"]}'
|
||||
assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||
else:
|
||||
assert re.match(expect_content, content), f'Expected {expect_content}, got {content}'
|
||||
|
||||
@@ -488,46 +529,46 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [
|
||||
(None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
# (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", "chatml"),
|
||||
@pytest.mark.parametrize("hf_repo,template_override", [
|
||||
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
|
||||
(None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
||||
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)),
|
||||
(None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
|
||||
("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)),
|
||||
("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
|
||||
|
||||
('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
(None, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
# ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", None),
|
||||
|
||||
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
(None, "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", None),
|
||||
|
||||
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
|
||||
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
|
||||
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
|
||||
(None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
||||
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
||||
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"),
|
||||
])
|
||||
def test_hello_world(expected_arguments_override: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
global server
|
||||
n_predict = 512 # High because of DeepSeek R1
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = 512 # High because of DeepSeek R1
|
||||
server.n_predict = n_predict
|
||||
server.model_hf_repo = hf_repo
|
||||
server.model_hf_file = None
|
||||
if isinstance(template_override, tuple):
|
||||
@@ -537,31 +578,29 @@ def test_hello_world(expected_arguments_override: str | None, hf_repo: str, temp
|
||||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": 256,
|
||||
|
||||
do_test_hello_world(server, max_tokens=n_predict)
|
||||
|
||||
|
||||
def do_test_hello_world(server: ServerProcess, **kwargs):
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
{"role": "system", "content": "You are a tool-calling agent."},
|
||||
{"role": "user", "content": "say hello world with python"},
|
||||
],
|
||||
"tools": [PYTHON_TOOL],
|
||||
# Note: without these greedy params, Functionary v3.2 writes `def hello_world():\n print("Hello, World!")\nhello_world()` which is correct but a pain to test.
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"top_p": 1.0,
|
||||
**kwargs,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}'
|
||||
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
if expected_arguments_override is not None:
|
||||
assert actual_arguments == expected_arguments_override
|
||||
else:
|
||||
actual_arguments = json.loads(actual_arguments)
|
||||
assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
|
||||
code = actual_arguments["code"]
|
||||
assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
|
||||
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'
|
||||
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||
assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
|
||||
code = actual_arguments["code"]
|
||||
assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
|
||||
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'
|
||||
|
||||
@@ -67,6 +67,9 @@ class ServerProcess:
|
||||
id_slot: int | None = None
|
||||
cache_prompt: bool | None = None
|
||||
n_slots: int | None = None
|
||||
ctk: str | None = None
|
||||
ctv: str | None = None
|
||||
fa: bool | None = None
|
||||
server_continuous_batching: bool | None = False
|
||||
server_embeddings: bool | None = False
|
||||
server_reranking: bool | None = False
|
||||
@@ -84,6 +87,7 @@ class ServerProcess:
|
||||
reasoning_format: Literal['deepseek', 'none'] | None = None
|
||||
chat_template: str | None = None
|
||||
chat_template_file: str | None = None
|
||||
server_path: str | None = None
|
||||
|
||||
# session variables
|
||||
process: subprocess.Popen | None = None
|
||||
@@ -97,7 +101,9 @@ class ServerProcess:
|
||||
self.server_port = int(os.environ["PORT"])
|
||||
|
||||
def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
|
||||
if "LLAMA_SERVER_BIN_PATH" in os.environ:
|
||||
if self.server_path is not None:
|
||||
server_path = self.server_path
|
||||
elif "LLAMA_SERVER_BIN_PATH" in os.environ:
|
||||
server_path = os.environ["LLAMA_SERVER_BIN_PATH"]
|
||||
elif os.name == "nt":
|
||||
server_path = "../../../build/bin/Release/llama-server.exe"
|
||||
@@ -151,6 +157,12 @@ class ServerProcess:
|
||||
server_args.extend(["--ctx-size", self.n_ctx])
|
||||
if self.n_slots:
|
||||
server_args.extend(["--parallel", self.n_slots])
|
||||
if self.ctk:
|
||||
server_args.extend(["-ctk", self.ctk])
|
||||
if self.ctv:
|
||||
server_args.extend(["-ctv", self.ctv])
|
||||
if self.fa is not None:
|
||||
server_args.append("-fa")
|
||||
if self.n_predict:
|
||||
server_args.extend(["--n-predict", self.n_predict])
|
||||
if self.slot_save_path:
|
||||
@@ -184,7 +196,7 @@ class ServerProcess:
|
||||
server_args.extend(["--chat-template-file", self.chat_template_file])
|
||||
|
||||
args = [str(arg) for arg in [server_path, *server_args]]
|
||||
print(f"bench: starting server with: {' '.join(args)}")
|
||||
print(f"tests: starting server with: {' '.join(args)}")
|
||||
|
||||
flags = 0
|
||||
if "nt" == os.name:
|
||||
@@ -215,6 +227,10 @@ class ServerProcess:
|
||||
return # server is ready
|
||||
except Exception as e:
|
||||
pass
|
||||
# Check if process died
|
||||
if self.process.poll() is not None:
|
||||
raise RuntimeError(f"Server process died with return code {self.process.returncode}")
|
||||
|
||||
print(f"Waiting for server to start...")
|
||||
time.sleep(0.5)
|
||||
raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
|
||||
|
||||
@@ -435,6 +435,10 @@ static std::string gen_chatcmplid() {
|
||||
return "chatcmpl-" + random_string();
|
||||
}
|
||||
|
||||
static std::string gen_tool_call_id() {
|
||||
return random_string();
|
||||
}
|
||||
|
||||
//
|
||||
// other common utils
|
||||
//
|
||||
@@ -607,6 +611,7 @@ static json oaicompat_completion_params_parse(
|
||||
inputs.use_jinja = use_jinja;
|
||||
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
|
||||
inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
|
||||
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
|
||||
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
|
||||
}
|
||||
@@ -620,10 +625,7 @@ static json oaicompat_completion_params_parse(
|
||||
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
|
||||
auto grammar_triggers = json::array();
|
||||
for (const auto & trigger : chat_params.grammar_triggers) {
|
||||
grammar_triggers.push_back({
|
||||
{"word", trigger.word},
|
||||
{"at_start", trigger.at_start},
|
||||
});
|
||||
grammar_triggers.push_back(trigger.to_json<json>());
|
||||
}
|
||||
llama_params["grammar_triggers"] = grammar_triggers;
|
||||
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
|
||||
|
||||
@@ -106,6 +106,7 @@ option(GGML_CPU_KLEIDIAI "ggml: use KleidiAI optimized kernels if applicable
|
||||
option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
|
||||
option(GGML_AVX_VNNI "ggml: enable AVX-VNNI" OFF)
|
||||
option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB})
|
||||
option(GGML_BMI2 "ggml: enable BMI2" ${INS_ENB})
|
||||
option(GGML_AVX512 "ggml: enable AVX512F" OFF)
|
||||
option(GGML_AVX512_VBMI "ggml: enable AVX512-VBMI" OFF)
|
||||
option(GGML_AVX512_VNNI "ggml: enable AVX512-VNNI" OFF)
|
||||
@@ -194,6 +195,8 @@ option(GGML_OPENCL "ggml: use OpenCL"
|
||||
option(GGML_OPENCL_PROFILING "ggml: use OpenCL profiling (increases overhead)" OFF)
|
||||
option(GGML_OPENCL_EMBED_KERNELS "ggml: embed kernels" ON)
|
||||
option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adreno" ON)
|
||||
set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING
|
||||
"gmml: OpenCL API version to target")
|
||||
|
||||
# toolchain for vulkan-shaders-gen
|
||||
set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")
|
||||
|
||||
@@ -80,6 +80,7 @@ extern "C" {
|
||||
GGML_BACKEND_API int ggml_cpu_has_avx (void);
|
||||
GGML_BACKEND_API int ggml_cpu_has_avx_vnni (void);
|
||||
GGML_BACKEND_API int ggml_cpu_has_avx2 (void);
|
||||
GGML_BACKEND_API int ggml_cpu_has_bmi2 (void);
|
||||
GGML_BACKEND_API int ggml_cpu_has_f16c (void);
|
||||
GGML_BACKEND_API int ggml_cpu_has_fma (void);
|
||||
GGML_BACKEND_API int ggml_cpu_has_avx512 (void);
|
||||
|
||||
@@ -236,7 +236,7 @@ add_library(ggml
|
||||
target_link_libraries(ggml PUBLIC ggml-base)
|
||||
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
target_link_libraries(ggml PRIVATE dl)
|
||||
target_link_libraries(ggml PRIVATE dl stdc++fs)
|
||||
endif()
|
||||
|
||||
function(ggml_add_backend_library backend)
|
||||
@@ -289,7 +289,7 @@ function(ggml_add_cpu_backend_variant tag_name)
|
||||
set(GGML_CPU_TAG_NAME ${tag_name})
|
||||
# other: OPENMP LLAMAFILE CPU_HBM
|
||||
foreach (feat NATIVE
|
||||
AVX AVX2 AVX_VNNI FMA F16C
|
||||
AVX AVX2 BMI2 AVX_VNNI FMA F16C
|
||||
AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16
|
||||
AMX_TILE AMX_INT8 AMX_BF16)
|
||||
set(GGML_${feat} OFF)
|
||||
@@ -309,13 +309,13 @@ if (GGML_CPU_ALL_VARIANTS)
|
||||
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL")
|
||||
endif()
|
||||
ggml_add_cpu_backend_variant(sandybridge AVX)
|
||||
ggml_add_cpu_backend_variant(haswell AVX F16C AVX2 FMA)
|
||||
ggml_add_cpu_backend_variant(skylakex AVX F16C AVX2 FMA AVX512)
|
||||
ggml_add_cpu_backend_variant(icelake AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
|
||||
ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 FMA AVX_VNNI)
|
||||
ggml_add_cpu_backend_variant(haswell AVX F16C AVX2 BMI2 FMA)
|
||||
ggml_add_cpu_backend_variant(skylakex AVX F16C AVX2 BMI2 FMA AVX512)
|
||||
ggml_add_cpu_backend_variant(icelake AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
|
||||
ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 BMI2 FMA AVX_VNNI)
|
||||
if (NOT MSVC)
|
||||
# MSVC doesn't support AMX
|
||||
ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
|
||||
ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
|
||||
endif()
|
||||
elseif (GGML_CPU)
|
||||
ggml_add_cpu_backend_variant_impl("")
|
||||
|
||||
@@ -76,7 +76,14 @@ namespace fs = std::filesystem;
|
||||
static std::string path_str(const fs::path & path) {
|
||||
std::string u8path;
|
||||
try {
|
||||
#if defined(__cpp_lib_char8_t)
|
||||
// C++20 and later: u8string() returns std::u8string
|
||||
std::u8string u8str = path.u8string();
|
||||
u8path = std::string(reinterpret_cast<const char*>(u8str.c_str()));
|
||||
#else
|
||||
// C++17: u8string() returns std::string
|
||||
u8path = path.u8string();
|
||||
#endif
|
||||
} catch (...) {
|
||||
}
|
||||
return u8path;
|
||||
@@ -490,7 +497,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
search_paths.push_back(get_executable_path());
|
||||
search_paths.push_back(fs::current_path());
|
||||
} else {
|
||||
search_paths.push_back(user_search_path);
|
||||
search_paths.push_back(fs::u8path(user_search_path));
|
||||
}
|
||||
|
||||
int best_score = 0;
|
||||
@@ -504,9 +511,9 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
||||
for (const auto & entry : dir_it) {
|
||||
if (entry.is_regular_file()) {
|
||||
auto filename = entry.path().filename().native();
|
||||
auto ext = entry.path().extension().native();
|
||||
if (filename.find(file_prefix) == 0 && ext == file_extension) {
|
||||
auto filename = entry.path().filename();
|
||||
auto ext = entry.path().extension();
|
||||
if (filename.native().find(file_prefix) == 0 && ext == file_extension) {
|
||||
dl_handle_ptr handle { dl_load_library(entry) };
|
||||
if (!handle && !silent) {
|
||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(entry.path()).c_str());
|
||||
@@ -537,7 +544,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
// try to load the base backend
|
||||
for (const auto & search_path : search_paths) {
|
||||
fs::path filename = backend_filename_prefix().native() + name_path.native() + backend_filename_extension().native();
|
||||
fs::path path = search_path.native() + filename.native();
|
||||
fs::path path = search_path / filename;
|
||||
if (fs::exists(path)) {
|
||||
return get_reg().load_backend(path, silent);
|
||||
}
|
||||
|
||||
@@ -219,6 +219,10 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
if (GGML_AVX_VNNI)
|
||||
list(APPEND ARCH_DEFINITIONS __AVXVNNI__ GGML_AVX_VNNI)
|
||||
endif()
|
||||
if (GGML_BMI2)
|
||||
# MSVC does not define macro __BMI2__
|
||||
list(APPEND ARCH_DEFINITIONS __BMI2__ GGML_BMI2)
|
||||
endif()
|
||||
else ()
|
||||
if (GGML_NATIVE)
|
||||
list(APPEND ARCH_FLAGS -march=native)
|
||||
@@ -233,6 +237,10 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
list(APPEND ARCH_FLAGS -mfma)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_FMA)
|
||||
endif()
|
||||
if (GGML_BMI2)
|
||||
list(APPEND ARCH_FLAGS -mbmi2)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_BMI2)
|
||||
endif()
|
||||
if (GGML_AVX)
|
||||
list(APPEND ARCH_FLAGS -mavx)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_AVX)
|
||||
|
||||
@@ -278,6 +278,10 @@ static int ggml_backend_cpu_x86_score() {
|
||||
if (!is.SSE42()) { return 0; }
|
||||
score += 1<<2;
|
||||
#endif
|
||||
#ifdef GGML_BMI2
|
||||
if (!is.BMI2()) { return 0; }
|
||||
score += 1<<3;
|
||||
#endif
|
||||
#ifdef GGML_AVX
|
||||
if (!is.AVX()) { return 0; }
|
||||
score += 1<<4;
|
||||
|
||||
@@ -11362,10 +11362,19 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
__m256i sumi = _mm256_setzero_si256();
|
||||
int sumi1 = 0;
|
||||
for (int ib = 0; ib < QK_K/32; ib += 2) {
|
||||
#ifdef __BMI2__
|
||||
const uint64_t packed_idx1 = _pdep_u64(*(const uint32_t *)qs, 0x00ff00ff00ff00ffULL) | _pdep_u64(qh[ib], 0x700070007000700ULL);
|
||||
const uint64_t packed_idx2 = _pdep_u64(*(const uint32_t *)(qs + 4), 0x00ff00ff00ff00ffULL) | _pdep_u64(qh[ib + 1], 0x700070007000700ULL);
|
||||
const uint16_t *idx1 = (const uint16_t *)(&packed_idx1);
|
||||
const uint16_t *idx2 = (const uint16_t *)(&packed_idx2);
|
||||
const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[idx1[3]], iq1s_grid[idx1[2]], iq1s_grid[idx1[1]], iq1s_grid[idx1[0]]);
|
||||
const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[idx2[3]], iq1s_grid[idx2[2]], iq1s_grid[idx2[1]], iq1s_grid[idx2[0]]);
|
||||
#else
|
||||
const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)],
|
||||
iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
|
||||
const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)],
|
||||
iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
|
||||
#endif
|
||||
qs += 8;
|
||||
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
@@ -11711,6 +11720,10 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
|
||||
const __m256i mask = _mm256_set1_epi16(0x7);
|
||||
const __m256i mone = _mm256_set1_epi16(1);
|
||||
const __m256i mone8 = _mm256_set1_epi8(1);
|
||||
const __m256i mtwo8 = _mm256_set1_epi8(2);
|
||||
// VPSHUFB cannot cross 128-bit lanes so odd shifts go to upper half.
|
||||
const __m256i scales_shift = _mm256_set_epi64x(9, 3, 6, 0);
|
||||
|
||||
__m256 accum1 = _mm256_setzero_ps();
|
||||
__m256 accum2 = _mm256_setzero_ps();
|
||||
@@ -11722,10 +11735,33 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
||||
|
||||
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||
// Extract 3-bit scales (16 values)
|
||||
__m256i scales = _mm256_set1_epi64x(*(const uint64_t*)sc);
|
||||
scales = _mm256_srlv_epi64(scales, scales_shift);
|
||||
scales = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scales, mask), 1), mone);
|
||||
|
||||
// Indices to repeat each scale 8 times.
|
||||
__m256i scales_idx1 = _mm256_set1_epi16(0x0100);
|
||||
__m256i scales_idx2 = _mm256_add_epi8(scales_idx1, _mm256_set1_epi8(8));
|
||||
|
||||
__m256i sumi1 = _mm256_setzero_si256();
|
||||
__m256i sumi2 = _mm256_setzero_si256();
|
||||
for (int ib = 0; ib < QK_K/32; ib += 2) {
|
||||
#ifdef __BMI2__
|
||||
const uint64_t packed_idx1 = _pdep_u64(*(const uint32_t *)qs, 0x00ff00ff00ff00ffULL)
|
||||
| _pdep_u64(*(const uint16_t*)(qh) & 0x7777, 0xf000f000f000f00ULL);
|
||||
const uint64_t packed_idx2 = _pdep_u64(*(const uint32_t *)(qs + 4), 0x00ff00ff00ff00ffULL)
|
||||
| _pdep_u64(*(const uint16_t*)(qh + 2) & 0x7777, 0xf000f000f000f00ULL);
|
||||
const uint16_t *idx1 = (const uint16_t *)(&packed_idx1);
|
||||
const uint16_t *idx2 = (const uint16_t *)(&packed_idx2);
|
||||
const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[idx1[3]], iq1s_grid[idx1[2]], iq1s_grid[idx1[1]], iq1s_grid[idx1[0]]);
|
||||
const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[idx2[3]], iq1s_grid[idx2[2]], iq1s_grid[idx2[1]], iq1s_grid[idx2[0]]);
|
||||
|
||||
// Convert signs to bytes 0x81 (negative) or 0x01 (positive)
|
||||
const uint64_t delta_sign = _pdep_u64(*(const uint32_t*)(qh) & 0x88888888, 0xf0f0f0f0f0f0f0f0ULL);
|
||||
const __m256i delta1 = _mm256_or_si256(mone8, _mm256_cvtepi8_epi64(_mm_set1_epi32(delta_sign)));
|
||||
const __m256i delta2 = _mm256_or_si256(mone8, _mm256_cvtepi8_epi64(_mm_set1_epi32(delta_sign >> 32)));
|
||||
#else
|
||||
const __m256i q1b_1 = _mm256_set_epi64x(
|
||||
iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)],
|
||||
iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]
|
||||
@@ -11734,11 +11770,6 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)],
|
||||
iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]
|
||||
);
|
||||
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
|
||||
const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
|
||||
const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
|
||||
|
||||
const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
|
||||
qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
|
||||
@@ -11748,15 +11779,21 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
|
||||
qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
|
||||
qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
|
||||
#endif
|
||||
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
|
||||
const __m256i dot3 = mul_add_epi8(delta1, q8b_1);
|
||||
const __m256i dot4 = mul_add_epi8(delta2, q8b_2);
|
||||
const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
|
||||
const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
|
||||
const __m256i dot3 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_1, delta1));
|
||||
const __m256i dot4 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_2, delta2));
|
||||
|
||||
__m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0));
|
||||
__m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6));
|
||||
__m256i scale1 = _mm256_shuffle_epi8(scales, scales_idx1);
|
||||
__m256i scale2 = _mm256_shuffle_epi8(scales, scales_idx2);
|
||||
|
||||
scales_idx1 = _mm256_add_epi8(scales_idx1, mtwo8);
|
||||
scales_idx2 = _mm256_add_epi8(scales_idx2, mtwo8);
|
||||
|
||||
scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone);
|
||||
scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone);
|
||||
const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
|
||||
const __m256i p2 = _mm256_madd_epi16(dot2, scale2);
|
||||
const __m256i p3 = _mm256_madd_epi16(dot3, scale1);
|
||||
|
||||
@@ -6648,6 +6648,135 @@ static void ggml_compute_forward_repeat_back(
|
||||
|
||||
// ggml_compute_forward_concat
|
||||
|
||||
static void ggml_compute_forward_concat_any(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
const size_t len = ggml_type_size(src0->type);
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int32_t dim = ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
GGML_ASSERT(dim >= 0 && dim < 4);
|
||||
|
||||
int64_t o[4] = {0, 0, 0, 0};
|
||||
o[dim] = src0->ne[dim];
|
||||
|
||||
const char * x;
|
||||
|
||||
// TODO: smarter multi-theading
|
||||
for (int i3 = 0; i3 < ne3; i3++) {
|
||||
for (int i2 = ith; i2 < ne2; i2 += nth) {
|
||||
for (int i1 = 0; i1 < ne1; i1++) {
|
||||
for (int i0 = 0; i0 < ne0; i0++) {
|
||||
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
||||
x = (const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03;
|
||||
} else {
|
||||
x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13;
|
||||
}
|
||||
|
||||
char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3;
|
||||
|
||||
memcpy(y, x, len);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_concat_i8(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int32_t dim = ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
GGML_ASSERT(dim >= 0 && dim < 4);
|
||||
|
||||
int64_t o[4] = {0, 0, 0, 0};
|
||||
o[dim] = src0->ne[dim];
|
||||
|
||||
const int8_t * x;
|
||||
|
||||
// TODO: smarter multi-theading
|
||||
for (int i3 = 0; i3 < ne3; i3++) {
|
||||
for (int i2 = ith; i2 < ne2; i2 += nth) {
|
||||
for (int i1 = 0; i1 < ne1; i1++) {
|
||||
for (int i0 = 0; i0 < ne0; i0++) {
|
||||
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
||||
x = (const int8_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
|
||||
} else {
|
||||
x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
|
||||
}
|
||||
|
||||
int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
|
||||
*y = *x;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_concat_f16(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int32_t dim = ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
GGML_ASSERT(dim >= 0 && dim < 4);
|
||||
|
||||
int64_t o[4] = {0, 0, 0, 0};
|
||||
o[dim] = src0->ne[dim];
|
||||
|
||||
const ggml_fp16_t * x;
|
||||
|
||||
// TODO: smarter multi-theading
|
||||
for (int i3 = 0; i3 < ne3; i3++) {
|
||||
for (int i2 = ith; i2 < ne2; i2 += nth) {
|
||||
for (int i1 = 0; i1 < ne1; i1++) {
|
||||
for (int i0 = 0; i0 < ne0; i0++) {
|
||||
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
||||
x = (const ggml_fp16_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
|
||||
} else {
|
||||
x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
|
||||
}
|
||||
|
||||
ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
|
||||
*y = *x;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_concat_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
@@ -6655,7 +6784,7 @@ static void ggml_compute_forward_concat_f32(
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
@@ -6698,6 +6827,16 @@ static void ggml_compute_forward_concat(
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_I16:
|
||||
{
|
||||
ggml_compute_forward_concat_f16(params, dst);
|
||||
} break;
|
||||
case GGML_TYPE_I8:
|
||||
{
|
||||
ggml_compute_forward_concat_i8(params, dst);
|
||||
} break;
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_I32:
|
||||
{
|
||||
@@ -6705,7 +6844,7 @@ static void ggml_compute_forward_concat(
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
ggml_compute_forward_concat_any(params, dst);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15440,6 +15579,14 @@ int ggml_cpu_has_amx_int8(void) {
|
||||
#endif
|
||||
}
|
||||
|
||||
int ggml_cpu_has_bmi2(void) {
|
||||
#if defined(__BMI2__)
|
||||
return 1;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
int ggml_cpu_has_fma(void) {
|
||||
#if defined(__FMA__)
|
||||
return 1;
|
||||
|
||||
@@ -511,6 +511,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
|
||||
if (ggml_cpu_has_fma()) {
|
||||
features.push_back({ "FMA", "1" });
|
||||
}
|
||||
if (ggml_cpu_has_bmi2()) {
|
||||
features.push_back({ "BMI2", "1" });
|
||||
}
|
||||
if (ggml_cpu_has_avx512()) {
|
||||
features.push_back({ "AVX512", "1" });
|
||||
}
|
||||
|
||||
@@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
|
||||
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
||||
if (cc == GGML_CUDA_CC_VOLTA) {
|
||||
if (fp16_mma_available(cc) && !new_mma_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -2571,7 +2571,7 @@ static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vecto
|
||||
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
|
||||
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
|
||||
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
|
||||
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
|
||||
*(void**)cuda_ctx->cuda_graph->params[i].kernelParams[1] = *(void**)updated_kernel_arg_ptr;
|
||||
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,12 +27,12 @@ configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h
|
||||
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
|
||||
configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)
|
||||
|
||||
set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
|
||||
if (GGML_METAL_EMBED_LIBRARY)
|
||||
enable_language(ASM)
|
||||
|
||||
add_compile_definitions(GGML_METAL_EMBED_LIBRARY)
|
||||
|
||||
set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
|
||||
set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
|
||||
set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
|
||||
|
||||
@@ -88,12 +88,11 @@ else()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
|
||||
COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air
|
||||
COMMAND xcrun -sdk macosx metallib ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
|
||||
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air
|
||||
COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o - |
|
||||
xcrun -sdk macosx metallib - -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
|
||||
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h
|
||||
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal
|
||||
DEPENDS ggml-metal.metal ggml-common.h
|
||||
DEPENDS ggml-metal.metal ${METALLIB_COMMON}
|
||||
COMMENT "Compiling Metal kernels"
|
||||
)
|
||||
|
||||
|
||||
@@ -285,4 +285,239 @@ typedef struct {
|
||||
float eps;
|
||||
} ggml_metal_kargs_rms_norm;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
int32_t n_groups;
|
||||
float eps;
|
||||
} ggml_metal_kargs_group_norm;
|
||||
|
||||
typedef struct {
|
||||
int32_t IC;
|
||||
int32_t IL;
|
||||
int32_t K;
|
||||
int32_t s0;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
} ggml_metal_kargs_conv_transpose_1d;
|
||||
|
||||
typedef struct {
|
||||
uint64_t ofs0;
|
||||
uint64_t ofs1;
|
||||
int32_t IW;
|
||||
int32_t IH;
|
||||
int32_t CHW;
|
||||
int32_t s0;
|
||||
int32_t s1;
|
||||
int32_t p0;
|
||||
int32_t p1;
|
||||
int32_t d0;
|
||||
int32_t d1;
|
||||
int32_t N;
|
||||
int32_t KH;
|
||||
int32_t KW;
|
||||
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
|
||||
} ggml_metal_kargs_im2col;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
int64_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int64_t ne10;
|
||||
int64_t ne11;
|
||||
int64_t ne12;
|
||||
int64_t ne13;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
int64_t ne0;
|
||||
int64_t ne1;
|
||||
int64_t ne2;
|
||||
int64_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_sum_rows;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
float m1;
|
||||
uint32_t n_head_log2;
|
||||
} ggml_metal_kargs_soft_max;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int n_past;
|
||||
} ggml_metal_kargs_diag_mask_inf;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
int64_t ne10;
|
||||
int64_t ne11;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
int64_t ne0;
|
||||
int64_t ne1;
|
||||
int64_t ne2;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
} ggml_metal_kargs_ssm_conv;
|
||||
|
||||
typedef struct {
|
||||
int64_t d_state;
|
||||
int64_t d_inner;
|
||||
int64_t n_seq_tokens;
|
||||
int64_t n_seqs;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
uint64_t nb20;
|
||||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb30;
|
||||
uint64_t nb31;
|
||||
uint64_t nb40;
|
||||
uint64_t nb41;
|
||||
uint64_t nb42;
|
||||
uint64_t nb50;
|
||||
uint64_t nb51;
|
||||
uint64_t nb52;
|
||||
} ggml_metal_kargs_ssm_scan;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
int64_t ne10;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
} ggml_metal_kargs_get_rows;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
int64_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int64_t ne0;
|
||||
int64_t ne1;
|
||||
int64_t ne2;
|
||||
int64_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
float sf0;
|
||||
float sf1;
|
||||
float sf2;
|
||||
float sf3;
|
||||
} ggml_metal_kargs_upscale;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
int64_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int64_t ne0;
|
||||
int64_t ne1;
|
||||
int64_t ne2;
|
||||
int64_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_pad;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
int64_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int64_t ne0;
|
||||
int64_t ne1;
|
||||
int64_t ne2;
|
||||
int64_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
int32_t p0;
|
||||
int32_t p1;
|
||||
} ggml_metal_kargs_pad_reflect_1d;
|
||||
|
||||
typedef struct {
|
||||
uint64_t nb1;
|
||||
int dim;
|
||||
int max_period;
|
||||
} ggml_metal_kargs_timestep_embedding;
|
||||
|
||||
typedef struct {
|
||||
float slope;
|
||||
} ggml_metal_kargs_leaky_relu;
|
||||
|
||||
typedef struct {
|
||||
int64_t ncols;
|
||||
int64_t ncols_pad;
|
||||
} ggml_metal_kargs_argsort;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne0;
|
||||
float start;
|
||||
float step;
|
||||
} ggml_metal_kargs_arange;
|
||||
|
||||
typedef struct {
|
||||
int32_t k0;
|
||||
int32_t k1;
|
||||
int32_t s0;
|
||||
int32_t s1;
|
||||
int32_t p0;
|
||||
int32_t p1;
|
||||
int64_t IH;
|
||||
int64_t IW;
|
||||
int64_t OH;
|
||||
int64_t OW;
|
||||
int64_t parallel_elements;
|
||||
} ggml_metal_kargs_pool_2d;
|
||||
|
||||
#endif // GGML_METAL_IMPL
|
||||
|
||||
+409
-339
@@ -46,6 +46,7 @@ static struct ggml_backend_device g_ggml_backend_metal_device;
|
||||
static struct ggml_backend_metal_device_context {
|
||||
id<MTLDevice> mtl_device;
|
||||
int mtl_device_ref_count;
|
||||
id<MTLLibrary> mtl_library;
|
||||
|
||||
bool has_simdgroup_reduction;
|
||||
bool has_simdgroup_mm;
|
||||
@@ -57,6 +58,7 @@ static struct ggml_backend_metal_device_context {
|
||||
} g_ggml_ctx_dev_main = {
|
||||
/*.mtl_device =*/ nil,
|
||||
/*.mtl_device_ref_count =*/ 0,
|
||||
/*.mtl_library =*/ nil,
|
||||
/*.has_simdgroup_reduction =*/ false,
|
||||
/*.has_simdgroup_mm =*/ false,
|
||||
/*.has_residency_sets =*/ false,
|
||||
@@ -108,6 +110,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
|
||||
ctx->mtl_device_ref_count--;
|
||||
|
||||
if (ctx->mtl_device_ref_count == 0) {
|
||||
if (ctx->mtl_library) {
|
||||
[ctx->mtl_library release];
|
||||
ctx->mtl_library = nil;
|
||||
}
|
||||
|
||||
if (ctx->mtl_device) {
|
||||
[ctx->mtl_device release];
|
||||
ctx->mtl_device = nil;
|
||||
@@ -467,11 +474,13 @@ struct ggml_backend_metal_context {
|
||||
// for now it is easier to work in a separate file
|
||||
// static NSString * const msl_library_source = @"see metal.metal";
|
||||
|
||||
#if !GGML_METAL_EMBED_LIBRARY
|
||||
// Here to assist with NSBundle Path Hack
|
||||
@interface GGMLMetalClass : NSObject
|
||||
@end
|
||||
@implementation GGMLMetalClass
|
||||
@end
|
||||
#endif
|
||||
|
||||
static void * ggml_metal_host_malloc(size_t n) {
|
||||
void * data = NULL;
|
||||
@@ -493,6 +502,139 @@ static void * ggml_metal_host_malloc(size_t n) {
|
||||
return data;
|
||||
}
|
||||
|
||||
// load library
|
||||
//
|
||||
// - first check if the library is embedded
|
||||
// - then check if the library is in the bundle
|
||||
// - if not found, load the source and compile it
|
||||
// - if that fails, return NULL
|
||||
static id<MTLLibrary> ggml_metal_load_library(id<MTLDevice> device, bool use_bfloat) {
|
||||
id<MTLLibrary> metal_library = nil;
|
||||
NSError * error = nil;
|
||||
NSString * src = nil;
|
||||
|
||||
#if GGML_METAL_EMBED_LIBRARY
|
||||
GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
|
||||
|
||||
extern const char ggml_metallib_start[];
|
||||
extern const char ggml_metallib_end[];
|
||||
|
||||
src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
|
||||
|
||||
#else
|
||||
|
||||
#ifdef SWIFT_PACKAGE
|
||||
NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;
|
||||
#else
|
||||
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
||||
#endif
|
||||
|
||||
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
|
||||
if (path_lib == nil) {
|
||||
// Try to find the resource in the directory where the current binary located.
|
||||
NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
|
||||
NSString * bin_dir = [current_binary stringByDeletingLastPathComponent];
|
||||
NSString * default_metallib_path = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
|
||||
if ([[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
|
||||
GGML_LOG_INFO("%s: found '%s'\n", __func__, [default_metallib_path UTF8String]);
|
||||
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:default_metallib_path error:&error];
|
||||
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
|
||||
// Optionally, if this is a symlink, try to resolve it.
|
||||
default_metallib_path = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:default_metallib_path error:&error];
|
||||
if (default_metallib_path && [default_metallib_path length] > 0 && ![[default_metallib_path substringToIndex:1] isEqualToString:@"/"]) {
|
||||
// It is a relative path, adding the binary directory as directory prefix.
|
||||
default_metallib_path = [NSString pathWithComponents:@[bin_dir, default_metallib_path]];
|
||||
}
|
||||
if (!default_metallib_path || ![[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
|
||||
// Link to the resource could not be resolved.
|
||||
default_metallib_path = nil;
|
||||
} else {
|
||||
GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [default_metallib_path UTF8String]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// The resource couldn't be found in the binary's directory.
|
||||
default_metallib_path = nil;
|
||||
}
|
||||
path_lib = default_metallib_path;
|
||||
}
|
||||
|
||||
if (path_lib != nil) {
|
||||
// pre-compiled library found
|
||||
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
|
||||
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
|
||||
|
||||
metal_library = [device newLibraryWithURL:libURL error:&error];
|
||||
if (error) {
|
||||
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
return NULL;
|
||||
}
|
||||
} else {
|
||||
GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
|
||||
|
||||
NSString * path_source;
|
||||
NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
|
||||
|
||||
GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
|
||||
|
||||
if (path_resource) {
|
||||
path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
|
||||
} else {
|
||||
path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
||||
}
|
||||
|
||||
if (path_source == nil) {
|
||||
GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
|
||||
path_source = @"ggml-metal.metal";
|
||||
}
|
||||
|
||||
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
|
||||
|
||||
src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
|
||||
if (error) {
|
||||
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if (!metal_library) {
|
||||
@autoreleasepool {
|
||||
// dictionary of preprocessor macros
|
||||
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
||||
|
||||
if (use_bfloat) {
|
||||
[prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"];
|
||||
}
|
||||
|
||||
#if GGML_METAL_EMBED_LIBRARY
|
||||
[prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
|
||||
#endif
|
||||
|
||||
MTLCompileOptions * options = [MTLCompileOptions new];
|
||||
options.preprocessorMacros = prep;
|
||||
|
||||
//[options setFastMathEnabled:false];
|
||||
|
||||
metal_library = [device newLibraryWithSource:src options:options error:&error];
|
||||
if (error) {
|
||||
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
#if !__has_feature(objc_arc)
|
||||
[options release];
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
#if GGML_METAL_EMBED_LIBRARY
|
||||
[src release];
|
||||
#endif // GGML_METAL_EMBED_LIBRARY
|
||||
|
||||
return metal_library;
|
||||
}
|
||||
|
||||
static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) {
|
||||
GGML_LOG_INFO("%s: allocating\n", __func__);
|
||||
|
||||
@@ -520,137 +662,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
|
||||
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
||||
|
||||
id<MTLLibrary> metal_library;
|
||||
|
||||
// load library
|
||||
//
|
||||
// - first check if the library is embedded
|
||||
// - then check if the library is in the bundle
|
||||
// - if not found, load the source and compile it
|
||||
// - if that fails, return NULL
|
||||
{
|
||||
NSBundle * bundle = nil;
|
||||
#ifdef SWIFT_PACKAGE
|
||||
bundle = SWIFTPM_MODULE_BUNDLE;
|
||||
#else
|
||||
bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
||||
#endif
|
||||
|
||||
NSError * error = nil;
|
||||
|
||||
#if GGML_METAL_EMBED_LIBRARY
|
||||
const bool try_metallib = false;
|
||||
#else
|
||||
const bool try_metallib = true;
|
||||
#endif
|
||||
|
||||
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
|
||||
if (path_lib == nil) {
|
||||
// Try to find the resource in the directory where the current binary located.
|
||||
NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
|
||||
NSString * bin_dir = [current_binary stringByDeletingLastPathComponent];
|
||||
NSString * default_metallib_path = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
|
||||
if ([[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
|
||||
GGML_LOG_INFO("%s: found '%s'\n", __func__, [default_metallib_path UTF8String]);
|
||||
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:default_metallib_path error:&error];
|
||||
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
|
||||
// Optionally, if this is a symlink, try to resolve it.
|
||||
default_metallib_path = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:default_metallib_path error:&error];
|
||||
if (default_metallib_path && [default_metallib_path length] > 0 && ![[default_metallib_path substringToIndex:1] isEqualToString:@"/"]) {
|
||||
// It is a relative path, adding the binary directory as directory prefix.
|
||||
default_metallib_path = [NSString pathWithComponents:@[bin_dir, default_metallib_path]];
|
||||
}
|
||||
if (!default_metallib_path || ![[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
|
||||
// Link to the resource could not be resolved.
|
||||
default_metallib_path = nil;
|
||||
} else {
|
||||
GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [default_metallib_path UTF8String]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// The resource couldn't be found in the binary's directory.
|
||||
default_metallib_path = nil;
|
||||
}
|
||||
path_lib = default_metallib_path;
|
||||
}
|
||||
|
||||
if (try_metallib && path_lib != nil) {
|
||||
// pre-compiled library found
|
||||
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
|
||||
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
|
||||
|
||||
metal_library = [device newLibraryWithURL:libURL error:&error];
|
||||
if (error) {
|
||||
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
return NULL;
|
||||
}
|
||||
} else {
|
||||
#if GGML_METAL_EMBED_LIBRARY
|
||||
GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
|
||||
|
||||
extern const char ggml_metallib_start[];
|
||||
extern const char ggml_metallib_end[];
|
||||
|
||||
NSString * src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
|
||||
#else
|
||||
GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
|
||||
|
||||
NSString * path_source;
|
||||
NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
|
||||
|
||||
GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
|
||||
|
||||
if (path_resource) {
|
||||
path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
|
||||
} else {
|
||||
path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
||||
}
|
||||
|
||||
if (path_source == nil) {
|
||||
GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
|
||||
path_source = @"ggml-metal.metal";
|
||||
}
|
||||
|
||||
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
|
||||
|
||||
NSString * src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
|
||||
if (error) {
|
||||
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
return NULL;
|
||||
}
|
||||
#endif // GGML_METAL_EMBED_LIBRARY
|
||||
|
||||
@autoreleasepool {
|
||||
// dictionary of preprocessor macros
|
||||
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
||||
|
||||
if (ctx_dev->use_bfloat) {
|
||||
[prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"];
|
||||
}
|
||||
|
||||
#if GGML_METAL_EMBED_LIBRARY
|
||||
[prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
|
||||
#endif
|
||||
|
||||
MTLCompileOptions * options = [MTLCompileOptions new];
|
||||
options.preprocessorMacros = prep;
|
||||
|
||||
//[options setFastMathEnabled:false];
|
||||
|
||||
metal_library = [device newLibraryWithSource:src options:options error:&error];
|
||||
if (error) {
|
||||
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
#if !__has_feature(objc_arc)
|
||||
[options release];
|
||||
#endif
|
||||
}
|
||||
#if GGML_METAL_EMBED_LIBRARY
|
||||
[src release];
|
||||
#endif // GGML_METAL_EMBED_LIBRARY
|
||||
}
|
||||
if (ctx_dev->mtl_library == nil) {
|
||||
ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
|
||||
}
|
||||
id<MTLLibrary> metal_library = ctx_dev->mtl_library;
|
||||
if (metal_library == nil) {
|
||||
GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// print MTL GPU family:
|
||||
@@ -724,7 +743,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
[metal_function release]; \
|
||||
if (error) { \
|
||||
GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
||||
[metal_library release]; \
|
||||
return NULL; \
|
||||
} \
|
||||
} else { \
|
||||
@@ -1043,8 +1061,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
||||
}
|
||||
|
||||
[metal_library release];
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
@@ -1944,34 +1960,38 @@ static void ggml_metal_encode_node(
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
|
||||
ggml_metal_kargs_sum_rows args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne10 =*/ ne10,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.ne12 =*/ ne12,
|
||||
/*.ne13 =*/ ne13,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
||||
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
||||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
||||
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
|
||||
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
|
||||
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
|
||||
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
|
||||
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
|
||||
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
@@ -2020,8 +2040,17 @@ static void ggml_metal_encode_node(
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
// TODO: optimize (see https://github.com/ggml-org/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
|
||||
ggml_metal_kargs_soft_max args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.scale =*/ scale,
|
||||
/*.max_bias =*/ max_bias,
|
||||
/*.m0 =*/ m0,
|
||||
/*.m1 =*/ m1,
|
||||
/*.n_head_log2 =*/ n_head_log2,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
if (id_src1) {
|
||||
@@ -2030,14 +2059,7 @@ static void ggml_metal_encode_node(
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
}
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
||||
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
|
||||
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
|
||||
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
|
||||
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
@@ -2055,13 +2077,16 @@ static void ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
|
||||
}
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
ggml_metal_kargs_diag_mask_inf args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.n_past =*/ n_past,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
||||
|
||||
if (ne00%8 == 0) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
@@ -2080,27 +2105,30 @@ static void ggml_metal_encode_node(
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
ggml_metal_kargs_ssm_conv args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.ne10 =*/ ne10,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
||||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
||||
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15];
|
||||
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17];
|
||||
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
@@ -2151,7 +2179,31 @@ static void ggml_metal_encode_node(
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
ggml_metal_kargs_ssm_scan args = {
|
||||
/*.d_state =*/ d_state,
|
||||
/*.d_inner =*/ d_inner,
|
||||
/*.n_seq_tokens =*/ n_seq_tokens,
|
||||
/*.n_seqs =*/ n_seqs,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.nb20 =*/ nb20,
|
||||
/*.nb21 =*/ nb21,
|
||||
/*.nb22 =*/ nb22,
|
||||
/*.nb30 =*/ nb30,
|
||||
/*.nb31 =*/ nb31,
|
||||
/*.nb40 =*/ nb40,
|
||||
/*.nb41 =*/ nb41,
|
||||
/*.nb42 =*/ nb42,
|
||||
/*.nb50 =*/ nb50,
|
||||
/*.nb51 =*/ nb51,
|
||||
/*.nb52 =*/ nb52,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
@@ -2160,30 +2212,7 @@ static void ggml_metal_encode_node(
|
||||
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
||||
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
||||
|
||||
[encoder setBytes:&d_state length:sizeof(d_state) atIndex:7];
|
||||
[encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8];
|
||||
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
|
||||
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10];
|
||||
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
||||
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
||||
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
|
||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
|
||||
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
|
||||
[encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
|
||||
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
|
||||
[encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
|
||||
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
|
||||
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
|
||||
[encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
|
||||
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
|
||||
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:7];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
@@ -3040,19 +3069,22 @@ static void ggml_metal_encode_node(
|
||||
default: GGML_ABORT("not implemented");
|
||||
}
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
ggml_metal_kargs_get_rows args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.ne10 =*/ ne10,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
||||
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
||||
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
|
||||
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
||||
} break;
|
||||
@@ -3109,18 +3141,21 @@ static void ggml_metal_encode_node(
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
ggml_metal_kargs_group_norm args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.n_groups =*/ n_groups,
|
||||
/*.eps =*/ eps,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
|
||||
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
|
||||
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
@@ -3278,8 +3313,8 @@ static void ggml_metal_encode_node(
|
||||
|
||||
const int32_t CHW = IC * KH * KW;
|
||||
|
||||
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
||||
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
||||
const uint64_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
||||
const uint64_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
|
||||
|
||||
@@ -3301,27 +3336,30 @@ static void ggml_metal_encode_node(
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
ggml_metal_kargs_im2col args = {
|
||||
/*.ofs0 =*/ ofs0,
|
||||
/*.ofs1 =*/ ofs1,
|
||||
/*.IW =*/ IW,
|
||||
/*.IH =*/ IH,
|
||||
/*.CHW =*/ CHW,
|
||||
/*.s0 =*/ s0,
|
||||
/*.s1 =*/ s1,
|
||||
/*.p0 =*/ p0,
|
||||
/*.p1 =*/ p1,
|
||||
/*.d0 =*/ d0,
|
||||
/*.d1 =*/ d1,
|
||||
/*.N =*/ N,
|
||||
/*.KH =*/ KH,
|
||||
/*.KW =*/ KW,
|
||||
/*.KHW =*/ KH * KW,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2];
|
||||
[encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
|
||||
[encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
|
||||
[encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
|
||||
[encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
|
||||
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
|
||||
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
|
||||
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
|
||||
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
|
||||
[encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
|
||||
[encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
||||
|
||||
if (is_gt_mttpt) {
|
||||
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
|
||||
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
|
||||
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
|
||||
|
||||
const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
|
||||
|
||||
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
|
||||
@@ -3361,16 +3399,20 @@ static void ggml_metal_encode_node(
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
ggml_metal_kargs_conv_transpose_1d args = {
|
||||
/*.IC =*/ IC,
|
||||
/*.IL =*/ IL,
|
||||
/*.K =*/ K,
|
||||
/*.s0 =*/ s0,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&IC length:sizeof( int32_t) atIndex:3];
|
||||
[encoder setBytes:&IL length:sizeof( int32_t) atIndex:4];
|
||||
[encoder setBytes:&K length:sizeof( int32_t) atIndex:5];
|
||||
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6];
|
||||
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
@@ -3385,30 +3427,33 @@ static void ggml_metal_encode_node(
|
||||
|
||||
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
ggml_metal_kargs_upscale args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
/*.sf0 =*/ sf0,
|
||||
/*.sf1 =*/ sf1,
|
||||
/*.sf2 =*/ sf2,
|
||||
/*.sf3 =*/ sf3
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
||||
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
||||
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
||||
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
||||
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
||||
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
||||
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
||||
[encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
|
||||
[encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
|
||||
[encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
|
||||
[encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
||||
|
||||
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
||||
|
||||
@@ -3420,26 +3465,29 @@ static void ggml_metal_encode_node(
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
ggml_metal_kargs_pad args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
||||
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
||||
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
||||
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
||||
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
||||
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
||||
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
||||
|
||||
const int nth = MIN(1024, ne0);
|
||||
|
||||
@@ -3454,24 +3502,31 @@ static void ggml_metal_encode_node(
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
|
||||
|
||||
ggml_metal_kargs_pad_reflect_1d args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
/*.p0 =*/ p0,
|
||||
/*.p1 =*/ p1
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
||||
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
||||
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12];
|
||||
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13];
|
||||
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14];
|
||||
[encoder setBytes:&p0 length:sizeof(p0) atIndex:15];
|
||||
[encoder setBytes:&p1 length:sizeof(p1) atIndex:16];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
||||
|
||||
const int nth = MIN(1024, ne0);
|
||||
|
||||
@@ -3489,12 +3544,15 @@ static void ggml_metal_encode_node(
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
ggml_metal_kargs_arange args = {
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.start =*/ start,
|
||||
/*.step =*/ step
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
|
||||
[encoder setBytes:&start length:sizeof(start) atIndex:2];
|
||||
[encoder setBytes:&step length:sizeof(step) atIndex:3];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:1];
|
||||
|
||||
const int nth = MIN(1024, ne0);
|
||||
|
||||
@@ -3511,13 +3569,16 @@ static void ggml_metal_encode_node(
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
ggml_metal_kargs_timestep_embedding args = {
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.dim =*/ dim,
|
||||
/*.max_period =*/ max_period
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
|
||||
[encoder setBytes:&dim length:sizeof(dim) atIndex:3];
|
||||
[encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
||||
|
||||
const int nth = MIN(1024, half);
|
||||
|
||||
@@ -3550,12 +3611,15 @@ static void ggml_metal_encode_node(
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
ggml_metal_kargs_argsort args = {
|
||||
/*.ncols =*/ ne00,
|
||||
/*.ncols_pad =*/ ne00_padded
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||
[encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
|
||||
@@ -3569,11 +3633,14 @@ static void ggml_metal_encode_node(
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
ggml_metal_kargs_leaky_relu args = {
|
||||
/*.slope =*/ slope
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&slope length:sizeof(slope) atIndex:2];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
@@ -4149,21 +4216,24 @@ static void ggml_metal_encode_node(
|
||||
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
|
||||
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
ggml_metal_kargs_pool_2d args_pool_2d = {
|
||||
/* .k0 = */ k0,
|
||||
/* .k1 = */ k1,
|
||||
/* .s0 = */ s0,
|
||||
/* .s1 = */ s1,
|
||||
/* .p0 = */ p0,
|
||||
/* .p1 = */ p1,
|
||||
/* .IH = */ IH,
|
||||
/* .IW = */ IW,
|
||||
/* .OH = */ OH,
|
||||
/* .OW = */ OW,
|
||||
/* .parallel_elements = */ parallel_elements
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
|
||||
[encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
|
||||
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
|
||||
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
|
||||
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
|
||||
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
|
||||
[encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
|
||||
[encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
|
||||
[encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
|
||||
[encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
|
||||
[encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&args_pool_2d length:sizeof(args_pool_2d) atIndex:2];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
||||
} break;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -21,7 +21,7 @@ if (MUSAToolkit_FOUND)
|
||||
message(STATUS "MUSA Toolkit found")
|
||||
|
||||
if (NOT DEFINED MUSA_ARCHITECTURES)
|
||||
set(MUSA_ARCHITECTURES "21;22")
|
||||
set(MUSA_ARCHITECTURES "21;22;31")
|
||||
endif()
|
||||
message(STATUS "Using MUSA architectures: ${MUSA_ARCHITECTURES}")
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ if (GGML_OPENCL_PROFILING)
|
||||
endif ()
|
||||
|
||||
add_compile_definitions(GGML_OPENCL_SOA_Q)
|
||||
add_compile_definitions(GGML_OPENCL_TARGET_VERSION=${GGML_OPENCL_TARGET_VERSION})
|
||||
|
||||
if (GGML_OPENCL_USE_ADRENO_KERNELS)
|
||||
message(STATUS "OpenCL will use matmul kernels optimized for Adreno")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#define CL_TARGET_OPENCL_VERSION 220
|
||||
#define CL_TARGET_OPENCL_VERSION GGML_OPENCL_TARGET_VERSION
|
||||
#define CL_USE_DEPRECATED_OPENCL_1_2_APIS
|
||||
|
||||
// suppress warnings in CL headers for GCC and Clang
|
||||
@@ -25,6 +25,8 @@
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <charconv>
|
||||
|
||||
#undef MIN
|
||||
#undef MAX
|
||||
@@ -62,6 +64,97 @@ enum ADRENO_GPU_GEN {
|
||||
X1E,
|
||||
};
|
||||
|
||||
struct ggml_cl_version {
|
||||
cl_uint major = 0;
|
||||
cl_uint minor = 0;
|
||||
};
|
||||
|
||||
// Parses a version string of form "XX.YY ". On an error returns ggml_cl_version with all zeroes.
|
||||
static ggml_cl_version parse_cl_version(std::string_view str) {
|
||||
size_t major_str_begin = 0;
|
||||
size_t major_str_end = str.find(".", major_str_begin);
|
||||
if (major_str_end == std::string::npos) {
|
||||
return {};
|
||||
}
|
||||
|
||||
size_t minor_str_begin = major_str_end + 1;
|
||||
size_t minor_str_end = str.find(" ", minor_str_begin);
|
||||
if (minor_str_end == std::string::npos) {
|
||||
return {};
|
||||
}
|
||||
|
||||
cl_uint version_major;
|
||||
if (std::from_chars(str.data() + major_str_begin, str.data() + major_str_end, version_major).ec != std::errc{}) {
|
||||
return {};
|
||||
}
|
||||
|
||||
cl_uint version_minor;
|
||||
if (std::from_chars(str.data() + minor_str_begin, str.data() + minor_str_end, version_minor).ec != std::errc{}) {
|
||||
return {};
|
||||
}
|
||||
return { version_major, version_minor };
|
||||
}
|
||||
|
||||
// Returns OpenCL platform's version. On an error returns ggml_cl_version with all zeroes.
|
||||
static ggml_cl_version get_opencl_platform_version(cl_platform_id platform) {
|
||||
size_t param_size;
|
||||
CL_CHECK(clGetPlatformInfo(platform, CL_PLATFORM_VERSION, 0, nullptr, ¶m_size));
|
||||
std::unique_ptr<char[]> param_storage(new char[param_size]);
|
||||
CL_CHECK(clGetPlatformInfo(platform, CL_PLATFORM_VERSION, param_size, param_storage.get(), nullptr));
|
||||
|
||||
auto param_value = std::string_view(param_storage.get(), param_size);
|
||||
const std::string version_prefix = "OpenCL "; // Suffix: "XX.YY <platform-specific-info>"
|
||||
if (param_value.find(version_prefix) != 0) {
|
||||
return {};
|
||||
}
|
||||
param_value.remove_prefix(version_prefix.length());
|
||||
return parse_cl_version(param_value);
|
||||
}
|
||||
|
||||
// Return a version to use in OpenCL C compilation. On an error returns ggml_cl_version with all zeroes.
|
||||
static ggml_cl_version get_opencl_c_version(ggml_cl_version platform_version, cl_device_id device) {
|
||||
size_t param_size;
|
||||
|
||||
#if CL_TARGET_OPENCL_VERSION >= 300
|
||||
if (platform_version.major >= 3) {
|
||||
CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_OPENCL_C_ALL_VERSIONS, 0, nullptr, ¶m_size));
|
||||
if (!param_size) {
|
||||
return {};
|
||||
}
|
||||
|
||||
std::unique_ptr<cl_name_version[]> versions(new cl_name_version[param_size]);
|
||||
CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_OPENCL_C_ALL_VERSIONS, param_size, versions.get(), nullptr));
|
||||
unsigned versions_count = param_size / sizeof(cl_name_version);
|
||||
|
||||
cl_version version_max = 0;
|
||||
for (unsigned i = 0; i < versions_count; i++) {
|
||||
version_max = std::max<cl_version>(versions[i].version, version_max);
|
||||
}
|
||||
|
||||
return { CL_VERSION_MAJOR(version_max), CL_VERSION_MINOR(version_max) };
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(platform_version);
|
||||
#endif // CL_TARGET_OPENCL_VERSION >= 300
|
||||
|
||||
CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_OPENCL_C_VERSION, 0, nullptr, ¶m_size));
|
||||
if (!param_size) {
|
||||
return {};
|
||||
}
|
||||
|
||||
std::unique_ptr<char[]> param_storage(new char[param_size]);
|
||||
CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_OPENCL_C_VERSION, param_size, param_storage.get(), nullptr));
|
||||
auto param_value = std::string_view(param_storage.get(), param_size);
|
||||
|
||||
const std::string version_prefix = "OpenCL C "; // Suffix: "XX.YY <platform-specific-info>"
|
||||
if (param_value.find(version_prefix) != 0) {
|
||||
return {};
|
||||
}
|
||||
param_value.remove_prefix(version_prefix.length());
|
||||
|
||||
return parse_cl_version(param_value);
|
||||
}
|
||||
|
||||
static ADRENO_GPU_GEN get_adreno_gpu_gen(const char *device_name) {
|
||||
if (strstr(device_name, "730") ||
|
||||
strstr(device_name, "740") ||
|
||||
@@ -278,7 +371,7 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
||||
|
||||
cl_int err;
|
||||
|
||||
#ifdef GGML_PROFILE_OPENCL
|
||||
#ifdef GGML_OPENCL_PROFILING
|
||||
GGML_LOG_INFO("ggml_opencl: OpenCL profiling enabled\n");
|
||||
#endif
|
||||
|
||||
@@ -470,16 +563,11 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
||||
// A local ref of cl_device_id for convenience
|
||||
cl_device_id device = backend_ctx->device;
|
||||
|
||||
// Check device OpenCL version, OpenCL 2.0 or above is required
|
||||
size_t device_ver_str_size;
|
||||
clGetDeviceInfo(device, CL_DEVICE_VERSION, 0, NULL, &device_ver_str_size);
|
||||
char *device_ver_buffer = (char *)alloca(device_ver_str_size + 1);
|
||||
clGetDeviceInfo(device, CL_DEVICE_VERSION, device_ver_str_size, device_ver_buffer, NULL);
|
||||
device_ver_buffer[device_ver_str_size] = '\0';
|
||||
GGML_LOG_INFO("ggml_opencl: device OpenCL version: %s\n", device_ver_buffer);
|
||||
ggml_cl_version platform_version = get_opencl_platform_version(default_device->platform->id);
|
||||
|
||||
if (strstr(device_ver_buffer, "OpenCL 2") == NULL &&
|
||||
strstr(device_ver_buffer, "OpenCL 3") == NULL) {
|
||||
// Check device OpenCL version, OpenCL 2.0 or above is required
|
||||
ggml_cl_version opencl_c_version = get_opencl_c_version(platform_version, device);
|
||||
if (opencl_c_version.major < 2) {
|
||||
GGML_LOG_ERROR("ggml_opencl: OpenCL 2.0 or above is required\n");
|
||||
return backend_ctx;
|
||||
}
|
||||
@@ -516,15 +604,17 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
||||
|
||||
// If OpenCL 3.0 is supported, then check for cl_khr_subgroups, which becomes
|
||||
// optional in OpenCL 3.0 (cl_khr_subgroup is mandatory in OpenCL 2.x)
|
||||
if (strstr(device_ver_buffer, "OpenCL 3") &&
|
||||
strstr(ext_buffer, "cl_khr_subgroups") == NULL &&
|
||||
if (opencl_c_version.major == 3 && strstr(ext_buffer, "cl_khr_subgroups") == NULL &&
|
||||
strstr(ext_buffer, "cl_intel_subgroups") == NULL) {
|
||||
GGML_LOG_ERROR("ggml_opencl: device does not support subgroups (cl_khr_subgroups or cl_intel_subgroups) "
|
||||
"(note that subgroups is an optional feature in OpenCL 3.0)\n");
|
||||
return backend_ctx;
|
||||
}
|
||||
|
||||
CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &backend_ctx->alignment, NULL));
|
||||
cl_uint base_align_in_bits;
|
||||
CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &base_align_in_bits, NULL));
|
||||
GGML_ASSERT(base_align_in_bits % 8u == 0);
|
||||
backend_ctx->alignment = base_align_in_bits / 8u;
|
||||
GGML_LOG_INFO("ggml_opencl: mem base addr align: %u\n", backend_ctx->alignment);
|
||||
|
||||
clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL);
|
||||
@@ -578,9 +668,12 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
||||
const std::string kernel_src = read_file("ggml-opencl.cl");
|
||||
#endif
|
||||
|
||||
std::string compile_opts =
|
||||
"-cl-std=CL2.0 -cl-mad-enable -cl-unsafe-math-optimizations "
|
||||
"-cl-finite-math-only -cl-fast-relaxed-math ";
|
||||
auto opencl_c_std =
|
||||
std::string("CL") + std::to_string(opencl_c_version.major) + "." + std::to_string(opencl_c_version.minor);
|
||||
|
||||
std::string compile_opts = std::string("-cl-std=") + opencl_c_std +
|
||||
" -cl-mad-enable -cl-unsafe-math-optimizations"
|
||||
" -cl-finite-math-only -cl-fast-relaxed-math";
|
||||
backend_ctx->program = build_program_from_source(context, device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
// Non matmul kernels.
|
||||
@@ -690,10 +783,10 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
||||
CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose_16, "kernel_transpose_16", &err), err));
|
||||
|
||||
// Gemv general
|
||||
std::string CL_gemv_compile_opts =
|
||||
" -cl-std=CL2.0 "
|
||||
" -cl-mad-enable "
|
||||
" -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size);
|
||||
std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std +
|
||||
" -cl-mad-enable "
|
||||
" -DSIMDGROUP_WIDTH=" +
|
||||
std::to_string(backend_ctx->adreno_wave_size);
|
||||
if (has_vector_subgroup_broadcast) {
|
||||
CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT ";
|
||||
}
|
||||
@@ -710,12 +803,12 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
||||
CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general = clCreateKernel(backend_ctx->program_CL_gemv_general, "kernel_gemv_noshuffle", &err), err));
|
||||
|
||||
// Gemv 2048, 16384
|
||||
CL_gemv_compile_opts =
|
||||
" -cl-std=CL2.0 "
|
||||
" -cl-mad-enable "
|
||||
" -DLINE_STRIDE_A=2048 "
|
||||
" -DBLOCK_STRIDE_A=16384 "
|
||||
" -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size);
|
||||
CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std +
|
||||
" -cl-mad-enable "
|
||||
" -DLINE_STRIDE_A=2048 "
|
||||
" -DBLOCK_STRIDE_A=16384 "
|
||||
" -DSIMDGROUP_WIDTH=" +
|
||||
std::to_string(backend_ctx->adreno_wave_size);
|
||||
if (has_vector_subgroup_broadcast) {
|
||||
CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT ";
|
||||
}
|
||||
@@ -732,12 +825,12 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
||||
CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_4096, "kernel_gemv_noshuffle", &err), err));
|
||||
|
||||
// Gemv 2048, 16384
|
||||
CL_gemv_compile_opts =
|
||||
" -cl-std=CL2.0 "
|
||||
" -cl-mad-enable "
|
||||
" -DLINE_STRIDE_A=2048 "
|
||||
" -DBLOCK_STRIDE_A=16384 "
|
||||
" -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size);
|
||||
CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std +
|
||||
" -cl-mad-enable "
|
||||
" -DLINE_STRIDE_A=2048 "
|
||||
" -DBLOCK_STRIDE_A=16384 "
|
||||
" -DSIMDGROUP_WIDTH=" +
|
||||
std::to_string(backend_ctx->adreno_wave_size);
|
||||
if (has_vector_subgroup_broadcast) {
|
||||
CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT ";
|
||||
}
|
||||
@@ -747,12 +840,12 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
||||
CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_11008, "kernel_gemv_noshuffle", &err), err));
|
||||
|
||||
// Gemv 5504, 44032
|
||||
CL_gemv_compile_opts =
|
||||
" -cl-std=CL2.0 "
|
||||
" -cl-mad-enable "
|
||||
" -DLINE_STRIDE_A=5504 "
|
||||
" -DBLOCK_STRIDE_A=44032 "
|
||||
" -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size);
|
||||
CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std +
|
||||
" -cl-mad-enable "
|
||||
" -DLINE_STRIDE_A=5504 "
|
||||
" -DBLOCK_STRIDE_A=44032 "
|
||||
" -DSIMDGROUP_WIDTH=" +
|
||||
std::to_string(backend_ctx->adreno_wave_size);
|
||||
if (has_vector_subgroup_broadcast) {
|
||||
CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT ";
|
||||
}
|
||||
@@ -762,12 +855,12 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
||||
CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_11008_1_4096, "kernel_gemv_noshuffle", &err), err));
|
||||
|
||||
// Gemv 16000, 128000
|
||||
CL_gemv_compile_opts =
|
||||
" -cl-std=CL2.0 "
|
||||
" -cl-mad-enable "
|
||||
" -DLINE_STRIDE_A=16000 "
|
||||
" -DBLOCK_STRIDE_A=128000 "
|
||||
" -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size);
|
||||
CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std +
|
||||
" -cl-mad-enable "
|
||||
" -DLINE_STRIDE_A=16000 "
|
||||
" -DBLOCK_STRIDE_A=128000 "
|
||||
" -DSIMDGROUP_WIDTH=" +
|
||||
std::to_string(backend_ctx->adreno_wave_size);
|
||||
if (has_vector_subgroup_broadcast) {
|
||||
CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT ";
|
||||
}
|
||||
@@ -1004,17 +1097,18 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_MUL:
|
||||
return true;
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
case GGML_OP_CLAMP:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
@@ -1198,17 +1292,14 @@ struct ggml_backend_opencl_buffer_context {
|
||||
std::string name;
|
||||
};
|
||||
|
||||
static void * const cl_ptr_base = (void *)(uintptr_t) 0x1000;
|
||||
|
||||
static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||
return cl_ptr_base;
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer->buft->device);
|
||||
return (void *) (uintptr_t) backend_ctx->alignment;
|
||||
}
|
||||
|
||||
static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
|
||||
@@ -1241,7 +1332,7 @@ static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buff
|
||||
tensor->extra = view_extra;
|
||||
} else {
|
||||
{
|
||||
size_t offset = (char *)tensor->data - (char *)cl_ptr_base;
|
||||
size_t offset = (char *) tensor->data - (char *) ggml_backend_opencl_buffer_get_base(buffer);
|
||||
|
||||
ggml_tensor_extra_cl * extra = ctx->ggml_opencl_alloc_temp_tensor_extra();
|
||||
extra->offset = offset;
|
||||
@@ -2573,26 +2664,33 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
const int ne00 = src0 ? src0->ne[0] : 0;
|
||||
const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
|
||||
const int ne01 = src0 ? src0->ne[1] : 0;
|
||||
const int ne02 = src0 ? src0->ne[2] : 0;
|
||||
const int ne03 = src0 ? src0->ne[3] : 0;
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||
const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
|
||||
const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
|
||||
const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
|
||||
|
||||
const int nth = MIN(64, ne00);
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_norm;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(float)*nth, NULL));
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth, NULL));
|
||||
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
size_t global_work_size[] = {(size_t)nrows*nth, 1, 1};
|
||||
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
||||
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
||||
|
||||
#ifdef GGML_OPENCL_PROFILING
|
||||
@@ -2630,16 +2728,19 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
const int ne00 = src0 ? src0->ne[0] : 0;
|
||||
const int ne01 = src0 ? src0->ne[1] : 0;
|
||||
const int ne02 = src0 ? src0->ne[2] : 0;
|
||||
const int ne03 = src0 ? src0->ne[3] : 0;
|
||||
|
||||
const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
|
||||
const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
|
||||
const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
|
||||
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||
|
||||
const int nth = MIN(64, ne00);
|
||||
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
size_t global_work_size[] = {(size_t)nrows*nth, 1, 1};
|
||||
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
||||
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_rms_norm;
|
||||
@@ -2654,15 +2755,20 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
|
||||
sizeof(local_work_size), local_work_size,
|
||||
sizeof(size_t), &sgs, NULL));
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps));
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps));
|
||||
// This is local memory - the size depends on subgroup size.
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(float)*nth/sgs, NULL));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth/sgs, NULL));
|
||||
|
||||
#ifdef GGML_OPENCL_PROFILING
|
||||
cl_event evt;
|
||||
@@ -3023,6 +3129,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
// enqueue kernel with profiling
|
||||
// <--------------------------------------------> //
|
||||
#ifdef GGML_OPENCL_PROFILING
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
|
||||
|
||||
g_profiling_info.emplace_back();
|
||||
@@ -3764,10 +3871,10 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||
const int ne02 = src0 ? src0->ne[2] : 0;
|
||||
const int ne03 = src0 ? src0->ne[3] : 0;
|
||||
|
||||
const int nb00 = src0 ? src0->nb[0] : 0;
|
||||
const int nb01 = src0 ? src0->nb[1] : 0;
|
||||
const int nb02 = src0 ? src0->nb[2] : 0;
|
||||
const int nb03 = src0 ? src0->nb[3] : 0;
|
||||
const cl_ulong nb00 = src0 ? src0->nb[0] : 0;
|
||||
const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
|
||||
const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
|
||||
const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
|
||||
|
||||
const int ne10 = src1 ? src1->ne[0] : 0;
|
||||
const int ne11 = src1 ? src1->ne[1] : 0; UNUSED(ne11);
|
||||
@@ -3779,10 +3886,10 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||
const int ne2 = dst ? dst->ne[2] : 0;
|
||||
const int ne3 = dst ? dst->ne[3] : 0;
|
||||
|
||||
const int nb0 = dst ? dst->nb[0] : 0;
|
||||
const int nb1 = dst ? dst->nb[1] : 0;
|
||||
const int nb2 = dst ? dst->nb[2] : 0;
|
||||
const int nb3 = dst ? dst->nb[3] : 0;
|
||||
const cl_ulong nb0 = dst ? dst->nb[0] : 0;
|
||||
const cl_ulong nb1 = dst ? dst->nb[1] : 0;
|
||||
const cl_ulong nb2 = dst ? dst->nb[2] : 0;
|
||||
const cl_ulong nb3 = dst ? dst->nb[3] : 0;
|
||||
|
||||
GGML_ASSERT(ne10 % ne02 == 0);
|
||||
GGML_ASSERT(ne10 >= ne02);
|
||||
|
||||
@@ -506,14 +506,23 @@ kernel void kernel_norm(
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
float eps,
|
||||
local float * sum
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
dst = (global void*)((global char*)dst + offsetd);
|
||||
|
||||
global float * x = (global float *) ((global char *) src0 + get_group_id(0)*nb01);
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
|
||||
global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
|
||||
// MEAN
|
||||
// parallel sum
|
||||
@@ -533,7 +542,7 @@ kernel void kernel_norm(
|
||||
|
||||
// recenter and VARIANCE
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
global float * y = dst + get_group_id(0)*ne00;
|
||||
global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
sum[get_local_id(0)] = 0.0f;
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
y[i00] = x[i00] - mean;
|
||||
@@ -566,14 +575,23 @@ kernel void kernel_rms_norm(
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
float eps,
|
||||
local float * sum // Note, the size depends on number of subgroups
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
global float4 * x = (global float4 *) ((global char *) src0 + get_group_id(0)*nb01);
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
|
||||
global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
global float * x_scalar = (global float *) x;
|
||||
float4 sumf = 0;
|
||||
float all_sum = 0;
|
||||
@@ -607,7 +625,7 @@ kernel void kernel_rms_norm(
|
||||
const float mean = sum[0];
|
||||
const float scale = 1.0f/sqrt(mean + eps);
|
||||
|
||||
global float4 * y = (global float4 *) (dst + get_group_id(0)*ne00);
|
||||
global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
global float * y_scalar = (global float *) y;
|
||||
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
||||
y[i00] = x[i00] * scale;
|
||||
|
||||
@@ -41,6 +41,7 @@
|
||||
#include "ggml-sycl/gemm.hpp"
|
||||
#include "ggml-sycl/sycl_hw.hpp"
|
||||
#include "ggml-sycl/getrows.hpp"
|
||||
#include "ggml.h"
|
||||
|
||||
static bool g_sycl_loaded = false;
|
||||
int g_ggml_sycl_debug = 0;
|
||||
@@ -3864,7 +3865,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -3981,23 +3982,24 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
return true;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_LOG:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
return true;
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_LOG:
|
||||
return (op->src[0]->type == GGML_TYPE_F32);
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_SCALE:
|
||||
return true;
|
||||
case GGML_OP_CONT:
|
||||
return op->src[0]->type != GGML_TYPE_BF16;
|
||||
|
||||
@@ -5,23 +5,24 @@
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16];
|
||||
shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16];
|
||||
shared FLOAT_TYPE sccache1[2][BLOCK_SIZE/16][16];
|
||||
shared FLOAT_TYPE sccache2[2][BLOCK_SIZE/16][16];
|
||||
|
||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||
uint csel = 0;
|
||||
|
||||
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
|
||||
const uint y_idx = i * QUANT_K + y_offset;
|
||||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||
csel ^= 1;
|
||||
|
||||
barrier();
|
||||
if (!all_threads) { // when we don't have enough blocks to use all threads
|
||||
if (i < num_blocks_per_row) {
|
||||
const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
|
||||
sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
|
||||
sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
|
||||
sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);
|
||||
sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
|
||||
}
|
||||
barrier();
|
||||
|
||||
@@ -29,8 +30,8 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
|
||||
continue;
|
||||
} else {
|
||||
const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
|
||||
sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
|
||||
sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
|
||||
sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);
|
||||
sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
|
||||
barrier();
|
||||
}
|
||||
|
||||
@@ -57,22 +58,22 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
|
||||
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
|
||||
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
|
||||
[[unroll]] for (int l = 0; l < 2; ++l) {
|
||||
sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[ix][ 8*v_im] * qs_u32_0[l ],
|
||||
fma(FLOAT_TYPE(b16[l]), sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2],
|
||||
fma(FLOAT_TYPE(b32[l]), sccache1[ix][2 + 8*v_im] * qs_u32_2[l ],
|
||||
fma(FLOAT_TYPE(b48[l]), sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2],
|
||||
fma(FLOAT_TYPE(b64[l]), sccache1[ix][4 + 8*v_im] * qs_u32_4[l ],
|
||||
fma(FLOAT_TYPE(b80[l]), sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2],
|
||||
fma(FLOAT_TYPE(b96[l]), sccache1[ix][6 + 8*v_im] * qs_u32_6[l ],
|
||||
fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
|
||||
sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[ix][ 8*v_im],
|
||||
fma(FLOAT_TYPE(b16[l]), sccache2[ix][1 + 8*v_im],
|
||||
fma(FLOAT_TYPE(b32[l]), sccache2[ix][2 + 8*v_im],
|
||||
fma(FLOAT_TYPE(b48[l]), sccache2[ix][3 + 8*v_im],
|
||||
fma(FLOAT_TYPE(b64[l]), sccache2[ix][4 + 8*v_im],
|
||||
fma(FLOAT_TYPE(b80[l]), sccache2[ix][5 + 8*v_im],
|
||||
fma(FLOAT_TYPE(b96[l]), sccache2[ix][6 + 8*v_im],
|
||||
fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2))))))));
|
||||
sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[csel][ix][ 8*v_im] * qs_u32_0[l ],
|
||||
fma(FLOAT_TYPE(b16[l]), sccache1[csel][ix][1 + 8*v_im] * qs_u32_0[l+2],
|
||||
fma(FLOAT_TYPE(b32[l]), sccache1[csel][ix][2 + 8*v_im] * qs_u32_2[l ],
|
||||
fma(FLOAT_TYPE(b48[l]), sccache1[csel][ix][3 + 8*v_im] * qs_u32_2[l+2],
|
||||
fma(FLOAT_TYPE(b64[l]), sccache1[csel][ix][4 + 8*v_im] * qs_u32_4[l ],
|
||||
fma(FLOAT_TYPE(b80[l]), sccache1[csel][ix][5 + 8*v_im] * qs_u32_4[l+2],
|
||||
fma(FLOAT_TYPE(b96[l]), sccache1[csel][ix][6 + 8*v_im] * qs_u32_6[l ],
|
||||
fma(FLOAT_TYPE(b112[l]), sccache1[csel][ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
|
||||
sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[csel][ix][ 8*v_im],
|
||||
fma(FLOAT_TYPE(b16[l]), sccache2[csel][ix][1 + 8*v_im],
|
||||
fma(FLOAT_TYPE(b32[l]), sccache2[csel][ix][2 + 8*v_im],
|
||||
fma(FLOAT_TYPE(b48[l]), sccache2[csel][ix][3 + 8*v_im],
|
||||
fma(FLOAT_TYPE(b64[l]), sccache2[csel][ix][4 + 8*v_im],
|
||||
fma(FLOAT_TYPE(b80[l]), sccache2[csel][ix][5 + 8*v_im],
|
||||
fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im],
|
||||
fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
|
||||
}
|
||||
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
|
||||
}
|
||||
|
||||
@@ -5,20 +5,21 @@
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][8];
|
||||
shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][2][8];
|
||||
|
||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||
uint csel = 0;
|
||||
|
||||
void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
|
||||
const uint y_idx = i * QUANT_K + y_offset;
|
||||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||
csel ^= 1;
|
||||
|
||||
if (!all_threads) { // when we don't have enough blocks to use all threads
|
||||
barrier();
|
||||
if (i < num_blocks_per_row)
|
||||
sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
|
||||
sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
|
||||
barrier();
|
||||
|
||||
if (i >= num_blocks_per_row)
|
||||
@@ -40,8 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co
|
||||
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
|
||||
|
||||
if (all_threads) {
|
||||
barrier();
|
||||
sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
|
||||
sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
|
||||
barrier();
|
||||
}
|
||||
|
||||
@@ -59,14 +59,14 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co
|
||||
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
|
||||
[[unroll]] for (int l = 0; l < 2; ++l) {
|
||||
sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - hmk_0[l ],
|
||||
fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2],
|
||||
fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - hmk_1[l ],
|
||||
fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2],
|
||||
fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - hmk_2[l ],
|
||||
fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2],
|
||||
fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - hmk_3[l ],
|
||||
fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum))))))));
|
||||
sum = fma(FLOAT_TYPE( b0[l]) * sccache[csel][ix][v_im][0], qs_u32_0[l ] - hmk_0[l ],
|
||||
fma(FLOAT_TYPE( b16[l]) * sccache[csel][ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2],
|
||||
fma(FLOAT_TYPE( b32[l]) * sccache[csel][ix][v_im][2], qs_u32_2[l ] - hmk_1[l ],
|
||||
fma(FLOAT_TYPE( b48[l]) * sccache[csel][ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2],
|
||||
fma(FLOAT_TYPE( b64[l]) * sccache[csel][ix][v_im][4], qs_u32_4[l ] - hmk_2[l ],
|
||||
fma(FLOAT_TYPE( b80[l]) * sccache[csel][ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2],
|
||||
fma(FLOAT_TYPE( b96[l]) * sccache[csel][ix][v_im][6], qs_u32_6[l ] - hmk_3[l ],
|
||||
fma(FLOAT_TYPE(b112[l]) * sccache[csel][ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum))))))));
|
||||
}
|
||||
temp[j][n] = fma(d, sum, temp[j][n]);
|
||||
}
|
||||
|
||||
@@ -6,20 +6,21 @@
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16];
|
||||
shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][16];
|
||||
|
||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||
uint csel = 0;
|
||||
|
||||
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
|
||||
const uint y_idx = i * QUANT_K + y_offset;
|
||||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||
csel ^= 1;
|
||||
|
||||
if (!all_threads) { // when we don't have enough blocks to use all threads
|
||||
barrier();
|
||||
if (i < num_blocks_per_row)
|
||||
sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
|
||||
sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
|
||||
barrier();
|
||||
|
||||
if (i >= num_blocks_per_row)
|
||||
@@ -51,8 +52,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
|
||||
const vec4 q3 = vec4(unpack8(q3_u32)) - 32;
|
||||
|
||||
if (all_threads) {
|
||||
barrier();
|
||||
sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
|
||||
sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
|
||||
barrier();
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
|
||||
sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]);
|
||||
sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]);
|
||||
}
|
||||
temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]);
|
||||
temp[j][n] = fma(fma(sum[0], sccache[csel][ix][s_offset], fma(sum[1], sccache[csel][ix][s_offset + 2], fma(sum[2], sccache[csel][ix][s_offset + 4], sum[3] * sccache[csel][ix][s_offset + 6]))), d, temp[j][n]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2332,6 +2332,7 @@ struct ggml_tensor * ggml_concat(
|
||||
struct ggml_tensor * b,
|
||||
int dim) {
|
||||
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
|
||||
GGML_ASSERT(a->type == b->type);
|
||||
|
||||
int64_t ne[GGML_MAX_DIMS];
|
||||
for (int d = 0; d < GGML_MAX_DIMS; ++d) {
|
||||
|
||||
+17
-5
@@ -1205,17 +1205,29 @@ extern "C" {
|
||||
const char * grammar_str,
|
||||
const char * grammar_root);
|
||||
|
||||
/// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639
|
||||
/// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future.
|
||||
/// @param trigger_tokens A list of tokens that will trigger the grammar sampler.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy(
|
||||
DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy(
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root,
|
||||
const char ** trigger_words,
|
||||
size_t num_trigger_words,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens);
|
||||
size_t num_trigger_tokens),
|
||||
"use llama_sampler_init_grammar_lazy_patterns instead");
|
||||
|
||||
|
||||
/// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639
|
||||
/// @param trigger_patterns A list of patterns that will trigger the grammar sampler. Pattern will be matched from the start of the generation output, and grammar sampler will be fed content starting from its first match group.
|
||||
/// @param trigger_tokens A list of tokens that will trigger the grammar sampler. Grammar sampler will be fed content starting from the trigger token included.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy_patterns(
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root,
|
||||
const char ** trigger_patterns,
|
||||
size_t num_trigger_patterns,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens);
|
||||
|
||||
|
||||
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
|
||||
|
||||
@@ -9,7 +9,7 @@ These templates can be updated with the following commands:
|
||||
./scripts/get_chat_template.py deepseek-ai/DeepSeek-R1-Distill-Qwen-32B > models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja
|
||||
./scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 > models/templates/fireworks-ai-llama-3-firefunction-v2.jinja
|
||||
./scripts/get_chat_template.py google/gemma-2-2b-it > models/templates/google-gemma-2-2b-it.jinja
|
||||
./scripts/get_chat_template.py meetkai/functionary-medium-v3. > models/templates/meetkai-functionary-medium-v3.jinja
|
||||
./scripts/get_chat_template.py meetkai/functionary-medium-v3.1 > models/templates/meetkai-functionary-medium-v3.1.jinja
|
||||
./scripts/get_chat_template.py meetkai/functionary-medium-v3.2 > models/templates/meetkai-functionary-medium-v3.2.jinja
|
||||
./scripts/get_chat_template.py meta-llama/Llama-3.1-8B-Instruct > models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja
|
||||
./scripts/get_chat_template.py meta-llama/Llama-3.2-3B-Instruct > models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja
|
||||
|
||||
@@ -10,3 +10,4 @@
|
||||
-r ./requirements/requirements-convert_hf_to_gguf_update.txt
|
||||
-r ./requirements/requirements-convert_llama_ggml_to_gguf.txt
|
||||
-r ./requirements/requirements-convert_lora_to_gguf.txt
|
||||
-r ./requirements/requirements-tool_bench.txt
|
||||
|
||||
@@ -10,3 +10,4 @@
|
||||
-r ./requirements-convert_hf_to_gguf_update.txt
|
||||
-r ./requirements-convert_legacy_llama.txt
|
||||
-r ./requirements-convert_llama_ggml_to_gguf.txt
|
||||
-r ./requirements-tool_bench.txt
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
aiohttp~=3.9.3
|
||||
pytest~=8.3.3
|
||||
huggingface_hub~=0.23.2
|
||||
matplotlib~=3.10.0
|
||||
numpy~=1.26.4
|
||||
openai~=1.55.3
|
||||
pandas~=2.2.3
|
||||
prometheus-client~=0.20.0
|
||||
requests~=2.32.3
|
||||
wget~=3.2
|
||||
typer~=0.15.1
|
||||
seaborn~=0.13.2
|
||||
@@ -75,7 +75,7 @@ if __name__ == '__main__':
|
||||
logging.info(f' - {m.hf_repo} / {m.hf_file}')
|
||||
|
||||
cli_path = os.environ.get(
|
||||
'LLAMA_SERVER_BIN_PATH',
|
||||
'LLAMA_CLI_BIN_PATH',
|
||||
os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'../build/bin/Release/llama-cli.exe' if os.name == 'nt' else '../build/bin/llama-cli'))
|
||||
|
||||
@@ -1 +1 @@
|
||||
58ecf6b96d887e408b6869915863fa1126483d51
|
||||
c7dfe3d174f98b14801f9ed12f129179d3e7b638
|
||||
|
||||
Executable
+368
@@ -0,0 +1,368 @@
|
||||
#!/usr/bin/env uv run
|
||||
'''
|
||||
Simplistic tool call benchmarks for llama-server and ollama.
|
||||
|
||||
Essentially runs the tests at server/examples/server/tests/unit/test_tool_call.py N times, at different temperatures and on different backends (current llama-server, baseline llama-server and ollama),
|
||||
and plots the results of multiple runs (from same .jsonl file or multiple ones) as a success rate heatmap.
|
||||
|
||||
Simple usage example:
|
||||
|
||||
cmake -B build -DLLAMA_CURL=1 && cmake --build build --config Release -j -t llama-server
|
||||
|
||||
export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server
|
||||
export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp}
|
||||
|
||||
./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 1.5B Q4_K_M" --output qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M
|
||||
./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 Coder 7B Q4_K_M" --output qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b
|
||||
|
||||
./scripts/tool_bench.py plot *.jsonl # Opens window w/ heatmap
|
||||
./scripts/tool_bench.py plot qwen*.jsonl --output qwen.png # Saves heatmap to qwen.png
|
||||
|
||||
(please see ./scripts/tool_bench.sh for a more complete example)
|
||||
'''
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "pytest",
|
||||
# "pandas",
|
||||
# "matplotlib",
|
||||
# "seaborn",
|
||||
# "requests",
|
||||
# "wget",
|
||||
# "typer",
|
||||
# ]
|
||||
# ///
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
import re
|
||||
from statistics import mean, median
|
||||
from typing import Annotated, Dict, List, Optional, Tuple
|
||||
import atexit
|
||||
import json
|
||||
import logging
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import typer
|
||||
|
||||
sys.path.insert(0, Path(__file__).parent.parent.as_posix())
|
||||
if True:
|
||||
from examples.server.tests.utils import ServerProcess
|
||||
from examples.server.tests.unit.test_tool_call import TIMEOUT_SERVER_START, do_test_calc_result, do_test_hello_world, do_test_weather
|
||||
|
||||
|
||||
@contextmanager
|
||||
def scoped_server(sp: ServerProcess):
|
||||
def stop():
|
||||
nonlocal sp
|
||||
if sp is not None:
|
||||
sp.stop()
|
||||
sp = None # type: ignore
|
||||
atexit.register(stop)
|
||||
yield sp
|
||||
stop()
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def plot(files: List[Path], output: Optional[Path] = None, test_regex: Optional[str] = None, server_regex: Optional[str] = None):
|
||||
|
||||
lines: List[Dict] = []
|
||||
for file in files:
|
||||
if not file.exists():
|
||||
logger.error(f"File not found: {file}")
|
||||
continue
|
||||
|
||||
try:
|
||||
with file.open() as f:
|
||||
raw_data = f.read()
|
||||
logger.info(f"Reading {file} ({len(raw_data)} bytes)")
|
||||
|
||||
for line_num, line in enumerate(raw_data.split('\n'), 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
record = json.loads(line)
|
||||
lines.append(record)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Invalid JSON at {file}:{line_num} - {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file}: {e}")
|
||||
|
||||
if not lines:
|
||||
raise Exception("No valid data was loaded")
|
||||
|
||||
data_dict: Dict[Tuple, float] = {}
|
||||
models: List[str] = []
|
||||
temps = set()
|
||||
tests = set()
|
||||
server_names = set()
|
||||
total_counts = set()
|
||||
for rec in lines:
|
||||
try:
|
||||
model = rec["model"]
|
||||
temp = rec["temp"]
|
||||
server_name = rec["server_name"]
|
||||
test = rec["test"]
|
||||
success = rec["success_ratio"]
|
||||
success_count = rec["success_count"]
|
||||
failure_count = rec["failure_count"]
|
||||
total_count = success_count + failure_count
|
||||
total_counts.add(total_count)
|
||||
|
||||
if test_regex and not re.search(test_regex, test):
|
||||
continue
|
||||
|
||||
if server_regex and not re.search(server_regex, server_name):
|
||||
continue
|
||||
|
||||
data_dict[(model, temp, server_name, test)] = success
|
||||
|
||||
if model not in models:
|
||||
models.append(model)
|
||||
temps.add(temp)
|
||||
tests.add(test)
|
||||
server_names.add(server_name)
|
||||
|
||||
except KeyError as e:
|
||||
logger.warning(f"Missing required field in record: {e}")
|
||||
|
||||
if len(total_counts) > 1:
|
||||
logger.warning(f"Total counts are not consistent: {total_counts}")
|
||||
|
||||
# Sort the collected values
|
||||
temps = list(sorted(temps, key=lambda x: x if x is not None else -1))
|
||||
tests = list(sorted(tests))
|
||||
server_names = list(sorted(server_names))
|
||||
|
||||
logger.info(f"Processed {len(lines)} lines")
|
||||
logger.info(f"Found {len(data_dict)} valid data points")
|
||||
logger.info(f"Models: {models}")
|
||||
logger.info(f"Temperatures: {temps}")
|
||||
logger.info(f"Tests: {tests}")
|
||||
logger.info(f"Servers: {server_names}")
|
||||
|
||||
matrix: list[list[float]] = []
|
||||
index: list[str] = []
|
||||
|
||||
all_cols = [
|
||||
(server_name, test)
|
||||
for server_name in server_names
|
||||
for test in tests
|
||||
]
|
||||
for model in models:
|
||||
for temp in temps:
|
||||
index.append(f"{model} @ {temp}")
|
||||
row_vals = [
|
||||
data_dict.get((model, temp, server_name, test), np.nan)
|
||||
for server_name, test in all_cols
|
||||
]
|
||||
matrix.append(row_vals)
|
||||
|
||||
columns: list[str] = [f"{server_name}\n{test}" for server_name, test in all_cols]
|
||||
|
||||
df = pd.DataFrame(matrix, index=np.array(index), columns=np.array(columns))
|
||||
|
||||
plt.figure(figsize=(12, 6))
|
||||
|
||||
sns.heatmap(
|
||||
df, annot=True, cmap="RdYlGn", vmin=0.0, vmax=1.0, cbar=True, fmt=".2f", center=0.5, square=True, linewidths=0.5,
|
||||
cbar_kws={"label": "Success Ratio"},
|
||||
)
|
||||
|
||||
plt.title(f"Tool Call Bench (n = {str(min(total_counts)) if len(total_counts) == 1 else f'{min(total_counts)}-{max(total_counts)}'})\nSuccess Ratios by Server & Test", pad=20)
|
||||
plt.xlabel("Server & Test", labelpad=10)
|
||||
plt.ylabel("Model @ Temperature", labelpad=10)
|
||||
|
||||
plt.xticks(rotation=45, ha='right')
|
||||
plt.yticks(rotation=0)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if output:
|
||||
plt.savefig(output, dpi=300, bbox_inches='tight')
|
||||
logger.info(f"Plot saved to {output}")
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
|
||||
@app.command()
|
||||
def run(
|
||||
output: Annotated[Path, typer.Option(help="Output JSON file")],
|
||||
model: Annotated[Optional[str], typer.Option(help="Name of the model to test (server agnostic)")] = None,
|
||||
hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None,
|
||||
chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None,
|
||||
ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None,
|
||||
llama_baseline: Annotated[Optional[str], typer.Option(help="llama-server baseline binary path to use as baseline")] = None,
|
||||
n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10,
|
||||
temp: Annotated[Optional[List[float]], typer.Option(help="Set of temperatures to test")] = None,
|
||||
top_p: Annotated[Optional[float], typer.Option(help="top_p")] = None,
|
||||
top_k: Annotated[Optional[int], typer.Option(help="top_k")] = None,
|
||||
ctk: Annotated[Optional[str], typer.Option(help="ctk")] = None,
|
||||
ctv: Annotated[Optional[str], typer.Option(help="ctv")] = None,
|
||||
fa: Annotated[Optional[bool], typer.Option(help="fa")] = None,
|
||||
seed: Annotated[Optional[int], typer.Option(help="Random seed")] = None,
|
||||
port: Annotated[int, typer.Option(help="llama-server port")] = 8084,
|
||||
force: Annotated[bool, typer.Option(help="Force overwrite of output file")] = False,
|
||||
append: Annotated[bool, typer.Option(help="Append to output file")] = False,
|
||||
|
||||
test_hello_world: Annotated[bool, typer.Option(help="Whether to run the hello world test")] = True,
|
||||
test_weather: Annotated[bool, typer.Option(help="Whether to run the weather test")] = True,
|
||||
test_calc_result: Annotated[bool, typer.Option(help="Whether to run the calc result test")] = False,
|
||||
):
|
||||
# Check only one of output and append
|
||||
|
||||
n_predict = 512 # High because of DeepSeek R1
|
||||
# n_ctx = 8192
|
||||
n_ctx = 2048
|
||||
|
||||
assert force or append or not output.exists(), f"Output file already exists: {output}; use --force to overwrite"
|
||||
|
||||
with output.open('a' if append else 'w') as output_file:
|
||||
|
||||
def run(server: ServerProcess, *, server_name: str, model_id: str, temp: Optional[float] = None, output_kwargs={}, request_kwargs={}):
|
||||
request_kwargs = {**request_kwargs}
|
||||
if temp is not None:
|
||||
request_kwargs['temperature'] = temp
|
||||
if top_p is not None:
|
||||
request_kwargs['top_p'] = top_p
|
||||
if top_k is not None:
|
||||
request_kwargs['top_k'] = top_k
|
||||
if seed is not None:
|
||||
request_kwargs['seed'] = seed
|
||||
|
||||
request_kwargs['cache_prompt'] = False
|
||||
|
||||
tests = {}
|
||||
if test_hello_world:
|
||||
tests["hello world"] = lambda server: do_test_hello_world(server, **request_kwargs)
|
||||
if test_weather:
|
||||
tests["weather"] = lambda server: do_test_weather(server, **request_kwargs)
|
||||
if test_calc_result:
|
||||
tests["calc result"] = lambda server: do_test_calc_result(server, None, 512, **request_kwargs)
|
||||
|
||||
for test_name, test in tests.items():
|
||||
success_count = 0
|
||||
failure_count = 0
|
||||
failures = []
|
||||
success_times = []
|
||||
failure_times = []
|
||||
logger.info(f"Running {test_name} ({server_name}, {model}): ")
|
||||
for i in range(n):
|
||||
start_time = time.time()
|
||||
|
||||
def elapsed():
|
||||
return time.time() - start_time
|
||||
|
||||
try:
|
||||
test(server)
|
||||
success_times.append(elapsed())
|
||||
success_count += 1
|
||||
logger.info('success')
|
||||
except Exception as e:
|
||||
logger.error(f'failure: {e}')
|
||||
failure_count += 1
|
||||
failure_times.append(elapsed())
|
||||
failures.append(str(e))
|
||||
# import traceback
|
||||
# traceback.print_exc()
|
||||
output_file.write(json.dumps({**output_kwargs, **dict(
|
||||
model=model,
|
||||
server_name=server_name,
|
||||
model_id=model_id,
|
||||
test=test_name,
|
||||
temp=t,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
ctk=ctk,
|
||||
ctv=ctv,
|
||||
seed=seed,
|
||||
success_ratio=float(success_count) / n,
|
||||
avg_time=mean(success_times + failure_times),
|
||||
median_time=median(success_times + failure_times),
|
||||
success_count=success_count,
|
||||
success_times=success_times,
|
||||
failure_count=failure_count,
|
||||
failure_times=failure_times,
|
||||
failures=list(set(failures)),
|
||||
)}) + '\n')
|
||||
output_file.flush()
|
||||
|
||||
for t in [None] if temp is None else [t if t >= 0 else None for t in temp]:
|
||||
if hf is not None:
|
||||
|
||||
servers: list[Tuple[str, Optional[str]]] = [('llama-server', None)]
|
||||
if llama_baseline is not None:
|
||||
servers.append(('llama-server (baseline)', llama_baseline))
|
||||
|
||||
for server_name, server_path in servers:
|
||||
server = ServerProcess()
|
||||
server.n_ctx = n_ctx
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.ctk = ctk
|
||||
server.ctv = ctv
|
||||
server.fa = fa
|
||||
server.n_predict = n_predict
|
||||
server.model_hf_repo = hf
|
||||
server.model_hf_file = None
|
||||
server.chat_template = chat_template
|
||||
server.server_path = server_path
|
||||
if port is not None:
|
||||
server.server_port = port
|
||||
# server.debug = True
|
||||
|
||||
with scoped_server(server):
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
for ignore_chat_grammar in [False]:
|
||||
run(
|
||||
server,
|
||||
server_name=server_name,
|
||||
model_id=hf,
|
||||
temp=t,
|
||||
output_kwargs=dict(
|
||||
chat_template=chat_template,
|
||||
),
|
||||
request_kwargs=dict(
|
||||
ignore_chat_grammar=ignore_chat_grammar,
|
||||
),
|
||||
)
|
||||
|
||||
if ollama is not None:
|
||||
server = ServerProcess()
|
||||
server.server_port = 11434
|
||||
server.server_host = "localhost"
|
||||
subprocess.check_call(["ollama", "pull", ollama])
|
||||
|
||||
with scoped_server(server):
|
||||
run(
|
||||
server,
|
||||
server_name="ollama",
|
||||
model_id=ollama,
|
||||
temp=t,
|
||||
output_kwargs=dict(
|
||||
chat_template=None,
|
||||
),
|
||||
request_kwargs=dict(
|
||||
model=ollama,
|
||||
max_tokens=n_predict,
|
||||
num_ctx = n_ctx,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
Executable
+66
@@ -0,0 +1,66 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
cmake --build build -j
|
||||
|
||||
export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp}
|
||||
export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server
|
||||
|
||||
if [ ! -x "$LLAMA_SERVER_BIN_PATH" ]; then
|
||||
echo "Could not find llama-server binary at $LLAMA_SERVER_BIN_PATH"
|
||||
exit 1
|
||||
fi
|
||||
if [ ! -d "$LLAMA_CACHE" ]; then
|
||||
echo "Could not find llama cache at $LLAMA_CACHE, please set LLAMA_CACHE explicitly."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export ARGS=(
|
||||
--llama-baseline="$(which llama-server)"
|
||||
--n 30
|
||||
--temp -1 # Leaves temperature parameter unset (use the server's default, e.g. 0.6 for ollama)
|
||||
--temp 0
|
||||
--temp 0.5
|
||||
--temp 0.75
|
||||
--temp 1
|
||||
--temp 1.5
|
||||
--temp 2
|
||||
--temp 5
|
||||
"$@"
|
||||
)
|
||||
|
||||
./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 Coder 0.5B Q4_K_M" --output ../qwenc0.5b.jsonl --hf bartowski/Qwen2.5-Coder-0.5B-Instruct-GGUF:Q4_K_M --ollama qwen2.5-coder:0.5b-instruct-q4_K_M
|
||||
./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 Coder 1.5B Q4_K_M" --output ../qwenc1.5b.jsonl --hf bartowski/Qwen2.5-Coder-1.5B-Instruct-GGUF:Q4_K_M --ollama qwen2.5-coder:1.5b-instruct-q4_K_M
|
||||
./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 Coder 3B Q4_K_M" --output ../qwenc3b.jsonl --hf bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M --ollama qwen2.5-coder:3b-instruct-q4_K_M
|
||||
./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 Coder 7B Q4_K_M" --output ../qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF:Q4_K_M --ollama qwen2.5-coder:7b-instruct-q4_K_M
|
||||
./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 Coder 32B Q4_K_M" --output ../qwenc32b.jsonl --hf bartowski/Qwen2.5-Coder-32B-Instruct-GGUF:Q4_K_M --ollama qwen2.5-coder:32B-instruct-q4_K_M
|
||||
./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 1.5B Q4_K_M" --output ../qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M --ollama qwen2.5:1.5b-instruct-q4_K_M
|
||||
./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 3B Q4_K_M" --output ../qwen3b.jsonl --hf bartowski/Qwen2.5-3B-Instruct-GGUF:Q4_K_M --ollama qwen2.5:3b-instruct-q4_K_M
|
||||
./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 7B Q4_K_M" --output ../qwen7b.jsonl --hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M --ollama qwen2.5:7b-instruct-q4_K_M
|
||||
|
||||
./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.2 Instruct 1B Q4_K_M" --output ../llama1b.jsonl --hf bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M --ollama llama3.2:1b-instruct-q4_K_M
|
||||
./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.2 Instruct 3B Q4_K_M" --output ../llama3b.jsonl --hf bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M --ollama llama3.2:3b-instruct-q4_K_M
|
||||
./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.1 Instruct 8B Q4_K_M" --output ../llama8b.jsonl --hf bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M --ollama llama3.1:8b-instruct-q4_K_M
|
||||
./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.3 70B Q4_K_M" --output ../llama70b.jsonl --hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M
|
||||
|
||||
./scripts/tool_bench.py run ${ARGS[@]} --model "Mistral Nemo Q4_K_M" --output ../nemo.jsonl --hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M --ollama mistral-nemo:12b-instruct-2407-q4_K_M
|
||||
|
||||
./scripts/tool_bench.py run ${ARGS[@]} --model "Hermes 3 Llama 3.1 8B Q4_K_M" --output ../hermes3.jsonl --hf bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M --ollama hermes3:8b-llama3.1-q4_K_M --chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use )
|
||||
./scripts/tool_bench.py run ${ARGS[@]} --model "Hermes 2 Pro Llama 3 8B Q4_K_M" --output ../hermes2.jsonl --hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M --ollama hermes2:8b-llama3-q4_K_M --chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use )
|
||||
|
||||
./scripts/tool_bench.py run ${ARGS[@]} --model "Functionary Small V3.2 Q4_K_M" --output ../funct3.2.jsonl --hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M
|
||||
./scripts/tool_bench.py run ${ARGS[@]} --model "FireFunction V2 IQ1_M" --output ../firef2.jsonl --hf bartowski/firefunction-v2-GGUF:IQ1_M --chat-template-file <( python scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 tool_use )
|
||||
|
||||
./scripts/tool_bench.py run ${ARGS[@]} --model "Command R7B 12-2024 Q6_K_L" --output ../c4ai.jsonl --hf bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L --chat-template-file <( python scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use )
|
||||
|
||||
./scripts/tool_bench.py run ${ARGS[@]} --model "Gemma 2 2B Q8_0" --output ../gemma2.jsonl --hf bartowski/gemma-2-2b-it-GGUF:Q8_0
|
||||
./scripts/tool_bench.py run ${ARGS[@]} --model "Phi 4 Instruct Q4_K_M" --output ../phi4.jsonl --hf bartowski/phi-4-GGUF:Q4_K_M # --ollama phi4
|
||||
./scripts/tool_bench.py run ${ARGS[@]} --model "Phi 3.5 Mini Instruct Q4_K_M" --output ../phi3.5.jsonl --hf bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M # --ollama phi3.5:3.8b-mini-instruct-q4_K_M
|
||||
|
||||
# ./scripts/tool_bench.py run ${ARGS[@]} --model "DeepSeek R1 Distill Qwen 7B Q6_K_L" --output ../dsqw7.jsonl --hf bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q6_K_L --chat-template-file <( python scripts/get_chat_template.py NousResearch/DeepSeek-R1-Distill-Qwen-7B tool_use )
|
||||
# ./scripts/tool_bench.py run ${ARGS[@]} --model "DeepSeek R1 Distill Qwen 32B Q4_K_M" --output ../dsqw32.jsonl --hf bartowski/DeepSeek-R1-Distill-Qwen-32B-GGUF:Q4_K_M --chat-template-file <( python scripts/get_chat_template.py NousResearch/DeepSeek-R1-Distill-Qwen-32B tool_use )
|
||||
|
||||
|
||||
for f in ../*.jsonl; do
|
||||
./scripts/tool_bench.py plot "$f" --output ${f%.jsonl}.png || true
|
||||
done
|
||||
+22
-22
@@ -969,7 +969,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||
/* .awaiting_trigger = */ false,
|
||||
/* .trigger_buffer = */ "",
|
||||
/* .trigger_tokens = */ {},
|
||||
/* .trigger_words = */ {},
|
||||
/* .trigger_patterns = */ {},
|
||||
};
|
||||
}
|
||||
|
||||
@@ -978,19 +978,15 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||
const char * grammar_str,
|
||||
const char * grammar_root,
|
||||
bool lazy,
|
||||
const char ** trigger_words,
|
||||
size_t num_trigger_words,
|
||||
const char ** trigger_patterns,
|
||||
size_t num_trigger_patterns,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens) {
|
||||
llama_grammar_parser parser;
|
||||
|
||||
// if there is a grammar, parse it
|
||||
if (!parser.parse(grammar_str)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// will be empty (default) if there are parse errors
|
||||
if (parser.rules.empty()) {
|
||||
// rules will be empty (default) if there are parse errors
|
||||
if (!parser.parse(grammar_str) || parser.rules.empty()) {
|
||||
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -1054,14 +1050,16 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||
} while (true);
|
||||
|
||||
std::vector<llama_token> vec_trigger_tokens;
|
||||
std::vector<std::string> vec_trigger_words;
|
||||
std::vector<llama_grammar_trigger_pattern> vec_trigger_patterns;
|
||||
for (size_t i = 0; i < num_trigger_tokens; i++) {
|
||||
GGML_ASSERT(trigger_tokens != nullptr);
|
||||
vec_trigger_tokens.push_back(trigger_tokens[i]);
|
||||
}
|
||||
for (size_t i = 0; i < num_trigger_words; i++) {
|
||||
GGML_ASSERT(trigger_words != nullptr);
|
||||
vec_trigger_words.push_back(trigger_words[i]);
|
||||
for (size_t i = 0; i < num_trigger_patterns; i++) {
|
||||
GGML_ASSERT(trigger_patterns != nullptr);
|
||||
auto & trigger = vec_trigger_patterns.emplace_back();
|
||||
trigger.pattern = trigger_patterns[i];
|
||||
trigger.regex = std::regex(trigger.pattern);
|
||||
}
|
||||
|
||||
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
||||
@@ -1076,7 +1074,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||
/* .awaiting_trigger = */ lazy,
|
||||
/* .trigger_buffer = */ "",
|
||||
std::move(vec_trigger_tokens),
|
||||
std::move(vec_trigger_words),
|
||||
std::move(vec_trigger_patterns),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1089,7 +1087,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
||||
}
|
||||
|
||||
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
|
||||
llama_grammar * result = new llama_grammar {
|
||||
auto * result = new llama_grammar {
|
||||
grammar.vocab,
|
||||
grammar.rules,
|
||||
grammar.stacks,
|
||||
@@ -1098,7 +1096,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
|
||||
grammar.awaiting_trigger,
|
||||
grammar.trigger_buffer,
|
||||
grammar.trigger_tokens,
|
||||
grammar.trigger_words,
|
||||
grammar.trigger_patterns,
|
||||
};
|
||||
|
||||
// redirect elements in stacks to point to new rules
|
||||
@@ -1173,16 +1171,18 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
|
||||
return;
|
||||
} else {
|
||||
// TODO: consider a smarter incremental substring search algorithm (store last position to search from).
|
||||
grammar.trigger_buffer += piece;
|
||||
for (const auto & word : grammar.trigger_words) {
|
||||
auto pos = grammar.trigger_buffer.find(word);
|
||||
if (pos != std::string::npos) {
|
||||
|
||||
std::smatch match;
|
||||
for (const auto & trigger_pattern : grammar.trigger_patterns) {
|
||||
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
|
||||
grammar.awaiting_trigger = false;
|
||||
auto constrained_str = grammar.trigger_buffer.substr(pos);
|
||||
// get from the first match to the end of the string
|
||||
auto constrained_str = grammar.trigger_buffer.substr(match.position(1));
|
||||
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
|
||||
grammar.trigger_buffer.clear();
|
||||
llama_grammar_accept_str(grammar, constrained_str);
|
||||
LLAMA_LOG_DEBUG("Grammar triggered on word `%s`", word.c_str());
|
||||
LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
+12
-3
@@ -3,6 +3,7 @@
|
||||
#include "llama.h"
|
||||
|
||||
#include <map>
|
||||
#include <regex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@@ -105,6 +106,11 @@ struct llama_grammar_parser {
|
||||
void print(FILE * file);
|
||||
};
|
||||
|
||||
struct llama_grammar_trigger_pattern {
|
||||
std::string pattern;
|
||||
std::regex regex;
|
||||
};
|
||||
|
||||
struct llama_grammar {
|
||||
// note: allow null vocab for testing (not great)
|
||||
const llama_vocab * vocab;
|
||||
@@ -122,7 +128,10 @@ struct llama_grammar {
|
||||
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
|
||||
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
|
||||
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
|
||||
std::vector<std::string> trigger_words;
|
||||
std::vector<llama_grammar_trigger_pattern>
|
||||
trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated
|
||||
// string, and the grammar will be given the string from the first match group onwards.
|
||||
|
||||
};
|
||||
|
||||
//
|
||||
@@ -141,8 +150,8 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||
const char * grammar_str,
|
||||
const char * grammar_root,
|
||||
bool lazy,
|
||||
const char ** trigger_words,
|
||||
size_t num_trigger_words,
|
||||
const char ** trigger_patterns,
|
||||
size_t num_trigger_patterns,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens);
|
||||
|
||||
|
||||
+43
-10
@@ -1449,7 +1449,9 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
||||
const char ** trigger_words,
|
||||
size_t num_trigger_words,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens);
|
||||
size_t num_trigger_tokens,
|
||||
const char ** trigger_patterns,
|
||||
size_t num_trigger_patterns);
|
||||
|
||||
static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
||||
@@ -1457,12 +1459,14 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<const char *> trigger_words;
|
||||
for (auto & word : ctx->grammar->trigger_words) {
|
||||
trigger_words.push_back(word.c_str());
|
||||
std::vector<const char *> trigger_patterns_c;
|
||||
trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size());
|
||||
for (auto & trigger_pattern : ctx->grammar->trigger_patterns) {
|
||||
trigger_patterns_c.push_back(trigger_pattern.pattern.c_str());
|
||||
}
|
||||
|
||||
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
|
||||
ctx->grammar->lazy, trigger_words.data(), trigger_words.size(),
|
||||
ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||
ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
|
||||
|
||||
llama_grammar_free_impl(ctx->grammar);
|
||||
@@ -1472,7 +1476,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
||||
static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
|
||||
|
||||
auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0);
|
||||
auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0, nullptr, 0);
|
||||
|
||||
// copy the state
|
||||
{
|
||||
@@ -1516,15 +1520,33 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
||||
const char ** trigger_words,
|
||||
size_t num_trigger_words,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens) {
|
||||
size_t num_trigger_tokens,
|
||||
const char ** trigger_patterns,
|
||||
size_t num_trigger_patterns) {
|
||||
auto * ctx = new llama_sampler_grammar;
|
||||
|
||||
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
||||
// TODO: remove trigger_words support.
|
||||
if (trigger_words != nullptr && num_trigger_words > 0) {
|
||||
GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
|
||||
std::string trigger_pattern("[\\s\\S]*?(");
|
||||
for (size_t i = 0; i < num_trigger_words; ++i) {
|
||||
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
||||
if (i > 0) {
|
||||
trigger_pattern += "|";
|
||||
}
|
||||
trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
|
||||
}
|
||||
trigger_pattern += ")[\\s\\S]*";
|
||||
auto trigger_pattern_c = trigger_pattern.c_str();
|
||||
trigger_patterns = &trigger_pattern_c;
|
||||
num_trigger_patterns = 1;
|
||||
}
|
||||
*ctx = {
|
||||
/* .vocab = */ vocab,
|
||||
/* .grammar_str = */ grammar_str,
|
||||
/* .grammar_root = */ grammar_root,
|
||||
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens),
|
||||
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
|
||||
};
|
||||
} else {
|
||||
*ctx = {
|
||||
@@ -1545,7 +1567,7 @@ struct llama_sampler * llama_sampler_init_grammar(
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root) {
|
||||
return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0);
|
||||
return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0, nullptr, 0);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_init_grammar_lazy(
|
||||
@@ -1556,7 +1578,18 @@ struct llama_sampler * llama_sampler_init_grammar_lazy(
|
||||
size_t num_trigger_words,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens) {
|
||||
return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens);
|
||||
return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens, nullptr, 0);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_init_grammar_lazy_patterns(
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root,
|
||||
const char ** trigger_patterns,
|
||||
size_t num_trigger_patterns,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens) {
|
||||
return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, nullptr, 0, trigger_tokens, num_trigger_tokens, trigger_patterns, num_trigger_patterns);
|
||||
}
|
||||
|
||||
// penalties
|
||||
|
||||
+150
-11
@@ -237,12 +237,35 @@ static void test_templates(const struct common_chat_templates * tmpls, const std
|
||||
auto earliest_trigger_pos = std::string::npos;
|
||||
auto constrained = data.delta;
|
||||
for (const auto & trigger : data.params.grammar_triggers) {
|
||||
auto pos = constrained.find(trigger.word);
|
||||
if (pos == std::string::npos) {
|
||||
continue;
|
||||
size_t pos = std::string::npos;
|
||||
std::smatch match;
|
||||
switch (trigger.type) {
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
||||
{
|
||||
const auto & word = trigger.value;
|
||||
pos = constrained.find(word);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||
{
|
||||
const auto & pattern = trigger.value;
|
||||
if (std::regex_search(constrained, match, std::regex(pattern))) {
|
||||
pos = match.position();
|
||||
}
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
|
||||
{
|
||||
const auto & pattern = trigger.value;
|
||||
if (std::regex_search(constrained, match, std::regex(pattern)) && match.position() == 0) {
|
||||
pos = 0;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("Unknown trigger type");
|
||||
}
|
||||
if (pos > 0 && trigger.at_start) {
|
||||
fprintf(stderr, "Trigger %s not at start of message, skipping:\n\n%s\n\n", trigger.word.c_str(), constrained.c_str());
|
||||
if (pos == std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
|
||||
@@ -260,7 +283,8 @@ static void test_templates(const struct common_chat_templates * tmpls, const std
|
||||
|
||||
if (grammar_triggered && test_grammar_if_triggered && !match_string(constrained, grammar.get())) {
|
||||
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
|
||||
"\n\nGrammar: " + data.params.grammar);
|
||||
"\n\nConstrained: " + constrained +
|
||||
"\n\nGrammar: " + data.params.grammar);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -456,6 +480,21 @@ static void test_msgs_oaicompat_json_conversion() {
|
||||
"]"
|
||||
),
|
||||
common_chat_msgs_to_json_oaicompat<json>({message_assist_call_python}).dump(2));
|
||||
|
||||
auto res = common_chat_msgs_parse_oaicompat(json::parse("[{\"role\": \"assistant\", \"tool_calls\": []}]"));
|
||||
assert_equals<size_t>(1, res.size());
|
||||
assert_equals<std::string>(res[0].role, "assistant");
|
||||
assert_equals(true, res[0].content.empty());
|
||||
assert_equals(true, res[0].tool_calls.empty());
|
||||
|
||||
try {
|
||||
common_chat_msgs_parse_oaicompat(json::parse("[{\"role\": \"assistant\"}]"));
|
||||
throw std::runtime_error("Expected exception");
|
||||
} catch (const std::exception & e) {
|
||||
if (std::string(e.what()).find("'content'") == std::string::npos) {
|
||||
throw std::runtime_error("Expected exception about missing 'content'");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void test_tools_oaicompat_json_conversion() {
|
||||
@@ -640,6 +679,106 @@ static void test_template_output_parsers() {
|
||||
inputs_tools)
|
||||
.format);
|
||||
|
||||
// Test parsing
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"<tool_call>\n"
|
||||
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</tool_call>",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"<function=special_function>{\"arg1\": 1}</function>",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"<function name=\"special_function\">\n"
|
||||
"{\"arg1\": 1}\n"
|
||||
"</function>",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"<tool>\n"
|
||||
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</tool>",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"<tools>\n"
|
||||
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</tools>",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"<response>\n"
|
||||
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</response>",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"```xml\n"
|
||||
"<response>\n"
|
||||
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</response>\n"
|
||||
"```",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"```xml\n"
|
||||
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"```",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"```\n"
|
||||
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"```",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"```\n"
|
||||
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"```",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"```json\n"
|
||||
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"```",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"```json\n"
|
||||
"\n"
|
||||
" <function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}} \n"
|
||||
" </function_call> \n"
|
||||
"``` ",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"<json>\n"
|
||||
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</json>",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"<xml>\n"
|
||||
" {\n"
|
||||
" \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}\n"
|
||||
" }\n"
|
||||
"</xml>",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"<JSON>\n"
|
||||
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</JSON>",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"{\n \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
|
||||
assert_msg_equals(message_assist_thoughts_unparsed_think,
|
||||
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_thoughts_unparsed_think,
|
||||
common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_thoughts,
|
||||
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING));
|
||||
assert_msg_equals(message_assist_thoughts,
|
||||
common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING));
|
||||
|
||||
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
||||
"<tool_call>\n"
|
||||
@@ -789,7 +928,7 @@ static void test_template_output_parsers() {
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
try {
|
||||
// try {
|
||||
#ifndef _WIN32
|
||||
if (argc > 1) {
|
||||
common_chat_templates_inputs inputs;
|
||||
@@ -827,8 +966,8 @@ int main(int argc, char ** argv) {
|
||||
std::cout << "\n[chat] All tests passed!" << '\n';
|
||||
}
|
||||
return 0;
|
||||
} catch (const std::exception & e) {
|
||||
std::cerr << "Error: " << e.what() << '\n';
|
||||
return 1;
|
||||
}
|
||||
// } catch (const std::exception & e) {
|
||||
// std::cerr << "Error: " << e.what() << '\n';
|
||||
// return 1;
|
||||
// }
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([0] | [1-9] [0-9]{0,15}) space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -104,7 +104,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([1-9] [0-9]{0,15}) space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -117,7 +117,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([1-2] [0-9]{1,15} | [3-9] [0-9]{0,15}) space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -130,7 +130,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([1-8] [0-9]{1,15} | [9] [0-9]{0,15}) space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -143,7 +143,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([1] ([0-9]{1,15}) | [2-9] [0-9]{1,15}) space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -156,7 +156,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([1] [0-9]{2,15} | [2] ([0-4] [0-9]{1,14} | [5-9] [0-9]{0,14}) | [3-9] [0-9]{1,15}) space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -169,7 +169,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" [1-9] [0-9]{0,15} | [0-9] | ([1-2] [0-9] | [3] "0")) space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -182,7 +182,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" ([0-5]) | [0] | [1-9] [0-9]{0,15}) space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -195,7 +195,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" ([0-9] | ([1-8] [0-9] | [9] [0-9]) | "1" ([0-1] [0-9] | [2] [0-3])) | [0] | [1-9] [0-9]{0,15}) space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -208,7 +208,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" ([0-4] [0-9]{1,15} | [5-9] [0-9]{0,15})) space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -221,7 +221,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" [1-9] [0-9]{0,15} | [0-1]) space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -234,7 +234,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" [1-9] [0-9]{0,15} | [0-9] | ([1-8] [0-9] | [9] [0-9]) | "100") space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -248,7 +248,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([0-9] | ([1] [0-9] | [2] [0-3])) space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -262,7 +262,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= (([1] ([5-9]) | [2-9] [0-9]) | ([1-2] [0-9]{2} | [3] "00")) space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -276,7 +276,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([5-9] | ([1-2] [0-9] | [3] "0")) space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -290,7 +290,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" ([0-9] | ([1-8] [0-9] | [9] [0-9]) | "1" ([0-1] [0-9] | [2] [0-3])) | [0-9] | ([1-3] [0-9] | [4] [0-2])) space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -304,7 +304,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" ([0-9] | "10") | [0-9] | "10") space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -340,7 +340,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= object
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
@@ -363,7 +363,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
date-time ::= date "T" time
|
||||
date-time-string ::= "\"" date-time "\"" space
|
||||
root ::= "[" space tuple-0 "," space uuid "," space tuple-2 "," space tuple-3 "]" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
time ::= ([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )
|
||||
time-string ::= "\"" time "\"" space
|
||||
tuple-0 ::= date-string
|
||||
@@ -382,7 +382,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "\"" char* "\"" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -396,7 +396,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "\"" char+ "\"" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -410,7 +410,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "\"" char{3,} "\"" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -424,7 +424,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "\"" char{0,3} "\"" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -439,7 +439,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "\"" char{1,4} "\"" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -451,7 +451,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("true" | "false") space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -464,7 +464,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
root ::= ("-"? integral-part) space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -476,7 +476,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"foo\"" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -488,7 +488,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "123" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -500,7 +500,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("\"red\"" | "\"amber\"" | "\"green\"" | "null" | "42" | "[\"foo\"]") space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -514,7 +514,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "[" space (string ("," space string)*)? "]" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -531,7 +531,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
null ::= "null" space
|
||||
root ::= alternative-0 | null
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -545,7 +545,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "[" space string "]" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -562,7 +562,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "[" space string "," space number "]" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -577,7 +577,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
root ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -593,7 +593,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
boolean ::= ("true" | "false") space
|
||||
root ::= "[" space boolean ("," space boolean)+ "]" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -609,7 +609,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
boolean ::= ("true" | "false") space
|
||||
root ::= "[" space boolean? "]" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -625,7 +625,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
boolean ::= ("true" | "false") space
|
||||
root ::= "[" space (boolean ("," space boolean)?)? "]" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -646,7 +646,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
item ::= number | integer
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "[" space item ("," space item){2,4} "]" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -665,7 +665,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
item ::= ("-" ([0-9] | "1" [0-2]) | [0-9] | ([1-8] [0-9] | [9] [0-9]) | ([1] [0-9]{2} | [2] "0" [0-7])) space
|
||||
root ::= "[" space item ("," space item){2,4} "]" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -684,7 +684,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
item ::= (([1] ([2-9]) | [2-9] [0-9]) | ([1] [0-9]{2} | [2] "0" [0-7])) space
|
||||
root ::= "[" space item ("," space item){2,4} "]" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -697,7 +697,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" ("ab" "c"? "d"* "ef" "g"+ ("hij")? "kl") "\"" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -710,7 +710,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" ("[]{}()|+*?") "\"" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -723,7 +723,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" ("\"") "\"" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -736,7 +736,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" ("A" | "B" | "C" | "D") "\"" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -751,7 +751,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
dot ::= [^\x0A\x0D]
|
||||
root ::= "\"" (("(" root-1{1,3} ")")? root-1{3,3} "-" root-1{4,4} " " "a"{3,5} "nd" dot dot dot) "\"" space
|
||||
root-1 ::= [0-9]
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -779,7 +779,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
c-kv ::= "\"c\"" space ":" space string
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "{" space b-kv "," space c-kv "," space a-kv "}" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -799,7 +799,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
a-kv ::= "\"a\"" space ":" space string
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "{" space (a-kv )? "}" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -823,7 +823,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
c-kv ::= "\"c\"" space ":" space string
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "{" space (a-kv a-rest | b-kv b-rest | c-kv )? "}" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -849,7 +849,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
d-kv ::= "\"d\"" space ":" space string
|
||||
d-rest ::= ( "," space c-kv )?
|
||||
root ::= "{" space b-kv "," space a-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -869,7 +869,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space (additional-kv ( "," space additional-kv )* )? "}" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -891,7 +891,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= object
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
@@ -913,7 +913,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= object
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
@@ -928,7 +928,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "{" space "}" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -952,7 +952,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space a-kv ( "," space ( additional-kv ( "," space additional-kv )* ) )? "}" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -977,7 +977,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space (a-kv a-rest | additional-kv ( "," space additional-kv )* )? "}" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -1004,7 +1004,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space and-kv ( "," space ( also-kv also-rest | additional-kv ( "," space additional-kv )* ) )? "}" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -1030,7 +1030,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
root ::= ("-"? integral-part) space
|
||||
root0 ::= "{" space (-kv -rest | a-kv a-rest | additional-kv ( "," space additional-kv )* )? "}" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -1055,7 +1055,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
integer ::= ("-"? integral-part) space
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
root ::= "{" space (a-kv a-rest | aa-kv aa-rest | additional-kv ( "," space additional-kv )* )? "}" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -1080,7 +1080,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
integer ::= ("-"? integral-part) space
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
root ::= "{" space (ab-kv ab-rest | ac-kv ac-rest | additional-kv ( "," space additional-kv )* )? "}" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -1109,7 +1109,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
foo ::= "{" space foo-a-kv "}" space
|
||||
foo-a-kv ::= "\"a\"" space ":" space string
|
||||
root ::= foo
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
)"""
|
||||
});
|
||||
@@ -1143,7 +1143,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= alternative-0 | alternative-1
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -1187,7 +1187,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space a-kv "," space b-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -1235,7 +1235,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
number-number-kv ::= "\"number\"" space ":" space number-number
|
||||
number-number-root-kv ::= "\"root\"" space ":" space number
|
||||
root ::= "{" space number-kv "}" space
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
}
|
||||
|
||||
@@ -120,13 +120,7 @@ int main(int argc, char * argv[]) {
|
||||
generate_data(0.0, test_data.size(), test_data.data());
|
||||
generate_data(1.0, test_data2.size(), test_data2.data());
|
||||
|
||||
// Initialize GGML, ensures float conversion tables are initialized
|
||||
struct ggml_init_params ggml_params = {
|
||||
/* .mem_size = */ 1*1024,
|
||||
/* .mem_buffer = */ NULL,
|
||||
/* .no_alloc = */ true,
|
||||
};
|
||||
struct ggml_context * ctx = ggml_init(ggml_params);
|
||||
ggml_cpu_init();
|
||||
|
||||
int num_failed = 0;
|
||||
bool failed = false;
|
||||
@@ -188,7 +182,5 @@ int main(int argc, char * argv[]) {
|
||||
printf("%d tests failed\n", num_failed);
|
||||
}
|
||||
|
||||
ggml_free(ctx);
|
||||
|
||||
return num_failed > 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user