mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 16:17:40 +02:00
Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e7ea94afcb | |||
| 96183e9820 | |||
| 487a6cc164 | |||
| 5a6a0dd7e1 | |||
| ded1561b42 | |||
| 9df06805ee | |||
| 2f18fe13c5 | |||
| c16c35b814 | |||
| 1a87dcdc45 | |||
| e7e3f35090 | |||
| b11f7c16bc | |||
| f818065d75 | |||
| 960d628f46 | |||
| 5c7c22c3e1 | |||
| beac5309f1 | |||
| 9d5d882d8c | |||
| 1ec44d178d | |||
| c7cddefcbd | |||
| e9d1b76d0a | |||
| 099bf06952 | |||
| 60bc8866b1 | |||
| e8ecce53b8 | |||
| 683b04cc4a | |||
| f728adab68 | |||
| 3e61ea0e2f | |||
| fdbd6abee2 | |||
| e12a0128ab | |||
| b3ce5cedf4 | |||
| e9fb3b3fc0 | |||
| 9c10954865 | |||
| fdb2c11c70 | |||
| 09cedfd699 |
+26
-19
@@ -35,8 +35,20 @@ AMD ZenDNN:
|
||||
documentation:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "**/*.md"
|
||||
- docs/**
|
||||
- media/**
|
||||
examples:
|
||||
- all:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- app/**
|
||||
- examples/**
|
||||
- tools/**
|
||||
- all-globs-to-all-files:
|
||||
- '!tools/server/**'
|
||||
- '!tools/mtmd/**'
|
||||
- '!tools/ui/**'
|
||||
testing:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
@@ -47,28 +59,12 @@ build:
|
||||
- cmake/**
|
||||
- CMakeLists.txt
|
||||
- CMakePresets.json
|
||||
examples:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- examples/**
|
||||
- tools/**
|
||||
devops:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- .devops/**
|
||||
- .github/**
|
||||
- ci/**
|
||||
python:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "**/*.py"
|
||||
- requirements/**
|
||||
- gguf-py/**
|
||||
- .flake8
|
||||
script:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- scripts/**
|
||||
android:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
@@ -81,9 +77,20 @@ server:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- tools/server/**
|
||||
|
||||
|
||||
|
||||
mtmd:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- tools/mtmd/**
|
||||
conversion:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- conversion/**
|
||||
- convert_*.py
|
||||
- gguf-py/**
|
||||
vendor:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- vendor/**
|
||||
ggml:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
|
||||
@@ -222,6 +222,16 @@ if (LLAMA_BUILD_APP)
|
||||
add_subdirectory(app)
|
||||
endif()
|
||||
|
||||
# Standalone libmtmd build without pulling in the rest of the tools/ tree.
|
||||
# Useful when packaging just the mtmd library for language bindings (e.g. an
|
||||
# Apple XCFramework, or a WASM build). When the full tools build is enabled,
|
||||
# mtmd is already built by the tools/ subdirectory above; this hook only fires
|
||||
# when LLAMA_BUILD_TOOLS is OFF to avoid double-adding the target.
|
||||
option(LLAMA_BUILD_MTMD "llama: build tools/mtmd library standalone" OFF)
|
||||
if (LLAMA_BUILD_MTMD AND NOT (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TOOLS))
|
||||
add_subdirectory(tools/mtmd)
|
||||
endif()
|
||||
|
||||
#
|
||||
# install
|
||||
#
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
set(TARGET llama-app)
|
||||
|
||||
add_executable(${TARGET} llama.cpp)
|
||||
add_executable(${TARGET} llama.cpp download.cpp)
|
||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
#include "download.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <filesystem>
|
||||
|
||||
static void print_usage(int /*argc*/, char ** argv) {
|
||||
printf(
|
||||
"\nexamples:\n"
|
||||
" %s -hf ggml-org/gemma-3-4b-it-qat-GGUF\n"
|
||||
" %s -hf ggml-org/gemma-3-4b-it-qat-GGUF:Q4_K_M\n"
|
||||
" %s -hf ggml-org/models -hff model.gguf\n"
|
||||
" %s -mu https://example.com/model.gguf -m model.gguf\n"
|
||||
"\n",
|
||||
argv[0], argv[0], argv[0], argv[0]
|
||||
);
|
||||
}
|
||||
|
||||
int llama_download(int argc, char ** argv);
|
||||
|
||||
int llama_download(int argc, char ** argv) {
|
||||
common_init();
|
||||
|
||||
common_params params;
|
||||
params.verbosity = LOG_LEVEL_ERROR;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DOWNLOAD, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
const bool has_source = !params.model.hf_repo.empty() || !params.model.url.empty() ||
|
||||
!params.model.path.empty() || !params.model.docker_repo.empty();
|
||||
if (!has_source) {
|
||||
fprintf(stderr, "error: no model source specified (use --hf-repo, --model-url, --model or --docker-repo)\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
try {
|
||||
common_models_handler handler = common_models_handler_init(params, LLAMA_EXAMPLE_DOWNLOAD);
|
||||
common_models_handler_apply(handler, params);
|
||||
} catch (const std::exception & e) {
|
||||
fprintf(stderr, "error: %s\n", e.what());
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (!params.models_preset.empty()) {
|
||||
// -hf pointed at a preset repo: print the preset path and stop
|
||||
printf("%s\n", params.models_preset.c_str());
|
||||
return 0;
|
||||
}
|
||||
if (params.model.path.empty()) {
|
||||
fprintf(stderr, "error: model download failed\n");
|
||||
return 1;
|
||||
}
|
||||
if (!std::filesystem::exists(params.model.path)) {
|
||||
fprintf(stderr, "error: model file does not exist: %s\n", params.model.path.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
printf("%s\n", params.model.path.c_str());
|
||||
if (!params.mmproj.path.empty()) {
|
||||
printf("%s\n", params.mmproj.path.c_str());
|
||||
}
|
||||
if (!params.speculative.draft.mparams.path.empty()) {
|
||||
printf("%s\n", params.speculative.draft.mparams.path.c_str());
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -19,6 +19,7 @@ int llama_batched_bench(int argc, char ** argv);
|
||||
int llama_fit_params(int argc, char ** argv);
|
||||
int llama_quantize(int argc, char ** argv);
|
||||
int llama_perplexity(int argc, char ** argv);
|
||||
int llama_download(int argc, char ** argv);
|
||||
|
||||
// Self-update is only supported for binaries built with llama-install.sh
|
||||
static int llama_update(int argc, char ** argv) {
|
||||
@@ -61,6 +62,7 @@ static const command cmds[] = {
|
||||
{"serve", "HTTP API server", {"server"}, false, llama_server },
|
||||
{"cli", "Command-line interactive interface", {"client"}, false, llama_cli },
|
||||
{"update", "Update llama to the latest release", {}, UPDATE_HIDDEN, llama_update },
|
||||
{"download", "Download a model", {"get"}, false, llama_download },
|
||||
{"completion", "Text completion", {"complete"}, true, llama_completion },
|
||||
{"bench", "Benchmark prompt processing and text generation", {}, true, llama_bench },
|
||||
{"batched-bench", "Benchmark batched decoding performance", {}, true, llama_batched_bench},
|
||||
|
||||
@@ -13,6 +13,7 @@ LLAMA_BUILD_EXAMPLES=OFF
|
||||
LLAMA_BUILD_TOOLS=OFF
|
||||
LLAMA_BUILD_TESTS=OFF
|
||||
LLAMA_BUILD_SERVER=OFF
|
||||
LLAMA_BUILD_MTMD=ON
|
||||
GGML_METAL=ON
|
||||
GGML_METAL_EMBED_LIBRARY=ON
|
||||
GGML_BLAS_DEFAULT=ON
|
||||
@@ -39,6 +40,7 @@ COMMON_CMAKE_ARGS=(
|
||||
-DLLAMA_BUILD_TOOLS=${LLAMA_BUILD_TOOLS}
|
||||
-DLLAMA_BUILD_TESTS=${LLAMA_BUILD_TESTS}
|
||||
-DLLAMA_BUILD_SERVER=${LLAMA_BUILD_SERVER}
|
||||
-DLLAMA_BUILD_MTMD=${LLAMA_BUILD_MTMD}
|
||||
-DGGML_METAL_EMBED_LIBRARY=${GGML_METAL_EMBED_LIBRARY}
|
||||
-DGGML_BLAS_DEFAULT=${GGML_BLAS_DEFAULT}
|
||||
-DGGML_METAL=${GGML_METAL}
|
||||
@@ -126,6 +128,8 @@ setup_framework_structure() {
|
||||
cp ggml/include/ggml-cpu.h ${header_path}
|
||||
cp ggml/include/ggml-blas.h ${header_path}
|
||||
cp ggml/include/gguf.h ${header_path}
|
||||
cp tools/mtmd/mtmd.h ${header_path}
|
||||
cp tools/mtmd/mtmd-helper.h ${header_path}
|
||||
|
||||
# Create module map (common for all platforms)
|
||||
cat > ${module_path}module.modulemap << EOF
|
||||
@@ -247,6 +251,7 @@ combine_static_libraries() {
|
||||
"${base_dir}/${build_dir}/ggml/src/${release_dir}/libggml-cpu.a"
|
||||
"${base_dir}/${build_dir}/ggml/src/ggml-metal/${release_dir}/libggml-metal.a"
|
||||
"${base_dir}/${build_dir}/ggml/src/ggml-blas/${release_dir}/libggml-blas.a"
|
||||
"${base_dir}/${build_dir}/tools/mtmd/${release_dir}/libmtmd.a"
|
||||
)
|
||||
|
||||
# Create temporary directory for processing
|
||||
@@ -410,6 +415,7 @@ cmake -B build-ios-sim -G Xcode \
|
||||
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_OPENSSL=OFF \
|
||||
-DMTMD_VIDEO=OFF \
|
||||
-S .
|
||||
cmake --build build-ios-sim --config Release -j $(sysctl -n hw.logicalcpu) -- -quiet
|
||||
|
||||
@@ -424,6 +430,7 @@ cmake -B build-ios-device -G Xcode \
|
||||
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_OPENSSL=OFF \
|
||||
-DMTMD_VIDEO=OFF \
|
||||
-S .
|
||||
cmake --build build-ios-device --config Release -j $(sysctl -n hw.logicalcpu) -- -quiet
|
||||
|
||||
@@ -450,6 +457,7 @@ cmake -B build-visionos -G Xcode \
|
||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_OPENSSL=OFF \
|
||||
-DLLAMA_BUILD_SERVER=OFF \
|
||||
-DMTMD_VIDEO=OFF \
|
||||
-S .
|
||||
cmake --build build-visionos --config Release -j $(sysctl -n hw.logicalcpu) -- -quiet
|
||||
|
||||
@@ -465,6 +473,7 @@ cmake -B build-visionos-sim -G Xcode \
|
||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_OPENSSL=OFF \
|
||||
-DLLAMA_BUILD_SERVER=OFF \
|
||||
-DMTMD_VIDEO=OFF \
|
||||
-S .
|
||||
cmake --build build-visionos-sim --config Release -j $(sysctl -n hw.logicalcpu) -- -quiet
|
||||
|
||||
@@ -481,6 +490,7 @@ cmake -B build-tvos-sim -G Xcode \
|
||||
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_OPENSSL=OFF \
|
||||
-DMTMD_VIDEO=OFF \
|
||||
-S .
|
||||
cmake --build build-tvos-sim --config Release -j $(sysctl -n hw.logicalcpu) -- -quiet
|
||||
|
||||
@@ -496,6 +506,7 @@ cmake -B build-tvos-device -G Xcode \
|
||||
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_OPENSSL=OFF \
|
||||
-DMTMD_VIDEO=OFF \
|
||||
-S .
|
||||
cmake --build build-tvos-device --config Release -j $(sysctl -n hw.logicalcpu) -- -quiet
|
||||
|
||||
|
||||
+212
-123
@@ -297,60 +297,6 @@ struct handle_model_result {
|
||||
std::string preset_path;
|
||||
};
|
||||
|
||||
static handle_model_result common_params_handle_model(struct common_params_model & model,
|
||||
const common_download_opts & opts) {
|
||||
handle_model_result result;
|
||||
|
||||
// TODO @ngxson : refactor this into a new common_model_download_context
|
||||
|
||||
if (!model.docker_repo.empty()) {
|
||||
model.path = common_docker_resolve_model(model.docker_repo);
|
||||
} else if (!model.hf_repo.empty()) {
|
||||
// If -m was used with -hf, treat the model "path" as the hf_file to download
|
||||
if (model.hf_file.empty() && !model.path.empty()) {
|
||||
model.hf_file = model.path;
|
||||
model.path = "";
|
||||
}
|
||||
common_download_opts hf_opts = opts;
|
||||
auto download_result = common_download_model(model, hf_opts);
|
||||
|
||||
if (!download_result.preset_path.empty()) {
|
||||
result.found_preset = true;
|
||||
result.preset_path = download_result.preset_path;
|
||||
return result; // skip everything else if preset.ini is used
|
||||
}
|
||||
|
||||
if (download_result.model_path.empty()) {
|
||||
throw std::runtime_error("failed to download model from Hugging Face");
|
||||
}
|
||||
|
||||
model.path = download_result.model_path;
|
||||
|
||||
if (!download_result.mmproj_path.empty()) {
|
||||
result.found_mmproj = true;
|
||||
result.mmproj.path = download_result.mmproj_path;
|
||||
}
|
||||
|
||||
if (!download_result.mtp_path.empty()) {
|
||||
result.found_mtp = true;
|
||||
result.mtp.path = download_result.mtp_path;
|
||||
}
|
||||
} else if (!model.url.empty()) {
|
||||
if (model.path.empty()) {
|
||||
auto f = string_split<std::string>(model.url, '#').front();
|
||||
f = string_split<std::string>(f, '?').front();
|
||||
model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
|
||||
}
|
||||
|
||||
auto download_result = common_download_model(model, opts);
|
||||
if (download_result.model_path.empty()) {
|
||||
throw std::runtime_error("failed to download model from " + model.url);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
const std::vector<ggml_type> kv_cache_types = {
|
||||
GGML_TYPE_F32,
|
||||
GGML_TYPE_F16,
|
||||
@@ -395,77 +341,204 @@ static bool parse_bool_value(const std::string & value) {
|
||||
}
|
||||
|
||||
//
|
||||
// CLI argument parsing functions
|
||||
// common_models_handler
|
||||
//
|
||||
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex, const common_params_handle_models_params & handle_params) {
|
||||
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
|
||||
params.speculative.types.end(),
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
|
||||
static std::string get_default_local_path(const std::string & url) {
|
||||
auto f = string_split<std::string>(url, '#').front();
|
||||
f = string_split<std::string>(f, '?').front();
|
||||
return fs_get_cache_file(string_split<std::string>(f, '/').back());
|
||||
}
|
||||
|
||||
common_models_handler common_models_handler_init(const common_params & params, llama_example curr_ex) {
|
||||
common_download_hf_plan plan;
|
||||
common_download_opts opts;
|
||||
|
||||
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
|
||||
params.speculative.types.end(),
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
|
||||
|
||||
// only download mmproj if the current example is using it
|
||||
bool use_mmproj = false;
|
||||
for (const auto & ex : mmproj_examples) {
|
||||
if (curr_ex == ex) {
|
||||
use_mmproj = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
opts.bearer_token = params.hf_token;
|
||||
opts.offline = params.offline;
|
||||
opts.skip_download = params.skip_download;
|
||||
opts.download_mtp = spec_type_draft_mtp;
|
||||
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
|
||||
opts.preset_only = handle_params.preset_only;
|
||||
opts.download_mmproj = use_mmproj && !params.no_mmproj
|
||||
&& params.mmproj.path.empty() && params.mmproj.url.empty();
|
||||
|
||||
if (handle_params.callback) {
|
||||
opts.callback = handle_params.callback;
|
||||
if (!params.model.hf_repo.empty()) {
|
||||
plan = common_download_get_hf_plan(params.model, opts);
|
||||
}
|
||||
|
||||
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
|
||||
// so we should not auto-discover mtp/mmproj siblings for them
|
||||
common_download_opts sub_opts = opts;
|
||||
sub_opts.download_mtp = false;
|
||||
sub_opts.download_mmproj = false;
|
||||
return common_models_handler{plan, opts};
|
||||
}
|
||||
|
||||
try {
|
||||
auto res = common_params_handle_model(params.model, opts);
|
||||
if (res.found_preset) {
|
||||
if (!params.models_preset.empty()) {
|
||||
throw std::invalid_argument("cannot use both --models-preset and -hf with a preset.ini file");
|
||||
bool common_models_handler_is_preset_repo(const common_models_handler & handler) {
|
||||
return !handler.plan.preset.url.empty();
|
||||
}
|
||||
|
||||
static std::vector<common_download_task> build_url_tasks(const common_params_model & model, common_download_opts opts) {
|
||||
auto parts = common_download_get_all_parts(model.url);
|
||||
std::vector<common_download_task> tasks;
|
||||
|
||||
// single-part: download straight to model.path if the user gave one (-m), else the cache default
|
||||
if (parts.size() == 1) {
|
||||
common_download_task task;
|
||||
task.url = parts[0];
|
||||
task.local_path = model.path.empty() ? get_default_local_path(parts[0]) : model.path;
|
||||
task.opts = opts;
|
||||
tasks.push_back(std::move(task));
|
||||
return tasks;
|
||||
}
|
||||
|
||||
// multi-part: place each part under the user's -m directory (if given), else the cache default
|
||||
std::string base_dir;
|
||||
if (!model.path.empty()) {
|
||||
auto pos = model.path.rfind('/');
|
||||
base_dir = pos == std::string::npos ? std::string(".") : model.path.substr(0, pos);
|
||||
}
|
||||
|
||||
for (const auto & part : parts) {
|
||||
common_download_task task;
|
||||
task.url = part;
|
||||
task.opts = opts;
|
||||
|
||||
std::string local = get_default_local_path(part);
|
||||
if (!base_dir.empty()) {
|
||||
auto pos = local.rfind('/');
|
||||
std::string name = pos == std::string::npos ? local : local.substr(pos + 1);
|
||||
local = base_dir + "/" + name;
|
||||
}
|
||||
task.local_path = local;
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
return tasks;
|
||||
}
|
||||
|
||||
void common_models_handler_apply(common_models_handler & handler, common_params & params, common_download_callback * callback) {
|
||||
std::vector<common_download_task> tasks;
|
||||
|
||||
auto & plan = handler.plan;
|
||||
|
||||
auto opts = handler.opts; // copy
|
||||
opts.callback = callback;
|
||||
|
||||
// handle plain "url" if needed
|
||||
auto handle_url = [&](common_params_model & model) {
|
||||
if (!model.url.empty()) {
|
||||
if (model.path.empty()) {
|
||||
model.path = get_default_local_path(model.url);
|
||||
}
|
||||
}
|
||||
};
|
||||
handle_url(params.model);
|
||||
handle_url(params.mmproj);
|
||||
handle_url(params.vocoder.model);
|
||||
handle_url(params.speculative.draft.mparams);
|
||||
|
||||
// optionally, if docker repo is set, resolve it
|
||||
if (!params.model.docker_repo.empty()) {
|
||||
params.model.url = common_docker_resolve_model(params.model.docker_repo);
|
||||
params.model.path = get_default_local_path(params.model.url);
|
||||
}
|
||||
|
||||
// handle plain "url" tasks (non-hf)
|
||||
if (!params.model.url.empty()) {
|
||||
auto url_tasks = build_url_tasks(params.model, opts);
|
||||
// the first part is what gets loaded, so point params.model.path at it
|
||||
if (!url_tasks.empty()) {
|
||||
std::string first_path = url_tasks.front().local_path;
|
||||
url_tasks.front().on_done = [&]() { params.model.path = first_path; };
|
||||
}
|
||||
for (auto & task : url_tasks) {
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
}
|
||||
if (!params.mmproj.url.empty()) {
|
||||
common_download_task task;
|
||||
task.url = params.mmproj.url;
|
||||
task.local_path = params.mmproj.path;
|
||||
task.opts = opts;
|
||||
tasks.push_back(task);
|
||||
}
|
||||
if (!params.vocoder.model.url.empty()) {
|
||||
common_download_task task;
|
||||
task.url = params.vocoder.model.url;
|
||||
task.local_path = params.vocoder.model.path;
|
||||
task.opts = opts;
|
||||
tasks.push_back(task);
|
||||
}
|
||||
if (!params.speculative.draft.mparams.url.empty()) {
|
||||
common_download_task task;
|
||||
task.url = params.speculative.draft.mparams.url;
|
||||
task.local_path = params.speculative.draft.mparams.path;
|
||||
task.opts = opts;
|
||||
tasks.push_back(task);
|
||||
}
|
||||
|
||||
// handle hf_plan tasks
|
||||
if (!plan.model_files.empty()) {
|
||||
for (size_t i = 0; i < plan.model_files.size(); ++i) {
|
||||
auto & model_file = plan.model_files[i];
|
||||
bool is_first = (i == 0);
|
||||
tasks.emplace_back(model_file, opts, [&, is_first]() {
|
||||
if (is_first) {
|
||||
// only use first part as model path
|
||||
params.model.path = hf_cache::finalize_file(model_file);
|
||||
} else {
|
||||
hf_cache::finalize_file(model_file);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
if (!plan.mmproj.local_path.empty()) {
|
||||
tasks.emplace_back(plan.mmproj, opts, [&]() {
|
||||
params.mmproj.path = hf_cache::finalize_file(plan.mmproj);
|
||||
});
|
||||
}
|
||||
if (!plan.mtp.local_path.empty()) {
|
||||
tasks.emplace_back(plan.mtp, opts, [&]() {
|
||||
// only fall back to the discovered MTP head when no draft was explicitly provided
|
||||
if (params.speculative.draft.mparams.empty()) {
|
||||
params.speculative.draft.mparams.path = hf_cache::finalize_file(plan.mtp);
|
||||
} else {
|
||||
hf_cache::finalize_file(plan.mtp);
|
||||
}
|
||||
});
|
||||
}
|
||||
if (!plan.preset.local_path.empty()) {
|
||||
tasks.emplace_back(plan.preset, opts, [&]() {
|
||||
// if HF repo is a preset repo, we simply run server in router mode with the preset.ini file
|
||||
params.models_preset_hf = params.model.hf_repo; // only for showing a warning
|
||||
params.models_preset = res.preset_path;
|
||||
params.models_preset = hf_cache::finalize_file(plan.preset);
|
||||
params.model = common_params_model{}; // make sure to clear model, so server starts in router mode
|
||||
return true;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (params.no_mmproj) {
|
||||
params.mmproj = {};
|
||||
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
|
||||
// optionally, handle mmproj model when -hf is specified
|
||||
params.mmproj = res.mmproj;
|
||||
}
|
||||
// only download mmproj if the current example is using it
|
||||
for (const auto & ex : mmproj_examples) {
|
||||
if (curr_ex == ex) {
|
||||
common_params_handle_model(params.mmproj, sub_opts);
|
||||
break;
|
||||
}
|
||||
}
|
||||
// run all tasks in parallel
|
||||
if (!params.offline) {
|
||||
common_download_run_tasks(tasks);
|
||||
}
|
||||
|
||||
// when --spec-type mtp is set and no draft model was provided explicitly,
|
||||
// fall back to the MTP head discovered alongside the -hf model
|
||||
if (spec_type_draft_mtp && res.found_mtp &&
|
||||
params.speculative.draft.mparams.path.empty() &&
|
||||
params.speculative.draft.mparams.hf_repo.empty() &&
|
||||
params.speculative.draft.mparams.url.empty()) {
|
||||
params.speculative.draft.mparams.path = res.mtp.path;
|
||||
// download successful, update params with the downloaded paths
|
||||
for (const auto & task : tasks) {
|
||||
if (task.on_done) {
|
||||
task.on_done();
|
||||
}
|
||||
common_params_handle_model(params.speculative.draft.mparams, sub_opts);
|
||||
common_params_handle_model(params.vocoder.model, sub_opts);
|
||||
return true;
|
||||
} catch (const common_skip_download_exception &) {
|
||||
return false;
|
||||
} catch (const std::exception &) {
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// CLI argument parsing functions
|
||||
//
|
||||
|
||||
static bool common_params_parse_ex(int argc, char ** argv, common_params_context & ctx_arg) {
|
||||
common_params & params = ctx_arg.params;
|
||||
|
||||
@@ -594,12 +667,15 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
const bool skip_model_download =
|
||||
// server will call common_params_handle_models() later, so we skip it here
|
||||
ctx_arg.ex == LLAMA_EXAMPLE_SERVER ||
|
||||
// download calls common_params_handle_models() itself and prints the paths
|
||||
ctx_arg.ex == LLAMA_EXAMPLE_DOWNLOAD ||
|
||||
// export_graph_ops loads only metadata
|
||||
ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
|
||||
|
||||
if (!skip_model_download) {
|
||||
// handle model and download
|
||||
common_params_handle_models(params, ctx_arg.ex, {});
|
||||
common_models_handler handler = common_models_handler_init(params, ctx_arg.ex);
|
||||
common_models_handler_apply(handler, params);
|
||||
|
||||
// model is required (except for server)
|
||||
// TODO @ngxson : maybe show a list of available models in CLI in this case
|
||||
@@ -671,15 +747,19 @@ static void common_params_print_usage(common_params_context & ctx_arg) {
|
||||
common_options.push_back(&opt);
|
||||
}
|
||||
}
|
||||
printf("----- common params -----\n\n");
|
||||
print_options(common_options);
|
||||
printf("\n\n----- sampling params -----\n\n");
|
||||
print_options(sampling_options);
|
||||
printf("\n\n----- speculative params -----\n\n");
|
||||
print_options(spec_options);
|
||||
// TODO: maybe convert enum llama_example to string
|
||||
printf("\n\n----- example-specific params -----\n\n");
|
||||
print_options(specific_options);
|
||||
bool first = true;
|
||||
auto print_section = [&](const char * header, std::vector<common_arg *> & options) {
|
||||
if (options.empty()) {
|
||||
return;
|
||||
}
|
||||
printf("%s----- %s -----\n\n", first ? "" : "\n\n", header);
|
||||
first = false;
|
||||
print_options(options);
|
||||
};
|
||||
print_section("common params", common_options);
|
||||
print_section("sampling params", sampling_options);
|
||||
print_section("speculative params", spec_options);
|
||||
print_section("example-specific params", specific_options);
|
||||
}
|
||||
|
||||
static void common_params_print_completion(common_params_context & ctx_arg) {
|
||||
@@ -1079,7 +1159,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
* - if both {LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_*,} are set, we will prioritize the LLAMA_EXAMPLE_* matching current example
|
||||
*/
|
||||
auto add_opt = [&](common_arg arg) {
|
||||
if ((arg.in_example(ex) || arg.in_example(LLAMA_EXAMPLE_COMMON)) && !arg.is_exclude(ex)) {
|
||||
// download only exposes the handful of args explicitly tagged for it
|
||||
const bool inherit_common = ex != LLAMA_EXAMPLE_DOWNLOAD;
|
||||
if ((arg.in_example(ex) || (inherit_common && arg.in_example(LLAMA_EXAMPLE_COMMON))) && !arg.is_exclude(ex)) {
|
||||
ctx_arg.options.push_back(std::move(arg));
|
||||
}
|
||||
};
|
||||
@@ -1090,7 +1172,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params) {
|
||||
params.usage = true;
|
||||
}
|
||||
));
|
||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}));
|
||||
add_opt(common_arg(
|
||||
{"--version"},
|
||||
"show version and build info",
|
||||
@@ -2212,7 +2294,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, bool value) {
|
||||
params.no_mmproj = !value;
|
||||
}
|
||||
).set_examples(mmproj_examples).set_env("LLAMA_ARG_MMPROJ_AUTO"));
|
||||
).set_examples({LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_MMPROJ_AUTO"));
|
||||
add_opt(common_arg(
|
||||
{"--mmproj-offload"},
|
||||
{"--no-mmproj-offload"},
|
||||
@@ -2611,14 +2693,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.model.path = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}).set_env("LLAMA_ARG_MODEL"));
|
||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_MODEL"));
|
||||
add_opt(common_arg(
|
||||
{"-mu", "--model-url"}, "MODEL_URL",
|
||||
"model download url (default: unused)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.model.url = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_MODEL_URL"));
|
||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_MODEL_URL"));
|
||||
add_opt(common_arg(
|
||||
{ "-dr", "--docker-repo" }, "[<repo>/]<model>[:quant]",
|
||||
"Docker Hub model repository. repo is optional, default to ai/. quant is optional, default to :latest.\n"
|
||||
@@ -2627,7 +2709,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.model.docker_repo = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_DOCKER_REPO"));
|
||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_DOCKER_REPO"));
|
||||
add_opt(common_arg(
|
||||
{"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]",
|
||||
"Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"
|
||||
@@ -2637,14 +2719,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.model.hf_repo = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_HF_REPO"));
|
||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_HF_REPO"));
|
||||
add_opt(common_arg(
|
||||
{"-hff", "--hf-file"}, "FILE",
|
||||
"Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.model.hf_file = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_HF_FILE"));
|
||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_HF_FILE"));
|
||||
add_opt(common_arg(
|
||||
{"-hfv", "-hfrv", "--hf-repo-v"}, "<user>/<model>[:quant]",
|
||||
"Hugging Face model repository for the vocoder model (default: unused)",
|
||||
@@ -2665,7 +2747,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.hf_token = value;
|
||||
}
|
||||
).set_env("HF_TOKEN"));
|
||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("HF_TOKEN"));
|
||||
add_opt(common_arg(
|
||||
{"--mtp"},
|
||||
"also download the multi-token prediction (MTP) head, if available (default: unused)",
|
||||
[](common_params & params) {
|
||||
params.speculative.types.push_back(COMMON_SPECULATIVE_TYPE_DRAFT_MTP);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_DOWNLOAD}));
|
||||
add_opt(common_arg(
|
||||
{"--context-file"}, "FNAME",
|
||||
"file to load context from (use comma-separated values to specify multiple files)",
|
||||
|
||||
+12
-11
@@ -8,6 +8,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
|
||||
// pseudo-env variable to identify preset-only arguments
|
||||
#define COMMON_ARG_PRESET_LOAD_ON_STARTUP "__PRESET_LOAD_ON_STARTUP"
|
||||
@@ -130,19 +131,19 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
|
||||
// see: https://github.com/ggml-org/llama.cpp/issues/18163
|
||||
void common_params_add_preset_options(std::vector<common_arg> & args);
|
||||
|
||||
struct common_params_handle_models_params {
|
||||
common_download_callback * callback = nullptr;
|
||||
bool preset_only = false; // if true, only check & download remote preset (for router mode)
|
||||
struct common_models_handler {
|
||||
common_download_hf_plan plan;
|
||||
common_download_opts opts;
|
||||
};
|
||||
|
||||
// populate model paths (main model, mmproj, etc) from -hf if necessary
|
||||
// return true if the model is ready to use
|
||||
// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc)
|
||||
// if params.skip_download is true, no downloads will be attempted. return false if the model is invalid or missing (e.g. ETag check failed)
|
||||
bool common_params_handle_models(
|
||||
common_params & params,
|
||||
llama_example curr_ex,
|
||||
const common_params_handle_models_params & handle_params);
|
||||
// initialize downloading opts and hf_plan if needed, but does not download anything yet
|
||||
common_models_handler common_models_handler_init(const common_params & params, llama_example curr_ex);
|
||||
|
||||
// check if the model is a preset repo (i.e. has a preset file)
|
||||
bool common_models_handler_is_preset_repo(const common_models_handler & handler);
|
||||
|
||||
// download and update params with the downloaded model path
|
||||
void common_models_handler_apply(common_models_handler & handler, common_params & params, common_download_callback * callback = nullptr);
|
||||
|
||||
// initialize argument parser context - used by test-arg-parser and preset
|
||||
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
||||
|
||||
@@ -2758,5 +2758,9 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
||||
std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_templates * chat_templates) {
|
||||
GGML_ASSERT(chat_templates != nullptr);
|
||||
GGML_ASSERT(chat_templates->template_default != nullptr);
|
||||
if (chat_templates->template_tool_use != nullptr) {
|
||||
// take the more expressive template when available
|
||||
return chat_templates->template_tool_use->caps.to_map();
|
||||
}
|
||||
return chat_templates->template_default->caps.to_map();
|
||||
}
|
||||
|
||||
+12
-8
@@ -96,6 +96,7 @@ enum llama_example {
|
||||
LLAMA_EXAMPLE_FIT_PARAMS,
|
||||
LLAMA_EXAMPLE_RESULTS,
|
||||
LLAMA_EXAMPLE_EXPORT_GRAPH_OPS,
|
||||
LLAMA_EXAMPLE_DOWNLOAD,
|
||||
|
||||
LLAMA_EXAMPLE_COUNT,
|
||||
};
|
||||
@@ -290,13 +291,13 @@ struct common_params_sampling {
|
||||
};
|
||||
|
||||
struct common_params_model {
|
||||
std::string path = ""; // model local path // NOLINT
|
||||
std::string url = ""; // model url to download // NOLINT
|
||||
std::string hf_repo = ""; // HF repo // NOLINT
|
||||
std::string hf_file = ""; // HF file // NOLINT
|
||||
std::string docker_repo = ""; // Docker repo // NOLINT
|
||||
std::string path = ""; // model local path
|
||||
std::string url = ""; // model url to download
|
||||
std::string hf_repo = ""; // HF repo
|
||||
std::string hf_file = ""; // HF file
|
||||
std::string docker_repo = ""; // Docker repo
|
||||
|
||||
std::string get_name() {
|
||||
std::string get_name() const {
|
||||
if (!hf_repo.empty()) {
|
||||
return hf_repo;
|
||||
}
|
||||
@@ -305,6 +306,10 @@ struct common_params_model {
|
||||
}
|
||||
return path;
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
return get_name().empty();
|
||||
}
|
||||
};
|
||||
|
||||
// draft-model-based speculative decoding parameters
|
||||
@@ -367,7 +372,7 @@ struct common_params_speculative {
|
||||
common_params_speculative_ngram_cache ngram_cache;
|
||||
|
||||
bool has_dft() const {
|
||||
return !draft.mparams.path.empty() || !draft.mparams.hf_repo.empty();
|
||||
return !draft.mparams.empty();
|
||||
}
|
||||
|
||||
uint32_t need_n_rs_seq() const {
|
||||
@@ -519,7 +524,6 @@ struct common_params {
|
||||
int32_t control_vector_layer_start = -1; // layer range for control vector
|
||||
int32_t control_vector_layer_end = -1; // layer range for control vector
|
||||
bool offline = false;
|
||||
bool skip_download = false; // skip model file downloading
|
||||
|
||||
int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
|
||||
int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
|
||||
|
||||
+22
-118
@@ -292,10 +292,6 @@ static int common_download_file_single_online(const std::string & url,
|
||||
|
||||
const bool file_exists = std::filesystem::exists(path);
|
||||
|
||||
if (!file_exists && opts.skip_download) {
|
||||
return -2; // file is missing and download is disabled
|
||||
}
|
||||
|
||||
if (file_exists && skip_etag) {
|
||||
LOG_DBG("%s: using cached file: %s\n", __func__, path.c_str());
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
@@ -362,9 +358,6 @@ static int common_download_file_single_online(const std::string & url,
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
// pass this point, the file exists but is different from the server version, so we need to redownload it
|
||||
if (opts.skip_download) {
|
||||
return -2; // special code to indicate that the download was skipped due to etag mismatch
|
||||
}
|
||||
if (remove(path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||
return -1;
|
||||
@@ -691,19 +684,8 @@ static void list_available_gguf_files(const hf_cache::hf_files & files) {
|
||||
}
|
||||
}
|
||||
|
||||
struct hf_plan {
|
||||
hf_cache::hf_file primary;
|
||||
hf_cache::hf_files model_files;
|
||||
hf_cache::hf_file mmproj;
|
||||
hf_cache::hf_file mtp;
|
||||
hf_cache::hf_file preset; // if set, only this file is downloaded
|
||||
};
|
||||
|
||||
static hf_plan get_hf_plan(const common_params_model & model,
|
||||
const common_download_opts & opts,
|
||||
bool download_mmproj,
|
||||
bool download_mtp) {
|
||||
hf_plan plan;
|
||||
common_download_hf_plan common_download_get_hf_plan(const common_params_model & model, const common_download_opts & opts) {
|
||||
common_download_hf_plan plan;
|
||||
hf_cache::hf_files all;
|
||||
|
||||
auto [repo, tag] = common_download_split_repo_tag(model.hf_repo);
|
||||
@@ -752,127 +734,49 @@ static hf_plan get_hf_plan(const common_params_model & model,
|
||||
plan.primary = primary;
|
||||
plan.model_files = get_split_files(all, primary);
|
||||
|
||||
if (download_mmproj) {
|
||||
if (opts.download_mmproj) {
|
||||
plan.mmproj = find_best_mmproj(all, primary.path);
|
||||
}
|
||||
|
||||
if (download_mtp) {
|
||||
if (opts.download_mtp) {
|
||||
plan.mtp = find_best_mtp(all, primary.path);
|
||||
}
|
||||
|
||||
return plan;
|
||||
}
|
||||
|
||||
struct download_task {
|
||||
std::string url;
|
||||
std::string path;
|
||||
};
|
||||
|
||||
static std::vector<download_task> get_url_tasks(const common_params_model & model) {
|
||||
auto split = get_gguf_split_info(model.url);
|
||||
|
||||
if (split.count <= 1) {
|
||||
return {{model.url, model.path}};
|
||||
}
|
||||
|
||||
auto filename = split.prefix;
|
||||
if (auto pos = split.prefix.rfind('/'); pos != std::string::npos) {
|
||||
filename = split.prefix.substr(pos + 1);
|
||||
}
|
||||
|
||||
auto parent_path = std::filesystem::path(model.path).parent_path();
|
||||
auto prefix_path = (parent_path / filename).string();
|
||||
|
||||
std::vector<download_task> tasks;
|
||||
for (int i = 1; i <= split.count; i++) {
|
||||
auto suffix = string_format("-%05d-of-%05d.gguf", i, split.count);
|
||||
tasks.push_back({split.prefix + suffix, prefix_path + suffix});
|
||||
}
|
||||
return tasks;
|
||||
}
|
||||
|
||||
common_download_model_result common_download_model(const common_params_model & model,
|
||||
const common_download_opts & opts) {
|
||||
common_download_model_result result;
|
||||
std::vector<download_task> tasks;
|
||||
hf_plan hf;
|
||||
|
||||
bool download_mmproj = opts.download_mmproj;
|
||||
bool download_mtp = opts.download_mtp;
|
||||
bool preset_only = opts.preset_only;
|
||||
bool is_hf = !model.hf_repo.empty();
|
||||
|
||||
if (is_hf) {
|
||||
hf = get_hf_plan(model, opts, download_mmproj, download_mtp);
|
||||
if (!hf.preset.path.empty()) {
|
||||
// if preset.ini exists, only download that file alone
|
||||
tasks.push_back({hf.preset.url, hf.preset.local_path});
|
||||
} else if (!preset_only) {
|
||||
// only add other files if we're NOT in preset-only mode (normal run, non-router)
|
||||
for (const auto & f : hf.model_files) {
|
||||
tasks.push_back({f.url, f.local_path});
|
||||
}
|
||||
if (!hf.mmproj.path.empty()) {
|
||||
tasks.push_back({hf.mmproj.url, hf.mmproj.local_path});
|
||||
}
|
||||
if (!hf.mtp.path.empty()) {
|
||||
tasks.push_back({hf.mtp.url, hf.mtp.local_path});
|
||||
}
|
||||
}
|
||||
} else if (!model.url.empty()) {
|
||||
tasks = get_url_tasks(model);
|
||||
} else {
|
||||
result.model_path = model.path;
|
||||
return result;
|
||||
}
|
||||
|
||||
if (tasks.empty()) {
|
||||
return result;
|
||||
}
|
||||
|
||||
void common_download_run_tasks(const std::vector<common_download_task> & tasks) {
|
||||
std::vector<std::future<int>> futures;
|
||||
for (const auto & task : tasks) {
|
||||
futures.push_back(std::async(std::launch::async,
|
||||
[&task, &opts, is_hf]() {
|
||||
return common_download_file_single(task.url, task.path, opts, is_hf);
|
||||
[&task]() {
|
||||
return common_download_file_single(task.url, task.local_path, task.opts, task.is_hf);
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
for (auto & f : futures) {
|
||||
int status = f.get();
|
||||
if (status == -2 && opts.skip_download) {
|
||||
throw common_skip_download_exception();
|
||||
}
|
||||
for (size_t i = 0; i < futures.size(); ++i) {
|
||||
std::string url = tasks[i].url;
|
||||
int status = futures[i].get();
|
||||
bool is_ok = is_http_status_ok(status);
|
||||
if (!is_ok) {
|
||||
return {};
|
||||
throw std::runtime_error(string_format("Download '%s' failed with status code: %d", url.c_str(), status));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (is_hf) {
|
||||
if (!hf.preset.path.empty()) {
|
||||
// if preset.ini is used, do not set other paths
|
||||
result.preset_path = hf_cache::finalize_file(hf.preset);
|
||||
} else {
|
||||
for (const auto & f : hf.model_files) {
|
||||
hf_cache::finalize_file(f);
|
||||
}
|
||||
result.model_path = hf.primary.final_path;
|
||||
std::vector<std::string> common_download_get_all_parts(const std::string & url) {
|
||||
auto split = get_gguf_split_info(url);
|
||||
|
||||
if (!hf.mmproj.path.empty()) {
|
||||
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
|
||||
}
|
||||
|
||||
if (!hf.mtp.path.empty()) {
|
||||
result.mtp_path = hf_cache::finalize_file(hf.mtp);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result.model_path = model.path;
|
||||
if (split.count <= 1) {
|
||||
return {url};
|
||||
}
|
||||
|
||||
return result;
|
||||
std::vector<std::string> parts;
|
||||
for (int i = 1; i <= split.count; i++) {
|
||||
auto suffix = string_format("-%05d-of-%05d.gguf", i, split.count);
|
||||
parts.push_back(split.prefix + suffix);
|
||||
}
|
||||
return parts;
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
+28
-43
@@ -1,7 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "hf-cache.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
|
||||
struct common_params_model;
|
||||
|
||||
@@ -47,67 +50,40 @@ struct common_cached_model_info {
|
||||
}
|
||||
};
|
||||
|
||||
// Options for common_download_model and common_download_file_single
|
||||
// Options for common_download_file_single
|
||||
struct common_download_opts {
|
||||
std::string bearer_token;
|
||||
common_header_list headers;
|
||||
bool offline = false;
|
||||
bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid
|
||||
bool download_mmproj = false;
|
||||
bool download_mtp = false;
|
||||
bool preset_only = false; // if true, only check & download remote preset (for router mode)
|
||||
common_download_callback * callback = nullptr;
|
||||
};
|
||||
|
||||
// Result of common_download_model
|
||||
struct common_download_model_result {
|
||||
std::string model_path;
|
||||
std::string mmproj_path;
|
||||
std::string mtp_path;
|
||||
std::string preset_path;
|
||||
struct common_download_task {
|
||||
common_download_opts opts;
|
||||
std::string url;
|
||||
std::string local_path;
|
||||
std::function<void()> on_done;
|
||||
bool is_hf = false;
|
||||
|
||||
common_download_task() = default;
|
||||
common_download_task(hf_cache::hf_file f,
|
||||
const common_download_opts & opts,
|
||||
std::function<void()> on_done = nullptr)
|
||||
: opts(opts), url(f.url), local_path(f.local_path), on_done(on_done), is_hf(true) {}
|
||||
};
|
||||
|
||||
// throw if the file is missing or invalid (e.g. ETag check failed)
|
||||
struct common_skip_download_exception : public std::runtime_error {
|
||||
common_skip_download_exception() : std::runtime_error("skip download") {}
|
||||
};
|
||||
void common_download_run_tasks(const std::vector<common_download_task> & tasks);
|
||||
|
||||
// Download model from HuggingFace repo or URL
|
||||
//
|
||||
// input (via model struct):
|
||||
// - model.hf_repo: HF repo with optional tag, see common_download_split_repo_tag
|
||||
// - model.hf_file: specific file in the repo (requires hf_repo)
|
||||
// - model.url: simple download (used if hf_repo is empty)
|
||||
// - model.path: local file path
|
||||
//
|
||||
// tag matching (for HF repos without model.hf_file):
|
||||
// - if tag is specified, searches for GGUF matching that quantization
|
||||
// - if no tag, searches for Q4_K_M, then Q4_0, then first available GGUF
|
||||
//
|
||||
// split GGUF: multi-part files like "model-00001-of-00003.gguf" are automatically
|
||||
// detected and all parts are downloaded
|
||||
//
|
||||
// caching:
|
||||
// - HF repos: uses HuggingFace cache
|
||||
// - URLs: uses ETag-based caching
|
||||
//
|
||||
// when opts.offline=true, no network requests are made
|
||||
// when download_mmproj=true, searches for mmproj in same directory as model or any parent directory
|
||||
// then with the closest quantization bits
|
||||
// when download_mtp=true, applies the same sibling search for an MTP-head GGUF
|
||||
//
|
||||
// returns result with model_path, mmproj_path and mtp_path (empty when not found / on failure)
|
||||
common_download_model_result common_download_model(
|
||||
const common_params_model & model,
|
||||
const common_download_opts & opts = {}
|
||||
);
|
||||
// if url is a multi-part GGUF file, returns all parts, otherwise returns the single file
|
||||
std::vector<std::string> common_download_get_all_parts(const std::string & url);
|
||||
|
||||
// returns list of cached models
|
||||
std::vector<common_cached_model_info> common_list_cached_models();
|
||||
|
||||
// download single file from url to local path
|
||||
// returns status code or -1 on error
|
||||
// returns -2 if the download was skipped due to ETag mismatch (file outdated, skip_download=true)
|
||||
// skip_etag: if true, don't read/write .etag files (for HF cache where filename is the hash)
|
||||
int common_download_file_single(const std::string & url,
|
||||
const std::string & path,
|
||||
@@ -124,3 +100,12 @@ std::string common_docker_resolve_model(const std::string & docker);
|
||||
// - if tag is present, removes only files matching that tag (and orphaned blobs)
|
||||
// returns true if anything was removed
|
||||
bool common_download_remove(const std::string & hf_repo_with_tag);
|
||||
|
||||
struct common_download_hf_plan {
|
||||
hf_cache::hf_file primary;
|
||||
hf_cache::hf_files model_files;
|
||||
hf_cache::hf_file mmproj;
|
||||
hf_cache::hf_file mtp;
|
||||
hf_cache::hf_file preset; // if set, only this file is downloaded
|
||||
};
|
||||
common_download_hf_plan common_download_get_hf_plan(const common_params_model & model, const common_download_opts & opts);
|
||||
|
||||
@@ -136,6 +136,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"LlamaModel": "llama",
|
||||
"Eagle3DraftModel": "llama",
|
||||
"Eagle3Speculator": "llama",
|
||||
"Eagle3LlamaForCausalLM": "llama",
|
||||
"LlamaForCausalLMEagle3": "llama",
|
||||
"LlavaForConditionalGeneration": "llama",
|
||||
"LlavaStableLMEpochForCausalLM": "stablelm",
|
||||
|
||||
@@ -23,6 +23,7 @@ from .base import ModelBase, TextModel, gguf, logger
|
||||
"LlavaForConditionalGeneration",
|
||||
"VoxtralForConditionalGeneration",
|
||||
"LlamaForCausalLMEagle3",
|
||||
"Eagle3LlamaForCausalLM",
|
||||
"Eagle3Speculator",
|
||||
"Eagle3DraftModel",
|
||||
"IQuestCoderForCausalLM",
|
||||
|
||||
+3
-4
@@ -114,7 +114,8 @@ class Mamba2Model(TextModel):
|
||||
hparams["text_config"] = hparams["llm_config"]
|
||||
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
|
||||
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
|
||||
self.d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * self.d_model
|
||||
self.expand = self.find_hparam(["mamba_expand", "expand"], optional=True) or 2
|
||||
self.d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or self.expand * self.d_model
|
||||
self.n_group = self.find_hparam(["n_groups"], optional=True) or 1
|
||||
|
||||
def set_vocab(self):
|
||||
@@ -144,11 +145,9 @@ class Mamba2Model(TextModel):
|
||||
|
||||
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
|
||||
|
||||
# Fail early for models which don't have a block expansion factor of 2
|
||||
# TODO: does this really matter?
|
||||
# skip the assertion for FalconH1 Model
|
||||
if self.model_arch != gguf.MODEL_ARCH.FALCON_H1:
|
||||
assert self.d_inner == 2 * self.d_model
|
||||
assert self.d_inner == self.expand * self.d_model
|
||||
assert self.d_inner % head_dim == 0
|
||||
|
||||
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
|
||||
|
||||
@@ -413,6 +413,15 @@ In two device selection modes, the default SYCL backend is level_zero, you can c
|
||||
|------------------|----------------------------------------|
|
||||
| Single device | --split-mode none --main-gpu DEVICE_ID |
|
||||
| Multiple devices | --split-mode layer (default) |
|
||||
| Multiple devices | --split-mode tensor (tensor parallelism) |
|
||||
|
||||
`--split-mode tensor` (tensor parallelism) shards each layer across the selected
|
||||
GPUs. It requires flash attention, which is auto-enabled when `--flash-attn` is
|
||||
left at its default `auto`, so `--split-mode tensor` works out of the box.
|
||||
Passing `--flash-attn off` together with `--split-mode tensor` is rejected at
|
||||
context creation. The default `f16` KV cache is recommended. Tensor parallelism
|
||||
is currently optimized for 2 GPUs; other device counts fall back to a generic
|
||||
all-reduce.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -715,6 +724,15 @@ In two device selection modes, the default SYCL backend is level_zero, you can c
|
||||
|------------------|----------------------------------------|
|
||||
| Single device | --split-mode none --main-gpu DEVICE_ID |
|
||||
| Multiple devices | --split-mode layer (default) |
|
||||
| Multiple devices | --split-mode tensor (tensor parallelism) |
|
||||
|
||||
`--split-mode tensor` (tensor parallelism) shards each layer across the selected
|
||||
GPUs. It requires flash attention, which is auto-enabled when `--flash-attn` is
|
||||
left at its default `auto`, so `--split-mode tensor` works out of the box.
|
||||
Passing `--flash-attn off` together with `--split-mode tensor` is rejected at
|
||||
context creation. The default `f16` KV cache is recommended. Tensor parallelism
|
||||
is currently optimized for 2 GPUs; other device counts fall back to a generic
|
||||
all-reduce.
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
+41
-1
@@ -13,6 +13,45 @@ The `llama-server` application supports several implementations of speculative d
|
||||
A much smaller model (called the _draft model_) generates drafts.
|
||||
A draft model is the most used approach in speculative decoding.
|
||||
|
||||
### EAGLE-3 (`draft-eagle3`)
|
||||
|
||||
EAGLE-3 uses a small draft model that reads the target model's hidden states to predict the next tokens, so it
|
||||
reaches higher acceptance than a standalone draft model of the same size. The draft is a one-layer transformer
|
||||
trained for a specific target model; it shares the target model's tokenizer and, optionally, uses a reduced draft
|
||||
vocabulary with its own `lm_head`, which is mapped back using a `d2t` table.
|
||||
|
||||
Convert the EAGLE-3 checkpoint with `--target-model-dir` so it inherits the target's tokenizer and the layer
|
||||
indices to read. Both the SpecForge `LlamaForCausalLMEagle3` and the vLLM/AngelSlim `Eagle3LlamaForCausalLM`
|
||||
checkpoint formats are supported (for example [`AngelSlim/Qwen3-4B_eagle3`](https://huggingface.co/AngelSlim/Qwen3-4B_eagle3)
|
||||
for `Qwen/Qwen3-4B`):
|
||||
|
||||
```bash
|
||||
python convert_hf_to_gguf.py AngelSlim/Qwen3-4B_eagle3 \
|
||||
--target-model-dir Qwen/Qwen3-4B --outtype bf16 --outfile Qwen3-4B-eagle3.gguf
|
||||
|
||||
llama-server -m Qwen3-4B.gguf -md Qwen3-4B-eagle3.gguf --spec-type draft-eagle3
|
||||
```
|
||||
|
||||
Supported EAGLE-3 draft models include:
|
||||
|
||||
- [yuhuili/EAGLE3-LLaMA3.1-Instruct-8B](https://huggingface.co/yuhuili/EAGLE3-LLaMA3.1-Instruct-8B)
|
||||
- [yuhuili/EAGLE3-LLaMA3.3-Instruct-70B](https://huggingface.co/yuhuili/EAGLE3-LLaMA3.3-Instruct-70B)
|
||||
- [RedHatAI/gemma-4-31B-it-speculator.eagle3](https://huggingface.co/RedHatAI/gemma-4-31B-it-speculator.eagle3)
|
||||
- [RedHatAI/gemma-4-26B-A4B-it-speculator.eagle3](https://huggingface.co/RedHatAI/gemma-4-26B-A4B-it-speculator.eagle3)
|
||||
- [Tengyunw/qwen3_8b_eagle3](https://huggingface.co/Tengyunw/qwen3_8b_eagle3)
|
||||
- [Tengyunw/qwen3_30b_moe_eagle3](https://huggingface.co/Tengyunw/qwen3_30b_moe_eagle3)
|
||||
- [AngelSlim/Qwen3-1.7B_eagle3](https://huggingface.co/AngelSlim/Qwen3-1.7B_eagle3)
|
||||
- [AngelSlim/Qwen3-4B_eagle3](https://huggingface.co/AngelSlim/Qwen3-4B_eagle3)
|
||||
- [AngelSlim/Qwen3-8B_eagle3](https://huggingface.co/AngelSlim/Qwen3-8B_eagle3)
|
||||
- [AngelSlim/Qwen3-14B_eagle3](https://huggingface.co/AngelSlim/Qwen3-14B_eagle3)
|
||||
- [AngelSlim/Qwen3-32B_eagle3](https://huggingface.co/AngelSlim/Qwen3-32B_eagle3)
|
||||
- [AngelSlim/Qwen3-a3B_eagle3](https://huggingface.co/AngelSlim/Qwen3-a3B_eagle3)
|
||||
- [RedHatAI/gpt-oss-20b-speculator.eagle3](https://huggingface.co/RedHatAI/gpt-oss-20b-speculator.eagle3)
|
||||
- [lmsys/EAGLE3-gpt-oss-120b-bf16](https://huggingface.co/lmsys/EAGLE3-gpt-oss-120b-bf16)
|
||||
- [nvidia/gpt-oss-120b-Eagle3-long-context](https://huggingface.co/nvidia/gpt-oss-120b-Eagle3-long-context)
|
||||
|
||||
For the full and up-to-date list of supported models, see #18039.
|
||||
|
||||
### n-gram Cache (`ngram-cache`)
|
||||
|
||||
An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences.
|
||||
@@ -108,7 +147,7 @@ If a draft model is combined with a draftless decoding the draftless decoding ha
|
||||
### General Speculative Parameters
|
||||
|
||||
```
|
||||
--spec-type [none|draft-simple|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
|
||||
--spec-type [none|draft-simple|draft-eagle3|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
|
||||
comma-separated list of types of speculative decoding to use
|
||||
(default: none)
|
||||
(env: LLAMA_ARG_SPEC_TYPE)
|
||||
@@ -247,6 +286,7 @@ Specifies a comma-separated list of speculative decoding types to use.
|
||||
|------|-------------|
|
||||
| `none` | No speculative decoding (default) |
|
||||
| `draft-simple` | Use a simple draft model for speculation |
|
||||
| `draft-eagle3` | Use an EAGLE-3 draft model that reads the target's hidden states |
|
||||
| `draft-mtp` | Use Multi Token Prediction (MTP) heads from the main model |
|
||||
| `ngram-cache` | Use n-gram cache lookup |
|
||||
| `ngram-simple` | Use simple n-gram pattern matching |
|
||||
|
||||
+1
-1
@@ -5,7 +5,7 @@ project("ggml" C CXX ASM)
|
||||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 15)
|
||||
set(GGML_VERSION_PATCH 2)
|
||||
set(GGML_VERSION_PATCH 3)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
|
||||
@@ -27,6 +27,14 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int de
|
||||
// split tensor buffer that splits matrices by rows across multiple devices
|
||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split);
|
||||
|
||||
// Tensor parallelism (--split-mode tensor): comm_init/free/allreduce_tensor
|
||||
// trio queried by the meta-backend via ggml_backend_reg_get_proc_address.
|
||||
// See typedefs in ggml/include/ggml-backend.h. Mirrors the CUDA backend's
|
||||
// pattern (ggml_backend_cuda_comm_*).
|
||||
GGML_BACKEND_API void * ggml_backend_sycl_comm_init(ggml_backend_t * backends, size_t n_backends);
|
||||
GGML_BACKEND_API void ggml_backend_sycl_comm_free(void * comm_ctx);
|
||||
GGML_BACKEND_API bool ggml_backend_sycl_comm_allreduce_tensor(void * comm_ctx, struct ggml_tensor ** tensors);
|
||||
|
||||
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void);
|
||||
|
||||
|
||||
@@ -75,12 +75,12 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
|
||||
ay1 = GGML_F32_VEC_LOAD(y + i);
|
||||
sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
|
||||
}
|
||||
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
|
||||
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmla on available elements only
|
||||
if (np2 < n) {
|
||||
svbool_t pg = svwhilelt_b32(np2, n);
|
||||
ax1 = svld1_f32(pg, x + np2);
|
||||
ay1 = svld1_f32(pg, y + np2);
|
||||
sum1 = svmad_f32_m(pg, ax1, ay1, sum1);
|
||||
sum1 = svmla_f32_m(pg, sum1, ax1, ay1);
|
||||
}
|
||||
// reduce sum1,sum2 to sum1
|
||||
GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);
|
||||
|
||||
@@ -34,26 +34,26 @@ template <float (*bin_op)(const float, const float),
|
||||
static __global__ void k_bin_bcast(const src0_t * src0,
|
||||
const src1_t * src1,
|
||||
dst_t * dst,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const uint32_t ne0,
|
||||
const uint32_t ne1,
|
||||
const uint32_t ne2,
|
||||
const uint3 ne3,
|
||||
const uint3 ne10,
|
||||
const uint3 ne11,
|
||||
const uint3 ne12,
|
||||
const uint3 ne13,
|
||||
/*const int s0,*/
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
const int s00,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
const int s10,
|
||||
const int s11,
|
||||
const int s12,
|
||||
const int s13,
|
||||
/*const uint32_t s0,*/
|
||||
const uint32_t s1,
|
||||
const uint32_t s2,
|
||||
const uint32_t s3,
|
||||
const uint32_t s00,
|
||||
const uint32_t s01,
|
||||
const uint32_t s02,
|
||||
const uint32_t s03,
|
||||
const uint32_t s10,
|
||||
const uint32_t s11,
|
||||
const uint32_t s12,
|
||||
const uint32_t s13,
|
||||
src1_ptrs... src1s) {
|
||||
ggml_cuda_pdl_lc();
|
||||
const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
@@ -61,7 +61,7 @@ static __global__ void k_bin_bcast(const src0_t * src0,
|
||||
const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
|
||||
const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
|
||||
|
||||
if (i0s >= (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) {
|
||||
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3.z) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -69,25 +69,32 @@ static __global__ void k_bin_bcast(const src0_t * src0,
|
||||
const uint32_t i12 = fastmodulo(i2, ne12);
|
||||
const uint32_t i13 = fastmodulo(i3, ne13);
|
||||
|
||||
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
|
||||
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
||||
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
||||
const size_t i_src0 = size_t( i3)*s03 + size_t( i2)*s02 + size_t( i1)*s01;
|
||||
const size_t i_src1 = size_t(i13)*s13 + size_t(i12)*s12 + size_t(i11)*s11;
|
||||
const size_t i_dst = size_t( i3)*s3 + size_t( i2)*s2 + size_t( i1)*s1;
|
||||
|
||||
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
const uint32_t s0 = blockDim.x * gridDim.x;
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
|
||||
for (uint32_t i0 = i0s; i0 < ne0; i0 += s0) {
|
||||
const uint32_t i10 = fastmodulo(i0, ne10);
|
||||
|
||||
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
|
||||
float result = src0_row ? (float) src0_row[size_t(i0)*s00] : 0.0f;
|
||||
if constexpr (sizeof...(src1_ptrs) > 0) {
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + size_t(i10)*s10])));
|
||||
} else {
|
||||
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
|
||||
result = bin_op(result, (float)src1[i_src1 + size_t(i10)*s10]);
|
||||
}
|
||||
|
||||
dst_row[i0] = (dst_t) result;
|
||||
|
||||
// protect i0 from overflow
|
||||
if (ne0 - i0 <= s0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -110,19 +117,19 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
|
||||
const uint3 ne12,
|
||||
const uint3 ne13,
|
||||
/*const int s0,*/
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
const int s00,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
const int s10,
|
||||
const int s11,
|
||||
const int s12,
|
||||
const int s13,
|
||||
const uint32_t s1,
|
||||
const uint32_t s2,
|
||||
const uint32_t s3,
|
||||
const uint32_t s00,
|
||||
const uint32_t s01,
|
||||
const uint32_t s02,
|
||||
const uint32_t s03,
|
||||
const uint32_t s10,
|
||||
const uint32_t s11,
|
||||
const uint32_t s12,
|
||||
const uint32_t s13,
|
||||
src1_ptrs... src1s) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
const uint32_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
const uint32_t i3 = fastdiv(i, prod_012);
|
||||
const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01);
|
||||
@@ -133,25 +140,25 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
|
||||
return;
|
||||
}
|
||||
|
||||
const int i11 = fastmodulo(i1, ne11);
|
||||
const int i12 = fastmodulo(i2, ne12);
|
||||
const int i13 = fastmodulo(i3, ne13);
|
||||
const uint32_t i11 = fastmodulo(i1, ne11);
|
||||
const uint32_t i12 = fastmodulo(i2, ne12);
|
||||
const uint32_t i13 = fastmodulo(i3, ne13);
|
||||
|
||||
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
|
||||
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
||||
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
||||
const size_t i_src0 = size_t( i3)*s03 + size_t( i2)*s02 + size_t( i1)*s01;
|
||||
const size_t i_src1 = size_t(i13)*s13 + size_t(i12)*s12 + size_t(i11)*s11;
|
||||
const size_t i_dst = size_t( i3)*s3 + size_t( i2)*s2 + size_t( i1)*s1;
|
||||
|
||||
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
const int i10 = fastmodulo(i0, ne10);
|
||||
const uint32_t i10 = fastmodulo(i0, ne10);
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
|
||||
float result = src0_row ? (float) src0_row[size_t(i0)*s00] : 0.0f;
|
||||
if constexpr (sizeof...(src1_ptrs) > 0) {
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + size_t(i10)*s10])));
|
||||
} else {
|
||||
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
|
||||
result = bin_op(result, (float)src1[i_src1 + size_t(i10)*s10]);
|
||||
}
|
||||
|
||||
dst_row[i0] = (dst_t) result;
|
||||
@@ -248,6 +255,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
size_t s02 = nb02 / sizeof(src0_t);
|
||||
size_t s03 = nb03 / sizeof(src0_t);
|
||||
|
||||
GGML_ASSERT(ne0 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(ne1 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(ne2 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(ne3 <= std::numeric_limits<uint32_t>::max());
|
||||
|
||||
//GGML_ASSERT(s0 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(s1 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(s2 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(s3 <= std::numeric_limits<uint32_t>::max());
|
||||
|
||||
GGML_ASSERT(s00 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(s01 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(s02 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(s03 <= std::numeric_limits<uint32_t>::max());
|
||||
|
||||
GGML_ASSERT(s10 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(s11 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(s12 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(s13 <= std::numeric_limits<uint32_t>::max());
|
||||
|
||||
GGML_ASSERT(cne1[0] <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(cne1[1] <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(cne1[2] <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(cne1[3] <= std::numeric_limits<uint32_t>::max());
|
||||
|
||||
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
|
||||
@@ -263,6 +295,8 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
|
||||
|
||||
GGML_ASSERT(ne2 * ne3 <= std::numeric_limits<unsigned int>::max());
|
||||
|
||||
const int block_size = 128;
|
||||
|
||||
int64_t hne0 = std::max(ne0 / 2LL, 1LL);
|
||||
@@ -281,7 +315,13 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
|
||||
|
||||
if (block_nums.z > 65535 || block_nums.y > 65535) {
|
||||
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
|
||||
int64_t block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
|
||||
|
||||
GGML_ASSERT(block_num <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(block_num * block_size <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(ne0 * ne1 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(ne0 * ne1 * ne2 <= std::numeric_limits<uint32_t>::max());
|
||||
|
||||
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
|
||||
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
|
||||
const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
|
||||
@@ -298,6 +338,10 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(int64_t(block_nums.x) * block_dims.x <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(int64_t(block_nums.y) * block_dims.y <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(int64_t(block_nums.z) * block_dims.z <= std::numeric_limits<uint32_t>::max());
|
||||
|
||||
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
|
||||
{
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
|
||||
|
||||
+35
-29
@@ -53,10 +53,10 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
|
||||
const int64_t nmat = ne / (ne00 * ne01);
|
||||
const int64_t n = ne00 * ne01;
|
||||
|
||||
const int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x;
|
||||
const int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
|
||||
const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset
|
||||
const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
|
||||
const int64_t x = (int64_t) blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x;
|
||||
const int64_t y = (int64_t) blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
|
||||
const int64_t tx = (int64_t) blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset
|
||||
const int64_t ty = (int64_t) blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
|
||||
|
||||
__shared__ float tile[2][CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
|
||||
int cur_tile_buf = 0;
|
||||
@@ -197,7 +197,7 @@ static void ggml_cpy_scalar_contiguous_cuda(
|
||||
cudaStream_t stream) {
|
||||
|
||||
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream);
|
||||
ggml_cuda_kernel_launch(cpy_scalar_contiguous<src_t, dst_t>, launch_params, cx, cdst, ne);
|
||||
}
|
||||
@@ -208,6 +208,14 @@ static void ggml_cpy_scalar_cuda(
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
||||
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
|
||||
|
||||
const auto launch_scalar_generic = [&]() {
|
||||
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream);
|
||||
ggml_cuda_kernel_launch(cpy_scalar<cpy_1_scalar<src_t, dst_t>>, launch_params,
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
};
|
||||
|
||||
if (transposed) {
|
||||
GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
|
||||
int64_t ne00n, ne01n, ne02n;
|
||||
@@ -224,20 +232,18 @@ static void ggml_cpy_scalar_cuda(
|
||||
int64_t grid_x = (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
|
||||
int64_t grid_y = (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
|
||||
int64_t grid_z = (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM;
|
||||
GGML_ASSERT(grid_x < UINT_MAX);
|
||||
GGML_ASSERT(grid_y < USHRT_MAX);
|
||||
GGML_ASSERT(grid_z < USHRT_MAX);
|
||||
dim3 dimGrid(grid_x, grid_y, grid_z);
|
||||
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(dimGrid, dimBlock, 0, stream);
|
||||
ggml_cuda_kernel_launch(cpy_scalar_transpose<dst_t>, launch_params,
|
||||
cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
GGML_ASSERT(grid_x <= INT_MAX);
|
||||
if (grid_y > USHRT_MAX || grid_z > USHRT_MAX) {
|
||||
launch_scalar_generic();
|
||||
} else {
|
||||
dim3 dimGrid(grid_x, grid_y, grid_z);
|
||||
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(dimGrid, dimBlock, 0, stream);
|
||||
ggml_cuda_kernel_launch(cpy_scalar_transpose<dst_t>, launch_params,
|
||||
cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
} else {
|
||||
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream);
|
||||
ggml_cuda_kernel_launch(cpy_scalar<cpy_1_scalar<src_t, dst_t>>, launch_params,
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
launch_scalar_generic();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -248,7 +254,7 @@ static void ggml_cpy_f32_q8_0_cuda(
|
||||
|
||||
GGML_ASSERT(ne % QK8_0 == 0);
|
||||
const int64_t num_blocks = ne / QK8_0;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
@@ -259,7 +265,7 @@ static void ggml_cpy_q8_0_f32_cuda(
|
||||
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
|
||||
|
||||
const int64_t num_blocks = ne;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
@@ -271,7 +277,7 @@ static void ggml_cpy_f32_q4_0_cuda(
|
||||
|
||||
GGML_ASSERT(ne % QK4_0 == 0);
|
||||
const int64_t num_blocks = ne / QK4_0;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
@@ -284,7 +290,7 @@ static void ggml_cpy_q4_0_f32_cuda(
|
||||
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
|
||||
cudaStream_t stream) {
|
||||
const int64_t num_blocks = ne;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
@@ -297,7 +303,7 @@ static void ggml_cpy_f32_q4_1_cuda(
|
||||
|
||||
GGML_ASSERT(ne % QK4_1 == 0);
|
||||
const int64_t num_blocks = ne / QK4_1;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
@@ -310,7 +316,7 @@ static void ggml_cpy_q4_1_f32_cuda(
|
||||
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
|
||||
cudaStream_t stream) {
|
||||
const int64_t num_blocks = ne;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
@@ -323,7 +329,7 @@ static void ggml_cpy_f32_q5_0_cuda(
|
||||
|
||||
GGML_ASSERT(ne % QK5_0 == 0);
|
||||
const int64_t num_blocks = ne / QK5_0;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
@@ -336,7 +342,7 @@ static void ggml_cpy_q5_0_f32_cuda(
|
||||
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
|
||||
cudaStream_t stream) {
|
||||
const int64_t num_blocks = ne;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
@@ -349,7 +355,7 @@ static void ggml_cpy_f32_q5_1_cuda(
|
||||
|
||||
GGML_ASSERT(ne % QK5_1 == 0);
|
||||
const int64_t num_blocks = ne / QK5_1;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
@@ -362,7 +368,7 @@ static void ggml_cpy_q5_1_f32_cuda(
|
||||
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
|
||||
cudaStream_t stream) {
|
||||
const int64_t num_blocks = ne;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
@@ -375,7 +381,7 @@ static void ggml_cpy_f32_iq4_nl_cuda(
|
||||
|
||||
GGML_ASSERT(ne % QK4_NL == 0);
|
||||
const int64_t num_blocks = ne / QK4_NL;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
@@ -2,6 +2,28 @@
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
static __global__ void k_compute_out_prod_ptrs(
|
||||
const float * src0_d, const float * src1_d, float * dst_d,
|
||||
const float ** ptrs_a, const float ** ptrs_b, float ** ptrs_c,
|
||||
const int64_t ne2, const int64_t ne3,
|
||||
const int64_t dps2, const int64_t dps3,
|
||||
const size_t s02, const size_t s03,
|
||||
const size_t s12, const size_t s13,
|
||||
const size_t s2, const size_t s3) {
|
||||
const int64_t i2 = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
const int64_t i3 = blockIdx.y*blockDim.y + threadIdx.y;
|
||||
|
||||
if (i2 >= ne2 || i3 >= ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t idx = i3*ne2 + i2;
|
||||
|
||||
ptrs_a[idx] = src0_d + (i3/dps3)*s03 + (i2/dps2)*s02;
|
||||
ptrs_b[idx] = src1_d + i3 *s13 + i2 *s12;
|
||||
ptrs_c[idx] = dst_d + i3 *s3 + i2 *s2;
|
||||
}
|
||||
|
||||
void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
@@ -67,18 +89,39 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
&beta, dst_d + i3 *s3, ldc, s2,
|
||||
batch_count));
|
||||
}
|
||||
} else if (ne2 > 1 || ne3 > 1) {
|
||||
// dps2 > 1 (src0 broadcast along dim 2 with non-uniform stride) or multiple GEMMs
|
||||
// along dim 3: compute per-GEMM pointers on the device and use a single batched GEMM.
|
||||
GGML_ASSERT(ne3 > 0);
|
||||
GGML_ASSERT(ne2 <= (int64_t) std::numeric_limits<int>::max() / ne3);
|
||||
const int batch_count = (int) (ne2 * ne3);
|
||||
|
||||
ggml_cuda_pool_alloc<const float *> ptrs_a(ctx.pool(), batch_count);
|
||||
ggml_cuda_pool_alloc<const float *> ptrs_b(ctx.pool(), batch_count);
|
||||
ggml_cuda_pool_alloc< float *> ptrs_c(ctx.pool(), batch_count);
|
||||
|
||||
const dim3 block_dims(16, 16);
|
||||
const dim3 grid_dims((ne2 + block_dims.x - 1)/block_dims.x, (ne3 + block_dims.y - 1)/block_dims.y);
|
||||
k_compute_out_prod_ptrs<<<grid_dims, block_dims, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ptrs_a.get(), ptrs_b.get(), ptrs_c.get(),
|
||||
ne2, ne3, dps2, dps3, s02, s03, s12, s13, s2, s3);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemmBatched(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, ptrs_a.get(), lda,
|
||||
ptrs_b.get(), ldb,
|
||||
&beta, ptrs_c.get(), ldc,
|
||||
batch_count));
|
||||
} else {
|
||||
// Fallback: ne2 == 1 (no batching benefit) or dps2 > 1 (src0 broadcast along dim 2
|
||||
// with non-uniform stride; would need cublasSgemmBatched with pointer arrays).
|
||||
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
||||
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
|
||||
src1_d + i3 *s13 + i2 *s12, ldb,
|
||||
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
|
||||
}
|
||||
}
|
||||
// ne2 == 1 && ne3 == 1: single GEMM
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, src0_d, lda,
|
||||
src1_d, ldb,
|
||||
&beta, dst_d, ldc));
|
||||
}
|
||||
}
|
||||
|
||||
Vendored
+1
@@ -48,6 +48,7 @@
|
||||
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
||||
#define cublasSetStream hipblasSetStream
|
||||
#define cublasSgemm hipblasSgemm
|
||||
#define cublasSgemmBatched hipblasSgemmBatched
|
||||
#define cublasSgemmStridedBatched hipblasSgemmStridedBatched
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define cublasOperation_t hipblasOperation_t
|
||||
|
||||
Vendored
+1
@@ -32,6 +32,7 @@
|
||||
#define cublasSetMathMode mublasSetMathMode
|
||||
#define cublasSetStream mublasSetStream
|
||||
#define cublasSgemm mublasSgemm
|
||||
#define cublasSgemmBatched mublasSgemmBatched
|
||||
#define cublasSgemmStridedBatched mublasSgemmStridedBatched
|
||||
#define cublasStatus_t mublasStatus_t
|
||||
#define cublasOperation_t mublasOperation_t
|
||||
|
||||
@@ -850,6 +850,7 @@ struct ggml_backend_opencl_context {
|
||||
ref_count--;
|
||||
if (ref_count == 0) {
|
||||
#ifdef GGML_OPENCL_PROFILING
|
||||
flush_profiling_batch();
|
||||
write_profiling_info();
|
||||
profiling_results.clear();
|
||||
#endif
|
||||
@@ -10152,14 +10153,8 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
const int ne00 = src0 ? src0->ne[0] : 0;
|
||||
const int ne01 = src0 ? src0->ne[1] : 0;
|
||||
const int ne02 = src0 ? src0->ne[2] : 0;
|
||||
const int ne03 = src0 ? src0->ne[3] : 0;
|
||||
|
||||
const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
|
||||
const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
|
||||
const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
|
||||
GGML_TENSOR_LOCALS(int, ne0, src0, ne);
|
||||
GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);
|
||||
|
||||
const int nth = MIN(64, ne00);
|
||||
|
||||
@@ -10173,11 +10168,12 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth, NULL));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &eps));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float)*nth, NULL));
|
||||
|
||||
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
||||
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
||||
|
||||
@@ -24,6 +24,7 @@ kernel void kernel_norm(
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
@@ -43,7 +44,8 @@ kernel void kernel_norm(
|
||||
// parallel sum
|
||||
sum[get_local_id(0)] = 0.0f;
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
sum[get_local_id(0)] += x[i00];
|
||||
// this kernel handles float, nb00/4 translates byte offset to element offset
|
||||
sum[get_local_id(0)] += x[i00*nb00/4];
|
||||
}
|
||||
// reduce
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -60,7 +62,8 @@ kernel void kernel_norm(
|
||||
global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
sum[get_local_id(0)] = 0.0f;
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
y[i00] = x[i00] - mean;
|
||||
// this kernel handles float, nb00/4 translates byte offset to element offset
|
||||
y[i00] = x[i00*nb00/4] - mean;
|
||||
sum[get_local_id(0)] += y[i00] * y[i00];
|
||||
}
|
||||
|
||||
|
||||
@@ -103,8 +103,8 @@ void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
// allocate packed arrays: A_packed (k x m), B_packed (k x n)
|
||||
ggml_sycl_pool_alloc<float> A_packed_alloc(ctx.pool());
|
||||
ggml_sycl_pool_alloc<float> B_packed_alloc(ctx.pool());
|
||||
A_packed_alloc.alloc((size_t) knl_n_total * patch_total * sizeof(float));
|
||||
B_packed_alloc.alloc((size_t) knl_n_total * oc * sizeof(float));
|
||||
A_packed_alloc.alloc((size_t) knl_n_total * patch_total);
|
||||
B_packed_alloc.alloc((size_t) knl_n_total * oc);
|
||||
|
||||
float * A_packed = A_packed_alloc.get();
|
||||
float * B_packed = B_packed_alloc.get();
|
||||
@@ -115,10 +115,16 @@ void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
// Combined kernel: im2col -> pack A, and pack B simultaneously
|
||||
const char * src1_base = (const char *) src1->data;
|
||||
const char * src0_base = (const char *) src0->data;
|
||||
const int64_t src1_nb0 = src1->nb[0];
|
||||
const int64_t src1_nb1 = src1->nb[1];
|
||||
const int64_t src1_nb2 = src1->nb[2];
|
||||
const int64_t src1_nb3 = src1->nb[3];
|
||||
const int64_t src1_w = src1->ne[0];
|
||||
const int64_t src1_h = src1->ne[1];
|
||||
const int64_t src1_d = src1->ne[2];
|
||||
|
||||
const bool src0_is_f32 = (src0->type == GGML_TYPE_F32);
|
||||
|
||||
// Compute correct strides for src0 as (knl_n_total, oc) matrix
|
||||
const int64_t src0_packed_nb0 = kernel_type_size;
|
||||
@@ -165,7 +171,7 @@ void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const int64_t sz = dst_z * s2 + kz * d2 - p2;
|
||||
|
||||
float val = 0.0f;
|
||||
if (sx >= 0 && sx < src1->ne[0] && sy >= 0 && sy < src1->ne[1] && sz >= 0 && sz < src1->ne[2]) {
|
||||
if (sx >= 0 && sx < src1_w && sy >= 0 && sy < src1_h && sz >= 0 && sz < src1_d) {
|
||||
const int64_t channel_idx = batch_idx * c + ic;
|
||||
const char * ptr = src1_base + sx * src1_nb0 + sy * src1_nb1 + sz * src1_nb2 + channel_idx * src1_nb3;
|
||||
val = *(const float *) ptr;
|
||||
@@ -184,9 +190,9 @@ void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
const int64_t row = t % k;
|
||||
const int64_t col = t / k;
|
||||
const char * src_ptr = (const char *) src0->data + row * src0_packed_nb0 + col * src0_packed_nb1;
|
||||
const char * src_ptr = src0_base + row * src0_packed_nb0 + col * src0_packed_nb1;
|
||||
float v;
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
if (src0_is_f32) {
|
||||
v = *(const float *) src_ptr;
|
||||
} else {
|
||||
v = sycl::vec<sycl::half, 1>(*(const sycl::half *) src_ptr).convert<float, sycl::rounding_mode::automatic>()[0];
|
||||
|
||||
@@ -5859,6 +5859,250 @@ static ggml_backend_dev_t ggml_backend_sycl_reg_get_device(ggml_backend_reg_t re
|
||||
return ctx->devices[index];
|
||||
}
|
||||
|
||||
// ==========================================================================
|
||||
// Tensor parallelism (--split-mode tensor) for the SYCL backend.
|
||||
//
|
||||
// The meta-backend invokes these three entry points via get_proc_address:
|
||||
// * ggml_backend_sycl_comm_init - one-time per-graph setup
|
||||
// * ggml_backend_sycl_comm_allreduce_tensor - per-allreduce step
|
||||
// * ggml_backend_sycl_comm_free - tear-down
|
||||
//
|
||||
// For N=2 (dual-GPU), this is a degenerate ring allreduce with dual paths
|
||||
// chosen by tensor size:
|
||||
//
|
||||
// * Small (nelem < 32K): FP32 direct memcpy + per-device ADD
|
||||
// kernel. The kernel depends_on() its corresponding memcpy event
|
||||
// so it doesn't read partial data. Both devices run in parallel.
|
||||
//
|
||||
// * Large (nelem >= 32K): BF16-compressed. Each device compresses
|
||||
// its FP32 partial to BF16 locally, cross-device memcpys
|
||||
// to the peer (half the PCI bandwidth), where it is decompressed
|
||||
// and added into the local FP32 partial. 6 SYCL submissions per
|
||||
// allreduce (2 compress + 2 memcpy + 2 decompress-add) vs the
|
||||
// 4 for the small path, but the bandwidth saving > 6 GB/s PCIe x 2
|
||||
// dominates for larger tensors.
|
||||
//
|
||||
// Storage: A persistent uint8_t buffer per device, sized to
|
||||
// 4 * nelem bytes. Both paths reinterpret the same bytes (small path
|
||||
// as nelem floats; large path as outbox + inbox = 2*nelem uint16_t
|
||||
// each, using the full 4*nelem byte budget either way). Single
|
||||
// alloc+free per device keeps the SYCL pool's strict-LIFO invariant
|
||||
// trivial.
|
||||
//
|
||||
// For non-(N=2 FP32 contiguous) cases, comm_init or comm_allreduce_tensor
|
||||
// returns null/false, causing the meta-backend to use its generic
|
||||
// butterfly all-reduce fallback.
|
||||
// ==========================================================================
|
||||
|
||||
struct ggml_backend_sycl_comm_context {
|
||||
std::vector<ggml_backend_t> backends;
|
||||
// ONE persistent per-device byte buffer, 4*nelem bytes. Both the
|
||||
// FP32 small-tensor path and the BF16 large-tensor path share it
|
||||
// by reinterpreting.
|
||||
std::unique_ptr<ggml_sycl_pool_alloc<uint8_t>> buf0;
|
||||
std::unique_ptr<ggml_sycl_pool_alloc<uint8_t>> buf1;
|
||||
int64_t buf_nelem = 0;
|
||||
};
|
||||
|
||||
void * ggml_backend_sycl_comm_init(ggml_backend_t * backends, size_t n_backends) try {
|
||||
for (size_t i = 0; i < n_backends; ++i) {
|
||||
if (!ggml_backend_is_sycl(backends[i])) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// Initial version: N=2 only. For N!=2, returning null makes the
|
||||
// meta-backend skip this backend-specific allreduce entirely.
|
||||
if (n_backends != 2) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto * ctx = new ggml_backend_sycl_comm_context;
|
||||
ctx->backends.assign(backends, backends + n_backends);
|
||||
auto * sctx0 = (ggml_backend_sycl_context *) backends[0]->context;
|
||||
auto * sctx1 = (ggml_backend_sycl_context *) backends[1]->context;
|
||||
ctx->buf0 = std::make_unique<ggml_sycl_pool_alloc<uint8_t>>(sctx0->pool());
|
||||
ctx->buf1 = std::make_unique<ggml_sycl_pool_alloc<uint8_t>>(sctx1->pool());
|
||||
return ctx;
|
||||
}
|
||||
catch (const sycl::exception &) { return nullptr; }
|
||||
catch (...) { return nullptr; }
|
||||
|
||||
void ggml_backend_sycl_comm_free(void * comm_ctx_v) {
|
||||
auto * comm_ctx = static_cast<ggml_backend_sycl_comm_context *>(comm_ctx_v);
|
||||
if (comm_ctx == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Sync both per-device queues so the pool_alloc destructors don't
|
||||
// return memory still in use by the last kernel.
|
||||
if (comm_ctx->backends.size() == 2) {
|
||||
auto * sctx0 = (ggml_backend_sycl_context *) comm_ctx->backends[0]->context;
|
||||
auto * sctx1 = (ggml_backend_sycl_context *) comm_ctx->backends[1]->context;
|
||||
try {
|
||||
sctx0->stream()->wait();
|
||||
sctx1->stream()->wait();
|
||||
} catch (...) { /* best effort during shutdown */ }
|
||||
}
|
||||
|
||||
delete comm_ctx;
|
||||
}
|
||||
|
||||
bool ggml_backend_sycl_comm_allreduce_tensor(void * comm_ctx_v, struct ggml_tensor ** tensors) try {
|
||||
if (comm_ctx_v == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto * comm_ctx = static_cast<ggml_backend_sycl_comm_context *>(comm_ctx_v);
|
||||
const size_t n_backends = comm_ctx->backends.size();
|
||||
|
||||
// Fast path: N=2, F32/F16, contiguous, matching shapes.
|
||||
if (n_backends != 2) {
|
||||
return false;
|
||||
}
|
||||
// Accept F32 or F16 inputs natively (types must match). F16 takes the
|
||||
// direct 2-byte memcpy + add path below; other types return false so the
|
||||
// meta-backend uses its generic all-reduce.
|
||||
if (tensors[0]->type != tensors[1]->type) {
|
||||
return false;
|
||||
}
|
||||
if (tensors[0]->type != GGML_TYPE_F32 && tensors[0]->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
if (!ggml_is_contiguous(tensors[0]) || !ggml_is_contiguous(tensors[1])) {
|
||||
return false;
|
||||
}
|
||||
if (ggml_nelements(tensors[0]) != ggml_nelements(tensors[1])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int64_t nelem = ggml_nelements(tensors[0]);
|
||||
const size_t nbytes = ggml_nbytes(tensors[0]);
|
||||
if (nelem == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
auto * ctx0 = (ggml_backend_sycl_context *) comm_ctx->backends[0]->context;
|
||||
auto * ctx1 = (ggml_backend_sycl_context *) comm_ctx->backends[1]->context;
|
||||
queue_ptr q0 = ctx0->stream();
|
||||
queue_ptr q1 = ctx1->stream();
|
||||
|
||||
// Grow per-device byte buffers if needed (4 * nelem bytes each).
|
||||
if (comm_ctx->buf_nelem < nelem) {
|
||||
comm_ctx->buf0->realloc(nelem * 4);
|
||||
comm_ctx->buf1->realloc(nelem * 4);
|
||||
comm_ctx->buf_nelem = nelem;
|
||||
}
|
||||
uint8_t * buf0 = comm_ctx->buf0->get();
|
||||
uint8_t * buf1 = comm_ctx->buf1->get();
|
||||
|
||||
// F16 native path: direct 2-byte cross-device copy + add, skipping the
|
||||
// F32 round-trip the meta-backend fallback would force. Cross-device copies
|
||||
// go through dev2dev_memcpy because the two devices are in separate SYCL
|
||||
// contexts (a raw peer-USM q->memcpy would be a silent no-op).
|
||||
if (tensors[0]->type == GGML_TYPE_F16) {
|
||||
sycl::half * f16_out0 = (sycl::half *) tensors[0]->data;
|
||||
sycl::half * f16_out1 = (sycl::half *) tensors[1]->data;
|
||||
sycl::half * f16_tmp0 = (sycl::half *) buf0;
|
||||
sycl::half * f16_tmp1 = (sycl::half *) buf1;
|
||||
|
||||
q0->wait();
|
||||
q1->wait();
|
||||
dev2dev_memcpy(ctx0->device, *q0, ctx1->device, *q1, f16_tmp0, tensors[1]->data, nbytes);
|
||||
dev2dev_memcpy(ctx1->device, *q1, ctx0->device, *q0, f16_tmp1, tensors[0]->data, nbytes);
|
||||
|
||||
q0->submit([&](sycl::handler & h) {
|
||||
h.parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) {
|
||||
f16_out0[i] = (sycl::half) ((float) f16_out0[i] + (float) f16_tmp0[i]);
|
||||
});
|
||||
});
|
||||
q1->submit([&](sycl::handler & h) {
|
||||
h.parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) {
|
||||
f16_out1[i] = (sycl::half) ((float) f16_out1[i] + (float) f16_tmp1[i]);
|
||||
});
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
float * out0 = (float *) tensors[0]->data;
|
||||
float * out1 = (float *) tensors[1]->data;
|
||||
|
||||
// BF16 threshold: above this, the PCIe savings from halving the
|
||||
// cross-device bytes outweigh the 2 extra compress kernels.
|
||||
// Below: stay on the FP32 fast path. Threshold mirrors the CUDA
|
||||
// NCCL allreduce pattern for n_backends=2.
|
||||
static constexpr int64_t BF16_THRESHOLD = 32768;
|
||||
|
||||
if (nelem < BF16_THRESHOLD) {
|
||||
// FP32 small path: 4 SYCL submissions per allreduce.
|
||||
float * tmp0 = (float *) buf0;
|
||||
float * tmp1 = (float *) buf1;
|
||||
|
||||
// COMM-D2D-FIX: the two devices are in SEPARATE SYCL contexts, so a raw
|
||||
// q->memcpy of a peer USM pointer is a silent no-op. Route cross-device
|
||||
// copies through dev2dev_memcpy (L0 direct copy / host staging). It is
|
||||
// synchronous, so wait for the local partials to be produced first.
|
||||
q0->wait();
|
||||
q1->wait();
|
||||
dev2dev_memcpy(ctx0->device, *q0, ctx1->device, *q1, tmp0, tensors[1]->data, nbytes);
|
||||
dev2dev_memcpy(ctx1->device, *q1, ctx0->device, *q0, tmp1, tensors[0]->data, nbytes);
|
||||
|
||||
q0->submit([&](sycl::handler & h) {
|
||||
h.parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) {
|
||||
out0[i] += tmp0[i];
|
||||
});
|
||||
});
|
||||
q1->submit([&](sycl::handler & h) {
|
||||
h.parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) {
|
||||
out1[i] += tmp1[i];
|
||||
});
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
// BF16 large path: 6 SYCL submissions per allreduce, but the
|
||||
// cross-device memcpy is HALF the bytes. Pure bit-shift
|
||||
// conversion (no rounding) — matches ggml's truncating fp32->bf16.
|
||||
uint16_t * outbox0 = (uint16_t *) buf0;
|
||||
uint16_t * inbox0 = outbox0 + nelem;
|
||||
uint16_t * outbox1 = (uint16_t *) buf1;
|
||||
uint16_t * inbox1 = outbox1 + nelem;
|
||||
|
||||
// Phase A: compress each device's local partial in parallel.
|
||||
sycl::event c0 = q0->parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) {
|
||||
outbox0[i] = (uint16_t) (sycl::bit_cast<uint32_t>(out0[i]) >> 16);
|
||||
});
|
||||
|
||||
sycl::event c1 = q1->parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) {
|
||||
outbox1[i] = (uint16_t) (sycl::bit_cast<uint32_t>(out1[i]) >> 16);
|
||||
});
|
||||
|
||||
// Phase B: COMM-D2D-FIX-BF16 cross-device copy of compressed bytes via
|
||||
// dev2dev_memcpy (separate SYCL contexts; sync copy after compress).
|
||||
const size_t bf16_bytes = nelem * sizeof(uint16_t);
|
||||
c0.wait();
|
||||
c1.wait();
|
||||
dev2dev_memcpy(ctx0->device, *q0, ctx1->device, *q1, inbox0, outbox1, bf16_bytes);
|
||||
dev2dev_memcpy(ctx1->device, *q1, ctx0->device, *q0, inbox1, outbox0, bf16_bytes);
|
||||
|
||||
// Phase C: decompress + add into local FP32 partial.
|
||||
q0->submit([&](sycl::handler & h) {
|
||||
h.parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) {
|
||||
out0[i] += sycl::bit_cast<float>(((uint32_t) inbox0[i]) << 16);
|
||||
});
|
||||
});
|
||||
|
||||
q1->submit([&](sycl::handler & h) {
|
||||
h.parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) {
|
||||
out1[i] += sycl::bit_cast<float>(((uint32_t) inbox1[i]) << 16);
|
||||
});
|
||||
});
|
||||
|
||||
return true;
|
||||
}
|
||||
catch (const sycl::exception &) { return false; }
|
||||
catch (...) { return false; }
|
||||
|
||||
static void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, const char *name) {
|
||||
GGML_UNUSED(reg);
|
||||
|
||||
@@ -5866,6 +6110,17 @@ static void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, cons
|
||||
return (void *)ggml_backend_sycl_split_buffer_type;
|
||||
}
|
||||
|
||||
// Tensor parallelism (--split-mode tensor) entry points.
|
||||
if (strcmp(name, "ggml_backend_comm_init") == 0) {
|
||||
return (void *)ggml_backend_sycl_comm_init;
|
||||
}
|
||||
if (strcmp(name, "ggml_backend_comm_free") == 0) {
|
||||
return (void *)ggml_backend_sycl_comm_free;
|
||||
}
|
||||
if (strcmp(name, "ggml_backend_comm_allreduce_tensor") == 0) {
|
||||
return (void *)ggml_backend_sycl_comm_allreduce_tensor;
|
||||
}
|
||||
|
||||
// SYCL doesn't support registering host memory, left here for reference
|
||||
// "ggml_backend_register_host_buffer"
|
||||
// "ggml_backend_unregister_host_buffer"
|
||||
|
||||
@@ -126,7 +126,7 @@ static void soft_max_f32(const float * x,
|
||||
break;
|
||||
}
|
||||
|
||||
const float val = sycl::native::exp(vals[col] - max_val);
|
||||
const float val = sycl::native::exp(sycl::max(vals[col] - max_val, -80.0f));
|
||||
tmp += val;
|
||||
vals[col] = val;
|
||||
}
|
||||
@@ -154,7 +154,7 @@ static void soft_max_f32(const float * x,
|
||||
tmp = warp_reduce_sum<WARP_SIZE>(tmp);
|
||||
}
|
||||
if (sinks) {
|
||||
tmp += sycl::native::exp(sinks[i02] - max_val);
|
||||
tmp += sycl::native::exp(sycl::max(sinks[i02] - max_val, -80.0f));
|
||||
}
|
||||
const float inv_sum = 1.0f / tmp;
|
||||
|
||||
|
||||
@@ -308,6 +308,7 @@ enum vk_device_architecture {
|
||||
AMD_RDNA1,
|
||||
AMD_RDNA2,
|
||||
AMD_RDNA3,
|
||||
INTEL_XE1,
|
||||
INTEL_XE2,
|
||||
NVIDIA_PRE_TURING,
|
||||
NVIDIA_TURING,
|
||||
@@ -365,21 +366,26 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
|
||||
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
|
||||
|
||||
bool subgroup_size_control = false;
|
||||
bool integer_dot_product = false;
|
||||
|
||||
for (const auto& properties : ext_props) {
|
||||
if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
|
||||
subgroup_size_control = true;
|
||||
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0) {
|
||||
integer_dot_product = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!subgroup_size_control) {
|
||||
if (!subgroup_size_control || !integer_dot_product) {
|
||||
return vk_device_architecture::OTHER;
|
||||
}
|
||||
|
||||
vk::PhysicalDeviceProperties2 props2;
|
||||
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
|
||||
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props;
|
||||
|
||||
props2.pNext = &subgroup_size_control_props;
|
||||
subgroup_size_control_props.pNext = &integer_dot_props;
|
||||
device.getProperties2(&props2);
|
||||
|
||||
if (subgroup_size_control_props.minSubgroupSize == 16) {
|
||||
@@ -388,6 +394,9 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
|
||||
// https://www.intel.com/content/www/us/en/content-details/824434/2024-intel-tech-tour-xe2-and-lunar-lake-s-gpu.html
|
||||
// https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
|
||||
return vk_device_architecture::INTEL_XE2;
|
||||
} else if (subgroup_size_control_props.minSubgroupSize == 8 &&
|
||||
integer_dot_product && integer_dot_props.integerDotProduct4x8BitPackedSignedAccelerated) {
|
||||
return vk_device_architecture::INTEL_XE1;
|
||||
}
|
||||
} else if (props.vendorID == VK_VENDOR_ID_NVIDIA) {
|
||||
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
|
||||
@@ -3837,7 +3846,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
l_warptile = { 256, 128, 128, 16, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
l_warptile_mmq = l_warptile_mmq_int = { 256, 128, 128, 32, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
l_warptile_mmq_int_k = { 256, 128, 128, 32, subgroup_size_16, 64, 1, 4, 2, 1, subgroup_size_16 };
|
||||
} else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support && device->architecture == INTEL_XE2) {
|
||||
} else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support) {
|
||||
// Xe2/Xe3 with coopmat enabled - warptile performance tuning
|
||||
l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
l_warptile_mmq = { 512, 128, 128, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
@@ -4710,7 +4719,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
}
|
||||
uint32_t rm_iq = 2 * rm_kq;
|
||||
|
||||
const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;
|
||||
const bool use_subgroups = device->subgroup_arithmetic;
|
||||
// Ensure a subgroup size >= 16 is available
|
||||
const bool use_subgroups16 = use_subgroups && subgroup_min_size_16;
|
||||
|
||||
@@ -6361,9 +6370,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
break;
|
||||
case VK_VENDOR_ID_INTEL: {
|
||||
// Current Windows driver does not expose BF16 support.
|
||||
// We only want to use l_warptile if coopmat is available and is Xe2+
|
||||
const bool xe2_with_coopmat = device->coopmat_support && device->architecture == INTEL_XE2;
|
||||
const bool use_l_warptile = (i == GGML_TYPE_BF16) ? (device->coopmat_bf16_support && xe2_with_coopmat) : xe2_with_coopmat;
|
||||
// We only want to use l_warptile if coopmat is available
|
||||
const bool use_l_warptile = (i == GGML_TYPE_BF16) ? (device->coopmat_bf16_support && device->coopmat_support) : device->coopmat_support;
|
||||
device->mul_mat_l[i] = use_l_warptile;
|
||||
device->mul_mat_id_l[i] = use_l_warptile;
|
||||
device->mul_mat_m[i] = true;
|
||||
@@ -17890,9 +17898,9 @@ static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev) {
|
||||
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
|
||||
switch (props.vendorID) {
|
||||
case VK_VENDOR_ID_INTEL:
|
||||
// Only allowing Xe2 GPU at the moment since Xe2 GPU can gain significant performance boost,
|
||||
// while some older hardware (ex. Arc A770) has performance regressions
|
||||
return arch == vk_device_architecture::INTEL_XE2;
|
||||
// Only allowing Xe2/Xe3 GPU and integrated Xe GPUs at the moment since older hardware (ex. Arc A770) has performance regressions.
|
||||
return (arch == vk_device_architecture::INTEL_XE2) ||
|
||||
(arch == vk_device_architecture::INTEL_XE1 && props.deviceType == vk::PhysicalDeviceType::eIntegratedGpu && driver_props.driverID == vk::DriverId::eIntelProprietaryWindows);
|
||||
case VK_VENDOR_ID_AMD:
|
||||
if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
|
||||
// Workaround for AMD proprietary driver reporting support on all GPUs
|
||||
@@ -17940,6 +17948,8 @@ static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev)
|
||||
case 0xE20B: // B580
|
||||
case 0xE211: // Pro B60
|
||||
return 20;
|
||||
case 0xB080: // PTL Xe3 LPG 2x6 (12 subslices)
|
||||
return 12;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -158,7 +158,7 @@ const uint32_t Csh_stride = BS_NPQ;
|
||||
#ifdef COOPMAT
|
||||
const uint32_t Csh_len = BS_K * Csh_stride;
|
||||
#else
|
||||
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1;
|
||||
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 8; // 8 to workaround compiler bug
|
||||
#endif
|
||||
shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ
|
||||
#endif
|
||||
|
||||
@@ -144,7 +144,7 @@ const uint32_t Csh_stride = BS_NPQ;
|
||||
#ifdef COOPMAT
|
||||
const uint32_t Csh_len = BS_K * Csh_stride;
|
||||
#else
|
||||
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1;
|
||||
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 8; // 8 to workaround compiler bug
|
||||
#endif
|
||||
shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ
|
||||
#endif
|
||||
|
||||
@@ -28,13 +28,10 @@ vec2 cache_b_ds;
|
||||
|
||||
#include "mul_mat_vecq_funcs.glsl"
|
||||
|
||||
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
|
||||
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint col, const uint b_qs_idx) {
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;
|
||||
|
||||
// Preload data_b block
|
||||
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
|
||||
const uint b_qs_idx = tid % (32 / K_PER_ITER);
|
||||
const uint b_block_idx_outer = b_block_idx / 4;
|
||||
const uint b_block_idx_inner = b_block_idx % 4;
|
||||
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
|
||||
@@ -91,35 +88,35 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
}
|
||||
}
|
||||
|
||||
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
|
||||
if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
|
||||
const uint col_stride = K_PER_ITER * BLOCK_SIZE;
|
||||
uint num_iters = p.ncols / col_stride;
|
||||
if (num_iters * col_stride + K_PER_ITER * tid < p.ncols) {
|
||||
num_iters++;
|
||||
}
|
||||
int unroll_count = 4;
|
||||
uint unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||
|
||||
uint i = 0;
|
||||
while (i < unrolled_iters) {
|
||||
const uint b_qs_idx = tid % (32 / K_PER_ITER);
|
||||
uint col = tid * K_PER_ITER;
|
||||
while (num_iters >= 4) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
|
||||
i++;
|
||||
[[unroll]] for (uint k = 0; k < 4; ++k) {
|
||||
iter(temp, first_row, num_rows, col, b_qs_idx);
|
||||
col += col_stride;
|
||||
}
|
||||
|
||||
num_iters -= 4;
|
||||
}
|
||||
|
||||
unroll_count = 2;
|
||||
unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||
|
||||
while (i < unrolled_iters) {
|
||||
if (num_iters >= 2) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
|
||||
i++;
|
||||
}
|
||||
iter(temp, first_row, num_rows, col, b_qs_idx);
|
||||
col += col_stride;
|
||||
iter(temp, first_row, num_rows, col, b_qs_idx);
|
||||
col += col_stride;
|
||||
num_iters -= 2;
|
||||
}
|
||||
while (i < num_iters) {
|
||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
|
||||
i++;
|
||||
|
||||
if (num_iters > 0) {
|
||||
iter(temp, first_row, num_rows, col, b_qs_idx);
|
||||
}
|
||||
|
||||
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||
|
||||
@@ -1 +1 @@
|
||||
707321c4cf6d21cb4bc831aa8b687dbf01a521ce
|
||||
eced84c86f8b012c752c016f7fe789adea168e1e
|
||||
|
||||
@@ -700,6 +700,7 @@ const char * llm_type_name(llm_type type) {
|
||||
case LLM_TYPE_160M: return "160M";
|
||||
case LLM_TYPE_190M: return "190M";
|
||||
case LLM_TYPE_220M: return "220M";
|
||||
case LLM_TYPE_230M: return "230M";
|
||||
case LLM_TYPE_250M: return "250M";
|
||||
case LLM_TYPE_256M: return "256M";
|
||||
case LLM_TYPE_270M: return "270M";
|
||||
|
||||
@@ -36,6 +36,7 @@ enum llm_type {
|
||||
LLM_TYPE_160M,
|
||||
LLM_TYPE_190M,
|
||||
LLM_TYPE_220M,
|
||||
LLM_TYPE_230M,
|
||||
LLM_TYPE_250M,
|
||||
LLM_TYPE_256M,
|
||||
LLM_TYPE_270M,
|
||||
|
||||
+1
-1
@@ -847,7 +847,7 @@ static void init_quantize_state_counters(quantize_state_impl & qs, std::vector<t
|
||||
qs.has_tied_embeddings = false;
|
||||
}
|
||||
}
|
||||
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer();
|
||||
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer_all;
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@@ -13,6 +13,7 @@ void llama_model_lfm2::load_arch_hparams(llama_model_loader & ml) {
|
||||
hparams.n_layer_dense_lead = hparams.n_layer();
|
||||
|
||||
switch (hparams.n_ff()) {
|
||||
case 2560: type = LLM_TYPE_230M; break;
|
||||
case 4608: type = LLM_TYPE_350M; break;
|
||||
case 6912: type = LLM_TYPE_700M; break;
|
||||
case 8192: type = LLM_TYPE_1_2B; break;
|
||||
|
||||
@@ -169,7 +169,6 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp,
|
||||
GGML_ASSERT(ubatch.equal_seqs());
|
||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||
GGML_ASSERT(d_inner % n_head == 0);
|
||||
GGML_ASSERT(d_inner % d_state == 0);
|
||||
GGML_ASSERT(d_inner % n_group == 0);
|
||||
|
||||
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
||||
|
||||
@@ -39,10 +39,11 @@ void llama_model_mamba2::load_arch_tensors(llama_model_loader &) {
|
||||
const int64_t d_inner = hparams.ssm_d_inner;
|
||||
const int64_t d_state = hparams.ssm_d_state;
|
||||
const int64_t n_group = hparams.ssm_n_group;
|
||||
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head;
|
||||
const int64_t dt_rank = hparams.ssm_dt_rank;
|
||||
|
||||
const int64_t conv_dim = d_inner + 2 * n_group * d_state;
|
||||
const int64_t d_in_proj = d_inner + conv_dim + dt_rank;
|
||||
|
||||
// only an expansion factor of 2 is supported for now
|
||||
GGML_ASSERT(2 * n_embd == d_inner);
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
@@ -68,11 +69,11 @@ void llama_model_mamba2::load_arch_tensors(llama_model_loader &) {
|
||||
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0);
|
||||
layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, 0);
|
||||
|
||||
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0);
|
||||
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {dt_rank}, 0);
|
||||
|
||||
// no "weight" suffix for these
|
||||
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0);
|
||||
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0);
|
||||
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, dt_rank}, 0);
|
||||
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, dt_rank}, 0);
|
||||
|
||||
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0);
|
||||
|
||||
|
||||
@@ -7973,6 +7973,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
test_cases.emplace_back(new test_conv_2d({ 256, 256, 192, 1 }, { 3, 3, 192, 96 }, kernel_type, 1, 1, 1, 1, 1, 1, false));
|
||||
}
|
||||
|
||||
// sycl backend will limit task global_range < MAX_INT
|
||||
// test cases for 2D im2col with large input W and H (occurs in stable-diffusion)
|
||||
@@ -8176,6 +8179,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 4, 1, 1}, {-1,-1,-1,-1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 1, 4, 1}, {-1,-1,-1,-1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {-1,-1,-1,-1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {2, 2097121, 1, 1}, {-1,-1,-1,-1}, {1, 0, 2, 3}));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {2, 2, 524281, 1}, {-1,-1,-1,-1}, {1, 0, 2, 3}));
|
||||
|
||||
// CPY - different src/dst shapes (reshaping via CPY)
|
||||
// Use permutations of {3, 5, 7, 32}. Total elements: 3*5*7*32 = 3360.
|
||||
@@ -8670,6 +8675,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
256, 16, 16, {ne2, 1}, {1, 1}));
|
||||
}
|
||||
|
||||
// nr2 sweep to cover the cublasSgemmBatched pointer-array path (dps2 > 1)
|
||||
for (int64_t nr2 : {8, 16, 32}) {
|
||||
test_cases.emplace_back(new test_out_prod(GGML_TYPE_F32, GGML_TYPE_F32,
|
||||
256, 16, 16, {1, 1}, {nr2, 1}));
|
||||
}
|
||||
|
||||
// add_id
|
||||
for (ggml_type type_a : {GGML_TYPE_F32}) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
|
||||
+59
-26
@@ -102,21 +102,34 @@ static float dot_product_error(const ggml_type_traits * qfns, const ggml_type_tr
|
||||
return fabsf(result - dot_ref) / test_size;
|
||||
}
|
||||
|
||||
int main(int argc, char * argv[]) {
|
||||
bool verbose = false;
|
||||
const size_t test_size = 32 * 128;
|
||||
static int test_vec_dot_f32(bool verbose) {
|
||||
const auto * f32 = ggml_get_type_traits_cpu(GGML_TYPE_F32);
|
||||
int num_failed = 0;
|
||||
for (int n : {1, 2, 3, 5, 7, 8, 15, 16, 17, 31, 33, 63, 67, 127, 129, 193, 255, 1023}) {
|
||||
std::vector<float> a(n);
|
||||
std::vector<float> b(n);
|
||||
generate_data(0.0, n, a.data());
|
||||
generate_data(1.0, n, b.data());
|
||||
|
||||
std::string arg;
|
||||
for (int i = 1; i < argc; i++) {
|
||||
arg = argv[i];
|
||||
float result = 0.0f;
|
||||
f32->vec_dot(n, &result, 0, a.data(), 0, b.data(), 0, 1);
|
||||
const float ref = dot_product(a.data(), b.data(), n);
|
||||
const float error = fabsf(result - ref) / n;
|
||||
|
||||
if (arg == "-v") {
|
||||
verbose = true;
|
||||
} else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
return 1;
|
||||
const bool failed = !(error < MAX_QUANTIZATION_REFERENCE_ERROR);
|
||||
num_failed += failed;
|
||||
if (failed || verbose) {
|
||||
printf(" f32 vec_dot n=%4d: %s (ref=%f got=%f err=%f)\n",
|
||||
n, RESULT_STR[failed], ref, result, error);
|
||||
}
|
||||
}
|
||||
return num_failed;
|
||||
}
|
||||
|
||||
static int test_vec_dot_q(bool verbose) {
|
||||
int num_failed = 0;
|
||||
|
||||
const size_t test_size = 32 * 128;
|
||||
|
||||
std::vector<float> test_data(test_size);
|
||||
std::vector<float> test_data2(test_size);
|
||||
@@ -124,11 +137,6 @@ int main(int argc, char * argv[]) {
|
||||
generate_data(0.0, test_data.size(), test_data.data());
|
||||
generate_data(1.0, test_data2.size(), test_data2.data());
|
||||
|
||||
ggml_cpu_init();
|
||||
|
||||
int num_failed = 0;
|
||||
bool failed = false;
|
||||
|
||||
for (int i = 0; i < GGML_TYPE_COUNT; i++) {
|
||||
ggml_type type = (ggml_type) i;
|
||||
const auto * qfns = ggml_get_type_traits(type);
|
||||
@@ -156,7 +164,7 @@ int main(int argc, char * argv[]) {
|
||||
type == GGML_TYPE_IQ3_S ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
|
||||
type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS :
|
||||
type == GGML_TYPE_NVFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_FP4 : MAX_QUANTIZATION_TOTAL_ERROR;
|
||||
failed = !(total_error < max_quantization_error);
|
||||
bool failed = !(total_error < max_quantization_error);
|
||||
num_failed += failed;
|
||||
if (failed || verbose) {
|
||||
printf("%5s absolute quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error);
|
||||
@@ -171,15 +179,15 @@ int main(int argc, char * argv[]) {
|
||||
|
||||
const float vec_dot_error = dot_product_error(qfns, qfns_cpu, test_size, test_data.data(), test_data2.data());
|
||||
const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS ||
|
||||
type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S
|
||||
? MAX_DOT_PRODUCT_ERROR_LOWBIT
|
||||
: type == GGML_TYPE_Q1_0
|
||||
? MAX_DOT_PRODUCT_ERROR_BINARY
|
||||
: type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0
|
||||
? MAX_DOT_PRODUCT_ERROR_TERNARY
|
||||
: type == GGML_TYPE_NVFP4
|
||||
? MAX_DOT_PRODUCT_ERROR_FP4
|
||||
: MAX_DOT_PRODUCT_ERROR;
|
||||
type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S
|
||||
? MAX_DOT_PRODUCT_ERROR_LOWBIT
|
||||
: type == GGML_TYPE_Q1_0
|
||||
? MAX_DOT_PRODUCT_ERROR_BINARY
|
||||
: type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0
|
||||
? MAX_DOT_PRODUCT_ERROR_TERNARY
|
||||
: type == GGML_TYPE_NVFP4
|
||||
? MAX_DOT_PRODUCT_ERROR_FP4
|
||||
: MAX_DOT_PRODUCT_ERROR;
|
||||
failed = !(vec_dot_error < max_allowed_error);
|
||||
num_failed += failed;
|
||||
if (failed || verbose) {
|
||||
@@ -188,6 +196,31 @@ int main(int argc, char * argv[]) {
|
||||
}
|
||||
}
|
||||
|
||||
return num_failed;
|
||||
}
|
||||
|
||||
int main(int argc, char * argv[]) {
|
||||
bool verbose = false;
|
||||
|
||||
std::string arg;
|
||||
for (int i = 1; i < argc; i++) {
|
||||
arg = argv[i];
|
||||
|
||||
if (arg == "-v") {
|
||||
verbose = true;
|
||||
} else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_cpu_init();
|
||||
|
||||
int num_failed = 0;
|
||||
|
||||
num_failed += test_vec_dot_f32(verbose);
|
||||
num_failed += test_vec_dot_q(verbose);
|
||||
|
||||
if (num_failed || verbose) {
|
||||
printf("%d tests failed\n", num_failed);
|
||||
}
|
||||
|
||||
@@ -146,6 +146,8 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
LOG_INF("Model %d/%d, Context %d/%d: %s\n\n", m + 1, num_models, c + 1, num_contexts, result.c_str());
|
||||
|
||||
llama_synchronize(ctx.get());
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1035,25 +1035,23 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
|
||||
if (!params.hf_repo.empty()) {
|
||||
for (size_t i = 0; i < params.hf_repo.size(); i++) {
|
||||
common_params_model model;
|
||||
|
||||
if (params.hf_file.empty() || params.hf_file[i].empty()) {
|
||||
model.hf_repo = params.hf_repo[i];
|
||||
} else {
|
||||
model.hf_repo = params.hf_repo[i];
|
||||
model.hf_file = params.hf_file[i];
|
||||
common_params p;
|
||||
p.hf_token = params.hf_token;
|
||||
p.offline = params.offline;
|
||||
p.model.hf_repo = params.hf_repo[i];
|
||||
if (!params.hf_file.empty() && !params.hf_file[i].empty()) {
|
||||
p.model.hf_file = params.hf_file[i];
|
||||
}
|
||||
|
||||
common_download_opts opts;
|
||||
opts.bearer_token = params.hf_token;
|
||||
opts.offline = params.offline;
|
||||
auto download_result = common_download_model(model, opts);
|
||||
if (download_result.model_path.empty()) {
|
||||
// only the text model file is needed
|
||||
common_models_handler models_handler = common_models_handler_init(p, LLAMA_EXAMPLE_BENCH);
|
||||
common_models_handler_apply(models_handler, p);
|
||||
if (p.model.path.empty()) {
|
||||
fprintf(stderr, "error: failed to download model from HuggingFace\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
params.model.push_back(download_result.model_path);
|
||||
params.model.push_back(p.model.path);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+23
-17
@@ -115,22 +115,28 @@ if (TARGET mtmd)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_executable(llama-llava-cli deprecation-warning.cpp)
|
||||
add_executable(llama-gemma3-cli deprecation-warning.cpp)
|
||||
add_executable(llama-minicpmv-cli deprecation-warning.cpp)
|
||||
add_executable(llama-qwen2vl-cli deprecation-warning.cpp)
|
||||
# Gate CLI binaries on LLAMA_BUILD_TOOLS so that standalone library-only
|
||||
# builds (LLAMA_BUILD_MTMD=ON with LLAMA_BUILD_TOOLS=OFF — e.g. Apple
|
||||
# XCFramework packaging) skip the executables entirely. LLAMA_BUILD_COMMON
|
||||
# defaults to ON in standalone builds, so we cannot rely on it for gating.
|
||||
if (LLAMA_BUILD_TOOLS)
|
||||
add_executable(llama-llava-cli deprecation-warning.cpp)
|
||||
add_executable(llama-gemma3-cli deprecation-warning.cpp)
|
||||
add_executable(llama-minicpmv-cli deprecation-warning.cpp)
|
||||
add_executable(llama-qwen2vl-cli deprecation-warning.cpp)
|
||||
|
||||
set(TARGET llama-mtmd-cli)
|
||||
add_executable (${TARGET} mtmd-cli.cpp)
|
||||
set_target_properties (${TARGET} PROPERTIES OUTPUT_NAME llama-mtmd-cli)
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
set(TARGET llama-mtmd-cli)
|
||||
add_executable (${TARGET} mtmd-cli.cpp)
|
||||
set_target_properties (${TARGET} PROPERTIES OUTPUT_NAME llama-mtmd-cli)
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
endif()
|
||||
target_link_libraries (${TARGET} PRIVATE llama-common mtmd Threads::Threads)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
# mtmd-debug tool
|
||||
add_executable(llama-mtmd-debug debug/mtmd-debug.cpp)
|
||||
set_target_properties(llama-mtmd-debug PROPERTIES OUTPUT_NAME llama-mtmd-debug)
|
||||
target_link_libraries(llama-mtmd-debug PRIVATE llama-common mtmd Threads::Threads)
|
||||
target_compile_features(llama-mtmd-debug PRIVATE cxx_std_17)
|
||||
endif()
|
||||
target_link_libraries (${TARGET} PRIVATE llama-common mtmd Threads::Threads)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
# mtmd-debug tool
|
||||
add_executable(llama-mtmd-debug debug/mtmd-debug.cpp)
|
||||
set_target_properties(llama-mtmd-debug PROPERTIES OUTPUT_NAME llama-mtmd-debug)
|
||||
target_link_libraries(llama-mtmd-debug PRIVATE llama-common mtmd Threads::Threads)
|
||||
target_compile_features(llama-mtmd-debug PRIVATE cxx_std_17)
|
||||
|
||||
@@ -55,8 +55,7 @@ struct clip_hparams {
|
||||
int32_t n_head = 0;
|
||||
int32_t n_head_kv = 0;
|
||||
int32_t n_layer = 0;
|
||||
// idefics3
|
||||
int32_t n_merge = 0; // number of patch merges **per-side**
|
||||
int32_t n_merge = 1; // number of patch merges **per-side**
|
||||
|
||||
// for preprocessor
|
||||
int32_t image_longest_edge = 0;
|
||||
@@ -135,8 +134,7 @@ struct clip_hparams {
|
||||
int32_t custom_image_max_tokens = -1;
|
||||
|
||||
void set_limit_image_tokens(int n_tokens_min, int n_tokens_max) {
|
||||
const int cur_merge = n_merge == 0 ? 1 : n_merge;
|
||||
const int patch_area = patch_size * patch_size * cur_merge * cur_merge;
|
||||
const int patch_area = patch_size * patch_size * n_merge * n_merge;
|
||||
image_min_pixels = (custom_image_min_tokens > 0 ? custom_image_min_tokens : n_tokens_min) * patch_area;
|
||||
image_max_pixels = (custom_image_max_tokens > 0 ? custom_image_max_tokens : n_tokens_max) * patch_area;
|
||||
warmup_image_size = static_cast<int>(std::sqrt(image_max_pixels));
|
||||
@@ -145,8 +143,7 @@ struct clip_hparams {
|
||||
void set_warmup_n_tokens(int n_tokens) {
|
||||
int n_tok_per_side = static_cast<int>(std::sqrt(n_tokens));
|
||||
GGML_ASSERT(n_tok_per_side * n_tok_per_side == n_tokens && "n_tokens must be n*n");
|
||||
const int cur_merge = n_merge == 0 ? 1 : n_merge;
|
||||
warmup_image_size = n_tok_per_side * patch_size * cur_merge;
|
||||
warmup_image_size = n_tok_per_side * patch_size * n_merge;
|
||||
// TODO: support warmup size for custom token numbers
|
||||
}
|
||||
// sam vit deepseek-ocr
|
||||
|
||||
+52
-15
@@ -1210,6 +1210,9 @@ struct clip_model_loader {
|
||||
{
|
||||
std::vector<int> pinpoints;
|
||||
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, pinpoints, false);
|
||||
if (pinpoints.size() % 2 != 0) {
|
||||
throw std::runtime_error(string_format("%s: image_grid_pinpoints must have an even number of elements, got %zu\n", __func__, pinpoints.size()));
|
||||
}
|
||||
if (!pinpoints.empty()) {
|
||||
for (size_t i = 0; i < pinpoints.size(); i += 2) {
|
||||
hparams.image_res_candidates.push_back({
|
||||
@@ -1252,15 +1255,16 @@ struct clip_model_loader {
|
||||
}
|
||||
|
||||
if (is_vision) {
|
||||
int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN);
|
||||
int idx_std = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD);
|
||||
GGML_ASSERT(idx_mean >= 0 && "image_mean not found");
|
||||
GGML_ASSERT(idx_std >= 0 && "image_std not found");
|
||||
const float * mean_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_mean);
|
||||
const float * std_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_std);
|
||||
std::vector<float> image_mean;
|
||||
std::vector<float> image_std;
|
||||
get_arr_f32(KEY_IMAGE_MEAN, image_mean, false);
|
||||
get_arr_f32(KEY_IMAGE_STD , image_std, false);
|
||||
if (image_mean.size() < 3 || image_std.size() < 3) {
|
||||
throw std::runtime_error(string_format("%s: image_mean/image_std arrays must have at least 3 elements, got %zu and %zu\n", __func__, image_mean.size(), image_std.size()));
|
||||
}
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
hparams.image_mean[i] = mean_data[i];
|
||||
hparams.image_std[i] = std_data[i];
|
||||
hparams.image_mean[i] = image_mean[i];
|
||||
hparams.image_std[i] = image_std[i];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1686,8 +1690,8 @@ struct clip_model_loader {
|
||||
if (hparams.image_size > 65536) {
|
||||
throw std::runtime_error(string_format("%s: image_size (%d) is too large (max 65536)\n", __func__, hparams.image_size));
|
||||
}
|
||||
if (hparams.patch_size <= 0) {
|
||||
throw std::runtime_error(string_format("%s: patch_size (%d) must be greater than 0\n", __func__, hparams.patch_size));
|
||||
if (hparams.patch_size <= 0 || hparams.patch_size >= 65536) {
|
||||
throw std::runtime_error(string_format("%s: patch_size (%d) must be positive and less than 65536\n", __func__, hparams.patch_size));
|
||||
}
|
||||
if (hparams.n_embd <= 0) {
|
||||
throw std::runtime_error(string_format("%s: n_embd (%d) must be greater than 0\n", __func__, hparams.n_embd));
|
||||
@@ -1695,6 +1699,9 @@ struct clip_model_loader {
|
||||
if (hparams.image_max_pixels < hparams.image_min_pixels) {
|
||||
throw std::runtime_error(string_format("%s: image_max_pixels (%d) is less than image_min_pixels (%d)\n", __func__, hparams.image_max_pixels, hparams.image_min_pixels));
|
||||
}
|
||||
if (hparams.n_merge < 0 || hparams.n_merge >= 65536) {
|
||||
throw std::runtime_error(string_format("%s: n_merge (%d) must be greater than 0 and less than 65536\n", __func__, hparams.n_merge));
|
||||
}
|
||||
}
|
||||
|
||||
LOG_INF("%s: projector: %s\n", __func__, proj_type.c_str());
|
||||
@@ -3067,6 +3074,29 @@ struct clip_model_loader {
|
||||
output = gguf_get_val_f32(ctx_gguf.get(), i);
|
||||
}
|
||||
|
||||
void get_arr_f32(const std::string & key, std::vector<float> & output, bool required = true) const {
|
||||
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
||||
if (i < 0) {
|
||||
if (required) {
|
||||
throw std::runtime_error("Key not found: " + key);
|
||||
}
|
||||
return;
|
||||
}
|
||||
const auto type = gguf_get_arr_type(ctx_gguf.get(), i);
|
||||
if (type != GGUF_TYPE_FLOAT32) {
|
||||
throw std::runtime_error(string_format("%s: array '%s' has type %d, expected %d (GGUF_TYPE_FLOAT32)\n", __func__, key.c_str(), type, GGUF_TYPE_FLOAT32));
|
||||
}
|
||||
const size_t n = gguf_get_arr_n(ctx_gguf.get(), i);
|
||||
if (n > (size_t) std::numeric_limits<int>::max()) {
|
||||
throw std::runtime_error(string_format("%s: array '%s' is too large (%zu elements)\n", __func__, key.c_str(), n));
|
||||
}
|
||||
output.resize(n);
|
||||
const float * values = (const float *)gguf_get_arr_data(ctx_gguf.get(), i);
|
||||
for (size_t j = 0; j < n; ++j) {
|
||||
output[j] = values[j];
|
||||
}
|
||||
}
|
||||
|
||||
void get_string(const std::string & key, std::string & output, bool required = true) const {
|
||||
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
||||
if (i < 0) {
|
||||
@@ -3086,11 +3116,18 @@ struct clip_model_loader {
|
||||
}
|
||||
return;
|
||||
}
|
||||
int n = gguf_get_arr_n(ctx_gguf.get(), i);
|
||||
const auto type = gguf_get_arr_type(ctx_gguf.get(), i);
|
||||
if (type != GGUF_TYPE_INT32) {
|
||||
throw std::runtime_error(string_format("%s: array '%s' has type %d, expected %d (GGUF_TYPE_INT32)\n", __func__, key.c_str(), type, GGUF_TYPE_INT32));
|
||||
}
|
||||
const size_t n = gguf_get_arr_n(ctx_gguf.get(), i);
|
||||
if (n > (size_t) std::numeric_limits<int>::max()) {
|
||||
throw std::runtime_error(string_format("%s: array '%s' is too large (%zu elements)\n", __func__, key.c_str(), n));
|
||||
}
|
||||
output.resize(n);
|
||||
const int32_t * values = (const int32_t *)gguf_get_arr_data(ctx_gguf.get(), i);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
output[i] = values[i];
|
||||
for (size_t j = 0; j < n; ++j) {
|
||||
output[j] = values[j];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3364,8 +3401,8 @@ int clip_n_output_tokens(const clip_ctx * ctx, const clip_image_f32 * img) {
|
||||
{
|
||||
// dynamic size
|
||||
int n_merge = ctx->model.hparams.n_merge;
|
||||
int n_patches_x = img->nx() / patch_size / (n_merge > 0 ? n_merge : 1);
|
||||
int n_patches_y = img->ny() / patch_size / (n_merge > 0 ? n_merge : 1);
|
||||
int n_patches_x = img->nx() / patch_size / n_merge;
|
||||
int n_patches_y = img->ny() / patch_size / n_merge;
|
||||
if (ctx->model.token_embd_img_break) {
|
||||
n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
|
||||
} else {
|
||||
|
||||
@@ -63,8 +63,8 @@ ggml_cgraph * clip_graph_pixtral::build() {
|
||||
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
|
||||
// after the concatenation, we have a tensor with shape [n_embd, n_patches_per_row + 1, n_rows]
|
||||
|
||||
const int p_y = n_merge > 0 ? n_patches_y / n_merge : n_patches_y;
|
||||
const int p_x = n_merge > 0 ? n_patches_x / n_merge : n_patches_x;
|
||||
const int p_y = n_patches_y / n_merge;
|
||||
const int p_x = n_patches_x / n_merge;
|
||||
const int p_total = p_x * p_y;
|
||||
const int n_embd_text = cur->ne[0];
|
||||
const int n_tokens_output = p_total + p_y - 1; // one [IMG_BREAK] per row, except the last row
|
||||
|
||||
@@ -628,7 +628,7 @@ mtmd_image_preproc_out mtmd_image_preprocessor_llava_uhd::preprocess(const clip_
|
||||
mtmd_image_preprocessor_llava_uhd::slice_instructions mtmd_image_preprocessor_llava_uhd::get_slice_instructions(const clip_image_size & original_size) {
|
||||
mtmd_image_preprocessor_llava_uhd::slice_instructions res;
|
||||
// align slices by patch_size * n_merge so an integer number of merger output tokens fits per slice
|
||||
const int n_merge = hparams.n_merge > 0 ? hparams.n_merge : 1;
|
||||
const int n_merge = hparams.n_merge;
|
||||
const int patch_size = hparams.patch_size * n_merge;
|
||||
const int slice_size = hparams.image_size;
|
||||
const int original_width = original_size.width;
|
||||
@@ -894,7 +894,7 @@ mtmd_image_preproc_out mtmd_image_preprocessor_dyn_size::preprocess(const clip_i
|
||||
clip_image_u8 resized_image;
|
||||
const clip_image_size original_size = img.get_size();
|
||||
// the original pixtral model doesn't have n_merge
|
||||
const int cur_merge = hparams.n_merge == 0 ? 1 : hparams.n_merge;
|
||||
const int cur_merge = hparams.n_merge;
|
||||
const clip_image_size target_size = img_tool::calc_size_preserved_ratio(
|
||||
original_size,
|
||||
hparams.patch_size * cur_merge,
|
||||
|
||||
@@ -40,6 +40,7 @@ struct debug_options {
|
||||
bool enable_reasoning = true;
|
||||
bool debug_jinja = false;
|
||||
bool force_tool_call = false;
|
||||
bool parallel_tool_calls = true;
|
||||
output_mode mode = output_mode::BOTH;
|
||||
input_message_type input_message = input_message_type::NONE;
|
||||
};
|
||||
@@ -87,6 +88,7 @@ static void print_usage(const char * program_name) {
|
||||
LOG_ERR("\nOptions:\n");
|
||||
LOG_ERR(" --no-tools Disable tool definitions\n");
|
||||
LOG_ERR(" --force-tool-call Set tool calls to forced\n");
|
||||
LOG_ERR(" --parallel-tool-calls=0|1 Set parallel_tool_calls (default: 1)\n");
|
||||
LOG_ERR(" --generation-prompt=0|1 Set add_generation_prompt (default: 1)\n");
|
||||
LOG_ERR(" --enable-reasoning=0|1 Enable reasoning parsing (default: 1)\n");
|
||||
LOG_ERR(" --output=MODE Output mode: analysis, template, both (default: both)\n");
|
||||
@@ -121,6 +123,8 @@ static bool parse_options(int argc, char ** argv, debug_options & opts) {
|
||||
opts.debug_jinja = true;
|
||||
} else if (arg == "--no-tools") {
|
||||
opts.with_tools = false;
|
||||
} else if (arg.rfind("--parallel-tool-calls=", 0) == 0) {
|
||||
opts.parallel_tool_calls = parse_bool_option(arg.substr(22));
|
||||
} else if (arg.rfind("--generation-prompt=", 0) == 0) {
|
||||
opts.generation_prompt = parse_bool_option(arg.substr(20));
|
||||
} else if (arg.rfind("--enable-reasoning=", 0) == 0) {
|
||||
@@ -349,7 +353,7 @@ static autoparser::generation_params prepare_params(const debug_options & opts,
|
||||
params.tools = json();
|
||||
params.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
}
|
||||
params.parallel_tool_calls = false;
|
||||
params.parallel_tool_calls = opts.parallel_tool_calls;
|
||||
return params;
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,8 @@ add_library(${TARGET} STATIC
|
||||
server-common.h
|
||||
server-context.cpp
|
||||
server-context.h
|
||||
server-stream.cpp
|
||||
server-stream.h
|
||||
server-tools.cpp
|
||||
server-tools.h
|
||||
server-schema.cpp
|
||||
|
||||
@@ -57,6 +57,7 @@ The core architecture consists of the following components:
|
||||
- `server_tokens`: Unified representation of token sequences (supports both text and multimodal tokens); used by `server_task` and `server_slot`.
|
||||
- `server_prompt_checkpoint`: For recurrent (e.g., RWKV) and SWA models, stores snapshots of KV cache state. Enables reuse when subsequent requests share the same prompt prefix, saving redundant computation.
|
||||
- `server_models`: Standalone component for managing multiple backend instances (used in router mode). It is completely independent of `server_context`.
|
||||
- `stream_session_manager`: Process wide owner of resumable SSE stream sessions (`g_stream_sessions`), keyed by conversation id. Backs the replay buffer that lets a client reattach to a generation after an HTTP disconnect. See the "Resumable streaming" section below.
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
@@ -117,6 +118,58 @@ Here is an example trace of an API request for text completion:
|
||||
- As the response is stateless, `server_res_generator` calls `response->update()` to update the response with the current state.
|
||||
- `server_res_generator` then calls `response->to_json()` and passes the response to the HTTP layer.
|
||||
|
||||
### Resumable streaming (SSE replay buffer)
|
||||
|
||||
By default a streaming generation is bound to its HTTP socket: when the socket drops (refresh, tab close, mobile background, transient network) the generation aborts and the live stream is lost. This feature keeps the generation running server side and lets a client reattach.
|
||||
|
||||
It is opt in via the `X-Conversation-Id` header on `POST /v1/chat/completions`. Without the header the OAI strict path is unchanged. The conversation id is the only identity end to end (server map key, client localStorage key, route path), with an optional `::model` suffix for direct routing in router mode.
|
||||
|
||||
The feature lives entirely in `server-stream.{h,cpp}` and rests on three types:
|
||||
|
||||
- `stream_session`: a bounded ring buffer (4 MiB cap, oldest bytes drop first) plus a condvar. `append` pushes raw SSE bytes, `read_from` drains from any offset and blocks for live bytes or finalize, `finalize` wakes readers, `cancel` stops the producer. One conv maps to at most one live session.
|
||||
- `stream_session_manager` (`g_stream_sessions`): owns all sessions keyed by conv id, enforces the one conv one session invariant via `create_or_replace`, and runs a GC thread that drops completed sessions past their TTL.
|
||||
- `stream_pipe_producer` / `stream_pipe_consumer`: the write and read ends. The producer owns the session lifetime and finalizes it on destruction; the consumer is read only and never finalizes, so a reader detaching cannot kill a running generation.
|
||||
|
||||
Producer side: `server_res_generator` attaches a producer pipe when the header is present. The HTTP content provider mirrors every chunk into the ring before writing it to the socket. While a pipe is attached, `stream_aware_should_stop` ignores peer disconnect, so a dropped socket does not stop generation: only an explicit `DELETE` does. When the peer leaves early, `on_complete` calls `close()`, which drains the rest of the generation into the ring on the http worker.
|
||||
|
||||
Lifetime safety: the producer pipe holds a shared `alive` flag also captured by the session cancel hook. `~server_res_generator` calls `cleanup()` to clear that hook while the reader is still alive, so a `cancel` arriving during teardown can never call `stop()` on a freed response. This ordering is the most fragile part of the feature: finalizing or destroying the producer before `cleanup()` runs reintroduces a use after free.
|
||||
|
||||
Consumer side: `GET /v1/stream/<conv_id>?from=N` opens a `text/event-stream` that replays buffered bytes from offset `N` and blocks for live bytes, so the browser reattaches like a fresh EventSource. An offset below the dropped prefix returns 400.
|
||||
|
||||
Routes:
|
||||
|
||||
- `GET /v1/stream/:conv_id?from=N`: replay or live reattach.
|
||||
- `POST /v1/streams/lookup` with `{"conversation_ids": [...]}`: returns session status only for ids the caller already owns. There is no listing route, so live sessions cannot be enumerated (an earlier `GET /v1/streams` was removed for exactly this reason).
|
||||
- `DELETE /v1/stream/:conv_id`: explicit Stop, idempotent (`evict_and_cancel`).
|
||||
|
||||
Router mode binds the same paths to proxy handlers. A `conv_id -> child` map (`conv_models`), populated when a POST is routed, resolves the owning child in one lookup with no polling. The lookup groups ids per child; GET and DELETE proxy straight to the owner. This loopback REST hop is expected to move to a websocket IPC later, swapping only the transport.
|
||||
|
||||
Lifecycle: `g_stream_sessions.start_gc()` runs in main after common init, `stop_gc()` runs first in `clean_up()` and finalizes every live session so no reader hangs. Reader blocking and the post drop drain both run on httplib worker threads, which block on a condvar rather than spin.
|
||||
|
||||
| Constant | Value | Role |
|
||||
| --- | --- | --- |
|
||||
| `STREAM_SESSION_TTL_SECONDS` | 300 | retention of a completed session before GC |
|
||||
| `STREAM_SESSION_MAX_BYTES` | 4 MiB | ring cap per session |
|
||||
| `STREAM_SESSION_GC_INTERVAL_SECONDS` | 60 | GC tick |
|
||||
| `STREAM_READ_WAKE_INTERVAL_MS` | 200 | read_from wake to recheck should_stop |
|
||||
| `STREAM_LOOKUP_TIMEOUT_MS` | 250 | router to child loopback budget |
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
Client -- "POST + X-Conversation-Id" --> RG[server_res_generator]
|
||||
RG -- attach --> Prod[stream_pipe_producer]
|
||||
Prod -- "write, drain on peer drop" --> Sess
|
||||
subgraph g_stream_sessions
|
||||
Sess[stream_session: ring buffer, 4 MiB]
|
||||
GC[GC thread] -- drop after TTL --> Sess
|
||||
end
|
||||
Sess -- read_from offset --> Cons[stream_pipe_consumer]
|
||||
Cons -- "GET /v1/stream/:id?from=N" --> Client
|
||||
DEL[DELETE /v1/stream/:id] -- evict_and_cancel --> Sess
|
||||
```
|
||||
|
||||
The diagram shows the buffer touch points. The live wire (chunks streamed to the original client during a normal generation) is the producer's default output, described under "Producer side" above.
|
||||
|
||||
### Testing
|
||||
|
||||
`llama-server` includes an automated test suite based on `pytest`.
|
||||
@@ -223,6 +276,7 @@ The flow for downloading a new model:
|
||||
- Speculative decoding: https://github.com/ggml-org/llama.cpp/pull/17808 and rework in https://github.com/ggml-org/llama.cpp/pull/17808
|
||||
- INI presets: https://github.com/ggml-org/llama.cpp/pull/17859 (+ refactoring: https://github.com/ggml-org/llama.cpp/pull/18169)
|
||||
- Sleeping mode: https://github.com/ggml-org/llama.cpp/pull/18228
|
||||
- Resumable streaming (SSE replay buffer): https://github.com/ggml-org/llama.cpp/pull/23226
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "server-task.h"
|
||||
#include "server-queue.h"
|
||||
#include "server-schema.h"
|
||||
#include "server-stream.h"
|
||||
|
||||
#include "build-info.h"
|
||||
#include "common.h"
|
||||
@@ -4022,6 +4023,15 @@ struct server_res_generator : server_http_res {
|
||||
queue_tasks.wait_until_no_sleep();
|
||||
}
|
||||
}
|
||||
~server_res_generator() override {
|
||||
// cleanup() must run while rd is still alive (rd is destroyed after this body returns)
|
||||
if (spipe) {
|
||||
spipe->cleanup();
|
||||
}
|
||||
}
|
||||
void stop() override {
|
||||
rd.stop();
|
||||
}
|
||||
void ok(const json & response_data) {
|
||||
status = 200;
|
||||
data = safe_json_to_str(response_data);
|
||||
@@ -4210,8 +4220,10 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
}
|
||||
};
|
||||
|
||||
auto effective_should_stop = stream_aware_should_stop(res_this, req.should_stop);
|
||||
|
||||
try {
|
||||
if (req.should_stop()) {
|
||||
if (effective_should_stop()) {
|
||||
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
||||
return false; // should_stop condition met
|
||||
}
|
||||
@@ -4245,8 +4257,8 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
// receive subsequent results
|
||||
bool timeout = false;
|
||||
int64_t start_time = ggml_time_ms();
|
||||
auto result = rd.next([&timeout, &req, &start_time, ¶ms]() {
|
||||
if (req.should_stop()) {
|
||||
auto result = rd.next([&timeout, &start_time, ¶ms, &effective_should_stop]() {
|
||||
if (effective_should_stop()) {
|
||||
return true; // should_stop condition met
|
||||
} else if (params.sse_ping_interval > 0 && ggml_time_ms() - start_time > (int64_t)params.sse_ping_interval * 1000) {
|
||||
timeout = true;
|
||||
@@ -4264,7 +4276,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
|
||||
if (result == nullptr) {
|
||||
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
||||
GGML_ASSERT(req.should_stop());
|
||||
GGML_ASSERT(effective_should_stop());
|
||||
return false; // should_stop condition met
|
||||
}
|
||||
|
||||
@@ -4302,6 +4314,10 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
};
|
||||
}
|
||||
|
||||
// attach a producer pipe to the response when X-Conversation-Id is present.
|
||||
// the pipe mirrors SSE chunks into the ring buffer and wires up the cancel hook.
|
||||
stream_session_attach_pipe(*res, req.headers);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "common.h"
|
||||
#include "server-http.h"
|
||||
#include "server-stream.h"
|
||||
#include "server-common.h"
|
||||
#include "ui.h"
|
||||
|
||||
@@ -456,13 +457,40 @@ static void set_headers(httplib::Response & res, const std::map<std::string, std
|
||||
}
|
||||
}
|
||||
|
||||
// percent-decode a path component (%XX). path params arrive raw from httplib, unlike query
|
||||
// params, so a conv id like "conv::model" sent as "conv%3A%3Amodel" must be decoded here to
|
||||
// match the value the client put in the X-Conversation-Id header
|
||||
static std::string decode_path_component(const std::string & in) {
|
||||
std::string out;
|
||||
out.reserve(in.size());
|
||||
for (size_t i = 0; i < in.size(); i++) {
|
||||
if (in[i] == '%' && i + 2 < in.size()) {
|
||||
auto hex = [](char c) -> int {
|
||||
if (c >= '0' && c <= '9') return c - '0';
|
||||
if (c >= 'a' && c <= 'f') return c - 'a' + 10;
|
||||
if (c >= 'A' && c <= 'F') return c - 'A' + 10;
|
||||
return -1;
|
||||
};
|
||||
int hi = hex(in[i + 1]);
|
||||
int lo = hex(in[i + 2]);
|
||||
if (hi >= 0 && lo >= 0) {
|
||||
out.push_back(char((hi << 4) | lo));
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
out.push_back(in[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
static std::map<std::string, std::string> get_params(const httplib::Request & req) {
|
||||
std::map<std::string, std::string> params;
|
||||
for (const auto & [key, value] : req.params) {
|
||||
params[key] = value;
|
||||
}
|
||||
for (const auto & [key, value] : req.path_params) {
|
||||
params[key] = value;
|
||||
params[key] = decode_path_component(value);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
@@ -497,26 +525,41 @@ static void process_handler_response(server_http_req_ptr && request, server_http
|
||||
set_headers(res, response->headers);
|
||||
const std::string content_type = response->content_type;
|
||||
// convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it
|
||||
std::shared_ptr q_ptr = std::move(request);
|
||||
std::shared_ptr r_ptr = std::move(response);
|
||||
const auto chunked_content_provider = [response = r_ptr](size_t, const httplib::DataSink & sink) -> bool {
|
||||
std::shared_ptr<server_http_req> q_ptr = std::move(request);
|
||||
std::shared_ptr<server_http_res> r_ptr = std::move(response);
|
||||
|
||||
const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool {
|
||||
std::string chunk;
|
||||
const bool has_next = response->next(chunk);
|
||||
if (!chunk.empty()) {
|
||||
// mirror into the ring buffer first, the session must reflect every SSE chunk
|
||||
// whether or not the wire write below succeeds
|
||||
if (response->spipe) {
|
||||
response->spipe->write(chunk.data(), chunk.size());
|
||||
}
|
||||
if (!sink.write(chunk.data(), chunk.size())) {
|
||||
// peer is gone, stop the wire path here
|
||||
return false;
|
||||
}
|
||||
SRV_DBG("http: streamed chunk: %s\n", chunk.c_str());
|
||||
}
|
||||
if (!has_next) {
|
||||
// producer reached its natural end on the wire, a later close() skips the drain
|
||||
if (response->spipe) {
|
||||
response->spipe->done();
|
||||
}
|
||||
sink.done();
|
||||
SRV_DBG("%s", "http: stream ended\n");
|
||||
}
|
||||
return has_next;
|
||||
};
|
||||
const auto on_complete = [request = q_ptr, response = r_ptr](bool) mutable {
|
||||
response.reset(); // trigger the destruction of the response object
|
||||
request.reset(); // trigger the destruction of the request object
|
||||
// on a dropped peer, close() drains the rest of the generation into the ring buffer
|
||||
if (response->spipe) {
|
||||
response->spipe->close();
|
||||
}
|
||||
response.reset(); // spipe destructor finalizes the session if attached
|
||||
request.reset();
|
||||
};
|
||||
res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete);
|
||||
} else {
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
@@ -10,6 +11,7 @@
|
||||
#include <unordered_map>
|
||||
|
||||
struct common_params;
|
||||
struct stream_pipe_producer; // defined in server-stream.h
|
||||
|
||||
// generator-like API for HTTP response generation
|
||||
// this object response with one of the 2 modes:
|
||||
@@ -23,12 +25,20 @@ struct server_http_res {
|
||||
std::string data;
|
||||
std::map<std::string, std::string> headers;
|
||||
|
||||
// TODO: move this to a virtual function once we have proper polymorphism support
|
||||
// if set, the stream survives a client disconnect: the producer pipe keeps draining into the
|
||||
// ring buffer and finalizes the session on destruction, so no explicit on_stream_end is needed.
|
||||
// shared_ptr (not unique_ptr) so the forward-declared type is safe to delete here.
|
||||
std::shared_ptr<stream_pipe_producer> spipe;
|
||||
|
||||
std::function<bool(std::string &)> next = nullptr;
|
||||
bool is_stream() const {
|
||||
return next != nullptr;
|
||||
}
|
||||
|
||||
// called when the session is cancelled (e.g. DELETE /v1/stream/<conv_id>).
|
||||
// server_res_generator overrides this to stop its reader; the default is a no-op.
|
||||
virtual void stop() {}
|
||||
|
||||
virtual ~server_http_res() = default;
|
||||
};
|
||||
|
||||
|
||||
+181
-19
@@ -1,12 +1,14 @@
|
||||
#include "server-common.h"
|
||||
#include "server-models.h"
|
||||
#include "server-context.h"
|
||||
#include "server-stream.h"
|
||||
|
||||
#include "build-info.h"
|
||||
#include "preset.h"
|
||||
#include "download.h"
|
||||
|
||||
#include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
|
||||
#include <optional>
|
||||
#include <sheredom/subprocess.h>
|
||||
|
||||
#include <functional>
|
||||
@@ -92,6 +94,9 @@ struct server_subproc {
|
||||
}
|
||||
};
|
||||
|
||||
// short loopback budget for the resumable stream router to child JSON calls (probe, lookup,
|
||||
// delete). distinct from params.timeout_read/write which only applies to the generation proxy
|
||||
static constexpr int STREAM_LOOKUP_TIMEOUT_MS = 250;
|
||||
|
||||
static std::filesystem::path get_server_exec_path() {
|
||||
#if defined(_WIN32)
|
||||
@@ -223,8 +228,8 @@ void server_model_meta::update_caps() {
|
||||
"LLAMA_ARG_HF_REPO_FILE",
|
||||
});
|
||||
params.offline = true;
|
||||
// params.skip_download = true; // TODO: ideally, we should validate the model here, but it takes too much time
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {});
|
||||
common_models_handler handler = common_models_handler_init(params, LLAMA_EXAMPLE_SERVER);
|
||||
common_models_handler_apply(handler, params); // note: this won't download the model because offline=true
|
||||
if (params.mmproj.path.empty()) {
|
||||
multimodal = { false, false };
|
||||
} else {
|
||||
@@ -1393,9 +1398,8 @@ struct server_download_state : public common_download_callback {
|
||||
|
||||
bool run(common_params & params) {
|
||||
try {
|
||||
common_params_handle_models_params p;
|
||||
p.callback = this;
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, p);
|
||||
common_models_handler handler = common_models_handler_init(params, LLAMA_EXAMPLE_SERVER);
|
||||
common_models_handler_apply(handler, params, this);
|
||||
is_ok = true;
|
||||
} catch (const std::exception & e) {
|
||||
auto model_name = params.model.get_name();
|
||||
@@ -1581,6 +1585,45 @@ static bool is_autoload(const common_params & params, const server_http_req & re
|
||||
}
|
||||
}
|
||||
|
||||
// percent encode one query or path component, covers reserved chars without pulling in
|
||||
// httplib::detail. used by the stream routes to forward conversation_id to children safely
|
||||
static std::string encode_qs(const std::string & in) {
|
||||
std::string out;
|
||||
out.reserve(in.size() * 3);
|
||||
for (unsigned char c : in) {
|
||||
bool safe = (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9')
|
||||
|| c == '-' || c == '_' || c == '.' || c == '~';
|
||||
if (safe) {
|
||||
out.push_back(char(c));
|
||||
} else {
|
||||
char buf[4];
|
||||
std::snprintf(buf, sizeof(buf), "%%%02X", c);
|
||||
out.append(buf, 3);
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// resolve the child that owns a conversation's stream session via the conv_id -> model map
|
||||
// populated when the POST was routed. single map lookup then a meta lookup, no polling, no
|
||||
// parsing of the conv id. returns nullopt when nothing maps, the caller answers not found and
|
||||
// the client recovers
|
||||
static std::optional<server_model_meta> resolve_child_for_conv(
|
||||
server_models & models, const std::string & conversation_id) {
|
||||
if (conversation_id.empty()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
auto tracked = models.conv_models.lookup(conversation_id);
|
||||
if (!tracked.has_value()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
auto meta = models.get_meta(*tracked);
|
||||
if (meta.has_value() && meta->is_ready()) {
|
||||
return meta;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
void server_models_routes::init_routes() {
|
||||
this->get_router_props = [this](const server_http_req & req) {
|
||||
std::string name = req.get_param("model");
|
||||
@@ -1629,6 +1672,12 @@ void server_models_routes::init_routes() {
|
||||
if (!router_validate_model(name, models, autoload, error_res)) {
|
||||
return error_res;
|
||||
}
|
||||
// remember which child serves this conversation so the stream routes can route straight
|
||||
// to it without polling, keyed on the exact conv id from the header
|
||||
std::string conv_id = stream_conv_id_from_headers(req.headers);
|
||||
if (!conv_id.empty()) {
|
||||
models.conv_models.remember(conv_id, name);
|
||||
}
|
||||
return models.proxy_request(req, method, name, true); // update last usage for POST request only
|
||||
};
|
||||
|
||||
@@ -1768,23 +1817,14 @@ void server_models_routes::init_routes() {
|
||||
throw std::invalid_argument("model must be a non-empty string");
|
||||
}
|
||||
|
||||
common_params_model model;
|
||||
common_download_opts opts;
|
||||
common_params p;
|
||||
p.model.hf_repo = name;
|
||||
p.hf_token = params.hf_token;
|
||||
|
||||
model.hf_repo = name;
|
||||
opts.bearer_token = params.hf_token;
|
||||
// note: we only check main model, no need sidecar here
|
||||
opts.download_mmproj = false;
|
||||
opts.download_mtp = false;
|
||||
|
||||
// first, only check if the model is valid and can be downloaded
|
||||
opts.skip_download = true;
|
||||
// validate by fetching metadata
|
||||
bool ok = false;
|
||||
try {
|
||||
auto validation = common_download_model(model, opts);
|
||||
ok = !validation.model_path.empty();
|
||||
} catch (const common_skip_download_exception &) {
|
||||
// model is valid and will be downloaded
|
||||
common_models_handler_init(p, LLAMA_EXAMPLE_SERVER);
|
||||
ok = true;
|
||||
} catch (...) {
|
||||
SRV_ERR("unknown error while validating model '%s'\n", name.c_str());
|
||||
@@ -1829,6 +1869,128 @@ void server_models_routes::init_routes() {
|
||||
res_ok(res, {{"success", true}});
|
||||
return res;
|
||||
};
|
||||
|
||||
this->router_stream_get = [this](const server_http_req & req) {
|
||||
// GET /v1/stream/<conv_id>?from=N. resolve the owning child from the conv_id -> model
|
||||
// map, 404 when nothing maps
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
std::string conv_id = req.get_param("conv_id");
|
||||
if (conv_id.empty()) {
|
||||
res_err(res, format_error_response("Missing conversation id in path", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
std::optional<server_model_meta> owner = resolve_child_for_conv(models, conv_id);
|
||||
if (!owner.has_value()) {
|
||||
res_err(res, format_error_response("Stream not found or expired", ERROR_TYPE_NOT_FOUND));
|
||||
return res;
|
||||
}
|
||||
std::string from = req.get_param("from");
|
||||
std::string child_path = "/v1/stream/" + encode_qs(conv_id);
|
||||
if (!from.empty()) {
|
||||
child_path += "?from=" + from;
|
||||
}
|
||||
SRV_INF("proxying stream resume to model %s on port %d, path=%s\n",
|
||||
owner->name.c_str(), owner->port, child_path.c_str());
|
||||
auto proxy = std::make_unique<server_http_proxy>(
|
||||
"GET",
|
||||
"http",
|
||||
CHILD_ADDR,
|
||||
owner->port,
|
||||
child_path,
|
||||
req.headers,
|
||||
req.body,
|
||||
req.files,
|
||||
req.should_stop,
|
||||
params.timeout_read,
|
||||
params.timeout_write);
|
||||
return std::unique_ptr<server_http_res>(std::move(proxy));
|
||||
};
|
||||
|
||||
this->router_streams_lookup = [this](const server_http_req & req) {
|
||||
// POST /v1/streams/lookup. resolve each requested conv id to its owning child via the
|
||||
// map, group the ids per child, and query only the children that actually own some of
|
||||
// them instead of fanning out to every ready child. a child only answers for the ids
|
||||
// it owns, never lists anything else
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
std::vector<std::string> requested;
|
||||
try {
|
||||
json body = json::parse(req.body);
|
||||
if (body.contains("conversation_ids") && body["conversation_ids"].is_array()) {
|
||||
for (const auto & v : body["conversation_ids"]) {
|
||||
if (v.is_string() && !v.get<std::string>().empty()) {
|
||||
requested.push_back(v.get<std::string>());
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (const std::exception &) {
|
||||
res_ok(res, json::array());
|
||||
return res;
|
||||
}
|
||||
|
||||
// group requested ids by the child port that owns them, drop ids that map to nothing
|
||||
std::unordered_map<int, json> per_child;
|
||||
for (const auto & cid : requested) {
|
||||
auto owner = resolve_child_for_conv(models, cid);
|
||||
if (!owner.has_value()) {
|
||||
continue;
|
||||
}
|
||||
per_child[owner->port].push_back(cid);
|
||||
}
|
||||
|
||||
json aggregated = json::array();
|
||||
for (auto & [port, ids] : per_child) {
|
||||
json child_body = {{"conversation_ids", ids}};
|
||||
httplib::Client cli(CHILD_ADDR, port);
|
||||
cli.set_connection_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
|
||||
cli.set_read_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
|
||||
cli.set_write_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
|
||||
auto resp = cli.Post("/v1/streams/lookup", child_body.dump(), "application/json");
|
||||
if (!resp || resp->status != 200) {
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
json child_arr = json::parse(resp->body);
|
||||
if (!child_arr.is_array()) {
|
||||
continue;
|
||||
}
|
||||
for (auto & entry : child_arr) {
|
||||
if (entry.is_object()) {
|
||||
aggregated.push_back(entry);
|
||||
}
|
||||
}
|
||||
} catch (const std::exception &) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
res_ok(res, aggregated);
|
||||
return res;
|
||||
};
|
||||
|
||||
this->router_stream_delete = [this](const server_http_req & req) {
|
||||
// DELETE /v1/stream/<conv_id>. resolve the owning child via the map and forward only to
|
||||
// it, evict_and_cancel is idempotent on the child
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
std::string conv_id = req.get_param("conv_id");
|
||||
if (conv_id.empty()) {
|
||||
res_err(res, format_error_response("Missing conversation id in path", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
std::string child_path = "/v1/stream/" + encode_qs(conv_id);
|
||||
auto owner = resolve_child_for_conv(models, conv_id);
|
||||
if (owner.has_value()) {
|
||||
httplib::Client cli(CHILD_ADDR, owner->port);
|
||||
cli.set_connection_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
|
||||
cli.set_read_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
|
||||
cli.set_write_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
|
||||
auto resp = cli.Delete(child_path.c_str());
|
||||
(void) resp; // best effort, 404 and network errors are equivalent to no op
|
||||
}
|
||||
// drop the tracking entry, the session is being torn down
|
||||
models.conv_models.forget(conv_id);
|
||||
res->status = 204;
|
||||
res->content_type = "application/json";
|
||||
return res;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,10 @@
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
/**
|
||||
* state diagram:
|
||||
@@ -126,6 +129,44 @@ private:
|
||||
// if true, the next get_meta() will trigger a reload of model list
|
||||
bool need_reload = false;
|
||||
|
||||
// conv_id -> model name that currently serves its stream session, lets the resumable stream
|
||||
// routes go straight to the owning child instead of polling every one. populated when
|
||||
// proxy_request forwards a POST carrying an X-Conversation-Id. best effort: a stale entry just
|
||||
// makes the child answer not found and the client recovers. owns its lock, one mutex per struct
|
||||
struct conv_model_tracker {
|
||||
void remember(const std::string & conv_id, const std::string & model) {
|
||||
if (conv_id.empty() || model.empty()) {
|
||||
return;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(mu);
|
||||
map[conv_id] = model;
|
||||
}
|
||||
|
||||
std::optional<std::string> lookup(const std::string & conv_id) {
|
||||
if (conv_id.empty()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(mu);
|
||||
auto it = map.find(conv_id);
|
||||
if (it == map.end()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void forget(const std::string & conv_id) {
|
||||
if (conv_id.empty()) {
|
||||
return;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(mu);
|
||||
map.erase(conv_id);
|
||||
}
|
||||
|
||||
private:
|
||||
std::mutex mu;
|
||||
std::unordered_map<std::string, std::string> map;
|
||||
};
|
||||
|
||||
common_preset_context ctx_preset;
|
||||
|
||||
common_params base_params;
|
||||
@@ -145,6 +186,9 @@ private:
|
||||
void notify_sse(const std::string & event, const std::string & model_id, const json & data = nullptr);
|
||||
|
||||
public:
|
||||
// conv_id -> model tracker for the resumable stream routes, owns its lock
|
||||
conv_model_tracker conv_models;
|
||||
|
||||
server_models(const common_params & params, int argc, char ** argv);
|
||||
|
||||
server_response sse; // for real-time updates via SSE endpoint
|
||||
@@ -268,6 +312,12 @@ struct server_models_routes {
|
||||
server_http_context::handler_t get_router_models_sse;
|
||||
server_http_context::handler_t post_router_models;
|
||||
server_http_context::handler_t del_router_models;
|
||||
|
||||
// router side handlers for the resumable streaming routes. each resolves the child that owns
|
||||
// a conversation through the conv_id -> model map, no probing or fan out
|
||||
server_http_context::handler_t router_stream_get;
|
||||
server_http_context::handler_t router_streams_lookup;
|
||||
server_http_context::handler_t router_stream_delete;
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -0,0 +1,569 @@
|
||||
#include "server-stream.h"
|
||||
#include "server-common.h"
|
||||
#include "server-http.h"
|
||||
#include "server-queue.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
namespace {
|
||||
constexpr int64_t STREAM_SESSION_TTL_SECONDS = 300;
|
||||
constexpr size_t STREAM_SESSION_MAX_BYTES = 4 * 1024 * 1024;
|
||||
constexpr int64_t STREAM_SESSION_GC_INTERVAL_SECONDS = 60;
|
||||
constexpr int64_t STREAM_READ_WAKE_INTERVAL_MS = 200;
|
||||
|
||||
// returns unix time in seconds
|
||||
int64_t now_seconds() {
|
||||
return std::chrono::duration_cast<std::chrono::seconds>(
|
||||
std::chrono::system_clock::now().time_since_epoch()
|
||||
).count();
|
||||
}
|
||||
}
|
||||
|
||||
stream_session::stream_session(std::string conversation_id_, size_t max_bytes_)
|
||||
: conversation_id(std::move(conversation_id_))
|
||||
, started_ts(now_seconds())
|
||||
, prefix_dropped(0)
|
||||
, cap_bytes(max_bytes_)
|
||||
, done(false)
|
||||
, cancelled(false)
|
||||
, completed_ts(0) {
|
||||
buffer.reserve(64 * 1024);
|
||||
}
|
||||
|
||||
bool stream_session::append(const char * data, size_t len) {
|
||||
if (len == 0) {
|
||||
return true;
|
||||
}
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mu);
|
||||
if (done.load(std::memory_order_relaxed)) {
|
||||
return false;
|
||||
}
|
||||
if (len >= cap_bytes) {
|
||||
// single chunk bigger than the cap, keep only the tail that fits
|
||||
size_t skip = len - cap_bytes;
|
||||
prefix_dropped += buffer.size() + skip;
|
||||
buffer.clear();
|
||||
buffer.insert(buffer.end(), data + skip, data + len);
|
||||
} else {
|
||||
size_t needed = buffer.size() + len;
|
||||
if (needed > cap_bytes) {
|
||||
size_t to_drop = needed - cap_bytes;
|
||||
buffer.erase(buffer.begin(), buffer.begin() + to_drop);
|
||||
prefix_dropped += to_drop;
|
||||
}
|
||||
buffer.insert(buffer.end(), data, data + len);
|
||||
}
|
||||
}
|
||||
cv.notify_all();
|
||||
return true;
|
||||
}
|
||||
|
||||
void stream_session::finalize() {
|
||||
bool was_done = done.exchange(true, std::memory_order_acq_rel);
|
||||
if (was_done) {
|
||||
return;
|
||||
}
|
||||
completed_ts.store(now_seconds(), std::memory_order_release);
|
||||
cv.notify_all();
|
||||
}
|
||||
|
||||
stream_read_status stream_session::read_from(size_t offset,
|
||||
const std::function<bool(const char *, size_t)> & sink,
|
||||
const std::function<bool()> & should_stop) {
|
||||
std::unique_lock<std::mutex> lock(mu);
|
||||
while (true) {
|
||||
if (should_stop && should_stop()) {
|
||||
return stream_read_status::OK;
|
||||
}
|
||||
if (offset < prefix_dropped) {
|
||||
return stream_read_status::OFFSET_LOST;
|
||||
}
|
||||
size_t logical_end = prefix_dropped + buffer.size();
|
||||
if (offset < logical_end) {
|
||||
size_t local_off = offset - prefix_dropped;
|
||||
size_t n = buffer.size() - local_off;
|
||||
// copy the available chunk under the lock, release before calling the sink
|
||||
std::vector<char> chunk(buffer.begin() + local_off, buffer.begin() + local_off + n);
|
||||
offset += n;
|
||||
lock.unlock();
|
||||
bool keep_going = sink(chunk.data(), chunk.size());
|
||||
if (!keep_going) {
|
||||
return stream_read_status::OK;
|
||||
}
|
||||
lock.lock();
|
||||
continue;
|
||||
}
|
||||
if (done.load(std::memory_order_acquire)) {
|
||||
return stream_read_status::OK;
|
||||
}
|
||||
// wait for new bytes, finalize, or a periodic wake to re check should_stop
|
||||
cv.wait_for(lock, std::chrono::milliseconds(STREAM_READ_WAKE_INTERVAL_MS));
|
||||
}
|
||||
}
|
||||
|
||||
bool stream_session::is_done() const {
|
||||
return done.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
size_t stream_session::total_size() const {
|
||||
std::lock_guard<std::mutex> lock(mu);
|
||||
return prefix_dropped + buffer.size();
|
||||
}
|
||||
|
||||
size_t stream_session::dropped_prefix() const {
|
||||
std::lock_guard<std::mutex> lock(mu);
|
||||
return prefix_dropped;
|
||||
}
|
||||
|
||||
int64_t stream_session::completed_at() const {
|
||||
return completed_ts.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
void stream_session::set_stop_producer(std::function<void()> fn) {
|
||||
std::lock_guard<std::mutex> lock(mu);
|
||||
stop_producer = std::move(fn);
|
||||
}
|
||||
|
||||
void stream_session::cancel() {
|
||||
// flip cancelled first so the producer-side stream_aware_should_stop can break out of the
|
||||
// recv() wait even if remove_waiting_task_ids does not notify the condvar (the cancel task
|
||||
// posted by rd.stop() will eventually notify, but we do not want to depend on that timing)
|
||||
cancelled.store(true, std::memory_order_release);
|
||||
// copy the hook under the lock then invoke outside, the producer side may grab queue locks
|
||||
// and we do not want to hold our mu across that path
|
||||
std::function<void()> fn;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mu);
|
||||
fn = stop_producer;
|
||||
}
|
||||
if (fn) {
|
||||
fn();
|
||||
}
|
||||
}
|
||||
|
||||
bool stream_session::is_cancelled() const {
|
||||
return cancelled.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
stream_session_manager::stream_session_manager()
|
||||
: running(false) {
|
||||
}
|
||||
|
||||
stream_session_manager::~stream_session_manager() {
|
||||
stop_gc();
|
||||
}
|
||||
|
||||
stream_session_ptr stream_session_manager::create_or_replace(const std::string & conversation_id) {
|
||||
// evict any previous session on the same conv, this guarantees the invariant
|
||||
// "one conv = at most one live session" and propagates cancel to its producer
|
||||
stream_session_ptr previous;
|
||||
auto fresh = std::make_shared<stream_session>(conversation_id, STREAM_SESSION_MAX_BYTES);
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(map_mu);
|
||||
auto it = sessions.find(conversation_id);
|
||||
if (it != sessions.end()) {
|
||||
previous = it->second;
|
||||
it->second = fresh;
|
||||
} else {
|
||||
sessions.emplace(conversation_id, fresh);
|
||||
}
|
||||
}
|
||||
if (previous) {
|
||||
previous->cancel();
|
||||
previous->finalize();
|
||||
}
|
||||
return fresh;
|
||||
}
|
||||
|
||||
stream_session_ptr stream_session_manager::get(const std::string & conversation_id) {
|
||||
std::shared_lock<std::shared_mutex> lock(map_mu);
|
||||
auto it = sessions.find(conversation_id);
|
||||
if (it == sessions.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<stream_session_ptr> stream_session_manager::list_all() const {
|
||||
std::vector<stream_session_ptr> out;
|
||||
std::shared_lock<std::shared_mutex> lock(map_mu);
|
||||
out.reserve(sessions.size());
|
||||
for (auto & kv : sessions) {
|
||||
out.push_back(kv.second);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
void stream_session_manager::evict(const std::string & conversation_id) {
|
||||
stream_session_ptr s;
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(map_mu);
|
||||
auto it = sessions.find(conversation_id);
|
||||
if (it == sessions.end()) {
|
||||
return;
|
||||
}
|
||||
s = it->second;
|
||||
sessions.erase(it);
|
||||
}
|
||||
// finalize outside the map lock so any pending readers wake up and exit
|
||||
s->finalize();
|
||||
}
|
||||
|
||||
void stream_session_manager::evict_and_cancel(const std::string & conversation_id) {
|
||||
stream_session_ptr s;
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(map_mu);
|
||||
auto it = sessions.find(conversation_id);
|
||||
if (it == sessions.end()) {
|
||||
return;
|
||||
}
|
||||
s = it->second;
|
||||
sessions.erase(it);
|
||||
}
|
||||
// signal the producer side first so the inference is cancelled at the queue level,
|
||||
// then finalize, which wakes any pending HTTP reader and lets the drain exit naturally
|
||||
s->cancel();
|
||||
s->finalize();
|
||||
}
|
||||
|
||||
void stream_session_manager::start_gc() {
|
||||
if (running.exchange(true)) {
|
||||
return;
|
||||
}
|
||||
gc_thread = std::thread([this] { gc_loop(); });
|
||||
}
|
||||
|
||||
void stream_session_manager::stop_gc() {
|
||||
bool was_running = running.exchange(false);
|
||||
if (was_running) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(gc_wake_mu);
|
||||
}
|
||||
gc_wake_cv.notify_all();
|
||||
if (gc_thread.joinable()) {
|
||||
gc_thread.join();
|
||||
}
|
||||
}
|
||||
// finalize all live sessions so no reader ever hangs
|
||||
std::vector<stream_session_ptr> snapshot;
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(map_mu);
|
||||
snapshot.reserve(sessions.size());
|
||||
for (auto & kv : sessions) {
|
||||
snapshot.push_back(kv.second);
|
||||
}
|
||||
sessions.clear();
|
||||
}
|
||||
for (auto & s : snapshot) {
|
||||
s->finalize();
|
||||
}
|
||||
}
|
||||
|
||||
void stream_session_manager::gc_loop() {
|
||||
while (running.load(std::memory_order_acquire)) {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(gc_wake_mu);
|
||||
gc_wake_cv.wait_for(lock,
|
||||
std::chrono::seconds(STREAM_SESSION_GC_INTERVAL_SECONDS),
|
||||
[this] { return !running.load(std::memory_order_acquire); });
|
||||
}
|
||||
if (!running.load(std::memory_order_acquire)) {
|
||||
return;
|
||||
}
|
||||
int64_t cutoff = now_seconds() - STREAM_SESSION_TTL_SECONDS;
|
||||
std::vector<stream_session_ptr> to_drop;
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(map_mu);
|
||||
for (auto it = sessions.begin(); it != sessions.end(); ) {
|
||||
int64_t completed = it->second->completed_at();
|
||||
if (completed != 0 && completed <= cutoff) {
|
||||
to_drop.push_back(it->second);
|
||||
it = sessions.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
// finalize outside the map lock, idempotent if the session was already done
|
||||
for (auto & s : to_drop) {
|
||||
s->finalize();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// process wide manager, lifecycle controlled by llama-server main() via start_gc/stop_gc
|
||||
stream_session_manager g_stream_sessions;
|
||||
|
||||
// stream_pipe ---------------------------------------------------------------------------------
|
||||
|
||||
stream_pipe::stream_pipe(stream_session_ptr session)
|
||||
: session_(std::move(session)) {
|
||||
}
|
||||
|
||||
bool stream_pipe::is_cancelled() const {
|
||||
return session_->is_cancelled();
|
||||
}
|
||||
|
||||
// stream_pipe_producer
|
||||
|
||||
stream_pipe_producer::stream_pipe_producer(stream_session_ptr session)
|
||||
: stream_pipe(std::move(session)) {
|
||||
}
|
||||
|
||||
stream_pipe_producer::~stream_pipe_producer() {
|
||||
cleanup();
|
||||
session_->finalize();
|
||||
}
|
||||
|
||||
void stream_pipe_producer::cleanup() {
|
||||
if (!alive_) {
|
||||
return;
|
||||
}
|
||||
alive_->store(false, std::memory_order_release);
|
||||
session_->set_stop_producer(nullptr);
|
||||
alive_.reset();
|
||||
}
|
||||
|
||||
bool stream_pipe_producer::write(const char * data, size_t len) {
|
||||
return session_->append(data, len);
|
||||
}
|
||||
|
||||
void stream_pipe_producer::done() {
|
||||
done_ = true;
|
||||
}
|
||||
|
||||
void stream_pipe_producer::close() {
|
||||
// httplib bails its content provider the moment is_peer_alive() goes false, so pump the rest
|
||||
// of the generation into the ring buffer here. a DELETE flips is_cancelled and cuts it short
|
||||
if (done_ || session_->is_cancelled()) {
|
||||
SRV_INF("stream_pipe close: skip drain (done=%d cancelled=%d) conv=%s\n",
|
||||
done_ ? 1 : 0, session_->is_cancelled() ? 1 : 0, session_->conversation_id.c_str());
|
||||
return;
|
||||
}
|
||||
SRV_INF("stream_pipe close: draining conv=%s\n", session_->conversation_id.c_str());
|
||||
size_t drained = 0;
|
||||
std::string chunk;
|
||||
while (true) {
|
||||
chunk.clear();
|
||||
bool has_next = res_->next(chunk);
|
||||
if (!chunk.empty()) {
|
||||
write(chunk.data(), chunk.size());
|
||||
drained += chunk.size();
|
||||
}
|
||||
if (!has_next) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
SRV_INF("stream_pipe close: drain ended conv=%s bytes=%zu\n", session_->conversation_id.c_str(), drained);
|
||||
}
|
||||
|
||||
std::shared_ptr<stream_pipe_producer> stream_pipe_producer::create(stream_session_ptr session,
|
||||
server_http_res & res) {
|
||||
auto alive = std::make_shared<std::atomic<bool>>(true);
|
||||
auto * res_ptr = &res;
|
||||
session->set_stop_producer([alive, res_ptr]() {
|
||||
if (alive->load(std::memory_order_acquire)) {
|
||||
res_ptr->stop();
|
||||
}
|
||||
});
|
||||
auto pipe = std::shared_ptr<stream_pipe_producer>(new stream_pipe_producer(std::move(session)));
|
||||
pipe->alive_ = std::move(alive);
|
||||
pipe->res_ = res_ptr;
|
||||
return pipe;
|
||||
}
|
||||
|
||||
// stream_pipe_consumer
|
||||
|
||||
stream_pipe_consumer::stream_pipe_consumer(stream_session_ptr session)
|
||||
: stream_pipe(std::move(session)) {
|
||||
}
|
||||
|
||||
stream_read_status stream_pipe_consumer::read(size_t & offset,
|
||||
const std::function<bool(const char *, size_t)> & sink,
|
||||
const std::function<bool()> & should_stop) {
|
||||
return session_->read_from(offset, sink, should_stop);
|
||||
}
|
||||
|
||||
std::shared_ptr<stream_pipe_consumer> stream_pipe_consumer::create(stream_session_ptr session) {
|
||||
return std::shared_ptr<stream_pipe_consumer>(new stream_pipe_consumer(std::move(session)));
|
||||
}
|
||||
|
||||
// helper, builds the standard error response and assigns it to a brand new http_res
|
||||
static server_http_res_ptr make_error_response(int status, const std::string & message, error_type type) {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
json err = format_error_response(message, type);
|
||||
res->status = json_value(err, "code", status);
|
||||
res->content_type = "application/json; charset=utf-8";
|
||||
res->data = safe_json_to_str({{"error", err}});
|
||||
return res;
|
||||
}
|
||||
|
||||
server_http_context::handler_t make_stream_get_handler() {
|
||||
return [](const server_http_req & req) -> server_http_res_ptr {
|
||||
// GET /v1/stream/<conv_id>?from=N replays the SSE bytes already buffered for the
|
||||
// session, blocks for more bytes when the session is still running, returns when
|
||||
// the session is finalized. the body is streamed back as text/event-stream so the
|
||||
// browser EventSource can attach to it like a fresh request
|
||||
std::string conv_id = req.get_param("conv_id");
|
||||
if (conv_id.empty()) {
|
||||
return make_error_response(400, "Missing conversation id in path", ERROR_TYPE_INVALID_REQUEST);
|
||||
}
|
||||
auto session = g_stream_sessions.get(conv_id);
|
||||
if (!session) {
|
||||
return make_error_response(404, "Stream not found or expired", ERROR_TYPE_NOT_FOUND);
|
||||
}
|
||||
size_t from = 0;
|
||||
std::string from_str = req.get_param("from");
|
||||
if (!from_str.empty()) {
|
||||
try {
|
||||
from = static_cast<size_t>(std::stoull(from_str));
|
||||
} catch (const std::exception &) {
|
||||
return make_error_response(400, "Invalid 'from' offset", ERROR_TYPE_INVALID_REQUEST);
|
||||
}
|
||||
}
|
||||
if (from < session->dropped_prefix()) {
|
||||
return make_error_response(400, "Stream offset lost, please restart", ERROR_TYPE_INVALID_REQUEST);
|
||||
}
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->status = 200;
|
||||
res->content_type = "text/event-stream";
|
||||
// the next closure reads from the ring buffer at the requested offset, blocks until
|
||||
// bytes arrive or the session finalizes. exit each call after draining the available
|
||||
// chunk so set_chunked_content_provider gets a chance to flush to the socket
|
||||
auto offset_ptr = std::make_shared<size_t>(from);
|
||||
// consumer pipe: read-only, does not finalize the session on destruction
|
||||
auto pipe = stream_pipe_consumer::create(session);
|
||||
res->next = [pipe, offset_ptr, &req](std::string & output) -> bool {
|
||||
bool got_any = false;
|
||||
pipe->read(*offset_ptr,
|
||||
[&](const char * d, size_t n) {
|
||||
output.append(d, n);
|
||||
*offset_ptr += n;
|
||||
got_any = true;
|
||||
return false;
|
||||
},
|
||||
req.should_stop);
|
||||
return got_any;
|
||||
};
|
||||
return res;
|
||||
};
|
||||
}
|
||||
|
||||
server_http_context::handler_t make_streams_lookup_handler() {
|
||||
return [](const server_http_req & req) -> server_http_res_ptr {
|
||||
// POST /v1/streams/lookup with body {"conversation_ids": ["X", "Y", ...]} returns the
|
||||
// matching sessions, only for ids the caller already knows. each id matches the exact key
|
||||
// and any "<id>::<model>" variant, so one lookup covers every per model session for a conv
|
||||
std::vector<std::string> requested;
|
||||
try {
|
||||
json body = json::parse(req.body);
|
||||
if (body.contains("conversation_ids") && body["conversation_ids"].is_array()) {
|
||||
for (const auto & v : body["conversation_ids"]) {
|
||||
if (v.is_string()) {
|
||||
std::string id = v.get<std::string>();
|
||||
if (!id.empty()) {
|
||||
requested.push_back(std::move(id));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->status = 400;
|
||||
res->content_type = "application/json; charset=utf-8";
|
||||
res->data = safe_json_to_str({{"error", {{"message", std::string("invalid body: ") + e.what()},
|
||||
{"type", "invalid_request_error"}}}});
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<stream_session_ptr> sessions;
|
||||
if (!requested.empty()) {
|
||||
auto all = g_stream_sessions.list_all();
|
||||
for (const auto & rid : requested) {
|
||||
const std::string with_sep = rid + "::";
|
||||
for (auto & s : all) {
|
||||
if (s->conversation_id == rid ||
|
||||
s->conversation_id.compare(0, with_sep.size(), with_sep) == 0) {
|
||||
sessions.push_back(s);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
json arr = json::array();
|
||||
for (auto & s : sessions) {
|
||||
arr.push_back({
|
||||
{"conversation_id", s->conversation_id},
|
||||
{"is_done", s->is_done()},
|
||||
{"total_bytes", s->total_size()},
|
||||
{"started_at", s->started_ts},
|
||||
{"completed_at", s->completed_at()},
|
||||
});
|
||||
}
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->status = 200;
|
||||
res->content_type = "application/json; charset=utf-8";
|
||||
res->data = safe_json_to_str(arr);
|
||||
return res;
|
||||
};
|
||||
}
|
||||
|
||||
server_http_context::handler_t make_stream_delete_handler() {
|
||||
return [](const server_http_req & req) -> server_http_res_ptr {
|
||||
// DELETE /v1/stream/<conv_id> is the explicit user Stop, cancels the producer hook
|
||||
// wired by handle_completions_impl and evicts the buffer. idempotent, a session that
|
||||
// already finalized or was never created returns 204 either way
|
||||
std::string conv_id = req.get_param("conv_id");
|
||||
if (conv_id.empty()) {
|
||||
return make_error_response(400, "Missing conversation id in path", ERROR_TYPE_INVALID_REQUEST);
|
||||
}
|
||||
SRV_INF("DELETE /v1/stream/%s -> evict_and_cancel\n", conv_id.c_str());
|
||||
g_stream_sessions.evict_and_cancel(conv_id);
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->status = 204;
|
||||
res->content_type = "application/json";
|
||||
return res;
|
||||
};
|
||||
}
|
||||
|
||||
std::string stream_conv_id_from_headers(const std::map<std::string, std::string> & headers) {
|
||||
// case-insensitive scan for x-conversation-id
|
||||
static constexpr char target[] = "x-conversation-id";
|
||||
static constexpr size_t target_len = sizeof(target) - 1;
|
||||
for (const auto & [hk, hv] : headers) {
|
||||
if (hk.size() != target_len) continue;
|
||||
bool match = true;
|
||||
for (size_t i = 0; i < target_len; ++i) {
|
||||
char c = hk[i];
|
||||
if (c >= 'A' && c <= 'Z') c = char(c + 32);
|
||||
if (c != target[i]) { match = false; break; }
|
||||
}
|
||||
if (match) {
|
||||
return hv;
|
||||
}
|
||||
}
|
||||
return std::string();
|
||||
}
|
||||
|
||||
void stream_session_attach_pipe(server_http_res & res, const std::map<std::string, std::string> & headers) {
|
||||
std::string conversation_id = stream_conv_id_from_headers(headers);
|
||||
SRV_INF("stream_session_attach_pipe: conv_id=%s (empty=%d)\n",
|
||||
conversation_id.c_str(), conversation_id.empty() ? 1 : 0);
|
||||
if (conversation_id.empty()) {
|
||||
return;
|
||||
}
|
||||
auto session = g_stream_sessions.create_or_replace(conversation_id);
|
||||
res.spipe = stream_pipe_producer::create(session, res);
|
||||
}
|
||||
|
||||
std::function<bool()> stream_aware_should_stop(server_http_res * res, std::function<bool()> fallback) {
|
||||
return [res, fallback = std::move(fallback)]() -> bool {
|
||||
if (res->spipe) {
|
||||
return res->spipe->is_cancelled();
|
||||
}
|
||||
return fallback();
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
#pragma once
|
||||
|
||||
#include "server-http.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
enum class stream_read_status {
|
||||
OK,
|
||||
OFFSET_LOST,
|
||||
};
|
||||
|
||||
// streaming buffer for one generation, survives HTTP disconnect. the producer appends raw SSE
|
||||
// bytes, readers drain from any offset via read_from and block until more bytes or finalize.
|
||||
// keyed by conversation_id: one conv = at most one live session
|
||||
struct stream_session {
|
||||
std::string conversation_id;
|
||||
int64_t started_ts; // unix seconds at construction, used by /v1/streams listing
|
||||
|
||||
stream_session(std::string conversation_id_, size_t max_bytes_);
|
||||
stream_session(const stream_session &) = delete;
|
||||
stream_session & operator=(const stream_session &) = delete;
|
||||
|
||||
// append raw bytes, drops from the front if the cap is reached.
|
||||
// returns false if the session is already finalized
|
||||
bool append(const char * data, size_t len);
|
||||
|
||||
// mark the session as complete, wakes all pending readers
|
||||
void finalize();
|
||||
|
||||
// drain bytes from offset, calling sink for each chunk. blocks until more
|
||||
// bytes arrive or finalize is called. returns OK on clean exit, OFFSET_LOST
|
||||
// if offset falls below the dropped prefix
|
||||
stream_read_status read_from(size_t offset,
|
||||
const std::function<bool(const char *, size_t)> & sink,
|
||||
const std::function<bool()> & should_stop);
|
||||
|
||||
bool is_done() const;
|
||||
bool is_cancelled() const;
|
||||
size_t total_size() const; // bytes that ever entered the session
|
||||
size_t dropped_prefix() const; // bytes evicted from the front due to cap
|
||||
int64_t completed_at() const; // 0 while alive, unix seconds after finalize
|
||||
|
||||
// attach the producer stop hook used to cancel its reader, pass an empty function to detach
|
||||
void set_stop_producer(std::function<void()> fn);
|
||||
|
||||
// signal the producer to abort its inference asap via the stop hook, idempotent
|
||||
void cancel();
|
||||
|
||||
private:
|
||||
mutable std::mutex mu;
|
||||
std::condition_variable cv;
|
||||
std::vector<char> buffer;
|
||||
size_t prefix_dropped;
|
||||
size_t cap_bytes;
|
||||
std::atomic<bool> done;
|
||||
std::atomic<bool> cancelled;
|
||||
std::atomic<int64_t> completed_ts;
|
||||
std::function<void()> stop_producer; // protected by mu
|
||||
};
|
||||
|
||||
using stream_session_ptr = std::shared_ptr<stream_session>;
|
||||
|
||||
// one end of a stream_session pipe. the base holds the session and the shared query, the
|
||||
// producer and consumer ends derive from it. virtual dtor so each end runs its own teardown:
|
||||
// the producer finalizes the session, the consumer leaves it untouched
|
||||
struct stream_pipe {
|
||||
virtual ~stream_pipe() = default;
|
||||
|
||||
// true if the session was cancelled (e.g. via DELETE /v1/stream/<conv_id>)
|
||||
bool is_cancelled() const;
|
||||
|
||||
protected:
|
||||
explicit stream_pipe(stream_session_ptr session);
|
||||
|
||||
stream_session_ptr session_;
|
||||
};
|
||||
|
||||
// producer end: writes chunks into the ring buffer and owns the session lifetime, finalizing it
|
||||
// on destruction.
|
||||
//
|
||||
// lifetime safety: holds a shared_ptr<atomic<bool>> alive also captured by the session's
|
||||
// stop_producer hook. cleanup() sets alive=false and clears the hook; it must run while the
|
||||
// response the hook calls stop() on is still alive. ~server_res_generator() does this explicitly.
|
||||
struct stream_pipe_producer : stream_pipe {
|
||||
~stream_pipe_producer() override;
|
||||
|
||||
// append raw bytes to the session's ring buffer, returns false if already finalized
|
||||
bool write(const char * data, size_t len);
|
||||
|
||||
// mark the natural end on the wire so a later close() is a no-op
|
||||
void done();
|
||||
|
||||
// on a peer drop, pump the response next() into the ring buffer until done. runs on the http
|
||||
// worker from on_complete, no-op after done() or cancel
|
||||
void close();
|
||||
|
||||
// disarm the stop hook and drop the alive guard, must run while the response the hook
|
||||
// references is still alive. idempotent, the destructor calls it too
|
||||
void cleanup();
|
||||
|
||||
// res.stop() is invoked when the session is cancelled, the alive guard ensures stop() is not
|
||||
// called after cleanup() has run
|
||||
static std::shared_ptr<stream_pipe_producer> create(stream_session_ptr session, server_http_res & res);
|
||||
|
||||
private:
|
||||
explicit stream_pipe_producer(stream_session_ptr session);
|
||||
|
||||
bool done_ = false;
|
||||
std::shared_ptr<std::atomic<bool>> alive_;
|
||||
server_http_res * res_ = nullptr;
|
||||
};
|
||||
|
||||
// consumer end: read-only replay of the ring buffer, the destructor does not finalize the session
|
||||
struct stream_pipe_consumer : stream_pipe {
|
||||
// drain bytes from offset, calling sink for each available chunk. blocks until more data
|
||||
// arrives or the session finalizes. should_stop is polled, returns OFFSET_LOST if offset
|
||||
// fell below the dropped prefix
|
||||
stream_read_status read(size_t & offset,
|
||||
const std::function<bool(const char *, size_t)> & sink,
|
||||
const std::function<bool()> & should_stop);
|
||||
|
||||
static std::shared_ptr<stream_pipe_consumer> create(stream_session_ptr session);
|
||||
|
||||
private:
|
||||
explicit stream_pipe_consumer(stream_session_ptr session);
|
||||
};
|
||||
|
||||
// owns all live sessions, runs a periodic GC to evict expired ones.
|
||||
// the map is keyed by conversation_id, so the invariant "one conv = at most one
|
||||
// live session" is enforced at the type level
|
||||
class stream_session_manager {
|
||||
public:
|
||||
stream_session_manager();
|
||||
~stream_session_manager();
|
||||
|
||||
stream_session_manager(const stream_session_manager &) = delete;
|
||||
stream_session_manager & operator=(const stream_session_manager &) = delete;
|
||||
|
||||
// install a new session for this conversation, evicting and cancelling any previous one.
|
||||
// the conversation_id must be non empty, the caller is responsible for that check.
|
||||
// returns the new session
|
||||
stream_session_ptr create_or_replace(const std::string & conversation_id);
|
||||
|
||||
// lookup, returns null if unknown or already evicted
|
||||
stream_session_ptr get(const std::string & conversation_id);
|
||||
|
||||
// list every live or recently completed session, used by GET /v1/streams without filter
|
||||
std::vector<stream_session_ptr> list_all() const;
|
||||
|
||||
// remove from the map and finalize, wakes any pending readers
|
||||
void evict(const std::string & conversation_id);
|
||||
|
||||
// signal the producer to cancel asap then evict, used by the explicit user Stop path
|
||||
void evict_and_cancel(const std::string & conversation_id);
|
||||
|
||||
void start_gc();
|
||||
void stop_gc();
|
||||
|
||||
private:
|
||||
void gc_loop();
|
||||
|
||||
mutable std::shared_mutex map_mu;
|
||||
std::unordered_map<std::string, stream_session_ptr> sessions; // key: conversation_id
|
||||
std::thread gc_thread;
|
||||
std::atomic<bool> running;
|
||||
std::mutex gc_wake_mu;
|
||||
std::condition_variable gc_wake_cv;
|
||||
};
|
||||
|
||||
// process wide manager, linked by both llama-server and llama-cli. llama-server main() drives
|
||||
// start_gc/stop_gc, llama-cli leaves it idle. the dtor calls stop_gc() unconditionally so exit
|
||||
// is safe whether or not the GC thread ran
|
||||
extern stream_session_manager g_stream_sessions;
|
||||
|
||||
// route handler factories operating on g_stream_sessions, wired under /v1/stream/* by server.cpp.
|
||||
// keeps the resumable stream surface confined to server-stream
|
||||
server_http_context::handler_t make_stream_get_handler();
|
||||
server_http_context::handler_t make_streams_lookup_handler();
|
||||
server_http_context::handler_t make_stream_delete_handler();
|
||||
|
||||
// extract the X-Conversation-Id header value (case-insensitive), empty when absent. exposed so
|
||||
// the router can track which child serves a forwarded POST
|
||||
std::string stream_conv_id_from_headers(const std::map<std::string, std::string> & headers);
|
||||
|
||||
// on an X-Conversation-Id header, create or replace the session and attach a producer pipe to
|
||||
// res. no-op when absent, called from the server_res_generator constructor
|
||||
void stream_session_attach_pipe(server_http_res & res, const std::map<std::string, std::string> & headers);
|
||||
|
||||
// should_stop closure that ignores peer disconnect when a pipe is attached, so only an explicit
|
||||
// DELETE stops the producer and generation keeps flowing into the ring buffer. without a pipe it
|
||||
// delegates to fallback, the legacy non-resumable flow
|
||||
std::function<bool()> stream_aware_should_stop(server_http_res * res, std::function<bool()> fallback);
|
||||
+67
-9
@@ -2,6 +2,7 @@
|
||||
#include "server-http.h"
|
||||
#include "server-models.h"
|
||||
#include "server-cors-proxy.h"
|
||||
#include "server-stream.h"
|
||||
#include "server-tools.h"
|
||||
|
||||
#include "arg.h"
|
||||
@@ -82,6 +83,10 @@ int llama_server(int argc, char ** argv) {
|
||||
|
||||
common_init();
|
||||
|
||||
// start the stream session manager GC right after common init, before any HTTP route can
|
||||
// touch it. lifecycle is symmetric, stop_gc() runs in clean_up() before backend free
|
||||
g_stream_sessions.start_gc();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
|
||||
return 1;
|
||||
}
|
||||
@@ -89,15 +94,16 @@ int llama_server(int argc, char ** argv) {
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
// note: router mode also accepts -hf remote-preset, so we need to check that first
|
||||
if (!params.model.hf_repo.empty()) {
|
||||
try {
|
||||
common_params_handle_models_params handle_params;
|
||||
handle_params.preset_only = true;
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, handle_params);
|
||||
} catch (const std::exception & e) {
|
||||
// ignored for now
|
||||
common_models_handler models_handler;
|
||||
try {
|
||||
models_handler = common_models_handler_init(params, LLAMA_EXAMPLE_SERVER);
|
||||
if (common_models_handler_is_preset_repo(models_handler)) {
|
||||
// apply the preset and start the server in router mode
|
||||
common_models_handler_apply(models_handler, params);
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("failed to fetch model metadata: %s\n", e.what());
|
||||
return 1;
|
||||
}
|
||||
|
||||
// router server never loads a model and must not touch the GPU
|
||||
@@ -238,9 +244,45 @@ int llama_server(int argc, char ** argv) {
|
||||
ctx_http.get ("/slots", ex_wrapper(routes.get_slots));
|
||||
ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots));
|
||||
|
||||
// resumable streaming, the conversation_id is the session identity end to end. router and
|
||||
// child wire different handlers under the same paths: a child binds the local g_stream_sessions
|
||||
// backed factories, the router binds proxies that resolve the owning child through the
|
||||
// conv_id -> model map
|
||||
server_http_context::handler_t stream_get_h;
|
||||
server_http_context::handler_t streams_lookup_h;
|
||||
server_http_context::handler_t stream_delete_h;
|
||||
if (is_router_server) {
|
||||
stream_get_h = models_routes->router_stream_get;
|
||||
streams_lookup_h = models_routes->router_streams_lookup;
|
||||
stream_delete_h = models_routes->router_stream_delete;
|
||||
} else {
|
||||
stream_get_h = make_stream_get_handler();
|
||||
streams_lookup_h = make_streams_lookup_handler();
|
||||
stream_delete_h = make_stream_delete_handler();
|
||||
}
|
||||
ctx_http.get ("/v1/stream/:conv_id", ex_wrapper(stream_get_h));
|
||||
// POST /v1/streams/lookup with body {"conversation_ids": [...]}. you can only ask for ids
|
||||
// you already own (the WebUI passes the convs visible in its sidebar). the server never
|
||||
// lists ids it has not been asked about, so a random caller cannot enumerate live sessions
|
||||
ctx_http.post("/v1/streams/lookup", ex_wrapper(streams_lookup_h));
|
||||
ctx_http.del ("/v1/stream/:conv_id", ex_wrapper(stream_delete_h));
|
||||
|
||||
// Google Cloud Platform (Vertex AI) compat
|
||||
ctx_http.register_gcp_compat();
|
||||
|
||||
// return 403 for disabled features
|
||||
server_http_context::handler_t res_403 = [](const server_http_req &) {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->status = 403;
|
||||
res->data = safe_json_to_str({
|
||||
{"error", {
|
||||
{"message", "this feature is disabled"},
|
||||
{"type", "feature_disabled"},
|
||||
}}
|
||||
});
|
||||
return res;
|
||||
};
|
||||
|
||||
// CORS proxy (EXPERIMENTAL, only used by the Web UI for MCP)
|
||||
if (params.ui_mcp_proxy) {
|
||||
SRV_WRN("%s", "-----------------\n");
|
||||
@@ -249,7 +291,11 @@ int llama_server(int argc, char ** argv) {
|
||||
SRV_WRN("%s", "-----------------\n");
|
||||
ctx_http.get ("/cors-proxy", ex_wrapper(proxy_handler_get));
|
||||
ctx_http.post("/cors-proxy", ex_wrapper(proxy_handler_post));
|
||||
} else {
|
||||
ctx_http.get ("/cors-proxy", ex_wrapper(res_403));
|
||||
ctx_http.post("/cors-proxy", ex_wrapper(res_403));
|
||||
}
|
||||
|
||||
// EXPERIMENTAL built-in tools
|
||||
if (!params.server_tools.empty()) {
|
||||
try {
|
||||
@@ -264,6 +310,9 @@ int llama_server(int argc, char ** argv) {
|
||||
SRV_WRN("%s", "-----------------\n");
|
||||
ctx_http.get ("/tools", ex_wrapper(tools.handle_get));
|
||||
ctx_http.post("/tools", ex_wrapper(tools.handle_post));
|
||||
} else {
|
||||
ctx_http.get ("/tools", ex_wrapper(res_403));
|
||||
ctx_http.post("/tools", ex_wrapper(res_403));
|
||||
}
|
||||
|
||||
//
|
||||
@@ -274,7 +323,12 @@ int llama_server(int argc, char ** argv) {
|
||||
return child.run_download(params);
|
||||
} else if (!is_router_server) {
|
||||
// single-model mode (NOT spawned by router)
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {});
|
||||
try {
|
||||
common_models_handler_apply(models_handler, params);
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("failed to download model: %s\n", e.what());
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
@@ -288,6 +342,8 @@ int llama_server(int argc, char ** argv) {
|
||||
|
||||
clean_up = [&models_routes]() {
|
||||
SRV_INF("%s: cleaning up before exit...\n", __func__);
|
||||
// stop the session GC first, it finalizes live sessions and wakes pending readers
|
||||
g_stream_sessions.stop_gc();
|
||||
if (models_routes.has_value()) {
|
||||
models_routes->stopping.store(true); // maybe redundant, but just to be safe
|
||||
models_routes->models.unload_all();
|
||||
@@ -314,6 +370,8 @@ int llama_server(int argc, char ** argv) {
|
||||
// setup clean up function, to be called before exit
|
||||
clean_up = [&ctx_http, &ctx_server]() {
|
||||
SRV_INF("%s: cleaning up before exit...\n", __func__);
|
||||
// stop the session GC first, it finalizes live sessions and wakes pending readers
|
||||
g_stream_sessions.stop_gc();
|
||||
ctx_http.stop();
|
||||
ctx_server.terminate();
|
||||
llama_backend_free();
|
||||
|
||||
@@ -16,7 +16,7 @@ def test_mcp_no_proxy():
|
||||
server.start()
|
||||
|
||||
res = server.make_request("GET", "/cors-proxy")
|
||||
assert res.status_code == 404
|
||||
assert res.status_code == 403
|
||||
|
||||
|
||||
def test_mcp_proxy():
|
||||
|
||||
+1
-1
@@ -33,7 +33,7 @@
|
||||
|
||||
{#if !readonly && onRemove}
|
||||
<div
|
||||
class="absolute top-10 right-2 flex items-center justify-center opacity-0 transition-opacity group-hover:opacity-100"
|
||||
class="absolute top-10 right-2 flex items-center justify-center opacity-0 transition-opacity group-focus-within:opacity-100 group-hover:opacity-100"
|
||||
>
|
||||
<ActionIcon icon={X} tooltip="Remove" stopPropagationOnClick onclick={() => onRemove?.()} />
|
||||
</div>
|
||||
|
||||
+1
-1
@@ -56,7 +56,7 @@
|
||||
<div class="relative flex h-6 items-center justify-between">
|
||||
<div class="right-0 flex items-center gap-2 opacity-100 transition-opacity">
|
||||
<div
|
||||
class="pointer-events-auto inset-0 flex items-center gap-1 opacity-0 transition-all duration-150 group-hover:opacity-100"
|
||||
class="pointer-events-auto inset-0 flex items-center gap-1 opacity-0 transition-all duration-150 group-focus-within:opacity-100 group-hover:opacity-100"
|
||||
>
|
||||
<ActionIcon icon={Edit} tooltip="Edit" onclick={editCtx.handleEdit} />
|
||||
<ActionIcon icon={Trash2} tooltip="Delete" onclick={onDelete} />
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
ChatMessages,
|
||||
ChatScreenDragOverlay,
|
||||
ChatScreenProcessingInfo,
|
||||
ChatScreenStreamResumeStatus,
|
||||
ServerLoadingSplash,
|
||||
ChatScreenServerError
|
||||
} from '$lib/components/app';
|
||||
@@ -281,6 +282,10 @@
|
||||
|
||||
<ChatScreenServerError />
|
||||
|
||||
{#if page.params.id}
|
||||
<ChatScreenStreamResumeStatus />
|
||||
{/if}
|
||||
|
||||
<div class="pointer-events-none flex flex-col gap-6 items-center w-full">
|
||||
{#if (isMobile.current ? mobileScrollDownHint || isMobileUserScrolledUp : autoScroll.userScrolledUp) && page.url.hash.includes(ROUTES.CHAT) && page.params.id}
|
||||
<ChatScreenActionScrollDown
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
<script lang="ts">
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { StreamConnectionState } from '$lib/enums';
|
||||
import { Loader2 } from '@lucide/svelte';
|
||||
|
||||
let state = $derived(chatStore.streamConnectionState);
|
||||
</script>
|
||||
|
||||
{#if state === StreamConnectionState.RESUMING}
|
||||
<div
|
||||
class="pointer-events-auto mx-auto mt-2 mb-2 flex max-w-[48rem] items-center gap-2 rounded-md border border-blue-400/40 bg-blue-50/60 px-3 py-1.5 text-sm text-blue-700 dark:bg-blue-950/40 dark:text-blue-200"
|
||||
role="status"
|
||||
aria-live="polite"
|
||||
>
|
||||
<Loader2 class="h-3.5 w-3.5 animate-spin" />
|
||||
<span>Reconnecting to the stream...</span>
|
||||
</div>
|
||||
{/if}
|
||||
@@ -683,3 +683,11 @@ export { default as ChatScreenProcessingInfo } from './ChatScreen/ChatScreenProc
|
||||
* Rendered inside ChatScreen when `serverError` store has a value.
|
||||
*/
|
||||
export { default as ChatScreenServerError } from './ChatScreen/ChatScreenServerError.svelte';
|
||||
|
||||
/**
|
||||
* Stream resume status indicator. Shows a small "Reconnecting to the stream..."
|
||||
* banner with a spinner while `chatStore.streamConnectionState` is `resuming`,
|
||||
* i.e. after a dropped connection is reattaching to the live SSE replay buffer.
|
||||
* Renders nothing otherwise. Shown inside ChatScreen only on an active conversation route.
|
||||
*/
|
||||
export { default as ChatScreenStreamResumeStatus } from './ChatScreen/ChatScreenStreamResumeStatus.svelte';
|
||||
|
||||
+10
-1
@@ -14,6 +14,7 @@
|
||||
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
|
||||
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { RouterService } from '$lib/services/router.service';
|
||||
import { isMobile } from '$lib/stores/viewport.svelte';
|
||||
import { TooltipSide } from '$lib/enums';
|
||||
@@ -34,6 +35,14 @@
|
||||
|
||||
const isStripExpanded = $derived(isExpandedMode || hoveredTooltip !== null);
|
||||
const isOnMobile = $derived(isMobile.current);
|
||||
const alwaysShowOnDesktop = $derived(config().alwaysShowSidebarOnDesktop as boolean);
|
||||
|
||||
// Keep the sidebar expanded on desktop when the user pins it open
|
||||
$effect(() => {
|
||||
if (alwaysShowOnDesktop && !isOnMobile) {
|
||||
isExpandedMode = true;
|
||||
}
|
||||
});
|
||||
|
||||
function toggleExpandedMode() {
|
||||
isExpandedMode = !isExpandedMode;
|
||||
@@ -183,7 +192,7 @@
|
||||
/>
|
||||
</div>
|
||||
|
||||
{#if isExpandedMode || isOnMobile}
|
||||
{#if isOnMobile || (isExpandedMode && !alwaysShowOnDesktop)}
|
||||
<div
|
||||
class="flex items-center transition-all duration-150 ease-out {isMobile.current &&
|
||||
!isExpandedMode
|
||||
|
||||
+56
-81
@@ -39,7 +39,6 @@
|
||||
depth = 0
|
||||
}: Props = $props();
|
||||
|
||||
let renderActionsDropdown = $state(false);
|
||||
let dropdownOpen = $state(false);
|
||||
|
||||
let isLoading = $derived(getAllLoadingChats().includes(conversation.id));
|
||||
@@ -71,26 +70,10 @@
|
||||
}
|
||||
}
|
||||
|
||||
function handleMouseLeave() {
|
||||
if (!dropdownOpen) {
|
||||
renderActionsDropdown = false;
|
||||
}
|
||||
}
|
||||
|
||||
function handleMouseOver() {
|
||||
renderActionsDropdown = true;
|
||||
}
|
||||
|
||||
function handleSelect() {
|
||||
onSelect?.(conversation.id);
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (!dropdownOpen) {
|
||||
renderActionsDropdown = false;
|
||||
}
|
||||
});
|
||||
|
||||
onMount(() => {
|
||||
document.addEventListener('edit-active-conversation', handleGlobalEditEvent as EventListener);
|
||||
|
||||
@@ -103,23 +86,19 @@
|
||||
});
|
||||
</script>
|
||||
|
||||
<!-- svelte-ignore a11y_mouse_events_have_key_events -->
|
||||
<button
|
||||
class="group flex min-h-9 w-full cursor-pointer items-center justify-between space-x-3 rounded-lg py-1.5 text-left transition-colors hover:bg-foreground/10 {isActive
|
||||
<div
|
||||
class="conversation-item group relative flex min-h-9 w-full items-center justify-between space-x-3 rounded-lg py-1.5 transition-colors hover:bg-foreground/10 {isActive
|
||||
? 'bg-foreground/5 text-accent-foreground'
|
||||
: ''} px-3"
|
||||
onclick={handleSelect}
|
||||
onmouseover={handleMouseOver}
|
||||
onmouseleave={handleMouseLeave}
|
||||
onfocusin={handleMouseOver}
|
||||
onfocusout={(e) => {
|
||||
if (!e.currentTarget.contains(e.relatedTarget as Node | null)) {
|
||||
handleMouseLeave();
|
||||
}
|
||||
}}
|
||||
>
|
||||
<button
|
||||
class="absolute inset-0 z-0 cursor-pointer rounded-lg focus:outline-none focus-visible:ring-2 focus-visible:ring-ring"
|
||||
onclick={handleSelect}
|
||||
aria-label={conversation.name}
|
||||
>
|
||||
</button>
|
||||
<div
|
||||
class="flex min-w-0 flex-1 items-center gap-2"
|
||||
class="pointer-events-none relative z-10 flex min-w-0 flex-1 items-center gap-2"
|
||||
style:padding-left="{depth * FORK_TREE_DEPTH_PADDING}px"
|
||||
>
|
||||
{#if depth > 0}
|
||||
@@ -130,7 +109,7 @@
|
||||
<a
|
||||
{...props}
|
||||
href={RouterService.chat(conversation.forkedFromConversationId)}
|
||||
class="flex shrink-0 items-center text-muted-foreground transition-colors hover:text-foreground"
|
||||
class="pointer-events-auto flex shrink-0 items-center text-muted-foreground transition-colors hover:text-foreground"
|
||||
>
|
||||
<GitBranch class="h-3.5 w-3.5" />
|
||||
</a>
|
||||
@@ -146,18 +125,15 @@
|
||||
{#if isLoading}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<div
|
||||
class="stop-button flex h-4 w-4 shrink-0 cursor-pointer items-center justify-center rounded text-muted-foreground transition-colors hover:text-foreground"
|
||||
<button
|
||||
class="stop-button pointer-events-auto flex h-4 w-4 shrink-0 cursor-pointer items-center justify-center rounded text-muted-foreground transition-colors hover:text-foreground"
|
||||
onclick={handleStop}
|
||||
onkeydown={(e) => e.key === 'Enter' && handleStop(e)}
|
||||
role="button"
|
||||
tabindex="0"
|
||||
aria-label="Stop generation"
|
||||
>
|
||||
<Loader2 class="loading-icon h-3.5 w-3.5 animate-spin" />
|
||||
|
||||
<Square class="stop-icon hidden h-3 w-3 fill-current text-destructive" />
|
||||
</div>
|
||||
</button>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
@@ -169,52 +145,50 @@
|
||||
<TruncatedText text={conversation.name} class="text-sm font-medium" showTooltip={false} />
|
||||
</div>
|
||||
|
||||
{#if renderActionsDropdown}
|
||||
<div class="actions flex items-center">
|
||||
<DropdownMenuActions
|
||||
triggerIcon={MoreHorizontal}
|
||||
triggerTooltip="More actions"
|
||||
bind:open={dropdownOpen}
|
||||
actions={[
|
||||
{
|
||||
icon: conversation.pinned ? PinOff : Pin,
|
||||
label: conversation.pinned ? 'Unpin' : 'Pin',
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
handleTogglePin();
|
||||
}
|
||||
},
|
||||
{
|
||||
icon: Pencil,
|
||||
label: 'Edit',
|
||||
onclick: handleEdit,
|
||||
shortcut: ['shift', 'cmd', 'e']
|
||||
},
|
||||
{
|
||||
icon: Download,
|
||||
label: 'Export',
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
conversationsStore.downloadConversation(conversation.id);
|
||||
},
|
||||
shortcut: ['shift', 'cmd', 's']
|
||||
},
|
||||
{
|
||||
icon: Trash2,
|
||||
label: 'Delete',
|
||||
onclick: handleDelete,
|
||||
variant: 'destructive',
|
||||
shortcut: ['shift', 'cmd', 'd'],
|
||||
separator: true
|
||||
<div class="actions pointer-events-auto relative z-20 flex items-center">
|
||||
<DropdownMenuActions
|
||||
triggerIcon={MoreHorizontal}
|
||||
triggerTooltip="More actions"
|
||||
bind:open={dropdownOpen}
|
||||
actions={[
|
||||
{
|
||||
icon: conversation.pinned ? PinOff : Pin,
|
||||
label: conversation.pinned ? 'Unpin' : 'Pin',
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
handleTogglePin();
|
||||
}
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
</button>
|
||||
},
|
||||
{
|
||||
icon: Pencil,
|
||||
label: 'Edit',
|
||||
onclick: handleEdit,
|
||||
shortcut: ['shift', 'cmd', 'e']
|
||||
},
|
||||
{
|
||||
icon: Download,
|
||||
label: 'Export',
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
conversationsStore.downloadConversation(conversation.id);
|
||||
},
|
||||
shortcut: ['shift', 'cmd', 's']
|
||||
},
|
||||
{
|
||||
icon: Trash2,
|
||||
label: 'Delete',
|
||||
onclick: handleDelete,
|
||||
variant: 'destructive',
|
||||
shortcut: ['shift', 'cmd', 'd'],
|
||||
separator: true
|
||||
}
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
button {
|
||||
.conversation-item {
|
||||
:global([data-slot='dropdown-menu-trigger']:not([data-state='open'])) {
|
||||
opacity: 0;
|
||||
}
|
||||
@@ -239,7 +213,8 @@
|
||||
}
|
||||
}
|
||||
|
||||
&:is(:hover) .stop-button {
|
||||
&:is(:hover) .stop-button,
|
||||
&:focus-within .stop-button {
|
||||
:global(.stop-icon) {
|
||||
display: block;
|
||||
}
|
||||
|
||||
@@ -21,5 +21,11 @@ export const API_TOOLS = {
|
||||
EXECUTE: '/tools'
|
||||
};
|
||||
|
||||
// resumable stream routes, the conv::model identity is appended as a path segment
|
||||
export const API_STREAM = {
|
||||
BASE: './v1/stream',
|
||||
LOOKUP: './v1/streams/lookup'
|
||||
};
|
||||
|
||||
/** CORS proxy endpoint path */
|
||||
export const CORS_PROXY_ENDPOINT = '/cors-proxy';
|
||||
|
||||
@@ -46,6 +46,7 @@ export * from './routes';
|
||||
export * from './sandbox';
|
||||
export * from './settings-keys';
|
||||
export * from './settings-registry';
|
||||
export * from './stream';
|
||||
export * from './supported-file-types';
|
||||
export * from './table-html-restorer';
|
||||
export * from './title-generation';
|
||||
|
||||
@@ -26,6 +26,9 @@ export const THINKING_ENABLED_DEFAULT_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.th
|
||||
export const REASONING_EFFORT_DEFAULT_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.reasoningEffortDefault`;
|
||||
export const USER_OVERRIDES_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.userOverrides`;
|
||||
|
||||
/** Key prefix for per-conversation resumable stream state, conversationId is appended */
|
||||
export const STREAM_RESUME_LOCALSTORAGE_KEY_PREFIX = `${STORAGE_APP_NAME}.streamResume.`;
|
||||
|
||||
// Deprecated old key names (kept for backward compat while users migrate)
|
||||
/** @deprecated Use {@link ALWAYS_ALLOWED_TOOLS_LOCALSTORAGE_KEY} instead */
|
||||
export const DEPRECATED_ALWAYS_ALLOWED_TOOLS_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME_DEPRECATED}.alwaysAllowedTools`;
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
// grace window after a visibilitychange before we kick a reader whose socket likely died
|
||||
// while the tab was hidden. covers brief background pauses without thrashing live streams
|
||||
export const STREAM_VISIBILITY_KICK_MS = 1000;
|
||||
@@ -5,6 +5,15 @@ export enum ChatMessageStatsView {
|
||||
SUMMARY = 'summary'
|
||||
}
|
||||
|
||||
/**
|
||||
* Connection state of a streamed completion, drives the resume status indicator.
|
||||
*/
|
||||
export enum StreamConnectionState {
|
||||
STREAMING = 'streaming',
|
||||
RESUMING = 'resuming',
|
||||
LOST = 'lost'
|
||||
}
|
||||
|
||||
/**
|
||||
* Reasoning format options for API requests.
|
||||
*/
|
||||
|
||||
@@ -10,6 +10,7 @@ export { AgenticSectionType, ContinueIntentKind, ToolCallType } from './agentic.
|
||||
|
||||
export {
|
||||
ChatMessageStatsView,
|
||||
StreamConnectionState,
|
||||
ContentPartType,
|
||||
ConversationSelectionMode,
|
||||
ErrorDialogType,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { getJsonHeaders } from '$lib/utils/api-headers';
|
||||
import { getAuthHeaders, getJsonHeaders } from '$lib/utils/api-headers';
|
||||
import { formatAttachmentText } from '$lib/utils/formatters';
|
||||
import { isAbortError } from '$lib/utils/abort';
|
||||
import { streamIdentity } from '$lib/utils/stream-identity';
|
||||
import {
|
||||
ATTACHMENT_LABEL_PDF_FILE,
|
||||
ATTACHMENT_LABEL_MCP_PROMPT,
|
||||
@@ -13,7 +14,10 @@ import {
|
||||
CONTROL_ACTION,
|
||||
SSE_LINE_SEPARATOR,
|
||||
SSE_DATA_PREFIX,
|
||||
SSE_DONE_MARKER
|
||||
SSE_DONE_MARKER,
|
||||
STREAM_VISIBILITY_KICK_MS,
|
||||
STREAM_RESUME_LOCALSTORAGE_KEY_PREFIX,
|
||||
API_STREAM
|
||||
} from '$lib/constants';
|
||||
import {
|
||||
AttachmentType,
|
||||
@@ -21,12 +25,14 @@ import {
|
||||
FileTypeAudio,
|
||||
MessageRole,
|
||||
MimeTypeAudio,
|
||||
ReasoningFormat
|
||||
ReasoningFormat,
|
||||
StreamConnectionState
|
||||
} from '$lib/enums';
|
||||
import type {
|
||||
ApiChatMessageContentPart,
|
||||
ApiChatMessageData,
|
||||
ApiChatCompletionToolCall
|
||||
ApiChatCompletionToolCall,
|
||||
ApiStreamSession
|
||||
} from '$lib/types/api';
|
||||
import type {
|
||||
AudioInputFormat,
|
||||
@@ -54,6 +60,19 @@ function getAudioInputFormat(mimeType: string): AudioInputFormat {
|
||||
return FileTypeAudio.MP3;
|
||||
}
|
||||
|
||||
interface ResumableStreamState {
|
||||
bytesReceived: number;
|
||||
updatedAt: number;
|
||||
|
||||
// model frozen at POST time, lets a reload rebuild the exact conv::model identity the
|
||||
// server keyed the session under. null when the POST carried no explicit model
|
||||
model?: string | null;
|
||||
}
|
||||
|
||||
function streamStorageKey(conversationId: string): string {
|
||||
return STREAM_RESUME_LOCALSTORAGE_KEY_PREFIX + conversationId;
|
||||
}
|
||||
|
||||
export class ChatService {
|
||||
/**
|
||||
*
|
||||
@@ -128,6 +147,7 @@ export class ChatService {
|
||||
onChunk,
|
||||
onComplete,
|
||||
onError,
|
||||
onConnectionState,
|
||||
onReasoningChunk,
|
||||
onToolCallChunk,
|
||||
onModel,
|
||||
@@ -312,9 +332,16 @@ export class ChatService {
|
||||
}
|
||||
|
||||
try {
|
||||
const headers: Record<string, string> = { ...getJsonHeaders() };
|
||||
// tag streaming requests with the conversation id, this single header is the opt in for the
|
||||
// server side replay buffer and powers discoverActiveStream on tab reopen. with an explicit
|
||||
// model the ::model suffix keeps the per model session distinct
|
||||
if (stream && conversationId) {
|
||||
headers['X-Conversation-Id'] = streamIdentity(conversationId, options.model);
|
||||
}
|
||||
const response = await fetch(API_CHAT.COMPLETIONS, {
|
||||
method: 'POST',
|
||||
headers: getJsonHeaders(),
|
||||
headers,
|
||||
body: JSON.stringify(requestBody),
|
||||
signal
|
||||
});
|
||||
@@ -341,7 +368,9 @@ export class ChatService {
|
||||
onCompletionId,
|
||||
onTimings,
|
||||
conversationId,
|
||||
signal
|
||||
signal,
|
||||
onConnectionState,
|
||||
options.model
|
||||
);
|
||||
|
||||
return;
|
||||
@@ -473,6 +502,116 @@ export class ChatService {
|
||||
* @param excludeReasoning - Whether to strip reasoning content (should match excludeReasoningFromContext setting)
|
||||
* @param signal - Optional AbortSignal to cancel the pre-encode request
|
||||
*/
|
||||
static async cancelServerStream(conversationId: string, model?: string | null): Promise<void> {
|
||||
if (!conversationId) return;
|
||||
try {
|
||||
const id = streamIdentity(conversationId, model);
|
||||
await fetch(`${API_STREAM.BASE}/${encodeURIComponent(id)}`, {
|
||||
method: 'DELETE',
|
||||
headers: getAuthHeaders()
|
||||
});
|
||||
} catch (e) {
|
||||
console.warn('cancelServerStream failed:', e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Pick the running session to splice into when discoverActiveStream lists candidates for a
|
||||
* conversation. Finalized sessions are not candidates: their final content was already written
|
||||
* to the DB by the original onComplete handler, so attaching to them would replay a buffer that
|
||||
* may not match what the DB holds. A continue session's buffer holds only the appended deltas,
|
||||
* not the pre continue prefix, so replaying it as a fresh generation would erase the original.
|
||||
*
|
||||
* Among running sessions we tie break on the most recent started_at, which covers the case of
|
||||
* multiple inferences left running on the same conversation.
|
||||
*/
|
||||
static selectActiveStream(
|
||||
sessions: ApiStreamSession[] | null | undefined
|
||||
): ApiStreamSession | null {
|
||||
if (!Array.isArray(sessions) || sessions.length === 0) {
|
||||
return null;
|
||||
}
|
||||
const running = sessions.filter((s) => !s.is_done);
|
||||
if (running.length === 0) {
|
||||
return null;
|
||||
}
|
||||
return running.reduce((best, cur) => (cur.started_at > best.started_at ? cur : best));
|
||||
}
|
||||
|
||||
// persist the running byte count and the frozen model for a conversation, a later visit
|
||||
// resumes the SSE replay at the right offset under the same conv::model identity
|
||||
static saveStreamState(
|
||||
conversationId: string,
|
||||
bytesReceived: number,
|
||||
model?: string | null
|
||||
): void {
|
||||
if (!conversationId) return;
|
||||
try {
|
||||
const state: ResumableStreamState = {
|
||||
bytesReceived,
|
||||
updatedAt: Date.now(),
|
||||
model: model ?? null
|
||||
};
|
||||
localStorage.setItem(streamStorageKey(conversationId), JSON.stringify(state));
|
||||
} catch {
|
||||
// localStorage may be full or disabled, silently ignore
|
||||
}
|
||||
}
|
||||
|
||||
static getStreamState(conversationId: string): ResumableStreamState | null {
|
||||
if (!conversationId) return null;
|
||||
try {
|
||||
const raw = localStorage.getItem(streamStorageKey(conversationId));
|
||||
if (!raw) return null;
|
||||
const parsed = JSON.parse(raw) as ResumableStreamState;
|
||||
if (!parsed || typeof parsed.bytesReceived !== 'number') return null;
|
||||
return parsed;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
static clearStreamState(conversationId: string): void {
|
||||
if (!conversationId) return;
|
||||
try {
|
||||
localStorage.removeItem(streamStorageKey(conversationId));
|
||||
} catch {
|
||||
// nothing to do
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Rebuild the stream identity for a resume. The model persisted at POST time wins, including a
|
||||
* stored null which means the POST carried no explicit model so the identity stays the bare conv
|
||||
* id. Only fall back to the caller supplied current model when nothing was persisted.
|
||||
*/
|
||||
static resumeStreamIdentity(
|
||||
conversationId: string,
|
||||
state: ResumableStreamState | null,
|
||||
fallbackModel: string | null
|
||||
): string {
|
||||
const model = state && state.model !== undefined ? state.model : fallbackModel;
|
||||
return streamIdentity(conversationId, model);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reconnect to an interrupted stream for this conversation. Returns the fetch Response so the
|
||||
* existing SSE parser drains it like a fresh stream. The server returns 200 on success, 404 if
|
||||
* no session exists for the conv_id, and 400 if the offset is below the dropped prefix.
|
||||
*/
|
||||
static async resumeStream(
|
||||
conversationId: string,
|
||||
signal?: AbortSignal,
|
||||
model?: string | null
|
||||
): Promise<Response | null> {
|
||||
if (!conversationId) return null;
|
||||
const state = ChatService.getStreamState(conversationId);
|
||||
const from = state?.bytesReceived ?? 0;
|
||||
const id = streamIdentity(conversationId, model);
|
||||
const url = `${API_STREAM.BASE}/${encodeURIComponent(id)}?from=${from}`;
|
||||
return await fetch(url, { method: 'GET', signal, headers: getAuthHeaders() });
|
||||
}
|
||||
|
||||
static async preEncode(
|
||||
messages: ApiChatMessageData[] | (DatabaseMessage & { extra?: DatabaseMessageExtra[] })[],
|
||||
model?: string | null,
|
||||
@@ -557,7 +696,7 @@ export class ChatService {
|
||||
* @returns {Promise<void>} Promise that resolves when streaming is complete
|
||||
* @throws {Error} if the stream cannot be read or parsed
|
||||
*/
|
||||
private static async handleStreamResponse(
|
||||
static async handleStreamResponse(
|
||||
response: Response,
|
||||
onChunk?: (chunk: string) => void,
|
||||
onComplete?: (
|
||||
@@ -573,15 +712,34 @@ export class ChatService {
|
||||
onCompletionId?: (id: string) => void,
|
||||
onTimings?: (timings?: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => void,
|
||||
conversationId?: string,
|
||||
abortSignal?: AbortSignal
|
||||
abortSignal?: AbortSignal,
|
||||
onConnectionState?: (state: StreamConnectionState) => void,
|
||||
streamModel?: string | null
|
||||
): Promise<void> {
|
||||
const reader = response.body?.getReader();
|
||||
let reader = response.body?.getReader();
|
||||
|
||||
if (!reader) {
|
||||
throw new Error('No response body');
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
// bytesParsed is the absolute server side buffer offset of the next byte to parse
|
||||
// segmentStartOffset is the absolute offset where the current reader started, reset on resume
|
||||
// segmentBytesRead is wire bytes read by the current reader
|
||||
let bytesParsed = 0;
|
||||
let segmentStartOffset = 0;
|
||||
let segmentBytesRead = 0;
|
||||
let lastByteAt = Date.now();
|
||||
// each resume must produce at least one byte to be retried again
|
||||
// if a resume returns 200 but yields nothing, we abandon
|
||||
// since the session has a bounded size, the total number of retries is bounded by construction
|
||||
let madeProgress = true;
|
||||
const encoder = new TextEncoder();
|
||||
if (conversationId) {
|
||||
ChatService.saveStreamState(conversationId, 0, streamModel);
|
||||
}
|
||||
onConnectionState?.(StreamConnectionState.STREAMING);
|
||||
|
||||
let decoder = new TextDecoder();
|
||||
let aggregatedContent = '';
|
||||
let fullReasoningContent = '';
|
||||
let aggregatedToolCalls: ApiChatCompletionToolCall[] = [];
|
||||
@@ -633,84 +791,180 @@ export class ChatService {
|
||||
}
|
||||
};
|
||||
|
||||
const onVisibilityChange = () => {
|
||||
if (typeof document === 'undefined') return;
|
||||
if (document.visibilityState !== 'visible') return;
|
||||
if (streamFinished) return;
|
||||
if (!conversationId) return;
|
||||
// the bytes have been quiet for too long, the OS likely killed the socket
|
||||
// kicking the reader unblocks reader.read with done=true so the outer loop can resume
|
||||
if (Date.now() - lastByteAt > STREAM_VISIBILITY_KICK_MS) {
|
||||
reader!.cancel().catch(() => {});
|
||||
}
|
||||
};
|
||||
if (typeof document !== 'undefined') {
|
||||
document.addEventListener('visibilitychange', onVisibilityChange);
|
||||
}
|
||||
|
||||
try {
|
||||
let chunk = '';
|
||||
// outer loop drives the resume cycle, swaps reader on premature end of stream
|
||||
while (true) {
|
||||
if (abortSignal?.aborted) break;
|
||||
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
if (abortSignal?.aborted) break;
|
||||
|
||||
chunk += decoder.decode(value, { stream: true });
|
||||
const lines = chunk.split(SSE_LINE_SEPARATOR);
|
||||
chunk = lines.pop() || '';
|
||||
|
||||
for (const line of lines) {
|
||||
while (true) {
|
||||
if (abortSignal?.aborted) break;
|
||||
|
||||
if (line.startsWith(SSE_DATA_PREFIX)) {
|
||||
const data = line.slice(SSE_DATA_PREFIX.length).trim();
|
||||
if (data === SSE_DONE_MARKER) {
|
||||
streamFinished = true;
|
||||
|
||||
continue;
|
||||
let done: boolean;
|
||||
let value: Uint8Array | undefined;
|
||||
try {
|
||||
const r = await reader.read();
|
||||
done = r.done;
|
||||
value = r.value;
|
||||
} catch (readErr) {
|
||||
// reader.read() rejects with TypeError when the underlying connection drops
|
||||
// instead of just resolving with done=true. treat it like done so the outer
|
||||
// loop swaps reader via the resume path
|
||||
if (isAbortError(readErr)) {
|
||||
throw readErr;
|
||||
}
|
||||
console.warn('reader.read() rejected, treating as premature end:', readErr);
|
||||
done = true;
|
||||
value = undefined;
|
||||
}
|
||||
if (done) break;
|
||||
|
||||
try {
|
||||
const parsed: ApiChatCompletionStreamChunk = JSON.parse(data);
|
||||
const choice = parsed.choices?.[0];
|
||||
const content = choice?.delta?.content;
|
||||
const reasoningContent = choice?.delta?.reasoning_content;
|
||||
const toolCalls = choice?.delta?.tool_calls;
|
||||
const timings = parsed.timings;
|
||||
const promptProgress = parsed.prompt_progress;
|
||||
if (abortSignal?.aborted) break;
|
||||
|
||||
const chunkModel = ChatService.extractModelName(parsed);
|
||||
if (chunkModel && !modelEmitted) {
|
||||
modelEmitted = true;
|
||||
onModel?.(chunkModel);
|
||||
}
|
||||
|
||||
if (parsed.id && !idEmitted) {
|
||||
idEmitted = true;
|
||||
onCompletionId?.(parsed.id);
|
||||
}
|
||||
|
||||
if (promptProgress) {
|
||||
ChatService.notifyTimings(undefined, promptProgress, onTimings);
|
||||
}
|
||||
|
||||
if (timings) {
|
||||
ChatService.notifyTimings(timings, promptProgress, onTimings);
|
||||
lastTimings = timings;
|
||||
}
|
||||
|
||||
if (content) {
|
||||
finalizeOpenToolCallBatch();
|
||||
aggregatedContent += content;
|
||||
if (!abortSignal?.aborted) {
|
||||
onChunk?.(content);
|
||||
}
|
||||
}
|
||||
|
||||
if (reasoningContent) {
|
||||
finalizeOpenToolCallBatch();
|
||||
fullReasoningContent += reasoningContent;
|
||||
if (!abortSignal?.aborted) {
|
||||
onReasoningChunk?.(reasoningContent);
|
||||
}
|
||||
}
|
||||
|
||||
processToolCallDelta(toolCalls);
|
||||
} catch (e) {
|
||||
console.error('Error parsing JSON chunk:', e);
|
||||
if (value && value.byteLength > 0) {
|
||||
segmentBytesRead += value.byteLength;
|
||||
lastByteAt = Date.now();
|
||||
if (!madeProgress) {
|
||||
madeProgress = true;
|
||||
onConnectionState?.(StreamConnectionState.STREAMING);
|
||||
}
|
||||
}
|
||||
|
||||
chunk += decoder.decode(value, { stream: true });
|
||||
const lines = chunk.split(SSE_LINE_SEPARATOR);
|
||||
chunk = lines.pop() || '';
|
||||
|
||||
// the persisted offset must point right after the last fully parsed line,
|
||||
// the trailing `chunk` is partial bytes still waiting for a newline
|
||||
if (conversationId) {
|
||||
const tailBytes = encoder.encode(chunk).byteLength;
|
||||
bytesParsed = segmentStartOffset + segmentBytesRead - tailBytes;
|
||||
ChatService.saveStreamState(conversationId, bytesParsed, streamModel);
|
||||
}
|
||||
|
||||
for (const line of lines) {
|
||||
if (abortSignal?.aborted) break;
|
||||
|
||||
if (line.startsWith(SSE_DATA_PREFIX)) {
|
||||
const data = line.slice(SSE_DATA_PREFIX.length).trim();
|
||||
if (data === SSE_DONE_MARKER) {
|
||||
streamFinished = true;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed: ApiChatCompletionStreamChunk = JSON.parse(data);
|
||||
const choice = parsed.choices?.[0];
|
||||
const content = choice?.delta?.content;
|
||||
const reasoningContent = choice?.delta?.reasoning_content;
|
||||
const toolCalls = choice?.delta?.tool_calls;
|
||||
const timings = parsed.timings;
|
||||
const promptProgress = parsed.prompt_progress;
|
||||
|
||||
const chunkModel = ChatService.extractModelName(parsed);
|
||||
if (chunkModel && !modelEmitted) {
|
||||
modelEmitted = true;
|
||||
onModel?.(chunkModel);
|
||||
}
|
||||
|
||||
if (parsed.id && !idEmitted) {
|
||||
idEmitted = true;
|
||||
onCompletionId?.(parsed.id);
|
||||
}
|
||||
|
||||
if (promptProgress) {
|
||||
ChatService.notifyTimings(undefined, promptProgress, onTimings);
|
||||
}
|
||||
|
||||
if (timings) {
|
||||
ChatService.notifyTimings(timings, promptProgress, onTimings);
|
||||
lastTimings = timings;
|
||||
}
|
||||
|
||||
if (content) {
|
||||
finalizeOpenToolCallBatch();
|
||||
aggregatedContent += content;
|
||||
if (!abortSignal?.aborted) {
|
||||
onChunk?.(content);
|
||||
}
|
||||
}
|
||||
|
||||
if (reasoningContent) {
|
||||
finalizeOpenToolCallBatch();
|
||||
fullReasoningContent += reasoningContent;
|
||||
if (!abortSignal?.aborted) {
|
||||
onReasoningChunk?.(reasoningContent);
|
||||
}
|
||||
}
|
||||
|
||||
processToolCallDelta(toolCalls);
|
||||
} catch (e) {
|
||||
console.error('Error parsing JSON chunk:', e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (abortSignal?.aborted) break;
|
||||
if (streamFinished) break;
|
||||
}
|
||||
|
||||
// inner reader done, decide whether to try a resume
|
||||
if (abortSignal?.aborted) break;
|
||||
if (streamFinished) break;
|
||||
if (!conversationId) break;
|
||||
|
||||
if (!madeProgress) {
|
||||
onConnectionState?.(StreamConnectionState.LOST);
|
||||
onError?.(new Error('Stream resume produced no new bytes, giving up'));
|
||||
break;
|
||||
}
|
||||
|
||||
onConnectionState?.(StreamConnectionState.RESUMING);
|
||||
madeProgress = false;
|
||||
|
||||
// the server resends starting at bytesParsed, discard any partial line we held, it
|
||||
// will be retransmitted from a clean line boundary. reuse the frozen model, not the
|
||||
// live dropdown
|
||||
const resumeResp = await ChatService.resumeStream(
|
||||
conversationId,
|
||||
abortSignal,
|
||||
streamModel
|
||||
).catch(() => null);
|
||||
// an abort landing during the resume request is intentional, not a lost connection
|
||||
if (abortSignal?.aborted) break;
|
||||
if (!resumeResp || resumeResp.status !== 200) {
|
||||
onConnectionState?.(StreamConnectionState.LOST);
|
||||
onError?.(new Error('Stream connection lost and could not be resumed'));
|
||||
break;
|
||||
}
|
||||
const newReader = resumeResp.body?.getReader();
|
||||
if (!newReader) break;
|
||||
|
||||
try {
|
||||
reader.releaseLock();
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
reader = newReader;
|
||||
decoder = new TextDecoder();
|
||||
chunk = '';
|
||||
segmentStartOffset = bytesParsed;
|
||||
segmentBytesRead = 0;
|
||||
lastByteAt = Date.now();
|
||||
}
|
||||
|
||||
if (abortSignal?.aborted) return;
|
||||
@@ -718,6 +972,10 @@ export class ChatService {
|
||||
if (streamFinished) {
|
||||
finalizeOpenToolCallBatch();
|
||||
|
||||
if (conversationId) {
|
||||
ChatService.clearStreamState(conversationId);
|
||||
}
|
||||
|
||||
const finalToolCalls =
|
||||
aggregatedToolCalls.length > 0 ? JSON.stringify(aggregatedToolCalls) : undefined;
|
||||
|
||||
@@ -735,7 +993,14 @@ export class ChatService {
|
||||
|
||||
throw err;
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
if (typeof document !== 'undefined') {
|
||||
document.removeEventListener('visibilitychange', onVisibilityChange);
|
||||
}
|
||||
try {
|
||||
reader.releaseLock();
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -628,19 +628,20 @@ export class MCPService {
|
||||
);
|
||||
|
||||
const runtimeErrorHandler = (error: Error) => {
|
||||
// Ignore errors that are expected when the SDK's transport is closed,
|
||||
// or when connecting to servers that don't support SSE (stateless-only
|
||||
// endpoints returning 405). The SDK wraps the original AbortError in
|
||||
// a new Error with the message "SSE stream disconnected: AbortError",
|
||||
// and also produces "Cannot cancel a stream locked by a reader".
|
||||
// DOMException is thrown by the browser when aborting fetch requests.
|
||||
const msg = error.message || String(error);
|
||||
// the SDK reports any post initialize error here, including the abort we trigger
|
||||
// ourselves on the next health check cycle, on tab unload, or on server teardown.
|
||||
// these are lifecycle aborts, not actionable errors, so we keep them out of the red console.
|
||||
// the SDK wraps the original AbortError in a generic Error like
|
||||
// "SSE stream disconnected: AbortError: The operation was aborted."
|
||||
// which isAbortError cannot recognize by name alone, so we also pattern match on the message
|
||||
if (isAbortError(error)) {
|
||||
return;
|
||||
}
|
||||
const msg = error?.message ?? '';
|
||||
if (
|
||||
error.name === 'AbortError' ||
|
||||
error instanceof DOMException ||
|
||||
msg.includes('SSE stream disconnected') ||
|
||||
msg.includes('stream locked by a reader') ||
|
||||
msg.includes('The operation was aborted')
|
||||
/SSE stream disconnected:.*AbortError/i.test(msg) ||
|
||||
/AbortError: .*aborted/i.test(msg) ||
|
||||
/stream locked by a reader/i.test(msg)
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -614,7 +614,7 @@ class AgenticStore {
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
undefined,
|
||||
conversationId,
|
||||
signal
|
||||
);
|
||||
|
||||
|
||||
@@ -11,9 +11,11 @@
|
||||
* @see ChatService in services/chat.service.ts for API operations
|
||||
*/
|
||||
|
||||
import { SvelteMap } from 'svelte/reactivity';
|
||||
import { SvelteMap, SvelteSet } from 'svelte/reactivity';
|
||||
import { DatabaseService } from '$lib/services/database.service';
|
||||
import { ChatService } from '$lib/services/chat.service';
|
||||
import { streamIdentity } from '$lib/utils/stream-identity';
|
||||
import { getAuthHeaders } from '$lib/utils/api-headers';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { agenticStore } from '$lib/stores/agentic.svelte';
|
||||
@@ -49,10 +51,17 @@ import type {
|
||||
import type {
|
||||
ApiChatMessageData,
|
||||
ApiProcessingState,
|
||||
ApiStreamSession,
|
||||
DatabaseMessage,
|
||||
DatabaseMessageExtra
|
||||
} from '$lib/types';
|
||||
import { ContinueIntentKind, ErrorDialogType, MessageRole, MessageType } from '$lib/enums';
|
||||
import {
|
||||
ContinueIntentKind,
|
||||
ErrorDialogType,
|
||||
MessageRole,
|
||||
MessageType,
|
||||
StreamConnectionState
|
||||
} from '$lib/enums';
|
||||
|
||||
interface ConversationStateEntry {
|
||||
lastAccessed: number;
|
||||
@@ -65,9 +74,25 @@ class ChatStore {
|
||||
isLoading = $state(false);
|
||||
// true while the active conversation streams reasoning content but no visible content yet
|
||||
isReasoning = $state(false);
|
||||
// resumable stream connection state for the active conversation
|
||||
// streaming -> bytes flowing normally, resuming -> waiting on /v1/stream/:id reconnect, lost -> unrecoverable
|
||||
streamConnectionState = $state<StreamConnectionState>(StreamConnectionState.STREAMING);
|
||||
chatLoadingStates = new SvelteMap<string, boolean>();
|
||||
chatReasoningStates = new SvelteMap<string, boolean>();
|
||||
chatStreamingStates = new SvelteMap<string, { response: string; messageId: string }>();
|
||||
chatStreamingStates = new SvelteMap<
|
||||
string,
|
||||
{ response: string; messageId: string; model?: string | null }
|
||||
>();
|
||||
// convs that the backend reports as having a running session, populated by the global sync
|
||||
// at app mount and on visibilitychange. it does not overlap with chatLoadingStates which
|
||||
// tracks inferences driven by this browser, both are unioned to feed the sidebar spinners
|
||||
private remoteRunningConvs = new SvelteSet<string>();
|
||||
// per conv attach lifecycle, used to derive the global streaming flag without flipping it
|
||||
// off when one conv finishes while another is still streaming. mirrors chatLoadingStates
|
||||
// in scope but tracks the attach + tee replay path specifically
|
||||
private attachingConvs = new SvelteSet<string>();
|
||||
// in-flight discoverActiveStream guard, keyed by conv id
|
||||
private discoveringConvs = new SvelteSet<string>();
|
||||
private abortControllers = new SvelteMap<string, AbortController>();
|
||||
private preEncodeAbortController: AbortController | null = null;
|
||||
private processingStates = new SvelteMap<string, ApiProcessingState | null>();
|
||||
@@ -98,6 +123,11 @@ class ChatStore {
|
||||
this.chatLoadingStates.delete(convId);
|
||||
if (convId === conversationsStore.activeConversation?.id) this.isLoading = false;
|
||||
this.setChatReasoning(convId, false);
|
||||
// the local pipe is the authoritative observer of session end: when it finishes (clean
|
||||
// onComplete or explicit Stop), the backend session is finalized too, so we drop the
|
||||
// sidebar hint for this conv right away instead of waiting for the next visibilitychange
|
||||
// snapshot. without this the spinner ghosts until the user toggles the tab
|
||||
this.remoteRunningConvs.delete(convId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -110,9 +140,18 @@ class ChatStore {
|
||||
if (convId === conversationsStore.activeConversation?.id) this.isReasoning = false;
|
||||
}
|
||||
}
|
||||
private setChatStreaming(convId: string, response: string, messageId: string): void {
|
||||
private setChatStreaming(
|
||||
convId: string,
|
||||
response: string,
|
||||
messageId: string,
|
||||
model?: string | null
|
||||
): void {
|
||||
this.touchConversationState(convId);
|
||||
this.chatStreamingStates.set(convId, { response, messageId });
|
||||
this.chatStreamingStates.set(convId, {
|
||||
response,
|
||||
messageId,
|
||||
model: model ?? this.chatStreamingStates.get(convId)?.model
|
||||
});
|
||||
if (convId === conversationsStore.activeConversation?.id) this.currentResponse = response;
|
||||
}
|
||||
private clearChatStreaming(convId: string): void {
|
||||
@@ -137,6 +176,314 @@ class ChatStore {
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Server side stream discovery, split in three pieces:
|
||||
*
|
||||
* probeServerStream(convId) -> hits POST /v1/streams/lookup with the conv id, returns the session to attach
|
||||
* to or null. Pure read, no side effect, no UI lock. Safe to fire in parallel with anything.
|
||||
*
|
||||
* attachServerStream(convId) -> flips the spinner immediately, fetches the replay stream
|
||||
* from byte 0, finds the assistant slot to splice into (creates a placeholder if the conv has
|
||||
* no assistant message yet, for cross device or fresh local DB cases), and pipes the SSE bytes
|
||||
* into the message via handleStreamResponse.
|
||||
*
|
||||
* discoverActiveStream(convId) -> probe + attach in one call. Used by callers that do not need
|
||||
* to overlap the probe with other async work.
|
||||
*
|
||||
* The mount of the chat page in +page.svelte calls probeServerStream in parallel with
|
||||
* loadConversation, then attachServerStream once both have settled. This gives the earliest
|
||||
* possible time to spinner and avoids racing against an empty activeMessages array.
|
||||
*/
|
||||
async probeServerStream(convId: string): Promise<ApiStreamSession | null> {
|
||||
if (!convId) return null;
|
||||
let listResp: Response;
|
||||
try {
|
||||
// POST the one conv id we are probing
|
||||
listResp = await fetch(`./v1/streams/lookup`, {
|
||||
method: 'POST',
|
||||
headers: { ...getAuthHeaders(), 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ conversation_ids: [convId] })
|
||||
});
|
||||
} catch (e) {
|
||||
console.warn('probeServerStream fetch failed:', e);
|
||||
return null;
|
||||
}
|
||||
if (!listResp.ok) {
|
||||
console.warn(`probeServerStream got HTTP ${listResp.status} for conv ${convId}`);
|
||||
return null;
|
||||
}
|
||||
let sessions: ApiStreamSession[];
|
||||
try {
|
||||
sessions = (await listResp.json()) as ApiStreamSession[];
|
||||
} catch (e) {
|
||||
console.warn('probeServerStream JSON parse failed:', e);
|
||||
return null;
|
||||
}
|
||||
return ChatService.selectActiveStream(sessions);
|
||||
}
|
||||
|
||||
async attachServerStream(convId: string, streamId?: string): Promise<void> {
|
||||
if (!convId) return;
|
||||
if (this.chatStreamingStates.has(convId)) return;
|
||||
|
||||
// flip the spinner immediately, the user sees activity as soon as the conv becomes active.
|
||||
// the global isStreamingActive flag is derived from attachingConvs.size, so adding here
|
||||
// turns it on, and removing in unlock only turns it off when this is the last attach
|
||||
this.setChatLoading(convId, true);
|
||||
this.attachingConvs.add(convId);
|
||||
this.setStreamingActive(true);
|
||||
// only set the active processing conv if we are looking at it, otherwise a background
|
||||
// attach would steal the indicator from the conv the user is currently viewing
|
||||
if (convId === conversationsStore.activeConversation?.id) {
|
||||
this.setActiveProcessingConversation(convId);
|
||||
}
|
||||
|
||||
const unlock = () => {
|
||||
this.attachingConvs.delete(convId);
|
||||
// flip the global flag off only when no other conv is still attaching
|
||||
if (this.attachingConvs.size === 0) {
|
||||
this.setStreamingActive(false);
|
||||
}
|
||||
this.setChatLoading(convId, false);
|
||||
this.clearChatStreaming(convId);
|
||||
};
|
||||
|
||||
// fetch the replay stream from byte 0, rebuild the assistant message from scratch.
|
||||
// resolve the server side identity, fall back to streamIdentity when the caller does not
|
||||
// pass a streamId. probeServerStream returns the full id (with ::model suffix when present)
|
||||
const id = streamId || streamIdentity(convId, selectedModelName());
|
||||
let response: Response;
|
||||
try {
|
||||
response = await fetch(`./v1/stream/${encodeURIComponent(id)}?from=0`, {
|
||||
headers: getAuthHeaders()
|
||||
});
|
||||
} catch (e) {
|
||||
console.error('attachServerStream replay fetch failed:', e);
|
||||
unlock();
|
||||
return;
|
||||
}
|
||||
if (!response.ok) {
|
||||
console.warn(`attachServerStream replay got HTTP ${response.status} for conv ${convId}`);
|
||||
unlock();
|
||||
return;
|
||||
}
|
||||
|
||||
// load the target conversation messages by id, not via the active store. when multiple
|
||||
// attaches run in parallel the active store may reflect another conv and writing through
|
||||
// its index mixes content across convs (CoT flicker, message bleed). by going through the
|
||||
// DB we stay isolated, and only mirror into the active store when the attached conv is
|
||||
// the one currently displayed
|
||||
let messages: DatabaseMessage[];
|
||||
try {
|
||||
messages = await DatabaseService.getConversationMessages(convId);
|
||||
} catch (e) {
|
||||
console.error('attachServerStream load messages failed:', e);
|
||||
unlock();
|
||||
return;
|
||||
}
|
||||
|
||||
// locate the slot to splice into, create a placeholder assistant message if there is none.
|
||||
// we use the conv-scoped findLastAssistantIdx helpers, they only depend on the array
|
||||
let targetIdx = this.findLastAssistantIdx(messages);
|
||||
if (targetIdx === -1) {
|
||||
const lastUserIdx = this.findLastUserIdx(messages);
|
||||
if (lastUserIdx === -1) {
|
||||
console.warn(
|
||||
`attachServerStream: conv ${convId} has no user or assistant message, cannot splice`
|
||||
);
|
||||
unlock();
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const placeholder = await DatabaseService.createMessageBranch(
|
||||
{
|
||||
convId,
|
||||
role: MessageRole.ASSISTANT,
|
||||
content: '',
|
||||
type: MessageType.TEXT,
|
||||
timestamp: Date.now(),
|
||||
parent: messages[lastUserIdx].id,
|
||||
children: [],
|
||||
toolCalls: ''
|
||||
} as Omit<DatabaseMessage, 'id'>,
|
||||
messages[lastUserIdx].id
|
||||
);
|
||||
messages = [...messages, placeholder];
|
||||
targetIdx = messages.length - 1;
|
||||
// only push into the active store when this conv is the one displayed right now
|
||||
if (convId === conversationsStore.activeConversation?.id) {
|
||||
conversationsStore.addMessageToActive(placeholder);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('attachServerStream placeholder creation failed:', e);
|
||||
unlock();
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (targetIdx === -1) {
|
||||
unlock();
|
||||
return;
|
||||
}
|
||||
const targetMessage = messages[targetIdx];
|
||||
const targetMessageId = targetMessage.id;
|
||||
// when the assistant slot already has content, the running session is a continue or
|
||||
// another append flow and its buffer holds only the appended deltas. preserve the prefix
|
||||
// and let the replay add to it. when the slot is empty the session buffer holds the whole
|
||||
// message so we wipe and rebuild from byte 0
|
||||
const existingContent = targetMessage.content ?? '';
|
||||
const existingReasoning = targetMessage.reasoningContent ?? '';
|
||||
const isAppendMode = existingContent.length > 0;
|
||||
|
||||
// helper: write to the active store only when the attached conv is currently displayed.
|
||||
// the lookup by message id is robust to reordering of activeMessages, two parallel attaches
|
||||
// can no longer step on each other's indices
|
||||
const writeActive = (updates: Partial<DatabaseMessage>) => {
|
||||
if (convId !== conversationsStore.activeConversation?.id) {
|
||||
return;
|
||||
}
|
||||
const liveIdx = conversationsStore.findMessageIndex(targetMessageId);
|
||||
if (liveIdx === -1) return;
|
||||
conversationsStore.updateMessageAtIndex(liveIdx, updates);
|
||||
};
|
||||
|
||||
if (!isAppendMode) {
|
||||
writeActive({ content: '', reasoningContent: undefined });
|
||||
}
|
||||
|
||||
// extract the model suffix, the resume calls in handleStreamResponse must reuse the model
|
||||
// the session was tagged with, not the live dropdown
|
||||
const sepIdx = id.indexOf('::');
|
||||
const attachedModel: string | null = sepIdx === -1 ? null : id.slice(sepIdx + 2);
|
||||
this.setChatStreaming(convId, existingContent, targetMessageId, attachedModel);
|
||||
const abortController = this.getOrCreateAbortController(convId);
|
||||
|
||||
let streamedContent = '';
|
||||
let streamedReasoningContent = '';
|
||||
|
||||
const cleanup = () => {
|
||||
unlock();
|
||||
this.setProcessingState(convId, null);
|
||||
};
|
||||
|
||||
try {
|
||||
await ChatService.handleStreamResponse(
|
||||
response,
|
||||
(chunk: string) => {
|
||||
streamedContent += chunk;
|
||||
const displayed = isAppendMode ? existingContent + streamedContent : streamedContent;
|
||||
writeActive({ content: displayed });
|
||||
this.setChatStreaming(convId, displayed, targetMessageId);
|
||||
},
|
||||
async (
|
||||
finalContent?: string,
|
||||
reasoningContent?: string,
|
||||
timings?: ChatMessageTimings,
|
||||
toolCalls?: string
|
||||
) => {
|
||||
const streamed = streamedContent || finalContent || '';
|
||||
const streamedR = streamedReasoningContent || reasoningContent || '';
|
||||
const content = isAppendMode ? existingContent + streamed : streamed;
|
||||
const reasoning = isAppendMode ? existingReasoning + streamedR : streamedR;
|
||||
// the DB write is the source of truth, mirror to the active store only when
|
||||
// the conv is currently displayed
|
||||
await DatabaseService.updateMessage(targetMessageId, {
|
||||
content,
|
||||
reasoningContent: reasoning || undefined,
|
||||
toolCalls: toolCalls || '',
|
||||
timings
|
||||
});
|
||||
writeActive({
|
||||
content,
|
||||
reasoningContent: reasoning || undefined,
|
||||
timings
|
||||
});
|
||||
cleanup();
|
||||
},
|
||||
(err: Error) => {
|
||||
console.error('attachServerStream pipe error:', err);
|
||||
cleanup();
|
||||
},
|
||||
(chunk: string) => {
|
||||
streamedReasoningContent += chunk;
|
||||
const displayed = isAppendMode
|
||||
? existingReasoning + streamedReasoningContent
|
||||
: streamedReasoningContent;
|
||||
writeActive({ reasoningContent: displayed });
|
||||
},
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
convId,
|
||||
abortController.signal,
|
||||
(connState: StreamConnectionState) => {
|
||||
if (convId === conversationsStore.activeConversation?.id) {
|
||||
this.streamConnectionState = connState;
|
||||
}
|
||||
},
|
||||
attachedModel
|
||||
);
|
||||
} catch (e) {
|
||||
console.error('attachServerStream pipe crashed:', e);
|
||||
cleanup();
|
||||
}
|
||||
}
|
||||
|
||||
async discoverActiveStream(convId: string): Promise<void> {
|
||||
if (!convId) return;
|
||||
if (this.chatStreamingStates.has(convId)) return;
|
||||
if (this.chatLoadingStates.get(convId)) return;
|
||||
// concurrency guard: another discover may already be running for this conv (typical race
|
||||
// between mount and visibilitychange on tab switch). a second concurrent fetch on the same
|
||||
// /v1/stream/<id> would duplicate every byte into the DB message, this guard bounces it
|
||||
if (this.discoveringConvs.has(convId)) return;
|
||||
this.discoveringConvs.add(convId);
|
||||
|
||||
try {
|
||||
// the model is frozen at POST time, rebuild the exact conv::model identity from the
|
||||
// persisted state so the lookup key matches what the server stored. null means a single
|
||||
// model conv with no ::suffix, only guess from the dropdown with no persisted state
|
||||
const localState = ChatService.getStreamState(convId);
|
||||
const streamId = ChatService.resumeStreamIdentity(convId, localState, selectedModelName());
|
||||
|
||||
// primary path: ask the server which sessions exist for this identity
|
||||
const serverTarget = await this.probeServerStream(streamId);
|
||||
if (serverTarget) {
|
||||
// pass the full server side identity (may carry a ::model suffix) so the GET routes
|
||||
// straight to the owning session, no probe or fan out
|
||||
await this.attachServerStream(convId, serverTarget.conversation_id);
|
||||
return;
|
||||
}
|
||||
|
||||
// fallback: local state remembers an interrupted byte offset for this conv, the server may
|
||||
// still have a live session matching that identity (we just lost the bytes mid stream). retry
|
||||
// with the frozen identity, the server probe inside attachServerStream tells us if it exists
|
||||
if (!localState) {
|
||||
return;
|
||||
}
|
||||
await this.attachServerStream(convId, streamId);
|
||||
// if attachServerStream failed (session gone, TTL expired), clear the local state to avoid retrying forever
|
||||
if (!this.chatStreamingStates.has(convId) && !this.chatLoadingStates.get(convId)) {
|
||||
ChatService.clearStreamState(convId);
|
||||
}
|
||||
} finally {
|
||||
this.discoveringConvs.delete(convId);
|
||||
}
|
||||
}
|
||||
|
||||
private findLastAssistantIdx(messages: DatabaseMessage[]): number {
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
if (messages[i].role === MessageRole.ASSISTANT) return i;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
private findLastUserIdx(messages: DatabaseMessage[]): number {
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
if (messages[i].role === MessageRole.USER) return i;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
clearUIState(): void {
|
||||
this.isLoading = false;
|
||||
@@ -265,13 +612,83 @@ class ChatStore {
|
||||
}
|
||||
|
||||
getAllLoadingChats(): string[] {
|
||||
return Array.from(this.chatLoadingStates.keys());
|
||||
// union of local (this browser is piping) and remote (backend reports a running session
|
||||
// for this conv but no local pipe yet) sources. the sidebar shows one spinner per entry
|
||||
const out = new SvelteSet<string>(this.chatLoadingStates.keys());
|
||||
for (const id of this.remoteRunningConvs) {
|
||||
out.add(id);
|
||||
}
|
||||
return Array.from(out);
|
||||
}
|
||||
|
||||
getAllStreamingChats(): string[] {
|
||||
return Array.from(this.chatStreamingStates.keys());
|
||||
}
|
||||
|
||||
/**
|
||||
* Resync the remote running convs set from the backend. Called by the layout at mount and on
|
||||
* visibilitychange, no polling. A snapshot semantic: the set is replaced wholesale, stale entries
|
||||
* for sessions that finalized while the browser was elsewhere are dropped naturally.
|
||||
*/
|
||||
async syncRemoteRunningStreams(): Promise<void> {
|
||||
// the conversations store loads from IndexedDB asynchronously, the +layout onMount caller
|
||||
// fires before that finishes. read ids straight from the DB so the result does not depend
|
||||
// on the store init race, and the sidebar spinners light up at first paint for every conv
|
||||
// the user owns even if it has not been hydrated into the store yet
|
||||
let ids: string[];
|
||||
try {
|
||||
const all = await DatabaseService.getAllConversations();
|
||||
ids = all.map((c) => c.id).filter((id) => !!id);
|
||||
} catch (e) {
|
||||
console.warn('syncRemoteRunningStreams DB read failed:', e);
|
||||
return;
|
||||
}
|
||||
// only ask about conv ids the user already owns
|
||||
if (ids.length === 0) {
|
||||
for (const id of Array.from(this.remoteRunningConvs)) {
|
||||
this.remoteRunningConvs.delete(id);
|
||||
}
|
||||
return;
|
||||
}
|
||||
// rebuild the frozen conv::model identity per conv so a session started with a model still
|
||||
// matches. the server response is mapped back to the bare id below for the sidebar set
|
||||
const lookupIds = ids.map((id) =>
|
||||
ChatService.resumeStreamIdentity(id, ChatService.getStreamState(id), null)
|
||||
);
|
||||
let sessions: ApiStreamSession[];
|
||||
try {
|
||||
const resp = await fetch('./v1/streams/lookup', {
|
||||
method: 'POST',
|
||||
headers: { ...getAuthHeaders(), 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ conversation_ids: lookupIds })
|
||||
});
|
||||
if (!resp.ok) return;
|
||||
const body = (await resp.json()) as unknown;
|
||||
if (!Array.isArray(body)) return;
|
||||
sessions = body as ApiStreamSession[];
|
||||
} catch (e) {
|
||||
console.warn('syncRemoteRunningStreams fetch failed:', e);
|
||||
return;
|
||||
}
|
||||
const running = new SvelteSet<string>();
|
||||
for (const s of sessions) {
|
||||
if (s && !s.is_done && typeof s.conversation_id === 'string' && s.conversation_id) {
|
||||
// strip the optional ::model suffix, the sidebar set is keyed by the bare conv id
|
||||
const sepIdx = s.conversation_id.indexOf('::');
|
||||
const bareId = sepIdx === -1 ? s.conversation_id : s.conversation_id.slice(0, sepIdx);
|
||||
running.add(bareId);
|
||||
}
|
||||
}
|
||||
for (const id of Array.from(this.remoteRunningConvs)) {
|
||||
if (!running.has(id)) {
|
||||
this.remoteRunningConvs.delete(id);
|
||||
}
|
||||
}
|
||||
for (const id of running) {
|
||||
this.remoteRunningConvs.add(id);
|
||||
}
|
||||
}
|
||||
|
||||
getChatStreamingPublic(convId: string): { response: string; messageId: string } | undefined {
|
||||
return this.getChatStreaming(convId);
|
||||
}
|
||||
@@ -922,6 +1339,11 @@ class ChatStore {
|
||||
onModel: streamCallbacks.onModel,
|
||||
onCompletionId: streamCallbacks.onCompletionId,
|
||||
onTimings: streamCallbacks.onTimings,
|
||||
onConnectionState: (state: StreamConnectionState) => {
|
||||
if (convId === conversationsStore.activeConversation?.id) {
|
||||
this.streamConnectionState = state;
|
||||
}
|
||||
},
|
||||
onComplete: async (
|
||||
finalContent?: string,
|
||||
reasoningContent?: string,
|
||||
@@ -979,6 +1401,12 @@ class ChatStore {
|
||||
async stopGenerationForChat(convId: string): Promise<void> {
|
||||
await this.savePartialResponseIfNeeded(convId);
|
||||
this.setStreamingActive(false);
|
||||
// tell the server to stop the generation, not just drop the HTTP socket. without this the
|
||||
// detached drain keeps producing tokens until eos or max_tokens. use the frozen identity
|
||||
// captured when the session started, not the live dropdown
|
||||
const streamStateForStop = this.chatStreamingStates.get(convId);
|
||||
const modelForStop = streamStateForStop?.model ?? selectedModelName();
|
||||
void ChatService.cancelServerStream(convId, modelForStop);
|
||||
this.abortRequest(convId);
|
||||
this.setChatLoading(convId, false);
|
||||
this.clearChatStreaming(convId);
|
||||
@@ -1393,7 +1821,11 @@ class ChatStore {
|
||||
|
||||
const updateStreamingContent = (fullContent: string) => {
|
||||
this.setChatStreaming(msg.convId, fullContent, msg.id);
|
||||
conversationsStore.updateMessageAtIndex(idx, { content: fullContent });
|
||||
// resolve the row by id on every write, switching to another conv mid continue makes
|
||||
// this a no op instead of writing positionally into the now displayed conversation
|
||||
conversationsStore.updateMessageAtIndex(conversationsStore.findMessageIndex(msg.id), {
|
||||
content: fullContent
|
||||
});
|
||||
};
|
||||
|
||||
const abortController = this.getOrCreateAbortController(msg.convId);
|
||||
@@ -1403,6 +1835,11 @@ class ChatStore {
|
||||
{
|
||||
...this.getApiOptions(),
|
||||
continueFinalMessage: true,
|
||||
onConnectionState: (state: StreamConnectionState) => {
|
||||
if (msg.convId === conversationsStore.activeConversation?.id) {
|
||||
this.streamConnectionState = state;
|
||||
}
|
||||
},
|
||||
onChunk: (chunk: string) => {
|
||||
appendedContent += chunk;
|
||||
hasReceivedContent = true;
|
||||
@@ -1414,7 +1851,7 @@ class ChatStore {
|
||||
hasReceivedContent = true;
|
||||
// mark streaming state so a stop mid-thinking can persist the partial reasoning
|
||||
this.setChatStreaming(msg.convId, originalContent + appendedContent, msg.id);
|
||||
conversationsStore.updateMessageAtIndex(idx, {
|
||||
conversationsStore.updateMessageAtIndex(conversationsStore.findMessageIndex(msg.id), {
|
||||
reasoningContent: originalReasoning + appendedReasoning
|
||||
});
|
||||
this.setChatReasoning(msg.convId, true);
|
||||
@@ -1455,7 +1892,7 @@ class ChatStore {
|
||||
timings
|
||||
});
|
||||
|
||||
conversationsStore.updateMessageAtIndex(idx, {
|
||||
conversationsStore.updateMessageAtIndex(conversationsStore.findMessageIndex(msg.id), {
|
||||
content: fullContent,
|
||||
reasoningContent: fullReasoning,
|
||||
timestamp: Date.now(),
|
||||
@@ -1477,11 +1914,14 @@ class ChatStore {
|
||||
timestamp: Date.now()
|
||||
});
|
||||
|
||||
conversationsStore.updateMessageAtIndex(idx, {
|
||||
content: originalContent + appendedContent,
|
||||
reasoningContent: originalReasoning + appendedReasoning || undefined,
|
||||
timestamp: Date.now()
|
||||
});
|
||||
conversationsStore.updateMessageAtIndex(
|
||||
conversationsStore.findMessageIndex(msg.id),
|
||||
{
|
||||
content: originalContent + appendedContent,
|
||||
reasoningContent: originalReasoning + appendedReasoning || undefined,
|
||||
timestamp: Date.now()
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
this.setChatLoading(msg.convId, false);
|
||||
@@ -1498,7 +1938,7 @@ class ChatStore {
|
||||
reasoningContent: originalReasoning + appendedReasoning || undefined,
|
||||
timestamp: Date.now()
|
||||
});
|
||||
conversationsStore.updateMessageAtIndex(idx, {
|
||||
conversationsStore.updateMessageAtIndex(conversationsStore.findMessageIndex(msg.id), {
|
||||
content: originalContent + appendedContent,
|
||||
reasoningContent: originalReasoning + appendedReasoning || undefined,
|
||||
timestamp: Date.now()
|
||||
|
||||
@@ -392,11 +392,14 @@ class ToolsStore {
|
||||
} catch (err) {
|
||||
const errorMessage = err instanceof Error ? err.message : String(err);
|
||||
this._error = errorMessage;
|
||||
// 404 from /tools means the server was started without --tools
|
||||
if (errorMessage.includes('404') || errorMessage.toLowerCase().includes('not found')) {
|
||||
// 403 from /tools means the server was started without --tools
|
||||
// TODO: check status code instead of relying on message
|
||||
if (errorMessage.includes('this feature is disabled')) {
|
||||
this._toolsEndpointUnreachable = true;
|
||||
console.info('[ToolsStore] Built-in tools are disabled on the server');
|
||||
} else {
|
||||
console.error('[ToolsStore] Failed to fetch built-in tools:', err);
|
||||
}
|
||||
console.error('[ToolsStore] Failed to fetch built-in tools:', err);
|
||||
} finally {
|
||||
this._loading = false;
|
||||
}
|
||||
|
||||
Vendored
+15
@@ -512,3 +512,18 @@ export interface ApiRouterModelsUnloadResponse {
|
||||
success: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Entry returned by POST /v1/streams/lookup. The client passes the conv ids it owns in the body
|
||||
* and the server returns one entry per matching live or recently completed background streaming
|
||||
* session, keyed by conversation_id. The WebUI uses this at mount and on visibilitychange to
|
||||
* populate sidebar spinners and to reattach to an ongoing inference for the active conversation.
|
||||
* The server never lists ids the client did not ask about, so foreign random UUIDs stay private.
|
||||
*/
|
||||
export interface ApiStreamSession {
|
||||
conversation_id: string;
|
||||
is_done: boolean;
|
||||
total_bytes: number;
|
||||
started_at: number;
|
||||
completed_at: number;
|
||||
}
|
||||
|
||||
@@ -34,7 +34,8 @@ export type {
|
||||
ApiRouterModelsListResponse,
|
||||
ApiRouterModelsUnloadRequest,
|
||||
ApiRouterModelsUnloadResponse,
|
||||
AudioInputFormat
|
||||
AudioInputFormat,
|
||||
ApiStreamSession
|
||||
} from './api';
|
||||
|
||||
// Chat types
|
||||
|
||||
Vendored
+4
-2
@@ -4,9 +4,10 @@ import type { OpenAIToolDefinition } from './mcp';
|
||||
import type { DatabaseMessageExtra } from './database';
|
||||
import type {
|
||||
ParameterSource,
|
||||
ReasoningEffort,
|
||||
SyncableParameterType,
|
||||
SettingsFieldType
|
||||
SettingsFieldType,
|
||||
StreamConnectionState,
|
||||
ReasoningEffort
|
||||
} from '$lib/enums';
|
||||
import type { Icon } from '@lucide/svelte';
|
||||
import type { Component } from 'svelte';
|
||||
@@ -119,6 +120,7 @@ export interface SettingsChatServiceOptions {
|
||||
toolCalls?: string
|
||||
) => void;
|
||||
onError?: (error: Error) => void;
|
||||
onConnectionState?: (state: StreamConnectionState) => void;
|
||||
}
|
||||
|
||||
export type SettingsConfigType = typeof SETTING_CONFIG_DEFAULT & {
|
||||
|
||||
@@ -6,6 +6,17 @@
|
||||
* when needed (e.g., user stops generation, navigates away, etc.).
|
||||
*/
|
||||
|
||||
// the standard DOMException name for a cancelled operation
|
||||
const ABORT_ERROR_NAME = 'AbortError';
|
||||
|
||||
// browser specific TypeError messages emitted when a fetch reader is cut by page unload,
|
||||
// navigation, or a transient network drop. functionally aborts, not actionable errors
|
||||
const ABORT_LIKE_MESSAGE_PATTERNS = [
|
||||
/input stream/i, // Firefox: stream cut at unload
|
||||
/network connection was lost/i, // Safari: transient network drop
|
||||
/load failed/i // Safari: page navigation during fetch
|
||||
];
|
||||
|
||||
/**
|
||||
* Throws an AbortError if the signal is aborted.
|
||||
* Use this at the start of async operations to fail fast.
|
||||
@@ -23,7 +34,7 @@
|
||||
*/
|
||||
export function throwIfAborted(signal?: AbortSignal): void {
|
||||
if (signal?.aborted) {
|
||||
throw new DOMException('Operation was aborted', 'AbortError');
|
||||
throw new DOMException('Operation was aborted', ABORT_ERROR_NAME);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,11 +59,18 @@ export function throwIfAborted(signal?: AbortSignal): void {
|
||||
* ```
|
||||
*/
|
||||
export function isAbortError(error: unknown): boolean {
|
||||
if (error instanceof DOMException && error.name === 'AbortError') {
|
||||
if (error instanceof DOMException && error.name === ABORT_ERROR_NAME) {
|
||||
return true;
|
||||
}
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
return true;
|
||||
if (error instanceof Error) {
|
||||
if (error.name === ABORT_ERROR_NAME) {
|
||||
return true;
|
||||
}
|
||||
// these patterns are functionally aborts, keep them out of the red console
|
||||
if (error instanceof TypeError) {
|
||||
const msg = error.message ?? '';
|
||||
if (ABORT_LIKE_MESSAGE_PATTERNS.some((re) => re.test(msg))) return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -133,7 +151,7 @@ export async function withAbortSignal<T>(promise: Promise<T>, signal?: AbortSign
|
||||
|
||||
return new Promise<T>((resolve, reject) => {
|
||||
const abortHandler = () => {
|
||||
reject(new DOMException('Operation was aborted', 'AbortError'));
|
||||
reject(new DOMException('Operation was aborted', ABORT_ERROR_NAME));
|
||||
};
|
||||
|
||||
signal.addEventListener('abort', abortHandler, { once: true });
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
/**
|
||||
* Build the conversation identity used by the server side replay buffer.
|
||||
*
|
||||
* The server identifies a stream session by a conversation id sent in the
|
||||
* X-Conversation-Id header. When the user has explicitly picked a model the
|
||||
* client appends ::modelName, so a per model session stays distinct and the
|
||||
* router resolves the owning child through its conv_id -> model map.
|
||||
*/
|
||||
export function streamIdentity(conversationId: string, model?: string | null): string {
|
||||
if (!conversationId) return '';
|
||||
if (!model) return conversationId;
|
||||
return `${conversationId}::${model}`;
|
||||
}
|
||||
@@ -4,7 +4,7 @@
|
||||
import { afterNavigate } from '$app/navigation';
|
||||
import { DialogModelNotAvailable } from '$lib/components/app';
|
||||
import { APP_NAME, ROUTES } from '$lib/constants';
|
||||
import { chatStore, isLoading } from '$lib/stores/chat.svelte';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore, activeConversation } from '$lib/stores/conversations.svelte';
|
||||
import { modelsStore, modelOptions } from '$lib/stores/models.svelte';
|
||||
|
||||
@@ -83,7 +83,7 @@
|
||||
|
||||
// Skip loading if this conversation is already active (e.g., just created)
|
||||
if (activeConversation()?.id === chatId) {
|
||||
// Still handle URL params even if conversation is active
|
||||
void chatStore.discoverActiveStream(chatId);
|
||||
if ((qParam !== null || modelParam !== null) && !urlParamsProcessed) {
|
||||
handleUrlParams();
|
||||
}
|
||||
@@ -92,35 +92,33 @@
|
||||
|
||||
(async () => {
|
||||
const success = await conversationsStore.loadConversation(chatId);
|
||||
if (success) {
|
||||
chatStore.syncLoadingStateForChat(chatId);
|
||||
|
||||
// Handle URL params after conversation is loaded
|
||||
if ((qParam !== null || modelParam !== null) && !urlParamsProcessed) {
|
||||
await handleUrlParams();
|
||||
}
|
||||
} else {
|
||||
if (!success) {
|
||||
await goto(ROUTES.START);
|
||||
return;
|
||||
}
|
||||
chatStore.syncLoadingStateForChat(chatId);
|
||||
// server probe (with localStorage fallback) and attach
|
||||
await chatStore.discoverActiveStream(chatId);
|
||||
|
||||
if ((qParam !== null || modelParam !== null) && !urlParamsProcessed) {
|
||||
await handleUrlParams();
|
||||
}
|
||||
})();
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (typeof window !== 'undefined') {
|
||||
const handleBeforeUnload = () => {
|
||||
if (isLoading()) {
|
||||
console.log('Page unload detected while streaming - aborting stream');
|
||||
chatStore.stopGeneration();
|
||||
}
|
||||
};
|
||||
if (typeof window === 'undefined' || typeof document === 'undefined') return;
|
||||
|
||||
window.addEventListener('beforeunload', handleBeforeUnload);
|
||||
|
||||
return () => {
|
||||
window.removeEventListener('beforeunload', handleBeforeUnload);
|
||||
};
|
||||
}
|
||||
// when the tab comes back to the foreground, re-run discovery to catch any race
|
||||
// where the initial mount probe missed an active session
|
||||
const onVisibility = () => {
|
||||
if (document.visibilityState !== 'visible') return;
|
||||
if (!chatId) return;
|
||||
void chatStore.discoverActiveStream(chatId);
|
||||
};
|
||||
document.addEventListener('visibilitychange', onVisibility);
|
||||
return () => document.removeEventListener('visibilitychange', onVisibility);
|
||||
});
|
||||
</script>
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
import { PwaMetaTags, PwaRefreshAlert } from '$lib/components/pwa';
|
||||
import { pwaAssetsHead } from 'virtual:pwa-assets/head';
|
||||
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { isRouterMode, serverStore } from '$lib/stores/server.svelte';
|
||||
@@ -33,8 +34,6 @@
|
||||
import { SETTINGS_KEYS } from '$lib/constants';
|
||||
|
||||
let { children } = $props();
|
||||
let alwaysShowSidebarOnDesktop = $derived(config().alwaysShowSidebarOnDesktop);
|
||||
let isDesktop = $derived(!isMobile.current);
|
||||
let innerHeight = $state<number | undefined>();
|
||||
let innerWidth = $state(browser ? window.innerWidth : 0);
|
||||
|
||||
@@ -156,20 +155,24 @@
|
||||
|
||||
onMount(() => {
|
||||
updateFavicon();
|
||||
// snapshot of every backend running stream on first load, populates the sidebar spinners
|
||||
// so the user sees each conv that has a live inference, even ones not opened yet
|
||||
void chatStore.syncRemoteRunningStreams();
|
||||
});
|
||||
|
||||
// refresh that snapshot when the tab returns to the foreground, a stream may have advanced
|
||||
// or ended while it was hidden. snapshot only, no polling
|
||||
function handleVisibilityChange() {
|
||||
if (document.visibilityState !== 'visible') return;
|
||||
void chatStore.syncRemoteRunningStreams();
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
void theme.isSystemDark;
|
||||
|
||||
updateFavicon();
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (alwaysShowSidebarOnDesktop && isDesktop) {
|
||||
return;
|
||||
}
|
||||
});
|
||||
|
||||
// Initialize server properties on app load (run once)
|
||||
$effect(() => {
|
||||
// Only fetch if we don't already have props
|
||||
@@ -288,6 +291,7 @@
|
||||
</svelte:head>
|
||||
|
||||
<svelte:window onkeydown={handleKeydown} bind:innerHeight bind:innerWidth />
|
||||
<svelte:document onvisibilitychange={handleVisibilityChange} />
|
||||
|
||||
<Tooltip.Provider delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
<div class="flex flex-col md:flex-row">
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import { isAbortError } from '$lib/utils/abort';
|
||||
|
||||
describe('isAbortError', () => {
|
||||
it('returns false for null, undefined and non-error values', () => {
|
||||
expect(isAbortError(null)).toBe(false);
|
||||
expect(isAbortError(undefined)).toBe(false);
|
||||
expect(isAbortError('string error')).toBe(false);
|
||||
expect(isAbortError({ name: 'AbortError' })).toBe(false);
|
||||
expect(isAbortError(42)).toBe(false);
|
||||
});
|
||||
|
||||
it('returns true for DOMException with AbortError name', () => {
|
||||
const err = new DOMException('Operation was aborted', 'AbortError');
|
||||
expect(isAbortError(err)).toBe(true);
|
||||
});
|
||||
|
||||
it('returns true for plain Error with AbortError name', () => {
|
||||
const err = new Error('aborted');
|
||||
err.name = 'AbortError';
|
||||
expect(isAbortError(err)).toBe(true);
|
||||
});
|
||||
|
||||
it('returns false for unrelated Error instances', () => {
|
||||
expect(isAbortError(new Error('something failed'))).toBe(false);
|
||||
expect(isAbortError(new TypeError('not related'))).toBe(false);
|
||||
expect(isAbortError(new RangeError('out of range'))).toBe(false);
|
||||
});
|
||||
|
||||
it('recognizes Firefox TypeError "Error in input stream" emitted at page unload', () => {
|
||||
expect(isAbortError(new TypeError('Error in input stream'))).toBe(true);
|
||||
expect(isAbortError(new TypeError('TypeError: Error in input stream'))).toBe(true);
|
||||
});
|
||||
|
||||
it('recognizes Safari "The network connection was lost" during transient drop', () => {
|
||||
expect(isAbortError(new TypeError('The network connection was lost.'))).toBe(true);
|
||||
});
|
||||
|
||||
it('recognizes Safari "Load failed" during page navigation', () => {
|
||||
expect(isAbortError(new TypeError('Load failed'))).toBe(true);
|
||||
});
|
||||
|
||||
it('does NOT recognize generic TypeError messages as aborts', () => {
|
||||
// matching too broadly would hide real bugs, the predicate must stay conservative
|
||||
expect(isAbortError(new TypeError('Failed to fetch'))).toBe(false);
|
||||
expect(isAbortError(new TypeError('Cannot read property of undefined'))).toBe(false);
|
||||
expect(isAbortError(new TypeError('NetworkError when attempting to fetch resource'))).toBe(
|
||||
false
|
||||
);
|
||||
});
|
||||
|
||||
it('is case insensitive on the matched substrings', () => {
|
||||
expect(isAbortError(new TypeError('error in INPUT STREAM'))).toBe(true);
|
||||
expect(isAbortError(new TypeError('the network connection WAS LOST'))).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,74 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import { ChatService } from '$lib/services/chat.service';
|
||||
import type { ApiStreamSession } from '$lib/types';
|
||||
|
||||
function makeSession(overrides: Partial<ApiStreamSession>): ApiStreamSession {
|
||||
return {
|
||||
conversation_id: 'conv',
|
||||
is_done: true,
|
||||
total_bytes: 0,
|
||||
started_at: 0,
|
||||
completed_at: 0,
|
||||
...overrides
|
||||
};
|
||||
}
|
||||
|
||||
describe('selectActiveStream', () => {
|
||||
it('returns null on empty input', () => {
|
||||
expect(ChatService.selectActiveStream([])).toBeNull();
|
||||
});
|
||||
|
||||
it('returns null on null or undefined input', () => {
|
||||
expect(ChatService.selectActiveStream(null)).toBeNull();
|
||||
expect(ChatService.selectActiveStream(undefined)).toBeNull();
|
||||
});
|
||||
|
||||
it('returns the single session when it is running', () => {
|
||||
const s = makeSession({ conversation_id: 'only', is_done: false, started_at: 42 });
|
||||
expect(ChatService.selectActiveStream([s])).toBe(s);
|
||||
});
|
||||
|
||||
it('returns null when the single session is finalized', () => {
|
||||
const s = makeSession({ conversation_id: 'only', is_done: true, started_at: 42 });
|
||||
expect(ChatService.selectActiveStream([s])).toBeNull();
|
||||
});
|
||||
|
||||
it('prefers a still running session over a finalized one regardless of started_at', () => {
|
||||
const finalized = makeSession({ conversation_id: 'old', is_done: true, started_at: 1000 });
|
||||
const running = makeSession({ conversation_id: 'new', is_done: false, started_at: 10 });
|
||||
expect(ChatService.selectActiveStream([finalized, running])?.conversation_id).toBe('new');
|
||||
expect(ChatService.selectActiveStream([running, finalized])?.conversation_id).toBe('new');
|
||||
});
|
||||
|
||||
it('among running sessions, picks the most recently started one', () => {
|
||||
const a = makeSession({ conversation_id: 'a', is_done: false, started_at: 100 });
|
||||
const b = makeSession({ conversation_id: 'b', is_done: false, started_at: 200 });
|
||||
const c = makeSession({ conversation_id: 'c', is_done: false, started_at: 150 });
|
||||
expect(ChatService.selectActiveStream([a, b, c])?.conversation_id).toBe('b');
|
||||
expect(ChatService.selectActiveStream([c, a, b])?.conversation_id).toBe('b');
|
||||
});
|
||||
|
||||
it('returns null when all sessions are finalized, the DB already holds the content', () => {
|
||||
const a = makeSession({ conversation_id: 'a', is_done: true, started_at: 10 });
|
||||
const b = makeSession({ conversation_id: 'b', is_done: true, started_at: 30 });
|
||||
const c = makeSession({ conversation_id: 'c', is_done: true, started_at: 20 });
|
||||
expect(ChatService.selectActiveStream([a, b, c])).toBeNull();
|
||||
});
|
||||
|
||||
it('keeps the first match on ties when both are running with identical started_at', () => {
|
||||
// reduce visits left to right, the initial accumulator stays unless a strictly greater value appears
|
||||
const a = makeSession({ conversation_id: 'first', is_done: false, started_at: 50 });
|
||||
const b = makeSession({ conversation_id: 'second', is_done: false, started_at: 50 });
|
||||
expect(ChatService.selectActiveStream([a, b])?.conversation_id).toBe('first');
|
||||
});
|
||||
|
||||
it('handles a typical realistic mix: two finalized old, one freshly running, one freshly finalized', () => {
|
||||
const old1 = makeSession({ conversation_id: 'old1', is_done: true, started_at: 100 });
|
||||
const old2 = makeSession({ conversation_id: 'old2', is_done: true, started_at: 200 });
|
||||
const freshFin = makeSession({ conversation_id: 'freshFin', is_done: true, started_at: 500 });
|
||||
const running = makeSession({ conversation_id: 'running', is_done: false, started_at: 400 });
|
||||
expect(ChatService.selectActiveStream([old1, old2, freshFin, running])?.conversation_id).toBe(
|
||||
'running'
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,128 @@
|
||||
import { afterEach, beforeAll, beforeEach, describe, expect, it } from 'vitest';
|
||||
|
||||
// node env unit project has no DOM, install a minimal localStorage backed by a Map
|
||||
beforeAll(() => {
|
||||
const store = new Map<string, string>();
|
||||
const polyfill: Storage = {
|
||||
get length() {
|
||||
return store.size;
|
||||
},
|
||||
clear: () => store.clear(),
|
||||
getItem: (k) => (store.has(k) ? store.get(k)! : null),
|
||||
key: (i) => Array.from(store.keys())[i] ?? null,
|
||||
removeItem: (k) => {
|
||||
store.delete(k);
|
||||
},
|
||||
setItem: (k, v) => {
|
||||
store.set(k, String(v));
|
||||
}
|
||||
};
|
||||
(globalThis as unknown as { localStorage: Storage }).localStorage = polyfill;
|
||||
});
|
||||
|
||||
import { ChatService } from '$lib/services/chat.service';
|
||||
import { STREAM_RESUME_LOCALSTORAGE_KEY_PREFIX } from '$lib/constants';
|
||||
|
||||
describe('ChatService stream resume', () => {
|
||||
beforeEach(() => {
|
||||
localStorage.clear();
|
||||
});
|
||||
afterEach(() => {
|
||||
localStorage.clear();
|
||||
});
|
||||
|
||||
it('returns null when no state exists for the conversation', () => {
|
||||
expect(ChatService.getStreamState('conv-a')).toBeNull();
|
||||
});
|
||||
|
||||
it('saves and reads back the byte count', () => {
|
||||
ChatService.saveStreamState('conv-a', 4242);
|
||||
const got = ChatService.getStreamState('conv-a');
|
||||
expect(got).not.toBeNull();
|
||||
expect(got!.bytesReceived).toBe(4242);
|
||||
expect(typeof got!.updatedAt).toBe('number');
|
||||
});
|
||||
|
||||
it('overwrites the previous byte count on a new save for the same conversation', () => {
|
||||
ChatService.saveStreamState('conv-a', 100);
|
||||
ChatService.saveStreamState('conv-a', 200);
|
||||
const got = ChatService.getStreamState('conv-a');
|
||||
expect(got!.bytesReceived).toBe(200);
|
||||
});
|
||||
|
||||
it('keeps states for distinct conversations isolated', () => {
|
||||
ChatService.saveStreamState('conv-a', 10);
|
||||
ChatService.saveStreamState('conv-b', 20);
|
||||
expect(ChatService.getStreamState('conv-a')!.bytesReceived).toBe(10);
|
||||
expect(ChatService.getStreamState('conv-b')!.bytesReceived).toBe(20);
|
||||
});
|
||||
|
||||
it('clears the state for a given conversation', () => {
|
||||
ChatService.saveStreamState('conv-a', 10);
|
||||
ChatService.clearStreamState('conv-a');
|
||||
expect(ChatService.getStreamState('conv-a')).toBeNull();
|
||||
});
|
||||
|
||||
it('ignores empty conversation id on save', () => {
|
||||
ChatService.saveStreamState('', 1);
|
||||
expect(ChatService.getStreamState('')).toBeNull();
|
||||
});
|
||||
|
||||
it('returns null on corrupted storage payload', () => {
|
||||
localStorage.setItem(`${STREAM_RESUME_LOCALSTORAGE_KEY_PREFIX}conv-a`, '{not-json');
|
||||
expect(ChatService.getStreamState('conv-a')).toBeNull();
|
||||
});
|
||||
|
||||
it('persists the model alongside the byte count', () => {
|
||||
ChatService.saveStreamState('conv-a', 10, 'model-x');
|
||||
expect(ChatService.getStreamState('conv-a')!.model).toBe('model-x');
|
||||
});
|
||||
|
||||
it('stores a null model when none is provided', () => {
|
||||
ChatService.saveStreamState('conv-a', 10);
|
||||
expect(ChatService.getStreamState('conv-a')!.model).toBeNull();
|
||||
});
|
||||
|
||||
it('overwrites the model on a new save for the same conversation', () => {
|
||||
ChatService.saveStreamState('conv-a', 10, 'model-x');
|
||||
ChatService.saveStreamState('conv-a', 20, 'model-y');
|
||||
expect(ChatService.getStreamState('conv-a')!.model).toBe('model-y');
|
||||
});
|
||||
|
||||
describe('resumeStreamIdentity', () => {
|
||||
it('appends the persisted model so the resume key matches the frozen POST identity', () => {
|
||||
ChatService.saveStreamState('conv-a', 10, 'model-x');
|
||||
expect(
|
||||
ChatService.resumeStreamIdentity('conv-a', ChatService.getStreamState('conv-a'), 'dropdown')
|
||||
).toBe('conv-a::model-x');
|
||||
});
|
||||
|
||||
it('keeps the bare conv id when the persisted model is null', () => {
|
||||
ChatService.saveStreamState('conv-a', 10);
|
||||
expect(
|
||||
ChatService.resumeStreamIdentity('conv-a', ChatService.getStreamState('conv-a'), 'dropdown')
|
||||
).toBe('conv-a');
|
||||
});
|
||||
|
||||
it('falls back to the current model only when no state is persisted', () => {
|
||||
expect(ChatService.resumeStreamIdentity('conv-a', null, 'dropdown')).toBe('conv-a::dropdown');
|
||||
});
|
||||
|
||||
it('ignores the fallback when a state exists, the persisted value is authoritative', () => {
|
||||
ChatService.saveStreamState('conv-a', 10, 'model-x');
|
||||
expect(
|
||||
ChatService.resumeStreamIdentity('conv-a', ChatService.getStreamState('conv-a'), 'dropdown')
|
||||
).toBe('conv-a::model-x');
|
||||
});
|
||||
|
||||
it('falls back when a legacy state has no model field', () => {
|
||||
localStorage.setItem(
|
||||
`${STREAM_RESUME_LOCALSTORAGE_KEY_PREFIX}conv-a`,
|
||||
JSON.stringify({ bytesReceived: 10, updatedAt: 1 })
|
||||
);
|
||||
expect(
|
||||
ChatService.resumeStreamIdentity('conv-a', ChatService.getStreamState('conv-a'), 'dropdown')
|
||||
).toBe('conv-a::dropdown');
|
||||
});
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user