mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-26 23:57:40 +02:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5a6a0dd7e1 | |||
| ded1561b42 | |||
| 9df06805ee | |||
| 2f18fe13c5 | |||
| c16c35b814 | |||
| 1a87dcdc45 | |||
| e7e3f35090 | |||
| b11f7c16bc | |||
| f818065d75 | |||
| 960d628f46 |
+3
-4
@@ -114,7 +114,8 @@ class Mamba2Model(TextModel):
|
||||
hparams["text_config"] = hparams["llm_config"]
|
||||
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
|
||||
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
|
||||
self.d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * self.d_model
|
||||
self.expand = self.find_hparam(["mamba_expand", "expand"], optional=True) or 2
|
||||
self.d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or self.expand * self.d_model
|
||||
self.n_group = self.find_hparam(["n_groups"], optional=True) or 1
|
||||
|
||||
def set_vocab(self):
|
||||
@@ -144,11 +145,9 @@ class Mamba2Model(TextModel):
|
||||
|
||||
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
|
||||
|
||||
# Fail early for models which don't have a block expansion factor of 2
|
||||
# TODO: does this really matter?
|
||||
# skip the assertion for FalconH1 Model
|
||||
if self.model_arch != gguf.MODEL_ARCH.FALCON_H1:
|
||||
assert self.d_inner == 2 * self.d_model
|
||||
assert self.d_inner == self.expand * self.d_model
|
||||
assert self.d_inner % head_dim == 0
|
||||
|
||||
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
|
||||
|
||||
@@ -75,12 +75,12 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
|
||||
ay1 = GGML_F32_VEC_LOAD(y + i);
|
||||
sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
|
||||
}
|
||||
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
|
||||
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmla on available elements only
|
||||
if (np2 < n) {
|
||||
svbool_t pg = svwhilelt_b32(np2, n);
|
||||
ax1 = svld1_f32(pg, x + np2);
|
||||
ay1 = svld1_f32(pg, y + np2);
|
||||
sum1 = svmad_f32_m(pg, ax1, ay1, sum1);
|
||||
sum1 = svmla_f32_m(pg, sum1, ax1, ay1);
|
||||
}
|
||||
// reduce sum1,sum2 to sum1
|
||||
GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);
|
||||
|
||||
@@ -2,6 +2,28 @@
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
static __global__ void k_compute_out_prod_ptrs(
|
||||
const float * src0_d, const float * src1_d, float * dst_d,
|
||||
const float ** ptrs_a, const float ** ptrs_b, float ** ptrs_c,
|
||||
const int64_t ne2, const int64_t ne3,
|
||||
const int64_t dps2, const int64_t dps3,
|
||||
const size_t s02, const size_t s03,
|
||||
const size_t s12, const size_t s13,
|
||||
const size_t s2, const size_t s3) {
|
||||
const int64_t i2 = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
const int64_t i3 = blockIdx.y*blockDim.y + threadIdx.y;
|
||||
|
||||
if (i2 >= ne2 || i3 >= ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t idx = i3*ne2 + i2;
|
||||
|
||||
ptrs_a[idx] = src0_d + (i3/dps3)*s03 + (i2/dps2)*s02;
|
||||
ptrs_b[idx] = src1_d + i3 *s13 + i2 *s12;
|
||||
ptrs_c[idx] = dst_d + i3 *s3 + i2 *s2;
|
||||
}
|
||||
|
||||
void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
@@ -67,18 +89,39 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
&beta, dst_d + i3 *s3, ldc, s2,
|
||||
batch_count));
|
||||
}
|
||||
} else if (ne2 > 1 || ne3 > 1) {
|
||||
// dps2 > 1 (src0 broadcast along dim 2 with non-uniform stride) or multiple GEMMs
|
||||
// along dim 3: compute per-GEMM pointers on the device and use a single batched GEMM.
|
||||
GGML_ASSERT(ne3 > 0);
|
||||
GGML_ASSERT(ne2 <= (int64_t) std::numeric_limits<int>::max() / ne3);
|
||||
const int batch_count = (int) (ne2 * ne3);
|
||||
|
||||
ggml_cuda_pool_alloc<const float *> ptrs_a(ctx.pool(), batch_count);
|
||||
ggml_cuda_pool_alloc<const float *> ptrs_b(ctx.pool(), batch_count);
|
||||
ggml_cuda_pool_alloc< float *> ptrs_c(ctx.pool(), batch_count);
|
||||
|
||||
const dim3 block_dims(16, 16);
|
||||
const dim3 grid_dims((ne2 + block_dims.x - 1)/block_dims.x, (ne3 + block_dims.y - 1)/block_dims.y);
|
||||
k_compute_out_prod_ptrs<<<grid_dims, block_dims, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ptrs_a.get(), ptrs_b.get(), ptrs_c.get(),
|
||||
ne2, ne3, dps2, dps3, s02, s03, s12, s13, s2, s3);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemmBatched(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, ptrs_a.get(), lda,
|
||||
ptrs_b.get(), ldb,
|
||||
&beta, ptrs_c.get(), ldc,
|
||||
batch_count));
|
||||
} else {
|
||||
// Fallback: ne2 == 1 (no batching benefit) or dps2 > 1 (src0 broadcast along dim 2
|
||||
// with non-uniform stride; would need cublasSgemmBatched with pointer arrays).
|
||||
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
||||
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
|
||||
src1_d + i3 *s13 + i2 *s12, ldb,
|
||||
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
|
||||
}
|
||||
}
|
||||
// ne2 == 1 && ne3 == 1: single GEMM
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, src0_d, lda,
|
||||
src1_d, ldb,
|
||||
&beta, dst_d, ldc));
|
||||
}
|
||||
}
|
||||
|
||||
Vendored
+1
@@ -48,6 +48,7 @@
|
||||
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
||||
#define cublasSetStream hipblasSetStream
|
||||
#define cublasSgemm hipblasSgemm
|
||||
#define cublasSgemmBatched hipblasSgemmBatched
|
||||
#define cublasSgemmStridedBatched hipblasSgemmStridedBatched
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define cublasOperation_t hipblasOperation_t
|
||||
|
||||
Vendored
+1
@@ -32,6 +32,7 @@
|
||||
#define cublasSetMathMode mublasSetMathMode
|
||||
#define cublasSetStream mublasSetStream
|
||||
#define cublasSgemm mublasSgemm
|
||||
#define cublasSgemmBatched mublasSgemmBatched
|
||||
#define cublasSgemmStridedBatched mublasSgemmStridedBatched
|
||||
#define cublasStatus_t mublasStatus_t
|
||||
#define cublasOperation_t mublasOperation_t
|
||||
|
||||
@@ -126,7 +126,7 @@ static void soft_max_f32(const float * x,
|
||||
break;
|
||||
}
|
||||
|
||||
const float val = sycl::native::exp(vals[col] - max_val);
|
||||
const float val = sycl::native::exp(sycl::max(vals[col] - max_val, -80.0f));
|
||||
tmp += val;
|
||||
vals[col] = val;
|
||||
}
|
||||
@@ -154,7 +154,7 @@ static void soft_max_f32(const float * x,
|
||||
tmp = warp_reduce_sum<WARP_SIZE>(tmp);
|
||||
}
|
||||
if (sinks) {
|
||||
tmp += sycl::native::exp(sinks[i02] - max_val);
|
||||
tmp += sycl::native::exp(sycl::max(sinks[i02] - max_val, -80.0f));
|
||||
}
|
||||
const float inv_sum = 1.0f / tmp;
|
||||
|
||||
|
||||
@@ -308,6 +308,7 @@ enum vk_device_architecture {
|
||||
AMD_RDNA1,
|
||||
AMD_RDNA2,
|
||||
AMD_RDNA3,
|
||||
INTEL_XE1,
|
||||
INTEL_XE2,
|
||||
NVIDIA_PRE_TURING,
|
||||
NVIDIA_TURING,
|
||||
@@ -365,21 +366,26 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
|
||||
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
|
||||
|
||||
bool subgroup_size_control = false;
|
||||
bool integer_dot_product = false;
|
||||
|
||||
for (const auto& properties : ext_props) {
|
||||
if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
|
||||
subgroup_size_control = true;
|
||||
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0) {
|
||||
integer_dot_product = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!subgroup_size_control) {
|
||||
if (!subgroup_size_control || !integer_dot_product) {
|
||||
return vk_device_architecture::OTHER;
|
||||
}
|
||||
|
||||
vk::PhysicalDeviceProperties2 props2;
|
||||
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
|
||||
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props;
|
||||
|
||||
props2.pNext = &subgroup_size_control_props;
|
||||
subgroup_size_control_props.pNext = &integer_dot_props;
|
||||
device.getProperties2(&props2);
|
||||
|
||||
if (subgroup_size_control_props.minSubgroupSize == 16) {
|
||||
@@ -388,6 +394,9 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
|
||||
// https://www.intel.com/content/www/us/en/content-details/824434/2024-intel-tech-tour-xe2-and-lunar-lake-s-gpu.html
|
||||
// https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
|
||||
return vk_device_architecture::INTEL_XE2;
|
||||
} else if (subgroup_size_control_props.minSubgroupSize == 8 &&
|
||||
integer_dot_product && integer_dot_props.integerDotProduct4x8BitPackedSignedAccelerated) {
|
||||
return vk_device_architecture::INTEL_XE1;
|
||||
}
|
||||
} else if (props.vendorID == VK_VENDOR_ID_NVIDIA) {
|
||||
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
|
||||
@@ -3837,7 +3846,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
l_warptile = { 256, 128, 128, 16, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
l_warptile_mmq = l_warptile_mmq_int = { 256, 128, 128, 32, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
l_warptile_mmq_int_k = { 256, 128, 128, 32, subgroup_size_16, 64, 1, 4, 2, 1, subgroup_size_16 };
|
||||
} else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support && device->architecture == INTEL_XE2) {
|
||||
} else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support) {
|
||||
// Xe2/Xe3 with coopmat enabled - warptile performance tuning
|
||||
l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
l_warptile_mmq = { 512, 128, 128, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
@@ -6361,9 +6370,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
break;
|
||||
case VK_VENDOR_ID_INTEL: {
|
||||
// Current Windows driver does not expose BF16 support.
|
||||
// We only want to use l_warptile if coopmat is available and is Xe2+
|
||||
const bool xe2_with_coopmat = device->coopmat_support && device->architecture == INTEL_XE2;
|
||||
const bool use_l_warptile = (i == GGML_TYPE_BF16) ? (device->coopmat_bf16_support && xe2_with_coopmat) : xe2_with_coopmat;
|
||||
// We only want to use l_warptile if coopmat is available
|
||||
const bool use_l_warptile = (i == GGML_TYPE_BF16) ? (device->coopmat_bf16_support && device->coopmat_support) : device->coopmat_support;
|
||||
device->mul_mat_l[i] = use_l_warptile;
|
||||
device->mul_mat_id_l[i] = use_l_warptile;
|
||||
device->mul_mat_m[i] = true;
|
||||
@@ -17890,9 +17898,9 @@ static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev) {
|
||||
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
|
||||
switch (props.vendorID) {
|
||||
case VK_VENDOR_ID_INTEL:
|
||||
// Only allowing Xe2 GPU at the moment since Xe2 GPU can gain significant performance boost,
|
||||
// while some older hardware (ex. Arc A770) has performance regressions
|
||||
return arch == vk_device_architecture::INTEL_XE2;
|
||||
// Only allowing Xe2/Xe3 GPU and integrated Xe GPUs at the moment since older hardware (ex. Arc A770) has performance regressions.
|
||||
return (arch == vk_device_architecture::INTEL_XE2) ||
|
||||
(arch == vk_device_architecture::INTEL_XE1 && props.deviceType == vk::PhysicalDeviceType::eIntegratedGpu && driver_props.driverID == vk::DriverId::eIntelProprietaryWindows);
|
||||
case VK_VENDOR_ID_AMD:
|
||||
if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
|
||||
// Workaround for AMD proprietary driver reporting support on all GPUs
|
||||
@@ -17940,6 +17948,8 @@ static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev)
|
||||
case 0xE20B: // B580
|
||||
case 0xE211: // Pro B60
|
||||
return 20;
|
||||
case 0xB080: // PTL Xe3 LPG 2x6 (12 subslices)
|
||||
return 12;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -158,7 +158,7 @@ const uint32_t Csh_stride = BS_NPQ;
|
||||
#ifdef COOPMAT
|
||||
const uint32_t Csh_len = BS_K * Csh_stride;
|
||||
#else
|
||||
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1;
|
||||
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 8; // 8 to workaround compiler bug
|
||||
#endif
|
||||
shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ
|
||||
#endif
|
||||
|
||||
@@ -144,7 +144,7 @@ const uint32_t Csh_stride = BS_NPQ;
|
||||
#ifdef COOPMAT
|
||||
const uint32_t Csh_len = BS_K * Csh_stride;
|
||||
#else
|
||||
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1;
|
||||
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 8; // 8 to workaround compiler bug
|
||||
#endif
|
||||
shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ
|
||||
#endif
|
||||
|
||||
@@ -169,7 +169,6 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp,
|
||||
GGML_ASSERT(ubatch.equal_seqs());
|
||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||
GGML_ASSERT(d_inner % n_head == 0);
|
||||
GGML_ASSERT(d_inner % d_state == 0);
|
||||
GGML_ASSERT(d_inner % n_group == 0);
|
||||
|
||||
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
||||
|
||||
@@ -39,10 +39,11 @@ void llama_model_mamba2::load_arch_tensors(llama_model_loader &) {
|
||||
const int64_t d_inner = hparams.ssm_d_inner;
|
||||
const int64_t d_state = hparams.ssm_d_state;
|
||||
const int64_t n_group = hparams.ssm_n_group;
|
||||
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head;
|
||||
const int64_t dt_rank = hparams.ssm_dt_rank;
|
||||
|
||||
const int64_t conv_dim = d_inner + 2 * n_group * d_state;
|
||||
const int64_t d_in_proj = d_inner + conv_dim + dt_rank;
|
||||
|
||||
// only an expansion factor of 2 is supported for now
|
||||
GGML_ASSERT(2 * n_embd == d_inner);
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
@@ -68,11 +69,11 @@ void llama_model_mamba2::load_arch_tensors(llama_model_loader &) {
|
||||
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0);
|
||||
layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, 0);
|
||||
|
||||
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0);
|
||||
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {dt_rank}, 0);
|
||||
|
||||
// no "weight" suffix for these
|
||||
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0);
|
||||
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0);
|
||||
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, dt_rank}, 0);
|
||||
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, dt_rank}, 0);
|
||||
|
||||
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0);
|
||||
|
||||
|
||||
@@ -7973,6 +7973,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
test_cases.emplace_back(new test_conv_2d({ 256, 256, 192, 1 }, { 3, 3, 192, 96 }, kernel_type, 1, 1, 1, 1, 1, 1, false));
|
||||
}
|
||||
|
||||
// sycl backend will limit task global_range < MAX_INT
|
||||
// test cases for 2D im2col with large input W and H (occurs in stable-diffusion)
|
||||
@@ -8672,6 +8675,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
256, 16, 16, {ne2, 1}, {1, 1}));
|
||||
}
|
||||
|
||||
// nr2 sweep to cover the cublasSgemmBatched pointer-array path (dps2 > 1)
|
||||
for (int64_t nr2 : {8, 16, 32}) {
|
||||
test_cases.emplace_back(new test_out_prod(GGML_TYPE_F32, GGML_TYPE_F32,
|
||||
256, 16, 16, {1, 1}, {nr2, 1}));
|
||||
}
|
||||
|
||||
// add_id
|
||||
for (ggml_type type_a : {GGML_TYPE_F32}) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
|
||||
+59
-26
@@ -102,21 +102,34 @@ static float dot_product_error(const ggml_type_traits * qfns, const ggml_type_tr
|
||||
return fabsf(result - dot_ref) / test_size;
|
||||
}
|
||||
|
||||
int main(int argc, char * argv[]) {
|
||||
bool verbose = false;
|
||||
const size_t test_size = 32 * 128;
|
||||
static int test_vec_dot_f32(bool verbose) {
|
||||
const auto * f32 = ggml_get_type_traits_cpu(GGML_TYPE_F32);
|
||||
int num_failed = 0;
|
||||
for (int n : {1, 2, 3, 5, 7, 8, 15, 16, 17, 31, 33, 63, 67, 127, 129, 193, 255, 1023}) {
|
||||
std::vector<float> a(n);
|
||||
std::vector<float> b(n);
|
||||
generate_data(0.0, n, a.data());
|
||||
generate_data(1.0, n, b.data());
|
||||
|
||||
std::string arg;
|
||||
for (int i = 1; i < argc; i++) {
|
||||
arg = argv[i];
|
||||
float result = 0.0f;
|
||||
f32->vec_dot(n, &result, 0, a.data(), 0, b.data(), 0, 1);
|
||||
const float ref = dot_product(a.data(), b.data(), n);
|
||||
const float error = fabsf(result - ref) / n;
|
||||
|
||||
if (arg == "-v") {
|
||||
verbose = true;
|
||||
} else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
return 1;
|
||||
const bool failed = !(error < MAX_QUANTIZATION_REFERENCE_ERROR);
|
||||
num_failed += failed;
|
||||
if (failed || verbose) {
|
||||
printf(" f32 vec_dot n=%4d: %s (ref=%f got=%f err=%f)\n",
|
||||
n, RESULT_STR[failed], ref, result, error);
|
||||
}
|
||||
}
|
||||
return num_failed;
|
||||
}
|
||||
|
||||
static int test_vec_dot_q(bool verbose) {
|
||||
int num_failed = 0;
|
||||
|
||||
const size_t test_size = 32 * 128;
|
||||
|
||||
std::vector<float> test_data(test_size);
|
||||
std::vector<float> test_data2(test_size);
|
||||
@@ -124,11 +137,6 @@ int main(int argc, char * argv[]) {
|
||||
generate_data(0.0, test_data.size(), test_data.data());
|
||||
generate_data(1.0, test_data2.size(), test_data2.data());
|
||||
|
||||
ggml_cpu_init();
|
||||
|
||||
int num_failed = 0;
|
||||
bool failed = false;
|
||||
|
||||
for (int i = 0; i < GGML_TYPE_COUNT; i++) {
|
||||
ggml_type type = (ggml_type) i;
|
||||
const auto * qfns = ggml_get_type_traits(type);
|
||||
@@ -156,7 +164,7 @@ int main(int argc, char * argv[]) {
|
||||
type == GGML_TYPE_IQ3_S ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
|
||||
type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS :
|
||||
type == GGML_TYPE_NVFP4 ? MAX_QUANTIZATION_TOTAL_ERROR_FP4 : MAX_QUANTIZATION_TOTAL_ERROR;
|
||||
failed = !(total_error < max_quantization_error);
|
||||
bool failed = !(total_error < max_quantization_error);
|
||||
num_failed += failed;
|
||||
if (failed || verbose) {
|
||||
printf("%5s absolute quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error);
|
||||
@@ -171,15 +179,15 @@ int main(int argc, char * argv[]) {
|
||||
|
||||
const float vec_dot_error = dot_product_error(qfns, qfns_cpu, test_size, test_data.data(), test_data2.data());
|
||||
const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS ||
|
||||
type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S
|
||||
? MAX_DOT_PRODUCT_ERROR_LOWBIT
|
||||
: type == GGML_TYPE_Q1_0
|
||||
? MAX_DOT_PRODUCT_ERROR_BINARY
|
||||
: type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0
|
||||
? MAX_DOT_PRODUCT_ERROR_TERNARY
|
||||
: type == GGML_TYPE_NVFP4
|
||||
? MAX_DOT_PRODUCT_ERROR_FP4
|
||||
: MAX_DOT_PRODUCT_ERROR;
|
||||
type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S
|
||||
? MAX_DOT_PRODUCT_ERROR_LOWBIT
|
||||
: type == GGML_TYPE_Q1_0
|
||||
? MAX_DOT_PRODUCT_ERROR_BINARY
|
||||
: type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0
|
||||
? MAX_DOT_PRODUCT_ERROR_TERNARY
|
||||
: type == GGML_TYPE_NVFP4
|
||||
? MAX_DOT_PRODUCT_ERROR_FP4
|
||||
: MAX_DOT_PRODUCT_ERROR;
|
||||
failed = !(vec_dot_error < max_allowed_error);
|
||||
num_failed += failed;
|
||||
if (failed || verbose) {
|
||||
@@ -188,6 +196,31 @@ int main(int argc, char * argv[]) {
|
||||
}
|
||||
}
|
||||
|
||||
return num_failed;
|
||||
}
|
||||
|
||||
int main(int argc, char * argv[]) {
|
||||
bool verbose = false;
|
||||
|
||||
std::string arg;
|
||||
for (int i = 1; i < argc; i++) {
|
||||
arg = argv[i];
|
||||
|
||||
if (arg == "-v") {
|
||||
verbose = true;
|
||||
} else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_cpu_init();
|
||||
|
||||
int num_failed = 0;
|
||||
|
||||
num_failed += test_vec_dot_f32(verbose);
|
||||
num_failed += test_vec_dot_q(verbose);
|
||||
|
||||
if (num_failed || verbose) {
|
||||
printf("%d tests failed\n", num_failed);
|
||||
}
|
||||
|
||||
@@ -55,8 +55,7 @@ struct clip_hparams {
|
||||
int32_t n_head = 0;
|
||||
int32_t n_head_kv = 0;
|
||||
int32_t n_layer = 0;
|
||||
// idefics3
|
||||
int32_t n_merge = 0; // number of patch merges **per-side**
|
||||
int32_t n_merge = 1; // number of patch merges **per-side**
|
||||
|
||||
// for preprocessor
|
||||
int32_t image_longest_edge = 0;
|
||||
@@ -135,8 +134,7 @@ struct clip_hparams {
|
||||
int32_t custom_image_max_tokens = -1;
|
||||
|
||||
void set_limit_image_tokens(int n_tokens_min, int n_tokens_max) {
|
||||
const int cur_merge = n_merge == 0 ? 1 : n_merge;
|
||||
const int patch_area = patch_size * patch_size * cur_merge * cur_merge;
|
||||
const int patch_area = patch_size * patch_size * n_merge * n_merge;
|
||||
image_min_pixels = (custom_image_min_tokens > 0 ? custom_image_min_tokens : n_tokens_min) * patch_area;
|
||||
image_max_pixels = (custom_image_max_tokens > 0 ? custom_image_max_tokens : n_tokens_max) * patch_area;
|
||||
warmup_image_size = static_cast<int>(std::sqrt(image_max_pixels));
|
||||
@@ -145,8 +143,7 @@ struct clip_hparams {
|
||||
void set_warmup_n_tokens(int n_tokens) {
|
||||
int n_tok_per_side = static_cast<int>(std::sqrt(n_tokens));
|
||||
GGML_ASSERT(n_tok_per_side * n_tok_per_side == n_tokens && "n_tokens must be n*n");
|
||||
const int cur_merge = n_merge == 0 ? 1 : n_merge;
|
||||
warmup_image_size = n_tok_per_side * patch_size * cur_merge;
|
||||
warmup_image_size = n_tok_per_side * patch_size * n_merge;
|
||||
// TODO: support warmup size for custom token numbers
|
||||
}
|
||||
// sam vit deepseek-ocr
|
||||
|
||||
+52
-15
@@ -1210,6 +1210,9 @@ struct clip_model_loader {
|
||||
{
|
||||
std::vector<int> pinpoints;
|
||||
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, pinpoints, false);
|
||||
if (pinpoints.size() % 2 != 0) {
|
||||
throw std::runtime_error(string_format("%s: image_grid_pinpoints must have an even number of elements, got %zu\n", __func__, pinpoints.size()));
|
||||
}
|
||||
if (!pinpoints.empty()) {
|
||||
for (size_t i = 0; i < pinpoints.size(); i += 2) {
|
||||
hparams.image_res_candidates.push_back({
|
||||
@@ -1252,15 +1255,16 @@ struct clip_model_loader {
|
||||
}
|
||||
|
||||
if (is_vision) {
|
||||
int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN);
|
||||
int idx_std = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD);
|
||||
GGML_ASSERT(idx_mean >= 0 && "image_mean not found");
|
||||
GGML_ASSERT(idx_std >= 0 && "image_std not found");
|
||||
const float * mean_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_mean);
|
||||
const float * std_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_std);
|
||||
std::vector<float> image_mean;
|
||||
std::vector<float> image_std;
|
||||
get_arr_f32(KEY_IMAGE_MEAN, image_mean, false);
|
||||
get_arr_f32(KEY_IMAGE_STD , image_std, false);
|
||||
if (image_mean.size() < 3 || image_std.size() < 3) {
|
||||
throw std::runtime_error(string_format("%s: image_mean/image_std arrays must have at least 3 elements, got %zu and %zu\n", __func__, image_mean.size(), image_std.size()));
|
||||
}
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
hparams.image_mean[i] = mean_data[i];
|
||||
hparams.image_std[i] = std_data[i];
|
||||
hparams.image_mean[i] = image_mean[i];
|
||||
hparams.image_std[i] = image_std[i];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1686,8 +1690,8 @@ struct clip_model_loader {
|
||||
if (hparams.image_size > 65536) {
|
||||
throw std::runtime_error(string_format("%s: image_size (%d) is too large (max 65536)\n", __func__, hparams.image_size));
|
||||
}
|
||||
if (hparams.patch_size <= 0) {
|
||||
throw std::runtime_error(string_format("%s: patch_size (%d) must be greater than 0\n", __func__, hparams.patch_size));
|
||||
if (hparams.patch_size <= 0 || hparams.patch_size >= 65536) {
|
||||
throw std::runtime_error(string_format("%s: patch_size (%d) must be positive and less than 65536\n", __func__, hparams.patch_size));
|
||||
}
|
||||
if (hparams.n_embd <= 0) {
|
||||
throw std::runtime_error(string_format("%s: n_embd (%d) must be greater than 0\n", __func__, hparams.n_embd));
|
||||
@@ -1695,6 +1699,9 @@ struct clip_model_loader {
|
||||
if (hparams.image_max_pixels < hparams.image_min_pixels) {
|
||||
throw std::runtime_error(string_format("%s: image_max_pixels (%d) is less than image_min_pixels (%d)\n", __func__, hparams.image_max_pixels, hparams.image_min_pixels));
|
||||
}
|
||||
if (hparams.n_merge < 0 || hparams.n_merge >= 65536) {
|
||||
throw std::runtime_error(string_format("%s: n_merge (%d) must be greater than 0 and less than 65536\n", __func__, hparams.n_merge));
|
||||
}
|
||||
}
|
||||
|
||||
LOG_INF("%s: projector: %s\n", __func__, proj_type.c_str());
|
||||
@@ -3067,6 +3074,29 @@ struct clip_model_loader {
|
||||
output = gguf_get_val_f32(ctx_gguf.get(), i);
|
||||
}
|
||||
|
||||
void get_arr_f32(const std::string & key, std::vector<float> & output, bool required = true) const {
|
||||
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
||||
if (i < 0) {
|
||||
if (required) {
|
||||
throw std::runtime_error("Key not found: " + key);
|
||||
}
|
||||
return;
|
||||
}
|
||||
const auto type = gguf_get_arr_type(ctx_gguf.get(), i);
|
||||
if (type != GGUF_TYPE_FLOAT32) {
|
||||
throw std::runtime_error(string_format("%s: array '%s' has type %d, expected %d (GGUF_TYPE_FLOAT32)\n", __func__, key.c_str(), type, GGUF_TYPE_FLOAT32));
|
||||
}
|
||||
const size_t n = gguf_get_arr_n(ctx_gguf.get(), i);
|
||||
if (n > (size_t) std::numeric_limits<int>::max()) {
|
||||
throw std::runtime_error(string_format("%s: array '%s' is too large (%zu elements)\n", __func__, key.c_str(), n));
|
||||
}
|
||||
output.resize(n);
|
||||
const float * values = (const float *)gguf_get_arr_data(ctx_gguf.get(), i);
|
||||
for (size_t j = 0; j < n; ++j) {
|
||||
output[j] = values[j];
|
||||
}
|
||||
}
|
||||
|
||||
void get_string(const std::string & key, std::string & output, bool required = true) const {
|
||||
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
||||
if (i < 0) {
|
||||
@@ -3086,11 +3116,18 @@ struct clip_model_loader {
|
||||
}
|
||||
return;
|
||||
}
|
||||
int n = gguf_get_arr_n(ctx_gguf.get(), i);
|
||||
const auto type = gguf_get_arr_type(ctx_gguf.get(), i);
|
||||
if (type != GGUF_TYPE_INT32) {
|
||||
throw std::runtime_error(string_format("%s: array '%s' has type %d, expected %d (GGUF_TYPE_INT32)\n", __func__, key.c_str(), type, GGUF_TYPE_INT32));
|
||||
}
|
||||
const size_t n = gguf_get_arr_n(ctx_gguf.get(), i);
|
||||
if (n > (size_t) std::numeric_limits<int>::max()) {
|
||||
throw std::runtime_error(string_format("%s: array '%s' is too large (%zu elements)\n", __func__, key.c_str(), n));
|
||||
}
|
||||
output.resize(n);
|
||||
const int32_t * values = (const int32_t *)gguf_get_arr_data(ctx_gguf.get(), i);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
output[i] = values[i];
|
||||
for (size_t j = 0; j < n; ++j) {
|
||||
output[j] = values[j];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3364,8 +3401,8 @@ int clip_n_output_tokens(const clip_ctx * ctx, const clip_image_f32 * img) {
|
||||
{
|
||||
// dynamic size
|
||||
int n_merge = ctx->model.hparams.n_merge;
|
||||
int n_patches_x = img->nx() / patch_size / (n_merge > 0 ? n_merge : 1);
|
||||
int n_patches_y = img->ny() / patch_size / (n_merge > 0 ? n_merge : 1);
|
||||
int n_patches_x = img->nx() / patch_size / n_merge;
|
||||
int n_patches_y = img->ny() / patch_size / n_merge;
|
||||
if (ctx->model.token_embd_img_break) {
|
||||
n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
|
||||
} else {
|
||||
|
||||
@@ -63,8 +63,8 @@ ggml_cgraph * clip_graph_pixtral::build() {
|
||||
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
|
||||
// after the concatenation, we have a tensor with shape [n_embd, n_patches_per_row + 1, n_rows]
|
||||
|
||||
const int p_y = n_merge > 0 ? n_patches_y / n_merge : n_patches_y;
|
||||
const int p_x = n_merge > 0 ? n_patches_x / n_merge : n_patches_x;
|
||||
const int p_y = n_patches_y / n_merge;
|
||||
const int p_x = n_patches_x / n_merge;
|
||||
const int p_total = p_x * p_y;
|
||||
const int n_embd_text = cur->ne[0];
|
||||
const int n_tokens_output = p_total + p_y - 1; // one [IMG_BREAK] per row, except the last row
|
||||
|
||||
@@ -628,7 +628,7 @@ mtmd_image_preproc_out mtmd_image_preprocessor_llava_uhd::preprocess(const clip_
|
||||
mtmd_image_preprocessor_llava_uhd::slice_instructions mtmd_image_preprocessor_llava_uhd::get_slice_instructions(const clip_image_size & original_size) {
|
||||
mtmd_image_preprocessor_llava_uhd::slice_instructions res;
|
||||
// align slices by patch_size * n_merge so an integer number of merger output tokens fits per slice
|
||||
const int n_merge = hparams.n_merge > 0 ? hparams.n_merge : 1;
|
||||
const int n_merge = hparams.n_merge;
|
||||
const int patch_size = hparams.patch_size * n_merge;
|
||||
const int slice_size = hparams.image_size;
|
||||
const int original_width = original_size.width;
|
||||
@@ -894,7 +894,7 @@ mtmd_image_preproc_out mtmd_image_preprocessor_dyn_size::preprocess(const clip_i
|
||||
clip_image_u8 resized_image;
|
||||
const clip_image_size original_size = img.get_size();
|
||||
// the original pixtral model doesn't have n_merge
|
||||
const int cur_merge = hparams.n_merge == 0 ? 1 : hparams.n_merge;
|
||||
const int cur_merge = hparams.n_merge;
|
||||
const clip_image_size target_size = img_tool::calc_size_preserved_ratio(
|
||||
original_size,
|
||||
hparams.patch_size * cur_merge,
|
||||
|
||||
@@ -15,6 +15,8 @@ add_library(${TARGET} STATIC
|
||||
server-common.h
|
||||
server-context.cpp
|
||||
server-context.h
|
||||
server-stream.cpp
|
||||
server-stream.h
|
||||
server-tools.cpp
|
||||
server-tools.h
|
||||
server-schema.cpp
|
||||
|
||||
@@ -57,6 +57,7 @@ The core architecture consists of the following components:
|
||||
- `server_tokens`: Unified representation of token sequences (supports both text and multimodal tokens); used by `server_task` and `server_slot`.
|
||||
- `server_prompt_checkpoint`: For recurrent (e.g., RWKV) and SWA models, stores snapshots of KV cache state. Enables reuse when subsequent requests share the same prompt prefix, saving redundant computation.
|
||||
- `server_models`: Standalone component for managing multiple backend instances (used in router mode). It is completely independent of `server_context`.
|
||||
- `stream_session_manager`: Process wide owner of resumable SSE stream sessions (`g_stream_sessions`), keyed by conversation id. Backs the replay buffer that lets a client reattach to a generation after an HTTP disconnect. See the "Resumable streaming" section below.
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
@@ -117,6 +118,58 @@ Here is an example trace of an API request for text completion:
|
||||
- As the response is stateless, `server_res_generator` calls `response->update()` to update the response with the current state.
|
||||
- `server_res_generator` then calls `response->to_json()` and passes the response to the HTTP layer.
|
||||
|
||||
### Resumable streaming (SSE replay buffer)
|
||||
|
||||
By default a streaming generation is bound to its HTTP socket: when the socket drops (refresh, tab close, mobile background, transient network) the generation aborts and the live stream is lost. This feature keeps the generation running server side and lets a client reattach.
|
||||
|
||||
It is opt in via the `X-Conversation-Id` header on `POST /v1/chat/completions`. Without the header the OAI strict path is unchanged. The conversation id is the only identity end to end (server map key, client localStorage key, route path), with an optional `::model` suffix for direct routing in router mode.
|
||||
|
||||
The feature lives entirely in `server-stream.{h,cpp}` and rests on three types:
|
||||
|
||||
- `stream_session`: a bounded ring buffer (4 MiB cap, oldest bytes drop first) plus a condvar. `append` pushes raw SSE bytes, `read_from` drains from any offset and blocks for live bytes or finalize, `finalize` wakes readers, `cancel` stops the producer. One conv maps to at most one live session.
|
||||
- `stream_session_manager` (`g_stream_sessions`): owns all sessions keyed by conv id, enforces the one conv one session invariant via `create_or_replace`, and runs a GC thread that drops completed sessions past their TTL.
|
||||
- `stream_pipe_producer` / `stream_pipe_consumer`: the write and read ends. The producer owns the session lifetime and finalizes it on destruction; the consumer is read only and never finalizes, so a reader detaching cannot kill a running generation.
|
||||
|
||||
Producer side: `server_res_generator` attaches a producer pipe when the header is present. The HTTP content provider mirrors every chunk into the ring before writing it to the socket. While a pipe is attached, `stream_aware_should_stop` ignores peer disconnect, so a dropped socket does not stop generation: only an explicit `DELETE` does. When the peer leaves early, `on_complete` calls `close()`, which drains the rest of the generation into the ring on the http worker.
|
||||
|
||||
Lifetime safety: the producer pipe holds a shared `alive` flag also captured by the session cancel hook. `~server_res_generator` calls `cleanup()` to clear that hook while the reader is still alive, so a `cancel` arriving during teardown can never call `stop()` on a freed response. This ordering is the most fragile part of the feature: finalizing or destroying the producer before `cleanup()` runs reintroduces a use after free.
|
||||
|
||||
Consumer side: `GET /v1/stream/<conv_id>?from=N` opens a `text/event-stream` that replays buffered bytes from offset `N` and blocks for live bytes, so the browser reattaches like a fresh EventSource. An offset below the dropped prefix returns 400.
|
||||
|
||||
Routes:
|
||||
|
||||
- `GET /v1/stream/:conv_id?from=N`: replay or live reattach.
|
||||
- `POST /v1/streams/lookup` with `{"conversation_ids": [...]}`: returns session status only for ids the caller already owns. There is no listing route, so live sessions cannot be enumerated (an earlier `GET /v1/streams` was removed for exactly this reason).
|
||||
- `DELETE /v1/stream/:conv_id`: explicit Stop, idempotent (`evict_and_cancel`).
|
||||
|
||||
Router mode binds the same paths to proxy handlers. A `conv_id -> child` map (`conv_models`), populated when a POST is routed, resolves the owning child in one lookup with no polling. The lookup groups ids per child; GET and DELETE proxy straight to the owner. This loopback REST hop is expected to move to a websocket IPC later, swapping only the transport.
|
||||
|
||||
Lifecycle: `g_stream_sessions.start_gc()` runs in main after common init, `stop_gc()` runs first in `clean_up()` and finalizes every live session so no reader hangs. Reader blocking and the post drop drain both run on httplib worker threads, which block on a condvar rather than spin.
|
||||
|
||||
| Constant | Value | Role |
|
||||
| --- | --- | --- |
|
||||
| `STREAM_SESSION_TTL_SECONDS` | 300 | retention of a completed session before GC |
|
||||
| `STREAM_SESSION_MAX_BYTES` | 4 MiB | ring cap per session |
|
||||
| `STREAM_SESSION_GC_INTERVAL_SECONDS` | 60 | GC tick |
|
||||
| `STREAM_READ_WAKE_INTERVAL_MS` | 200 | read_from wake to recheck should_stop |
|
||||
| `STREAM_LOOKUP_TIMEOUT_MS` | 250 | router to child loopback budget |
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
Client -- "POST + X-Conversation-Id" --> RG[server_res_generator]
|
||||
RG -- attach --> Prod[stream_pipe_producer]
|
||||
Prod -- "write, drain on peer drop" --> Sess
|
||||
subgraph g_stream_sessions
|
||||
Sess[stream_session: ring buffer, 4 MiB]
|
||||
GC[GC thread] -- drop after TTL --> Sess
|
||||
end
|
||||
Sess -- read_from offset --> Cons[stream_pipe_consumer]
|
||||
Cons -- "GET /v1/stream/:id?from=N" --> Client
|
||||
DEL[DELETE /v1/stream/:id] -- evict_and_cancel --> Sess
|
||||
```
|
||||
|
||||
The diagram shows the buffer touch points. The live wire (chunks streamed to the original client during a normal generation) is the producer's default output, described under "Producer side" above.
|
||||
|
||||
### Testing
|
||||
|
||||
`llama-server` includes an automated test suite based on `pytest`.
|
||||
@@ -223,6 +276,7 @@ The flow for downloading a new model:
|
||||
- Speculative decoding: https://github.com/ggml-org/llama.cpp/pull/17808 and rework in https://github.com/ggml-org/llama.cpp/pull/17808
|
||||
- INI presets: https://github.com/ggml-org/llama.cpp/pull/17859 (+ refactoring: https://github.com/ggml-org/llama.cpp/pull/18169)
|
||||
- Sleeping mode: https://github.com/ggml-org/llama.cpp/pull/18228
|
||||
- Resumable streaming (SSE replay buffer): https://github.com/ggml-org/llama.cpp/pull/23226
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "server-task.h"
|
||||
#include "server-queue.h"
|
||||
#include "server-schema.h"
|
||||
#include "server-stream.h"
|
||||
|
||||
#include "build-info.h"
|
||||
#include "common.h"
|
||||
@@ -4022,6 +4023,15 @@ struct server_res_generator : server_http_res {
|
||||
queue_tasks.wait_until_no_sleep();
|
||||
}
|
||||
}
|
||||
~server_res_generator() override {
|
||||
// cleanup() must run while rd is still alive (rd is destroyed after this body returns)
|
||||
if (spipe) {
|
||||
spipe->cleanup();
|
||||
}
|
||||
}
|
||||
void stop() override {
|
||||
rd.stop();
|
||||
}
|
||||
void ok(const json & response_data) {
|
||||
status = 200;
|
||||
data = safe_json_to_str(response_data);
|
||||
@@ -4210,8 +4220,10 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
}
|
||||
};
|
||||
|
||||
auto effective_should_stop = stream_aware_should_stop(res_this, req.should_stop);
|
||||
|
||||
try {
|
||||
if (req.should_stop()) {
|
||||
if (effective_should_stop()) {
|
||||
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
||||
return false; // should_stop condition met
|
||||
}
|
||||
@@ -4245,8 +4257,8 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
// receive subsequent results
|
||||
bool timeout = false;
|
||||
int64_t start_time = ggml_time_ms();
|
||||
auto result = rd.next([&timeout, &req, &start_time, ¶ms]() {
|
||||
if (req.should_stop()) {
|
||||
auto result = rd.next([&timeout, &start_time, ¶ms, &effective_should_stop]() {
|
||||
if (effective_should_stop()) {
|
||||
return true; // should_stop condition met
|
||||
} else if (params.sse_ping_interval > 0 && ggml_time_ms() - start_time > (int64_t)params.sse_ping_interval * 1000) {
|
||||
timeout = true;
|
||||
@@ -4264,7 +4276,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
|
||||
if (result == nullptr) {
|
||||
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
||||
GGML_ASSERT(req.should_stop());
|
||||
GGML_ASSERT(effective_should_stop());
|
||||
return false; // should_stop condition met
|
||||
}
|
||||
|
||||
@@ -4302,6 +4314,10 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
};
|
||||
}
|
||||
|
||||
// attach a producer pipe to the response when X-Conversation-Id is present.
|
||||
// the pipe mirrors SSE chunks into the ring buffer and wires up the cancel hook.
|
||||
stream_session_attach_pipe(*res, req.headers);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "common.h"
|
||||
#include "server-http.h"
|
||||
#include "server-stream.h"
|
||||
#include "server-common.h"
|
||||
#include "ui.h"
|
||||
|
||||
@@ -456,13 +457,40 @@ static void set_headers(httplib::Response & res, const std::map<std::string, std
|
||||
}
|
||||
}
|
||||
|
||||
// percent-decode a path component (%XX). path params arrive raw from httplib, unlike query
|
||||
// params, so a conv id like "conv::model" sent as "conv%3A%3Amodel" must be decoded here to
|
||||
// match the value the client put in the X-Conversation-Id header
|
||||
static std::string decode_path_component(const std::string & in) {
|
||||
std::string out;
|
||||
out.reserve(in.size());
|
||||
for (size_t i = 0; i < in.size(); i++) {
|
||||
if (in[i] == '%' && i + 2 < in.size()) {
|
||||
auto hex = [](char c) -> int {
|
||||
if (c >= '0' && c <= '9') return c - '0';
|
||||
if (c >= 'a' && c <= 'f') return c - 'a' + 10;
|
||||
if (c >= 'A' && c <= 'F') return c - 'A' + 10;
|
||||
return -1;
|
||||
};
|
||||
int hi = hex(in[i + 1]);
|
||||
int lo = hex(in[i + 2]);
|
||||
if (hi >= 0 && lo >= 0) {
|
||||
out.push_back(char((hi << 4) | lo));
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
out.push_back(in[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
static std::map<std::string, std::string> get_params(const httplib::Request & req) {
|
||||
std::map<std::string, std::string> params;
|
||||
for (const auto & [key, value] : req.params) {
|
||||
params[key] = value;
|
||||
}
|
||||
for (const auto & [key, value] : req.path_params) {
|
||||
params[key] = value;
|
||||
params[key] = decode_path_component(value);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
@@ -497,26 +525,41 @@ static void process_handler_response(server_http_req_ptr && request, server_http
|
||||
set_headers(res, response->headers);
|
||||
const std::string content_type = response->content_type;
|
||||
// convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it
|
||||
std::shared_ptr q_ptr = std::move(request);
|
||||
std::shared_ptr r_ptr = std::move(response);
|
||||
const auto chunked_content_provider = [response = r_ptr](size_t, const httplib::DataSink & sink) -> bool {
|
||||
std::shared_ptr<server_http_req> q_ptr = std::move(request);
|
||||
std::shared_ptr<server_http_res> r_ptr = std::move(response);
|
||||
|
||||
const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool {
|
||||
std::string chunk;
|
||||
const bool has_next = response->next(chunk);
|
||||
if (!chunk.empty()) {
|
||||
// mirror into the ring buffer first, the session must reflect every SSE chunk
|
||||
// whether or not the wire write below succeeds
|
||||
if (response->spipe) {
|
||||
response->spipe->write(chunk.data(), chunk.size());
|
||||
}
|
||||
if (!sink.write(chunk.data(), chunk.size())) {
|
||||
// peer is gone, stop the wire path here
|
||||
return false;
|
||||
}
|
||||
SRV_DBG("http: streamed chunk: %s\n", chunk.c_str());
|
||||
}
|
||||
if (!has_next) {
|
||||
// producer reached its natural end on the wire, a later close() skips the drain
|
||||
if (response->spipe) {
|
||||
response->spipe->done();
|
||||
}
|
||||
sink.done();
|
||||
SRV_DBG("%s", "http: stream ended\n");
|
||||
}
|
||||
return has_next;
|
||||
};
|
||||
const auto on_complete = [request = q_ptr, response = r_ptr](bool) mutable {
|
||||
response.reset(); // trigger the destruction of the response object
|
||||
request.reset(); // trigger the destruction of the request object
|
||||
// on a dropped peer, close() drains the rest of the generation into the ring buffer
|
||||
if (response->spipe) {
|
||||
response->spipe->close();
|
||||
}
|
||||
response.reset(); // spipe destructor finalizes the session if attached
|
||||
request.reset();
|
||||
};
|
||||
res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete);
|
||||
} else {
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
@@ -10,6 +11,7 @@
|
||||
#include <unordered_map>
|
||||
|
||||
struct common_params;
|
||||
struct stream_pipe_producer; // defined in server-stream.h
|
||||
|
||||
// generator-like API for HTTP response generation
|
||||
// this object response with one of the 2 modes:
|
||||
@@ -23,12 +25,20 @@ struct server_http_res {
|
||||
std::string data;
|
||||
std::map<std::string, std::string> headers;
|
||||
|
||||
// TODO: move this to a virtual function once we have proper polymorphism support
|
||||
// if set, the stream survives a client disconnect: the producer pipe keeps draining into the
|
||||
// ring buffer and finalizes the session on destruction, so no explicit on_stream_end is needed.
|
||||
// shared_ptr (not unique_ptr) so the forward-declared type is safe to delete here.
|
||||
std::shared_ptr<stream_pipe_producer> spipe;
|
||||
|
||||
std::function<bool(std::string &)> next = nullptr;
|
||||
bool is_stream() const {
|
||||
return next != nullptr;
|
||||
}
|
||||
|
||||
// called when the session is cancelled (e.g. DELETE /v1/stream/<conv_id>).
|
||||
// server_res_generator overrides this to stop its reader; the default is a no-op.
|
||||
virtual void stop() {}
|
||||
|
||||
virtual ~server_http_res() = default;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
#include "server-common.h"
|
||||
#include "server-models.h"
|
||||
#include "server-context.h"
|
||||
#include "server-stream.h"
|
||||
|
||||
#include "build-info.h"
|
||||
#include "preset.h"
|
||||
#include "download.h"
|
||||
|
||||
#include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
|
||||
#include <optional>
|
||||
#include <sheredom/subprocess.h>
|
||||
|
||||
#include <functional>
|
||||
@@ -92,6 +94,9 @@ struct server_subproc {
|
||||
}
|
||||
};
|
||||
|
||||
// short loopback budget for the resumable stream router to child JSON calls (probe, lookup,
|
||||
// delete). distinct from params.timeout_read/write which only applies to the generation proxy
|
||||
static constexpr int STREAM_LOOKUP_TIMEOUT_MS = 250;
|
||||
|
||||
static std::filesystem::path get_server_exec_path() {
|
||||
#if defined(_WIN32)
|
||||
@@ -1580,6 +1585,45 @@ static bool is_autoload(const common_params & params, const server_http_req & re
|
||||
}
|
||||
}
|
||||
|
||||
// percent encode one query or path component, covers reserved chars without pulling in
|
||||
// httplib::detail. used by the stream routes to forward conversation_id to children safely
|
||||
static std::string encode_qs(const std::string & in) {
|
||||
std::string out;
|
||||
out.reserve(in.size() * 3);
|
||||
for (unsigned char c : in) {
|
||||
bool safe = (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9')
|
||||
|| c == '-' || c == '_' || c == '.' || c == '~';
|
||||
if (safe) {
|
||||
out.push_back(char(c));
|
||||
} else {
|
||||
char buf[4];
|
||||
std::snprintf(buf, sizeof(buf), "%%%02X", c);
|
||||
out.append(buf, 3);
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// resolve the child that owns a conversation's stream session via the conv_id -> model map
|
||||
// populated when the POST was routed. single map lookup then a meta lookup, no polling, no
|
||||
// parsing of the conv id. returns nullopt when nothing maps, the caller answers not found and
|
||||
// the client recovers
|
||||
static std::optional<server_model_meta> resolve_child_for_conv(
|
||||
server_models & models, const std::string & conversation_id) {
|
||||
if (conversation_id.empty()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
auto tracked = models.conv_models.lookup(conversation_id);
|
||||
if (!tracked.has_value()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
auto meta = models.get_meta(*tracked);
|
||||
if (meta.has_value() && meta->is_ready()) {
|
||||
return meta;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
void server_models_routes::init_routes() {
|
||||
this->get_router_props = [this](const server_http_req & req) {
|
||||
std::string name = req.get_param("model");
|
||||
@@ -1628,6 +1672,12 @@ void server_models_routes::init_routes() {
|
||||
if (!router_validate_model(name, models, autoload, error_res)) {
|
||||
return error_res;
|
||||
}
|
||||
// remember which child serves this conversation so the stream routes can route straight
|
||||
// to it without polling, keyed on the exact conv id from the header
|
||||
std::string conv_id = stream_conv_id_from_headers(req.headers);
|
||||
if (!conv_id.empty()) {
|
||||
models.conv_models.remember(conv_id, name);
|
||||
}
|
||||
return models.proxy_request(req, method, name, true); // update last usage for POST request only
|
||||
};
|
||||
|
||||
@@ -1819,6 +1869,128 @@ void server_models_routes::init_routes() {
|
||||
res_ok(res, {{"success", true}});
|
||||
return res;
|
||||
};
|
||||
|
||||
this->router_stream_get = [this](const server_http_req & req) {
|
||||
// GET /v1/stream/<conv_id>?from=N. resolve the owning child from the conv_id -> model
|
||||
// map, 404 when nothing maps
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
std::string conv_id = req.get_param("conv_id");
|
||||
if (conv_id.empty()) {
|
||||
res_err(res, format_error_response("Missing conversation id in path", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
std::optional<server_model_meta> owner = resolve_child_for_conv(models, conv_id);
|
||||
if (!owner.has_value()) {
|
||||
res_err(res, format_error_response("Stream not found or expired", ERROR_TYPE_NOT_FOUND));
|
||||
return res;
|
||||
}
|
||||
std::string from = req.get_param("from");
|
||||
std::string child_path = "/v1/stream/" + encode_qs(conv_id);
|
||||
if (!from.empty()) {
|
||||
child_path += "?from=" + from;
|
||||
}
|
||||
SRV_INF("proxying stream resume to model %s on port %d, path=%s\n",
|
||||
owner->name.c_str(), owner->port, child_path.c_str());
|
||||
auto proxy = std::make_unique<server_http_proxy>(
|
||||
"GET",
|
||||
"http",
|
||||
CHILD_ADDR,
|
||||
owner->port,
|
||||
child_path,
|
||||
req.headers,
|
||||
req.body,
|
||||
req.files,
|
||||
req.should_stop,
|
||||
params.timeout_read,
|
||||
params.timeout_write);
|
||||
return std::unique_ptr<server_http_res>(std::move(proxy));
|
||||
};
|
||||
|
||||
this->router_streams_lookup = [this](const server_http_req & req) {
|
||||
// POST /v1/streams/lookup. resolve each requested conv id to its owning child via the
|
||||
// map, group the ids per child, and query only the children that actually own some of
|
||||
// them instead of fanning out to every ready child. a child only answers for the ids
|
||||
// it owns, never lists anything else
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
std::vector<std::string> requested;
|
||||
try {
|
||||
json body = json::parse(req.body);
|
||||
if (body.contains("conversation_ids") && body["conversation_ids"].is_array()) {
|
||||
for (const auto & v : body["conversation_ids"]) {
|
||||
if (v.is_string() && !v.get<std::string>().empty()) {
|
||||
requested.push_back(v.get<std::string>());
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (const std::exception &) {
|
||||
res_ok(res, json::array());
|
||||
return res;
|
||||
}
|
||||
|
||||
// group requested ids by the child port that owns them, drop ids that map to nothing
|
||||
std::unordered_map<int, json> per_child;
|
||||
for (const auto & cid : requested) {
|
||||
auto owner = resolve_child_for_conv(models, cid);
|
||||
if (!owner.has_value()) {
|
||||
continue;
|
||||
}
|
||||
per_child[owner->port].push_back(cid);
|
||||
}
|
||||
|
||||
json aggregated = json::array();
|
||||
for (auto & [port, ids] : per_child) {
|
||||
json child_body = {{"conversation_ids", ids}};
|
||||
httplib::Client cli(CHILD_ADDR, port);
|
||||
cli.set_connection_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
|
||||
cli.set_read_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
|
||||
cli.set_write_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
|
||||
auto resp = cli.Post("/v1/streams/lookup", child_body.dump(), "application/json");
|
||||
if (!resp || resp->status != 200) {
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
json child_arr = json::parse(resp->body);
|
||||
if (!child_arr.is_array()) {
|
||||
continue;
|
||||
}
|
||||
for (auto & entry : child_arr) {
|
||||
if (entry.is_object()) {
|
||||
aggregated.push_back(entry);
|
||||
}
|
||||
}
|
||||
} catch (const std::exception &) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
res_ok(res, aggregated);
|
||||
return res;
|
||||
};
|
||||
|
||||
this->router_stream_delete = [this](const server_http_req & req) {
|
||||
// DELETE /v1/stream/<conv_id>. resolve the owning child via the map and forward only to
|
||||
// it, evict_and_cancel is idempotent on the child
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
std::string conv_id = req.get_param("conv_id");
|
||||
if (conv_id.empty()) {
|
||||
res_err(res, format_error_response("Missing conversation id in path", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
std::string child_path = "/v1/stream/" + encode_qs(conv_id);
|
||||
auto owner = resolve_child_for_conv(models, conv_id);
|
||||
if (owner.has_value()) {
|
||||
httplib::Client cli(CHILD_ADDR, owner->port);
|
||||
cli.set_connection_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
|
||||
cli.set_read_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
|
||||
cli.set_write_timeout(0, STREAM_LOOKUP_TIMEOUT_MS * 1000);
|
||||
auto resp = cli.Delete(child_path.c_str());
|
||||
(void) resp; // best effort, 404 and network errors are equivalent to no op
|
||||
}
|
||||
// drop the tracking entry, the session is being torn down
|
||||
models.conv_models.forget(conv_id);
|
||||
res->status = 204;
|
||||
res->content_type = "application/json";
|
||||
return res;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,10 @@
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
/**
|
||||
* state diagram:
|
||||
@@ -126,6 +129,44 @@ private:
|
||||
// if true, the next get_meta() will trigger a reload of model list
|
||||
bool need_reload = false;
|
||||
|
||||
// conv_id -> model name that currently serves its stream session, lets the resumable stream
|
||||
// routes go straight to the owning child instead of polling every one. populated when
|
||||
// proxy_request forwards a POST carrying an X-Conversation-Id. best effort: a stale entry just
|
||||
// makes the child answer not found and the client recovers. owns its lock, one mutex per struct
|
||||
struct conv_model_tracker {
|
||||
void remember(const std::string & conv_id, const std::string & model) {
|
||||
if (conv_id.empty() || model.empty()) {
|
||||
return;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(mu);
|
||||
map[conv_id] = model;
|
||||
}
|
||||
|
||||
std::optional<std::string> lookup(const std::string & conv_id) {
|
||||
if (conv_id.empty()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(mu);
|
||||
auto it = map.find(conv_id);
|
||||
if (it == map.end()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void forget(const std::string & conv_id) {
|
||||
if (conv_id.empty()) {
|
||||
return;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(mu);
|
||||
map.erase(conv_id);
|
||||
}
|
||||
|
||||
private:
|
||||
std::mutex mu;
|
||||
std::unordered_map<std::string, std::string> map;
|
||||
};
|
||||
|
||||
common_preset_context ctx_preset;
|
||||
|
||||
common_params base_params;
|
||||
@@ -145,6 +186,9 @@ private:
|
||||
void notify_sse(const std::string & event, const std::string & model_id, const json & data = nullptr);
|
||||
|
||||
public:
|
||||
// conv_id -> model tracker for the resumable stream routes, owns its lock
|
||||
conv_model_tracker conv_models;
|
||||
|
||||
server_models(const common_params & params, int argc, char ** argv);
|
||||
|
||||
server_response sse; // for real-time updates via SSE endpoint
|
||||
@@ -268,6 +312,12 @@ struct server_models_routes {
|
||||
server_http_context::handler_t get_router_models_sse;
|
||||
server_http_context::handler_t post_router_models;
|
||||
server_http_context::handler_t del_router_models;
|
||||
|
||||
// router side handlers for the resumable streaming routes. each resolves the child that owns
|
||||
// a conversation through the conv_id -> model map, no probing or fan out
|
||||
server_http_context::handler_t router_stream_get;
|
||||
server_http_context::handler_t router_streams_lookup;
|
||||
server_http_context::handler_t router_stream_delete;
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -0,0 +1,569 @@
|
||||
#include "server-stream.h"
|
||||
#include "server-common.h"
|
||||
#include "server-http.h"
|
||||
#include "server-queue.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
namespace {
|
||||
constexpr int64_t STREAM_SESSION_TTL_SECONDS = 300;
|
||||
constexpr size_t STREAM_SESSION_MAX_BYTES = 4 * 1024 * 1024;
|
||||
constexpr int64_t STREAM_SESSION_GC_INTERVAL_SECONDS = 60;
|
||||
constexpr int64_t STREAM_READ_WAKE_INTERVAL_MS = 200;
|
||||
|
||||
// returns unix time in seconds
|
||||
int64_t now_seconds() {
|
||||
return std::chrono::duration_cast<std::chrono::seconds>(
|
||||
std::chrono::system_clock::now().time_since_epoch()
|
||||
).count();
|
||||
}
|
||||
}
|
||||
|
||||
stream_session::stream_session(std::string conversation_id_, size_t max_bytes_)
|
||||
: conversation_id(std::move(conversation_id_))
|
||||
, started_ts(now_seconds())
|
||||
, prefix_dropped(0)
|
||||
, cap_bytes(max_bytes_)
|
||||
, done(false)
|
||||
, cancelled(false)
|
||||
, completed_ts(0) {
|
||||
buffer.reserve(64 * 1024);
|
||||
}
|
||||
|
||||
bool stream_session::append(const char * data, size_t len) {
|
||||
if (len == 0) {
|
||||
return true;
|
||||
}
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mu);
|
||||
if (done.load(std::memory_order_relaxed)) {
|
||||
return false;
|
||||
}
|
||||
if (len >= cap_bytes) {
|
||||
// single chunk bigger than the cap, keep only the tail that fits
|
||||
size_t skip = len - cap_bytes;
|
||||
prefix_dropped += buffer.size() + skip;
|
||||
buffer.clear();
|
||||
buffer.insert(buffer.end(), data + skip, data + len);
|
||||
} else {
|
||||
size_t needed = buffer.size() + len;
|
||||
if (needed > cap_bytes) {
|
||||
size_t to_drop = needed - cap_bytes;
|
||||
buffer.erase(buffer.begin(), buffer.begin() + to_drop);
|
||||
prefix_dropped += to_drop;
|
||||
}
|
||||
buffer.insert(buffer.end(), data, data + len);
|
||||
}
|
||||
}
|
||||
cv.notify_all();
|
||||
return true;
|
||||
}
|
||||
|
||||
void stream_session::finalize() {
|
||||
bool was_done = done.exchange(true, std::memory_order_acq_rel);
|
||||
if (was_done) {
|
||||
return;
|
||||
}
|
||||
completed_ts.store(now_seconds(), std::memory_order_release);
|
||||
cv.notify_all();
|
||||
}
|
||||
|
||||
stream_read_status stream_session::read_from(size_t offset,
|
||||
const std::function<bool(const char *, size_t)> & sink,
|
||||
const std::function<bool()> & should_stop) {
|
||||
std::unique_lock<std::mutex> lock(mu);
|
||||
while (true) {
|
||||
if (should_stop && should_stop()) {
|
||||
return stream_read_status::OK;
|
||||
}
|
||||
if (offset < prefix_dropped) {
|
||||
return stream_read_status::OFFSET_LOST;
|
||||
}
|
||||
size_t logical_end = prefix_dropped + buffer.size();
|
||||
if (offset < logical_end) {
|
||||
size_t local_off = offset - prefix_dropped;
|
||||
size_t n = buffer.size() - local_off;
|
||||
// copy the available chunk under the lock, release before calling the sink
|
||||
std::vector<char> chunk(buffer.begin() + local_off, buffer.begin() + local_off + n);
|
||||
offset += n;
|
||||
lock.unlock();
|
||||
bool keep_going = sink(chunk.data(), chunk.size());
|
||||
if (!keep_going) {
|
||||
return stream_read_status::OK;
|
||||
}
|
||||
lock.lock();
|
||||
continue;
|
||||
}
|
||||
if (done.load(std::memory_order_acquire)) {
|
||||
return stream_read_status::OK;
|
||||
}
|
||||
// wait for new bytes, finalize, or a periodic wake to re check should_stop
|
||||
cv.wait_for(lock, std::chrono::milliseconds(STREAM_READ_WAKE_INTERVAL_MS));
|
||||
}
|
||||
}
|
||||
|
||||
bool stream_session::is_done() const {
|
||||
return done.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
size_t stream_session::total_size() const {
|
||||
std::lock_guard<std::mutex> lock(mu);
|
||||
return prefix_dropped + buffer.size();
|
||||
}
|
||||
|
||||
size_t stream_session::dropped_prefix() const {
|
||||
std::lock_guard<std::mutex> lock(mu);
|
||||
return prefix_dropped;
|
||||
}
|
||||
|
||||
int64_t stream_session::completed_at() const {
|
||||
return completed_ts.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
void stream_session::set_stop_producer(std::function<void()> fn) {
|
||||
std::lock_guard<std::mutex> lock(mu);
|
||||
stop_producer = std::move(fn);
|
||||
}
|
||||
|
||||
void stream_session::cancel() {
|
||||
// flip cancelled first so the producer-side stream_aware_should_stop can break out of the
|
||||
// recv() wait even if remove_waiting_task_ids does not notify the condvar (the cancel task
|
||||
// posted by rd.stop() will eventually notify, but we do not want to depend on that timing)
|
||||
cancelled.store(true, std::memory_order_release);
|
||||
// copy the hook under the lock then invoke outside, the producer side may grab queue locks
|
||||
// and we do not want to hold our mu across that path
|
||||
std::function<void()> fn;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mu);
|
||||
fn = stop_producer;
|
||||
}
|
||||
if (fn) {
|
||||
fn();
|
||||
}
|
||||
}
|
||||
|
||||
bool stream_session::is_cancelled() const {
|
||||
return cancelled.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
stream_session_manager::stream_session_manager()
|
||||
: running(false) {
|
||||
}
|
||||
|
||||
stream_session_manager::~stream_session_manager() {
|
||||
stop_gc();
|
||||
}
|
||||
|
||||
stream_session_ptr stream_session_manager::create_or_replace(const std::string & conversation_id) {
|
||||
// evict any previous session on the same conv, this guarantees the invariant
|
||||
// "one conv = at most one live session" and propagates cancel to its producer
|
||||
stream_session_ptr previous;
|
||||
auto fresh = std::make_shared<stream_session>(conversation_id, STREAM_SESSION_MAX_BYTES);
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(map_mu);
|
||||
auto it = sessions.find(conversation_id);
|
||||
if (it != sessions.end()) {
|
||||
previous = it->second;
|
||||
it->second = fresh;
|
||||
} else {
|
||||
sessions.emplace(conversation_id, fresh);
|
||||
}
|
||||
}
|
||||
if (previous) {
|
||||
previous->cancel();
|
||||
previous->finalize();
|
||||
}
|
||||
return fresh;
|
||||
}
|
||||
|
||||
stream_session_ptr stream_session_manager::get(const std::string & conversation_id) {
|
||||
std::shared_lock<std::shared_mutex> lock(map_mu);
|
||||
auto it = sessions.find(conversation_id);
|
||||
if (it == sessions.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<stream_session_ptr> stream_session_manager::list_all() const {
|
||||
std::vector<stream_session_ptr> out;
|
||||
std::shared_lock<std::shared_mutex> lock(map_mu);
|
||||
out.reserve(sessions.size());
|
||||
for (auto & kv : sessions) {
|
||||
out.push_back(kv.second);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
void stream_session_manager::evict(const std::string & conversation_id) {
|
||||
stream_session_ptr s;
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(map_mu);
|
||||
auto it = sessions.find(conversation_id);
|
||||
if (it == sessions.end()) {
|
||||
return;
|
||||
}
|
||||
s = it->second;
|
||||
sessions.erase(it);
|
||||
}
|
||||
// finalize outside the map lock so any pending readers wake up and exit
|
||||
s->finalize();
|
||||
}
|
||||
|
||||
void stream_session_manager::evict_and_cancel(const std::string & conversation_id) {
|
||||
stream_session_ptr s;
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(map_mu);
|
||||
auto it = sessions.find(conversation_id);
|
||||
if (it == sessions.end()) {
|
||||
return;
|
||||
}
|
||||
s = it->second;
|
||||
sessions.erase(it);
|
||||
}
|
||||
// signal the producer side first so the inference is cancelled at the queue level,
|
||||
// then finalize, which wakes any pending HTTP reader and lets the drain exit naturally
|
||||
s->cancel();
|
||||
s->finalize();
|
||||
}
|
||||
|
||||
void stream_session_manager::start_gc() {
|
||||
if (running.exchange(true)) {
|
||||
return;
|
||||
}
|
||||
gc_thread = std::thread([this] { gc_loop(); });
|
||||
}
|
||||
|
||||
void stream_session_manager::stop_gc() {
|
||||
bool was_running = running.exchange(false);
|
||||
if (was_running) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(gc_wake_mu);
|
||||
}
|
||||
gc_wake_cv.notify_all();
|
||||
if (gc_thread.joinable()) {
|
||||
gc_thread.join();
|
||||
}
|
||||
}
|
||||
// finalize all live sessions so no reader ever hangs
|
||||
std::vector<stream_session_ptr> snapshot;
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(map_mu);
|
||||
snapshot.reserve(sessions.size());
|
||||
for (auto & kv : sessions) {
|
||||
snapshot.push_back(kv.second);
|
||||
}
|
||||
sessions.clear();
|
||||
}
|
||||
for (auto & s : snapshot) {
|
||||
s->finalize();
|
||||
}
|
||||
}
|
||||
|
||||
void stream_session_manager::gc_loop() {
|
||||
while (running.load(std::memory_order_acquire)) {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(gc_wake_mu);
|
||||
gc_wake_cv.wait_for(lock,
|
||||
std::chrono::seconds(STREAM_SESSION_GC_INTERVAL_SECONDS),
|
||||
[this] { return !running.load(std::memory_order_acquire); });
|
||||
}
|
||||
if (!running.load(std::memory_order_acquire)) {
|
||||
return;
|
||||
}
|
||||
int64_t cutoff = now_seconds() - STREAM_SESSION_TTL_SECONDS;
|
||||
std::vector<stream_session_ptr> to_drop;
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(map_mu);
|
||||
for (auto it = sessions.begin(); it != sessions.end(); ) {
|
||||
int64_t completed = it->second->completed_at();
|
||||
if (completed != 0 && completed <= cutoff) {
|
||||
to_drop.push_back(it->second);
|
||||
it = sessions.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
// finalize outside the map lock, idempotent if the session was already done
|
||||
for (auto & s : to_drop) {
|
||||
s->finalize();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// process wide manager, lifecycle controlled by llama-server main() via start_gc/stop_gc
|
||||
stream_session_manager g_stream_sessions;
|
||||
|
||||
// stream_pipe ---------------------------------------------------------------------------------
|
||||
|
||||
stream_pipe::stream_pipe(stream_session_ptr session)
|
||||
: session_(std::move(session)) {
|
||||
}
|
||||
|
||||
bool stream_pipe::is_cancelled() const {
|
||||
return session_->is_cancelled();
|
||||
}
|
||||
|
||||
// stream_pipe_producer
|
||||
|
||||
stream_pipe_producer::stream_pipe_producer(stream_session_ptr session)
|
||||
: stream_pipe(std::move(session)) {
|
||||
}
|
||||
|
||||
stream_pipe_producer::~stream_pipe_producer() {
|
||||
cleanup();
|
||||
session_->finalize();
|
||||
}
|
||||
|
||||
void stream_pipe_producer::cleanup() {
|
||||
if (!alive_) {
|
||||
return;
|
||||
}
|
||||
alive_->store(false, std::memory_order_release);
|
||||
session_->set_stop_producer(nullptr);
|
||||
alive_.reset();
|
||||
}
|
||||
|
||||
bool stream_pipe_producer::write(const char * data, size_t len) {
|
||||
return session_->append(data, len);
|
||||
}
|
||||
|
||||
void stream_pipe_producer::done() {
|
||||
done_ = true;
|
||||
}
|
||||
|
||||
void stream_pipe_producer::close() {
|
||||
// httplib bails its content provider the moment is_peer_alive() goes false, so pump the rest
|
||||
// of the generation into the ring buffer here. a DELETE flips is_cancelled and cuts it short
|
||||
if (done_ || session_->is_cancelled()) {
|
||||
SRV_INF("stream_pipe close: skip drain (done=%d cancelled=%d) conv=%s\n",
|
||||
done_ ? 1 : 0, session_->is_cancelled() ? 1 : 0, session_->conversation_id.c_str());
|
||||
return;
|
||||
}
|
||||
SRV_INF("stream_pipe close: draining conv=%s\n", session_->conversation_id.c_str());
|
||||
size_t drained = 0;
|
||||
std::string chunk;
|
||||
while (true) {
|
||||
chunk.clear();
|
||||
bool has_next = res_->next(chunk);
|
||||
if (!chunk.empty()) {
|
||||
write(chunk.data(), chunk.size());
|
||||
drained += chunk.size();
|
||||
}
|
||||
if (!has_next) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
SRV_INF("stream_pipe close: drain ended conv=%s bytes=%zu\n", session_->conversation_id.c_str(), drained);
|
||||
}
|
||||
|
||||
std::shared_ptr<stream_pipe_producer> stream_pipe_producer::create(stream_session_ptr session,
|
||||
server_http_res & res) {
|
||||
auto alive = std::make_shared<std::atomic<bool>>(true);
|
||||
auto * res_ptr = &res;
|
||||
session->set_stop_producer([alive, res_ptr]() {
|
||||
if (alive->load(std::memory_order_acquire)) {
|
||||
res_ptr->stop();
|
||||
}
|
||||
});
|
||||
auto pipe = std::shared_ptr<stream_pipe_producer>(new stream_pipe_producer(std::move(session)));
|
||||
pipe->alive_ = std::move(alive);
|
||||
pipe->res_ = res_ptr;
|
||||
return pipe;
|
||||
}
|
||||
|
||||
// stream_pipe_consumer
|
||||
|
||||
stream_pipe_consumer::stream_pipe_consumer(stream_session_ptr session)
|
||||
: stream_pipe(std::move(session)) {
|
||||
}
|
||||
|
||||
stream_read_status stream_pipe_consumer::read(size_t & offset,
|
||||
const std::function<bool(const char *, size_t)> & sink,
|
||||
const std::function<bool()> & should_stop) {
|
||||
return session_->read_from(offset, sink, should_stop);
|
||||
}
|
||||
|
||||
std::shared_ptr<stream_pipe_consumer> stream_pipe_consumer::create(stream_session_ptr session) {
|
||||
return std::shared_ptr<stream_pipe_consumer>(new stream_pipe_consumer(std::move(session)));
|
||||
}
|
||||
|
||||
// helper, builds the standard error response and assigns it to a brand new http_res
|
||||
static server_http_res_ptr make_error_response(int status, const std::string & message, error_type type) {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
json err = format_error_response(message, type);
|
||||
res->status = json_value(err, "code", status);
|
||||
res->content_type = "application/json; charset=utf-8";
|
||||
res->data = safe_json_to_str({{"error", err}});
|
||||
return res;
|
||||
}
|
||||
|
||||
server_http_context::handler_t make_stream_get_handler() {
|
||||
return [](const server_http_req & req) -> server_http_res_ptr {
|
||||
// GET /v1/stream/<conv_id>?from=N replays the SSE bytes already buffered for the
|
||||
// session, blocks for more bytes when the session is still running, returns when
|
||||
// the session is finalized. the body is streamed back as text/event-stream so the
|
||||
// browser EventSource can attach to it like a fresh request
|
||||
std::string conv_id = req.get_param("conv_id");
|
||||
if (conv_id.empty()) {
|
||||
return make_error_response(400, "Missing conversation id in path", ERROR_TYPE_INVALID_REQUEST);
|
||||
}
|
||||
auto session = g_stream_sessions.get(conv_id);
|
||||
if (!session) {
|
||||
return make_error_response(404, "Stream not found or expired", ERROR_TYPE_NOT_FOUND);
|
||||
}
|
||||
size_t from = 0;
|
||||
std::string from_str = req.get_param("from");
|
||||
if (!from_str.empty()) {
|
||||
try {
|
||||
from = static_cast<size_t>(std::stoull(from_str));
|
||||
} catch (const std::exception &) {
|
||||
return make_error_response(400, "Invalid 'from' offset", ERROR_TYPE_INVALID_REQUEST);
|
||||
}
|
||||
}
|
||||
if (from < session->dropped_prefix()) {
|
||||
return make_error_response(400, "Stream offset lost, please restart", ERROR_TYPE_INVALID_REQUEST);
|
||||
}
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->status = 200;
|
||||
res->content_type = "text/event-stream";
|
||||
// the next closure reads from the ring buffer at the requested offset, blocks until
|
||||
// bytes arrive or the session finalizes. exit each call after draining the available
|
||||
// chunk so set_chunked_content_provider gets a chance to flush to the socket
|
||||
auto offset_ptr = std::make_shared<size_t>(from);
|
||||
// consumer pipe: read-only, does not finalize the session on destruction
|
||||
auto pipe = stream_pipe_consumer::create(session);
|
||||
res->next = [pipe, offset_ptr, &req](std::string & output) -> bool {
|
||||
bool got_any = false;
|
||||
pipe->read(*offset_ptr,
|
||||
[&](const char * d, size_t n) {
|
||||
output.append(d, n);
|
||||
*offset_ptr += n;
|
||||
got_any = true;
|
||||
return false;
|
||||
},
|
||||
req.should_stop);
|
||||
return got_any;
|
||||
};
|
||||
return res;
|
||||
};
|
||||
}
|
||||
|
||||
server_http_context::handler_t make_streams_lookup_handler() {
|
||||
return [](const server_http_req & req) -> server_http_res_ptr {
|
||||
// POST /v1/streams/lookup with body {"conversation_ids": ["X", "Y", ...]} returns the
|
||||
// matching sessions, only for ids the caller already knows. each id matches the exact key
|
||||
// and any "<id>::<model>" variant, so one lookup covers every per model session for a conv
|
||||
std::vector<std::string> requested;
|
||||
try {
|
||||
json body = json::parse(req.body);
|
||||
if (body.contains("conversation_ids") && body["conversation_ids"].is_array()) {
|
||||
for (const auto & v : body["conversation_ids"]) {
|
||||
if (v.is_string()) {
|
||||
std::string id = v.get<std::string>();
|
||||
if (!id.empty()) {
|
||||
requested.push_back(std::move(id));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->status = 400;
|
||||
res->content_type = "application/json; charset=utf-8";
|
||||
res->data = safe_json_to_str({{"error", {{"message", std::string("invalid body: ") + e.what()},
|
||||
{"type", "invalid_request_error"}}}});
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<stream_session_ptr> sessions;
|
||||
if (!requested.empty()) {
|
||||
auto all = g_stream_sessions.list_all();
|
||||
for (const auto & rid : requested) {
|
||||
const std::string with_sep = rid + "::";
|
||||
for (auto & s : all) {
|
||||
if (s->conversation_id == rid ||
|
||||
s->conversation_id.compare(0, with_sep.size(), with_sep) == 0) {
|
||||
sessions.push_back(s);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
json arr = json::array();
|
||||
for (auto & s : sessions) {
|
||||
arr.push_back({
|
||||
{"conversation_id", s->conversation_id},
|
||||
{"is_done", s->is_done()},
|
||||
{"total_bytes", s->total_size()},
|
||||
{"started_at", s->started_ts},
|
||||
{"completed_at", s->completed_at()},
|
||||
});
|
||||
}
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->status = 200;
|
||||
res->content_type = "application/json; charset=utf-8";
|
||||
res->data = safe_json_to_str(arr);
|
||||
return res;
|
||||
};
|
||||
}
|
||||
|
||||
server_http_context::handler_t make_stream_delete_handler() {
|
||||
return [](const server_http_req & req) -> server_http_res_ptr {
|
||||
// DELETE /v1/stream/<conv_id> is the explicit user Stop, cancels the producer hook
|
||||
// wired by handle_completions_impl and evicts the buffer. idempotent, a session that
|
||||
// already finalized or was never created returns 204 either way
|
||||
std::string conv_id = req.get_param("conv_id");
|
||||
if (conv_id.empty()) {
|
||||
return make_error_response(400, "Missing conversation id in path", ERROR_TYPE_INVALID_REQUEST);
|
||||
}
|
||||
SRV_INF("DELETE /v1/stream/%s -> evict_and_cancel\n", conv_id.c_str());
|
||||
g_stream_sessions.evict_and_cancel(conv_id);
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->status = 204;
|
||||
res->content_type = "application/json";
|
||||
return res;
|
||||
};
|
||||
}
|
||||
|
||||
std::string stream_conv_id_from_headers(const std::map<std::string, std::string> & headers) {
|
||||
// case-insensitive scan for x-conversation-id
|
||||
static constexpr char target[] = "x-conversation-id";
|
||||
static constexpr size_t target_len = sizeof(target) - 1;
|
||||
for (const auto & [hk, hv] : headers) {
|
||||
if (hk.size() != target_len) continue;
|
||||
bool match = true;
|
||||
for (size_t i = 0; i < target_len; ++i) {
|
||||
char c = hk[i];
|
||||
if (c >= 'A' && c <= 'Z') c = char(c + 32);
|
||||
if (c != target[i]) { match = false; break; }
|
||||
}
|
||||
if (match) {
|
||||
return hv;
|
||||
}
|
||||
}
|
||||
return std::string();
|
||||
}
|
||||
|
||||
void stream_session_attach_pipe(server_http_res & res, const std::map<std::string, std::string> & headers) {
|
||||
std::string conversation_id = stream_conv_id_from_headers(headers);
|
||||
SRV_INF("stream_session_attach_pipe: conv_id=%s (empty=%d)\n",
|
||||
conversation_id.c_str(), conversation_id.empty() ? 1 : 0);
|
||||
if (conversation_id.empty()) {
|
||||
return;
|
||||
}
|
||||
auto session = g_stream_sessions.create_or_replace(conversation_id);
|
||||
res.spipe = stream_pipe_producer::create(session, res);
|
||||
}
|
||||
|
||||
std::function<bool()> stream_aware_should_stop(server_http_res * res, std::function<bool()> fallback) {
|
||||
return [res, fallback = std::move(fallback)]() -> bool {
|
||||
if (res->spipe) {
|
||||
return res->spipe->is_cancelled();
|
||||
}
|
||||
return fallback();
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
#pragma once
|
||||
|
||||
#include "server-http.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
enum class stream_read_status {
|
||||
OK,
|
||||
OFFSET_LOST,
|
||||
};
|
||||
|
||||
// streaming buffer for one generation, survives HTTP disconnect. the producer appends raw SSE
|
||||
// bytes, readers drain from any offset via read_from and block until more bytes or finalize.
|
||||
// keyed by conversation_id: one conv = at most one live session
|
||||
struct stream_session {
|
||||
std::string conversation_id;
|
||||
int64_t started_ts; // unix seconds at construction, used by /v1/streams listing
|
||||
|
||||
stream_session(std::string conversation_id_, size_t max_bytes_);
|
||||
stream_session(const stream_session &) = delete;
|
||||
stream_session & operator=(const stream_session &) = delete;
|
||||
|
||||
// append raw bytes, drops from the front if the cap is reached.
|
||||
// returns false if the session is already finalized
|
||||
bool append(const char * data, size_t len);
|
||||
|
||||
// mark the session as complete, wakes all pending readers
|
||||
void finalize();
|
||||
|
||||
// drain bytes from offset, calling sink for each chunk. blocks until more
|
||||
// bytes arrive or finalize is called. returns OK on clean exit, OFFSET_LOST
|
||||
// if offset falls below the dropped prefix
|
||||
stream_read_status read_from(size_t offset,
|
||||
const std::function<bool(const char *, size_t)> & sink,
|
||||
const std::function<bool()> & should_stop);
|
||||
|
||||
bool is_done() const;
|
||||
bool is_cancelled() const;
|
||||
size_t total_size() const; // bytes that ever entered the session
|
||||
size_t dropped_prefix() const; // bytes evicted from the front due to cap
|
||||
int64_t completed_at() const; // 0 while alive, unix seconds after finalize
|
||||
|
||||
// attach the producer stop hook used to cancel its reader, pass an empty function to detach
|
||||
void set_stop_producer(std::function<void()> fn);
|
||||
|
||||
// signal the producer to abort its inference asap via the stop hook, idempotent
|
||||
void cancel();
|
||||
|
||||
private:
|
||||
mutable std::mutex mu;
|
||||
std::condition_variable cv;
|
||||
std::vector<char> buffer;
|
||||
size_t prefix_dropped;
|
||||
size_t cap_bytes;
|
||||
std::atomic<bool> done;
|
||||
std::atomic<bool> cancelled;
|
||||
std::atomic<int64_t> completed_ts;
|
||||
std::function<void()> stop_producer; // protected by mu
|
||||
};
|
||||
|
||||
using stream_session_ptr = std::shared_ptr<stream_session>;
|
||||
|
||||
// one end of a stream_session pipe. the base holds the session and the shared query, the
|
||||
// producer and consumer ends derive from it. virtual dtor so each end runs its own teardown:
|
||||
// the producer finalizes the session, the consumer leaves it untouched
|
||||
struct stream_pipe {
|
||||
virtual ~stream_pipe() = default;
|
||||
|
||||
// true if the session was cancelled (e.g. via DELETE /v1/stream/<conv_id>)
|
||||
bool is_cancelled() const;
|
||||
|
||||
protected:
|
||||
explicit stream_pipe(stream_session_ptr session);
|
||||
|
||||
stream_session_ptr session_;
|
||||
};
|
||||
|
||||
// producer end: writes chunks into the ring buffer and owns the session lifetime, finalizing it
|
||||
// on destruction.
|
||||
//
|
||||
// lifetime safety: holds a shared_ptr<atomic<bool>> alive also captured by the session's
|
||||
// stop_producer hook. cleanup() sets alive=false and clears the hook; it must run while the
|
||||
// response the hook calls stop() on is still alive. ~server_res_generator() does this explicitly.
|
||||
struct stream_pipe_producer : stream_pipe {
|
||||
~stream_pipe_producer() override;
|
||||
|
||||
// append raw bytes to the session's ring buffer, returns false if already finalized
|
||||
bool write(const char * data, size_t len);
|
||||
|
||||
// mark the natural end on the wire so a later close() is a no-op
|
||||
void done();
|
||||
|
||||
// on a peer drop, pump the response next() into the ring buffer until done. runs on the http
|
||||
// worker from on_complete, no-op after done() or cancel
|
||||
void close();
|
||||
|
||||
// disarm the stop hook and drop the alive guard, must run while the response the hook
|
||||
// references is still alive. idempotent, the destructor calls it too
|
||||
void cleanup();
|
||||
|
||||
// res.stop() is invoked when the session is cancelled, the alive guard ensures stop() is not
|
||||
// called after cleanup() has run
|
||||
static std::shared_ptr<stream_pipe_producer> create(stream_session_ptr session, server_http_res & res);
|
||||
|
||||
private:
|
||||
explicit stream_pipe_producer(stream_session_ptr session);
|
||||
|
||||
bool done_ = false;
|
||||
std::shared_ptr<std::atomic<bool>> alive_;
|
||||
server_http_res * res_ = nullptr;
|
||||
};
|
||||
|
||||
// consumer end: read-only replay of the ring buffer, the destructor does not finalize the session
|
||||
struct stream_pipe_consumer : stream_pipe {
|
||||
// drain bytes from offset, calling sink for each available chunk. blocks until more data
|
||||
// arrives or the session finalizes. should_stop is polled, returns OFFSET_LOST if offset
|
||||
// fell below the dropped prefix
|
||||
stream_read_status read(size_t & offset,
|
||||
const std::function<bool(const char *, size_t)> & sink,
|
||||
const std::function<bool()> & should_stop);
|
||||
|
||||
static std::shared_ptr<stream_pipe_consumer> create(stream_session_ptr session);
|
||||
|
||||
private:
|
||||
explicit stream_pipe_consumer(stream_session_ptr session);
|
||||
};
|
||||
|
||||
// owns all live sessions, runs a periodic GC to evict expired ones.
|
||||
// the map is keyed by conversation_id, so the invariant "one conv = at most one
|
||||
// live session" is enforced at the type level
|
||||
class stream_session_manager {
|
||||
public:
|
||||
stream_session_manager();
|
||||
~stream_session_manager();
|
||||
|
||||
stream_session_manager(const stream_session_manager &) = delete;
|
||||
stream_session_manager & operator=(const stream_session_manager &) = delete;
|
||||
|
||||
// install a new session for this conversation, evicting and cancelling any previous one.
|
||||
// the conversation_id must be non empty, the caller is responsible for that check.
|
||||
// returns the new session
|
||||
stream_session_ptr create_or_replace(const std::string & conversation_id);
|
||||
|
||||
// lookup, returns null if unknown or already evicted
|
||||
stream_session_ptr get(const std::string & conversation_id);
|
||||
|
||||
// list every live or recently completed session, used by GET /v1/streams without filter
|
||||
std::vector<stream_session_ptr> list_all() const;
|
||||
|
||||
// remove from the map and finalize, wakes any pending readers
|
||||
void evict(const std::string & conversation_id);
|
||||
|
||||
// signal the producer to cancel asap then evict, used by the explicit user Stop path
|
||||
void evict_and_cancel(const std::string & conversation_id);
|
||||
|
||||
void start_gc();
|
||||
void stop_gc();
|
||||
|
||||
private:
|
||||
void gc_loop();
|
||||
|
||||
mutable std::shared_mutex map_mu;
|
||||
std::unordered_map<std::string, stream_session_ptr> sessions; // key: conversation_id
|
||||
std::thread gc_thread;
|
||||
std::atomic<bool> running;
|
||||
std::mutex gc_wake_mu;
|
||||
std::condition_variable gc_wake_cv;
|
||||
};
|
||||
|
||||
// process wide manager, linked by both llama-server and llama-cli. llama-server main() drives
|
||||
// start_gc/stop_gc, llama-cli leaves it idle. the dtor calls stop_gc() unconditionally so exit
|
||||
// is safe whether or not the GC thread ran
|
||||
extern stream_session_manager g_stream_sessions;
|
||||
|
||||
// route handler factories operating on g_stream_sessions, wired under /v1/stream/* by server.cpp.
|
||||
// keeps the resumable stream surface confined to server-stream
|
||||
server_http_context::handler_t make_stream_get_handler();
|
||||
server_http_context::handler_t make_streams_lookup_handler();
|
||||
server_http_context::handler_t make_stream_delete_handler();
|
||||
|
||||
// extract the X-Conversation-Id header value (case-insensitive), empty when absent. exposed so
|
||||
// the router can track which child serves a forwarded POST
|
||||
std::string stream_conv_id_from_headers(const std::map<std::string, std::string> & headers);
|
||||
|
||||
// on an X-Conversation-Id header, create or replace the session and attach a producer pipe to
|
||||
// res. no-op when absent, called from the server_res_generator constructor
|
||||
void stream_session_attach_pipe(server_http_res & res, const std::map<std::string, std::string> & headers);
|
||||
|
||||
// should_stop closure that ignores peer disconnect when a pipe is attached, so only an explicit
|
||||
// DELETE stops the producer and generation keeps flowing into the ring buffer. without a pipe it
|
||||
// delegates to fallback, the legacy non-resumable flow
|
||||
std::function<bool()> stream_aware_should_stop(server_http_res * res, std::function<bool()> fallback);
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "server-http.h"
|
||||
#include "server-models.h"
|
||||
#include "server-cors-proxy.h"
|
||||
#include "server-stream.h"
|
||||
#include "server-tools.h"
|
||||
|
||||
#include "arg.h"
|
||||
@@ -82,6 +83,10 @@ int llama_server(int argc, char ** argv) {
|
||||
|
||||
common_init();
|
||||
|
||||
// start the stream session manager GC right after common init, before any HTTP route can
|
||||
// touch it. lifecycle is symmetric, stop_gc() runs in clean_up() before backend free
|
||||
g_stream_sessions.start_gc();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
|
||||
return 1;
|
||||
}
|
||||
@@ -239,6 +244,29 @@ int llama_server(int argc, char ** argv) {
|
||||
ctx_http.get ("/slots", ex_wrapper(routes.get_slots));
|
||||
ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots));
|
||||
|
||||
// resumable streaming, the conversation_id is the session identity end to end. router and
|
||||
// child wire different handlers under the same paths: a child binds the local g_stream_sessions
|
||||
// backed factories, the router binds proxies that resolve the owning child through the
|
||||
// conv_id -> model map
|
||||
server_http_context::handler_t stream_get_h;
|
||||
server_http_context::handler_t streams_lookup_h;
|
||||
server_http_context::handler_t stream_delete_h;
|
||||
if (is_router_server) {
|
||||
stream_get_h = models_routes->router_stream_get;
|
||||
streams_lookup_h = models_routes->router_streams_lookup;
|
||||
stream_delete_h = models_routes->router_stream_delete;
|
||||
} else {
|
||||
stream_get_h = make_stream_get_handler();
|
||||
streams_lookup_h = make_streams_lookup_handler();
|
||||
stream_delete_h = make_stream_delete_handler();
|
||||
}
|
||||
ctx_http.get ("/v1/stream/:conv_id", ex_wrapper(stream_get_h));
|
||||
// POST /v1/streams/lookup with body {"conversation_ids": [...]}. you can only ask for ids
|
||||
// you already own (the WebUI passes the convs visible in its sidebar). the server never
|
||||
// lists ids it has not been asked about, so a random caller cannot enumerate live sessions
|
||||
ctx_http.post("/v1/streams/lookup", ex_wrapper(streams_lookup_h));
|
||||
ctx_http.del ("/v1/stream/:conv_id", ex_wrapper(stream_delete_h));
|
||||
|
||||
// Google Cloud Platform (Vertex AI) compat
|
||||
ctx_http.register_gcp_compat();
|
||||
|
||||
@@ -314,6 +342,8 @@ int llama_server(int argc, char ** argv) {
|
||||
|
||||
clean_up = [&models_routes]() {
|
||||
SRV_INF("%s: cleaning up before exit...\n", __func__);
|
||||
// stop the session GC first, it finalizes live sessions and wakes pending readers
|
||||
g_stream_sessions.stop_gc();
|
||||
if (models_routes.has_value()) {
|
||||
models_routes->stopping.store(true); // maybe redundant, but just to be safe
|
||||
models_routes->models.unload_all();
|
||||
@@ -340,6 +370,8 @@ int llama_server(int argc, char ** argv) {
|
||||
// setup clean up function, to be called before exit
|
||||
clean_up = [&ctx_http, &ctx_server]() {
|
||||
SRV_INF("%s: cleaning up before exit...\n", __func__);
|
||||
// stop the session GC first, it finalizes live sessions and wakes pending readers
|
||||
g_stream_sessions.stop_gc();
|
||||
ctx_http.stop();
|
||||
ctx_server.terminate();
|
||||
llama_backend_free();
|
||||
|
||||
+1
-1
@@ -33,7 +33,7 @@
|
||||
|
||||
{#if !readonly && onRemove}
|
||||
<div
|
||||
class="absolute top-10 right-2 flex items-center justify-center opacity-0 transition-opacity group-hover:opacity-100"
|
||||
class="absolute top-10 right-2 flex items-center justify-center opacity-0 transition-opacity group-focus-within:opacity-100 group-hover:opacity-100"
|
||||
>
|
||||
<ActionIcon icon={X} tooltip="Remove" stopPropagationOnClick onclick={() => onRemove?.()} />
|
||||
</div>
|
||||
|
||||
+1
-1
@@ -56,7 +56,7 @@
|
||||
<div class="relative flex h-6 items-center justify-between">
|
||||
<div class="right-0 flex items-center gap-2 opacity-100 transition-opacity">
|
||||
<div
|
||||
class="pointer-events-auto inset-0 flex items-center gap-1 opacity-0 transition-all duration-150 group-hover:opacity-100"
|
||||
class="pointer-events-auto inset-0 flex items-center gap-1 opacity-0 transition-all duration-150 group-focus-within:opacity-100 group-hover:opacity-100"
|
||||
>
|
||||
<ActionIcon icon={Edit} tooltip="Edit" onclick={editCtx.handleEdit} />
|
||||
<ActionIcon icon={Trash2} tooltip="Delete" onclick={onDelete} />
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
ChatMessages,
|
||||
ChatScreenDragOverlay,
|
||||
ChatScreenProcessingInfo,
|
||||
ChatScreenStreamResumeStatus,
|
||||
ServerLoadingSplash,
|
||||
ChatScreenServerError
|
||||
} from '$lib/components/app';
|
||||
@@ -281,6 +282,10 @@
|
||||
|
||||
<ChatScreenServerError />
|
||||
|
||||
{#if page.params.id}
|
||||
<ChatScreenStreamResumeStatus />
|
||||
{/if}
|
||||
|
||||
<div class="pointer-events-none flex flex-col gap-6 items-center w-full">
|
||||
{#if (isMobile.current ? mobileScrollDownHint || isMobileUserScrolledUp : autoScroll.userScrolledUp) && page.url.hash.includes(ROUTES.CHAT) && page.params.id}
|
||||
<ChatScreenActionScrollDown
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
<script lang="ts">
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { StreamConnectionState } from '$lib/enums';
|
||||
import { Loader2 } from '@lucide/svelte';
|
||||
|
||||
let state = $derived(chatStore.streamConnectionState);
|
||||
</script>
|
||||
|
||||
{#if state === StreamConnectionState.RESUMING}
|
||||
<div
|
||||
class="pointer-events-auto mx-auto mt-2 mb-2 flex max-w-[48rem] items-center gap-2 rounded-md border border-blue-400/40 bg-blue-50/60 px-3 py-1.5 text-sm text-blue-700 dark:bg-blue-950/40 dark:text-blue-200"
|
||||
role="status"
|
||||
aria-live="polite"
|
||||
>
|
||||
<Loader2 class="h-3.5 w-3.5 animate-spin" />
|
||||
<span>Reconnecting to the stream...</span>
|
||||
</div>
|
||||
{/if}
|
||||
@@ -683,3 +683,11 @@ export { default as ChatScreenProcessingInfo } from './ChatScreen/ChatScreenProc
|
||||
* Rendered inside ChatScreen when `serverError` store has a value.
|
||||
*/
|
||||
export { default as ChatScreenServerError } from './ChatScreen/ChatScreenServerError.svelte';
|
||||
|
||||
/**
|
||||
* Stream resume status indicator. Shows a small "Reconnecting to the stream..."
|
||||
* banner with a spinner while `chatStore.streamConnectionState` is `resuming`,
|
||||
* i.e. after a dropped connection is reattaching to the live SSE replay buffer.
|
||||
* Renders nothing otherwise. Shown inside ChatScreen only on an active conversation route.
|
||||
*/
|
||||
export { default as ChatScreenStreamResumeStatus } from './ChatScreen/ChatScreenStreamResumeStatus.svelte';
|
||||
|
||||
+56
-81
@@ -39,7 +39,6 @@
|
||||
depth = 0
|
||||
}: Props = $props();
|
||||
|
||||
let renderActionsDropdown = $state(false);
|
||||
let dropdownOpen = $state(false);
|
||||
|
||||
let isLoading = $derived(getAllLoadingChats().includes(conversation.id));
|
||||
@@ -71,26 +70,10 @@
|
||||
}
|
||||
}
|
||||
|
||||
function handleMouseLeave() {
|
||||
if (!dropdownOpen) {
|
||||
renderActionsDropdown = false;
|
||||
}
|
||||
}
|
||||
|
||||
function handleMouseOver() {
|
||||
renderActionsDropdown = true;
|
||||
}
|
||||
|
||||
function handleSelect() {
|
||||
onSelect?.(conversation.id);
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (!dropdownOpen) {
|
||||
renderActionsDropdown = false;
|
||||
}
|
||||
});
|
||||
|
||||
onMount(() => {
|
||||
document.addEventListener('edit-active-conversation', handleGlobalEditEvent as EventListener);
|
||||
|
||||
@@ -103,23 +86,19 @@
|
||||
});
|
||||
</script>
|
||||
|
||||
<!-- svelte-ignore a11y_mouse_events_have_key_events -->
|
||||
<button
|
||||
class="group flex min-h-9 w-full cursor-pointer items-center justify-between space-x-3 rounded-lg py-1.5 text-left transition-colors hover:bg-foreground/10 {isActive
|
||||
<div
|
||||
class="conversation-item group relative flex min-h-9 w-full items-center justify-between space-x-3 rounded-lg py-1.5 transition-colors hover:bg-foreground/10 {isActive
|
||||
? 'bg-foreground/5 text-accent-foreground'
|
||||
: ''} px-3"
|
||||
onclick={handleSelect}
|
||||
onmouseover={handleMouseOver}
|
||||
onmouseleave={handleMouseLeave}
|
||||
onfocusin={handleMouseOver}
|
||||
onfocusout={(e) => {
|
||||
if (!e.currentTarget.contains(e.relatedTarget as Node | null)) {
|
||||
handleMouseLeave();
|
||||
}
|
||||
}}
|
||||
>
|
||||
<button
|
||||
class="absolute inset-0 z-0 cursor-pointer rounded-lg focus:outline-none focus-visible:ring-2 focus-visible:ring-ring"
|
||||
onclick={handleSelect}
|
||||
aria-label={conversation.name}
|
||||
>
|
||||
</button>
|
||||
<div
|
||||
class="flex min-w-0 flex-1 items-center gap-2"
|
||||
class="pointer-events-none relative z-10 flex min-w-0 flex-1 items-center gap-2"
|
||||
style:padding-left="{depth * FORK_TREE_DEPTH_PADDING}px"
|
||||
>
|
||||
{#if depth > 0}
|
||||
@@ -130,7 +109,7 @@
|
||||
<a
|
||||
{...props}
|
||||
href={RouterService.chat(conversation.forkedFromConversationId)}
|
||||
class="flex shrink-0 items-center text-muted-foreground transition-colors hover:text-foreground"
|
||||
class="pointer-events-auto flex shrink-0 items-center text-muted-foreground transition-colors hover:text-foreground"
|
||||
>
|
||||
<GitBranch class="h-3.5 w-3.5" />
|
||||
</a>
|
||||
@@ -146,18 +125,15 @@
|
||||
{#if isLoading}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<div
|
||||
class="stop-button flex h-4 w-4 shrink-0 cursor-pointer items-center justify-center rounded text-muted-foreground transition-colors hover:text-foreground"
|
||||
<button
|
||||
class="stop-button pointer-events-auto flex h-4 w-4 shrink-0 cursor-pointer items-center justify-center rounded text-muted-foreground transition-colors hover:text-foreground"
|
||||
onclick={handleStop}
|
||||
onkeydown={(e) => e.key === 'Enter' && handleStop(e)}
|
||||
role="button"
|
||||
tabindex="0"
|
||||
aria-label="Stop generation"
|
||||
>
|
||||
<Loader2 class="loading-icon h-3.5 w-3.5 animate-spin" />
|
||||
|
||||
<Square class="stop-icon hidden h-3 w-3 fill-current text-destructive" />
|
||||
</div>
|
||||
</button>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
@@ -169,52 +145,50 @@
|
||||
<TruncatedText text={conversation.name} class="text-sm font-medium" showTooltip={false} />
|
||||
</div>
|
||||
|
||||
{#if renderActionsDropdown}
|
||||
<div class="actions flex items-center">
|
||||
<DropdownMenuActions
|
||||
triggerIcon={MoreHorizontal}
|
||||
triggerTooltip="More actions"
|
||||
bind:open={dropdownOpen}
|
||||
actions={[
|
||||
{
|
||||
icon: conversation.pinned ? PinOff : Pin,
|
||||
label: conversation.pinned ? 'Unpin' : 'Pin',
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
handleTogglePin();
|
||||
}
|
||||
},
|
||||
{
|
||||
icon: Pencil,
|
||||
label: 'Edit',
|
||||
onclick: handleEdit,
|
||||
shortcut: ['shift', 'cmd', 'e']
|
||||
},
|
||||
{
|
||||
icon: Download,
|
||||
label: 'Export',
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
conversationsStore.downloadConversation(conversation.id);
|
||||
},
|
||||
shortcut: ['shift', 'cmd', 's']
|
||||
},
|
||||
{
|
||||
icon: Trash2,
|
||||
label: 'Delete',
|
||||
onclick: handleDelete,
|
||||
variant: 'destructive',
|
||||
shortcut: ['shift', 'cmd', 'd'],
|
||||
separator: true
|
||||
<div class="actions pointer-events-auto relative z-20 flex items-center">
|
||||
<DropdownMenuActions
|
||||
triggerIcon={MoreHorizontal}
|
||||
triggerTooltip="More actions"
|
||||
bind:open={dropdownOpen}
|
||||
actions={[
|
||||
{
|
||||
icon: conversation.pinned ? PinOff : Pin,
|
||||
label: conversation.pinned ? 'Unpin' : 'Pin',
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
handleTogglePin();
|
||||
}
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
</button>
|
||||
},
|
||||
{
|
||||
icon: Pencil,
|
||||
label: 'Edit',
|
||||
onclick: handleEdit,
|
||||
shortcut: ['shift', 'cmd', 'e']
|
||||
},
|
||||
{
|
||||
icon: Download,
|
||||
label: 'Export',
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
conversationsStore.downloadConversation(conversation.id);
|
||||
},
|
||||
shortcut: ['shift', 'cmd', 's']
|
||||
},
|
||||
{
|
||||
icon: Trash2,
|
||||
label: 'Delete',
|
||||
onclick: handleDelete,
|
||||
variant: 'destructive',
|
||||
shortcut: ['shift', 'cmd', 'd'],
|
||||
separator: true
|
||||
}
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
button {
|
||||
.conversation-item {
|
||||
:global([data-slot='dropdown-menu-trigger']:not([data-state='open'])) {
|
||||
opacity: 0;
|
||||
}
|
||||
@@ -239,7 +213,8 @@
|
||||
}
|
||||
}
|
||||
|
||||
&:is(:hover) .stop-button {
|
||||
&:is(:hover) .stop-button,
|
||||
&:focus-within .stop-button {
|
||||
:global(.stop-icon) {
|
||||
display: block;
|
||||
}
|
||||
|
||||
@@ -21,5 +21,11 @@ export const API_TOOLS = {
|
||||
EXECUTE: '/tools'
|
||||
};
|
||||
|
||||
// resumable stream routes, the conv::model identity is appended as a path segment
|
||||
export const API_STREAM = {
|
||||
BASE: './v1/stream',
|
||||
LOOKUP: './v1/streams/lookup'
|
||||
};
|
||||
|
||||
/** CORS proxy endpoint path */
|
||||
export const CORS_PROXY_ENDPOINT = '/cors-proxy';
|
||||
|
||||
@@ -46,6 +46,7 @@ export * from './routes';
|
||||
export * from './sandbox';
|
||||
export * from './settings-keys';
|
||||
export * from './settings-registry';
|
||||
export * from './stream';
|
||||
export * from './supported-file-types';
|
||||
export * from './table-html-restorer';
|
||||
export * from './title-generation';
|
||||
|
||||
@@ -26,6 +26,9 @@ export const THINKING_ENABLED_DEFAULT_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.th
|
||||
export const REASONING_EFFORT_DEFAULT_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.reasoningEffortDefault`;
|
||||
export const USER_OVERRIDES_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME}.userOverrides`;
|
||||
|
||||
/** Key prefix for per-conversation resumable stream state, conversationId is appended */
|
||||
export const STREAM_RESUME_LOCALSTORAGE_KEY_PREFIX = `${STORAGE_APP_NAME}.streamResume.`;
|
||||
|
||||
// Deprecated old key names (kept for backward compat while users migrate)
|
||||
/** @deprecated Use {@link ALWAYS_ALLOWED_TOOLS_LOCALSTORAGE_KEY} instead */
|
||||
export const DEPRECATED_ALWAYS_ALLOWED_TOOLS_LOCALSTORAGE_KEY = `${STORAGE_APP_NAME_DEPRECATED}.alwaysAllowedTools`;
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
// grace window after a visibilitychange before we kick a reader whose socket likely died
|
||||
// while the tab was hidden. covers brief background pauses without thrashing live streams
|
||||
export const STREAM_VISIBILITY_KICK_MS = 1000;
|
||||
@@ -5,6 +5,15 @@ export enum ChatMessageStatsView {
|
||||
SUMMARY = 'summary'
|
||||
}
|
||||
|
||||
/**
|
||||
* Connection state of a streamed completion, drives the resume status indicator.
|
||||
*/
|
||||
export enum StreamConnectionState {
|
||||
STREAMING = 'streaming',
|
||||
RESUMING = 'resuming',
|
||||
LOST = 'lost'
|
||||
}
|
||||
|
||||
/**
|
||||
* Reasoning format options for API requests.
|
||||
*/
|
||||
|
||||
@@ -10,6 +10,7 @@ export { AgenticSectionType, ContinueIntentKind, ToolCallType } from './agentic.
|
||||
|
||||
export {
|
||||
ChatMessageStatsView,
|
||||
StreamConnectionState,
|
||||
ContentPartType,
|
||||
ConversationSelectionMode,
|
||||
ErrorDialogType,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { getJsonHeaders } from '$lib/utils/api-headers';
|
||||
import { getAuthHeaders, getJsonHeaders } from '$lib/utils/api-headers';
|
||||
import { formatAttachmentText } from '$lib/utils/formatters';
|
||||
import { isAbortError } from '$lib/utils/abort';
|
||||
import { streamIdentity } from '$lib/utils/stream-identity';
|
||||
import {
|
||||
ATTACHMENT_LABEL_PDF_FILE,
|
||||
ATTACHMENT_LABEL_MCP_PROMPT,
|
||||
@@ -13,7 +14,10 @@ import {
|
||||
CONTROL_ACTION,
|
||||
SSE_LINE_SEPARATOR,
|
||||
SSE_DATA_PREFIX,
|
||||
SSE_DONE_MARKER
|
||||
SSE_DONE_MARKER,
|
||||
STREAM_VISIBILITY_KICK_MS,
|
||||
STREAM_RESUME_LOCALSTORAGE_KEY_PREFIX,
|
||||
API_STREAM
|
||||
} from '$lib/constants';
|
||||
import {
|
||||
AttachmentType,
|
||||
@@ -21,12 +25,14 @@ import {
|
||||
FileTypeAudio,
|
||||
MessageRole,
|
||||
MimeTypeAudio,
|
||||
ReasoningFormat
|
||||
ReasoningFormat,
|
||||
StreamConnectionState
|
||||
} from '$lib/enums';
|
||||
import type {
|
||||
ApiChatMessageContentPart,
|
||||
ApiChatMessageData,
|
||||
ApiChatCompletionToolCall
|
||||
ApiChatCompletionToolCall,
|
||||
ApiStreamSession
|
||||
} from '$lib/types/api';
|
||||
import type {
|
||||
AudioInputFormat,
|
||||
@@ -54,6 +60,19 @@ function getAudioInputFormat(mimeType: string): AudioInputFormat {
|
||||
return FileTypeAudio.MP3;
|
||||
}
|
||||
|
||||
interface ResumableStreamState {
|
||||
bytesReceived: number;
|
||||
updatedAt: number;
|
||||
|
||||
// model frozen at POST time, lets a reload rebuild the exact conv::model identity the
|
||||
// server keyed the session under. null when the POST carried no explicit model
|
||||
model?: string | null;
|
||||
}
|
||||
|
||||
function streamStorageKey(conversationId: string): string {
|
||||
return STREAM_RESUME_LOCALSTORAGE_KEY_PREFIX + conversationId;
|
||||
}
|
||||
|
||||
export class ChatService {
|
||||
/**
|
||||
*
|
||||
@@ -128,6 +147,7 @@ export class ChatService {
|
||||
onChunk,
|
||||
onComplete,
|
||||
onError,
|
||||
onConnectionState,
|
||||
onReasoningChunk,
|
||||
onToolCallChunk,
|
||||
onModel,
|
||||
@@ -312,9 +332,16 @@ export class ChatService {
|
||||
}
|
||||
|
||||
try {
|
||||
const headers: Record<string, string> = { ...getJsonHeaders() };
|
||||
// tag streaming requests with the conversation id, this single header is the opt in for the
|
||||
// server side replay buffer and powers discoverActiveStream on tab reopen. with an explicit
|
||||
// model the ::model suffix keeps the per model session distinct
|
||||
if (stream && conversationId) {
|
||||
headers['X-Conversation-Id'] = streamIdentity(conversationId, options.model);
|
||||
}
|
||||
const response = await fetch(API_CHAT.COMPLETIONS, {
|
||||
method: 'POST',
|
||||
headers: getJsonHeaders(),
|
||||
headers,
|
||||
body: JSON.stringify(requestBody),
|
||||
signal
|
||||
});
|
||||
@@ -341,7 +368,9 @@ export class ChatService {
|
||||
onCompletionId,
|
||||
onTimings,
|
||||
conversationId,
|
||||
signal
|
||||
signal,
|
||||
onConnectionState,
|
||||
options.model
|
||||
);
|
||||
|
||||
return;
|
||||
@@ -473,6 +502,116 @@ export class ChatService {
|
||||
* @param excludeReasoning - Whether to strip reasoning content (should match excludeReasoningFromContext setting)
|
||||
* @param signal - Optional AbortSignal to cancel the pre-encode request
|
||||
*/
|
||||
static async cancelServerStream(conversationId: string, model?: string | null): Promise<void> {
|
||||
if (!conversationId) return;
|
||||
try {
|
||||
const id = streamIdentity(conversationId, model);
|
||||
await fetch(`${API_STREAM.BASE}/${encodeURIComponent(id)}`, {
|
||||
method: 'DELETE',
|
||||
headers: getAuthHeaders()
|
||||
});
|
||||
} catch (e) {
|
||||
console.warn('cancelServerStream failed:', e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Pick the running session to splice into when discoverActiveStream lists candidates for a
|
||||
* conversation. Finalized sessions are not candidates: their final content was already written
|
||||
* to the DB by the original onComplete handler, so attaching to them would replay a buffer that
|
||||
* may not match what the DB holds. A continue session's buffer holds only the appended deltas,
|
||||
* not the pre continue prefix, so replaying it as a fresh generation would erase the original.
|
||||
*
|
||||
* Among running sessions we tie break on the most recent started_at, which covers the case of
|
||||
* multiple inferences left running on the same conversation.
|
||||
*/
|
||||
static selectActiveStream(
|
||||
sessions: ApiStreamSession[] | null | undefined
|
||||
): ApiStreamSession | null {
|
||||
if (!Array.isArray(sessions) || sessions.length === 0) {
|
||||
return null;
|
||||
}
|
||||
const running = sessions.filter((s) => !s.is_done);
|
||||
if (running.length === 0) {
|
||||
return null;
|
||||
}
|
||||
return running.reduce((best, cur) => (cur.started_at > best.started_at ? cur : best));
|
||||
}
|
||||
|
||||
// persist the running byte count and the frozen model for a conversation, a later visit
|
||||
// resumes the SSE replay at the right offset under the same conv::model identity
|
||||
static saveStreamState(
|
||||
conversationId: string,
|
||||
bytesReceived: number,
|
||||
model?: string | null
|
||||
): void {
|
||||
if (!conversationId) return;
|
||||
try {
|
||||
const state: ResumableStreamState = {
|
||||
bytesReceived,
|
||||
updatedAt: Date.now(),
|
||||
model: model ?? null
|
||||
};
|
||||
localStorage.setItem(streamStorageKey(conversationId), JSON.stringify(state));
|
||||
} catch {
|
||||
// localStorage may be full or disabled, silently ignore
|
||||
}
|
||||
}
|
||||
|
||||
static getStreamState(conversationId: string): ResumableStreamState | null {
|
||||
if (!conversationId) return null;
|
||||
try {
|
||||
const raw = localStorage.getItem(streamStorageKey(conversationId));
|
||||
if (!raw) return null;
|
||||
const parsed = JSON.parse(raw) as ResumableStreamState;
|
||||
if (!parsed || typeof parsed.bytesReceived !== 'number') return null;
|
||||
return parsed;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
static clearStreamState(conversationId: string): void {
|
||||
if (!conversationId) return;
|
||||
try {
|
||||
localStorage.removeItem(streamStorageKey(conversationId));
|
||||
} catch {
|
||||
// nothing to do
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Rebuild the stream identity for a resume. The model persisted at POST time wins, including a
|
||||
* stored null which means the POST carried no explicit model so the identity stays the bare conv
|
||||
* id. Only fall back to the caller supplied current model when nothing was persisted.
|
||||
*/
|
||||
static resumeStreamIdentity(
|
||||
conversationId: string,
|
||||
state: ResumableStreamState | null,
|
||||
fallbackModel: string | null
|
||||
): string {
|
||||
const model = state && state.model !== undefined ? state.model : fallbackModel;
|
||||
return streamIdentity(conversationId, model);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reconnect to an interrupted stream for this conversation. Returns the fetch Response so the
|
||||
* existing SSE parser drains it like a fresh stream. The server returns 200 on success, 404 if
|
||||
* no session exists for the conv_id, and 400 if the offset is below the dropped prefix.
|
||||
*/
|
||||
static async resumeStream(
|
||||
conversationId: string,
|
||||
signal?: AbortSignal,
|
||||
model?: string | null
|
||||
): Promise<Response | null> {
|
||||
if (!conversationId) return null;
|
||||
const state = ChatService.getStreamState(conversationId);
|
||||
const from = state?.bytesReceived ?? 0;
|
||||
const id = streamIdentity(conversationId, model);
|
||||
const url = `${API_STREAM.BASE}/${encodeURIComponent(id)}?from=${from}`;
|
||||
return await fetch(url, { method: 'GET', signal, headers: getAuthHeaders() });
|
||||
}
|
||||
|
||||
static async preEncode(
|
||||
messages: ApiChatMessageData[] | (DatabaseMessage & { extra?: DatabaseMessageExtra[] })[],
|
||||
model?: string | null,
|
||||
@@ -557,7 +696,7 @@ export class ChatService {
|
||||
* @returns {Promise<void>} Promise that resolves when streaming is complete
|
||||
* @throws {Error} if the stream cannot be read or parsed
|
||||
*/
|
||||
private static async handleStreamResponse(
|
||||
static async handleStreamResponse(
|
||||
response: Response,
|
||||
onChunk?: (chunk: string) => void,
|
||||
onComplete?: (
|
||||
@@ -573,15 +712,34 @@ export class ChatService {
|
||||
onCompletionId?: (id: string) => void,
|
||||
onTimings?: (timings?: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => void,
|
||||
conversationId?: string,
|
||||
abortSignal?: AbortSignal
|
||||
abortSignal?: AbortSignal,
|
||||
onConnectionState?: (state: StreamConnectionState) => void,
|
||||
streamModel?: string | null
|
||||
): Promise<void> {
|
||||
const reader = response.body?.getReader();
|
||||
let reader = response.body?.getReader();
|
||||
|
||||
if (!reader) {
|
||||
throw new Error('No response body');
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
// bytesParsed is the absolute server side buffer offset of the next byte to parse
|
||||
// segmentStartOffset is the absolute offset where the current reader started, reset on resume
|
||||
// segmentBytesRead is wire bytes read by the current reader
|
||||
let bytesParsed = 0;
|
||||
let segmentStartOffset = 0;
|
||||
let segmentBytesRead = 0;
|
||||
let lastByteAt = Date.now();
|
||||
// each resume must produce at least one byte to be retried again
|
||||
// if a resume returns 200 but yields nothing, we abandon
|
||||
// since the session has a bounded size, the total number of retries is bounded by construction
|
||||
let madeProgress = true;
|
||||
const encoder = new TextEncoder();
|
||||
if (conversationId) {
|
||||
ChatService.saveStreamState(conversationId, 0, streamModel);
|
||||
}
|
||||
onConnectionState?.(StreamConnectionState.STREAMING);
|
||||
|
||||
let decoder = new TextDecoder();
|
||||
let aggregatedContent = '';
|
||||
let fullReasoningContent = '';
|
||||
let aggregatedToolCalls: ApiChatCompletionToolCall[] = [];
|
||||
@@ -633,84 +791,180 @@ export class ChatService {
|
||||
}
|
||||
};
|
||||
|
||||
const onVisibilityChange = () => {
|
||||
if (typeof document === 'undefined') return;
|
||||
if (document.visibilityState !== 'visible') return;
|
||||
if (streamFinished) return;
|
||||
if (!conversationId) return;
|
||||
// the bytes have been quiet for too long, the OS likely killed the socket
|
||||
// kicking the reader unblocks reader.read with done=true so the outer loop can resume
|
||||
if (Date.now() - lastByteAt > STREAM_VISIBILITY_KICK_MS) {
|
||||
reader!.cancel().catch(() => {});
|
||||
}
|
||||
};
|
||||
if (typeof document !== 'undefined') {
|
||||
document.addEventListener('visibilitychange', onVisibilityChange);
|
||||
}
|
||||
|
||||
try {
|
||||
let chunk = '';
|
||||
// outer loop drives the resume cycle, swaps reader on premature end of stream
|
||||
while (true) {
|
||||
if (abortSignal?.aborted) break;
|
||||
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
if (abortSignal?.aborted) break;
|
||||
|
||||
chunk += decoder.decode(value, { stream: true });
|
||||
const lines = chunk.split(SSE_LINE_SEPARATOR);
|
||||
chunk = lines.pop() || '';
|
||||
|
||||
for (const line of lines) {
|
||||
while (true) {
|
||||
if (abortSignal?.aborted) break;
|
||||
|
||||
if (line.startsWith(SSE_DATA_PREFIX)) {
|
||||
const data = line.slice(SSE_DATA_PREFIX.length).trim();
|
||||
if (data === SSE_DONE_MARKER) {
|
||||
streamFinished = true;
|
||||
|
||||
continue;
|
||||
let done: boolean;
|
||||
let value: Uint8Array | undefined;
|
||||
try {
|
||||
const r = await reader.read();
|
||||
done = r.done;
|
||||
value = r.value;
|
||||
} catch (readErr) {
|
||||
// reader.read() rejects with TypeError when the underlying connection drops
|
||||
// instead of just resolving with done=true. treat it like done so the outer
|
||||
// loop swaps reader via the resume path
|
||||
if (isAbortError(readErr)) {
|
||||
throw readErr;
|
||||
}
|
||||
console.warn('reader.read() rejected, treating as premature end:', readErr);
|
||||
done = true;
|
||||
value = undefined;
|
||||
}
|
||||
if (done) break;
|
||||
|
||||
try {
|
||||
const parsed: ApiChatCompletionStreamChunk = JSON.parse(data);
|
||||
const choice = parsed.choices?.[0];
|
||||
const content = choice?.delta?.content;
|
||||
const reasoningContent = choice?.delta?.reasoning_content;
|
||||
const toolCalls = choice?.delta?.tool_calls;
|
||||
const timings = parsed.timings;
|
||||
const promptProgress = parsed.prompt_progress;
|
||||
if (abortSignal?.aborted) break;
|
||||
|
||||
const chunkModel = ChatService.extractModelName(parsed);
|
||||
if (chunkModel && !modelEmitted) {
|
||||
modelEmitted = true;
|
||||
onModel?.(chunkModel);
|
||||
}
|
||||
|
||||
if (parsed.id && !idEmitted) {
|
||||
idEmitted = true;
|
||||
onCompletionId?.(parsed.id);
|
||||
}
|
||||
|
||||
if (promptProgress) {
|
||||
ChatService.notifyTimings(undefined, promptProgress, onTimings);
|
||||
}
|
||||
|
||||
if (timings) {
|
||||
ChatService.notifyTimings(timings, promptProgress, onTimings);
|
||||
lastTimings = timings;
|
||||
}
|
||||
|
||||
if (content) {
|
||||
finalizeOpenToolCallBatch();
|
||||
aggregatedContent += content;
|
||||
if (!abortSignal?.aborted) {
|
||||
onChunk?.(content);
|
||||
}
|
||||
}
|
||||
|
||||
if (reasoningContent) {
|
||||
finalizeOpenToolCallBatch();
|
||||
fullReasoningContent += reasoningContent;
|
||||
if (!abortSignal?.aborted) {
|
||||
onReasoningChunk?.(reasoningContent);
|
||||
}
|
||||
}
|
||||
|
||||
processToolCallDelta(toolCalls);
|
||||
} catch (e) {
|
||||
console.error('Error parsing JSON chunk:', e);
|
||||
if (value && value.byteLength > 0) {
|
||||
segmentBytesRead += value.byteLength;
|
||||
lastByteAt = Date.now();
|
||||
if (!madeProgress) {
|
||||
madeProgress = true;
|
||||
onConnectionState?.(StreamConnectionState.STREAMING);
|
||||
}
|
||||
}
|
||||
|
||||
chunk += decoder.decode(value, { stream: true });
|
||||
const lines = chunk.split(SSE_LINE_SEPARATOR);
|
||||
chunk = lines.pop() || '';
|
||||
|
||||
// the persisted offset must point right after the last fully parsed line,
|
||||
// the trailing `chunk` is partial bytes still waiting for a newline
|
||||
if (conversationId) {
|
||||
const tailBytes = encoder.encode(chunk).byteLength;
|
||||
bytesParsed = segmentStartOffset + segmentBytesRead - tailBytes;
|
||||
ChatService.saveStreamState(conversationId, bytesParsed, streamModel);
|
||||
}
|
||||
|
||||
for (const line of lines) {
|
||||
if (abortSignal?.aborted) break;
|
||||
|
||||
if (line.startsWith(SSE_DATA_PREFIX)) {
|
||||
const data = line.slice(SSE_DATA_PREFIX.length).trim();
|
||||
if (data === SSE_DONE_MARKER) {
|
||||
streamFinished = true;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed: ApiChatCompletionStreamChunk = JSON.parse(data);
|
||||
const choice = parsed.choices?.[0];
|
||||
const content = choice?.delta?.content;
|
||||
const reasoningContent = choice?.delta?.reasoning_content;
|
||||
const toolCalls = choice?.delta?.tool_calls;
|
||||
const timings = parsed.timings;
|
||||
const promptProgress = parsed.prompt_progress;
|
||||
|
||||
const chunkModel = ChatService.extractModelName(parsed);
|
||||
if (chunkModel && !modelEmitted) {
|
||||
modelEmitted = true;
|
||||
onModel?.(chunkModel);
|
||||
}
|
||||
|
||||
if (parsed.id && !idEmitted) {
|
||||
idEmitted = true;
|
||||
onCompletionId?.(parsed.id);
|
||||
}
|
||||
|
||||
if (promptProgress) {
|
||||
ChatService.notifyTimings(undefined, promptProgress, onTimings);
|
||||
}
|
||||
|
||||
if (timings) {
|
||||
ChatService.notifyTimings(timings, promptProgress, onTimings);
|
||||
lastTimings = timings;
|
||||
}
|
||||
|
||||
if (content) {
|
||||
finalizeOpenToolCallBatch();
|
||||
aggregatedContent += content;
|
||||
if (!abortSignal?.aborted) {
|
||||
onChunk?.(content);
|
||||
}
|
||||
}
|
||||
|
||||
if (reasoningContent) {
|
||||
finalizeOpenToolCallBatch();
|
||||
fullReasoningContent += reasoningContent;
|
||||
if (!abortSignal?.aborted) {
|
||||
onReasoningChunk?.(reasoningContent);
|
||||
}
|
||||
}
|
||||
|
||||
processToolCallDelta(toolCalls);
|
||||
} catch (e) {
|
||||
console.error('Error parsing JSON chunk:', e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (abortSignal?.aborted) break;
|
||||
if (streamFinished) break;
|
||||
}
|
||||
|
||||
// inner reader done, decide whether to try a resume
|
||||
if (abortSignal?.aborted) break;
|
||||
if (streamFinished) break;
|
||||
if (!conversationId) break;
|
||||
|
||||
if (!madeProgress) {
|
||||
onConnectionState?.(StreamConnectionState.LOST);
|
||||
onError?.(new Error('Stream resume produced no new bytes, giving up'));
|
||||
break;
|
||||
}
|
||||
|
||||
onConnectionState?.(StreamConnectionState.RESUMING);
|
||||
madeProgress = false;
|
||||
|
||||
// the server resends starting at bytesParsed, discard any partial line we held, it
|
||||
// will be retransmitted from a clean line boundary. reuse the frozen model, not the
|
||||
// live dropdown
|
||||
const resumeResp = await ChatService.resumeStream(
|
||||
conversationId,
|
||||
abortSignal,
|
||||
streamModel
|
||||
).catch(() => null);
|
||||
// an abort landing during the resume request is intentional, not a lost connection
|
||||
if (abortSignal?.aborted) break;
|
||||
if (!resumeResp || resumeResp.status !== 200) {
|
||||
onConnectionState?.(StreamConnectionState.LOST);
|
||||
onError?.(new Error('Stream connection lost and could not be resumed'));
|
||||
break;
|
||||
}
|
||||
const newReader = resumeResp.body?.getReader();
|
||||
if (!newReader) break;
|
||||
|
||||
try {
|
||||
reader.releaseLock();
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
reader = newReader;
|
||||
decoder = new TextDecoder();
|
||||
chunk = '';
|
||||
segmentStartOffset = bytesParsed;
|
||||
segmentBytesRead = 0;
|
||||
lastByteAt = Date.now();
|
||||
}
|
||||
|
||||
if (abortSignal?.aborted) return;
|
||||
@@ -718,6 +972,10 @@ export class ChatService {
|
||||
if (streamFinished) {
|
||||
finalizeOpenToolCallBatch();
|
||||
|
||||
if (conversationId) {
|
||||
ChatService.clearStreamState(conversationId);
|
||||
}
|
||||
|
||||
const finalToolCalls =
|
||||
aggregatedToolCalls.length > 0 ? JSON.stringify(aggregatedToolCalls) : undefined;
|
||||
|
||||
@@ -735,7 +993,14 @@ export class ChatService {
|
||||
|
||||
throw err;
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
if (typeof document !== 'undefined') {
|
||||
document.removeEventListener('visibilitychange', onVisibilityChange);
|
||||
}
|
||||
try {
|
||||
reader.releaseLock();
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -628,19 +628,20 @@ export class MCPService {
|
||||
);
|
||||
|
||||
const runtimeErrorHandler = (error: Error) => {
|
||||
// Ignore errors that are expected when the SDK's transport is closed,
|
||||
// or when connecting to servers that don't support SSE (stateless-only
|
||||
// endpoints returning 405). The SDK wraps the original AbortError in
|
||||
// a new Error with the message "SSE stream disconnected: AbortError",
|
||||
// and also produces "Cannot cancel a stream locked by a reader".
|
||||
// DOMException is thrown by the browser when aborting fetch requests.
|
||||
const msg = error.message || String(error);
|
||||
// the SDK reports any post initialize error here, including the abort we trigger
|
||||
// ourselves on the next health check cycle, on tab unload, or on server teardown.
|
||||
// these are lifecycle aborts, not actionable errors, so we keep them out of the red console.
|
||||
// the SDK wraps the original AbortError in a generic Error like
|
||||
// "SSE stream disconnected: AbortError: The operation was aborted."
|
||||
// which isAbortError cannot recognize by name alone, so we also pattern match on the message
|
||||
if (isAbortError(error)) {
|
||||
return;
|
||||
}
|
||||
const msg = error?.message ?? '';
|
||||
if (
|
||||
error.name === 'AbortError' ||
|
||||
error instanceof DOMException ||
|
||||
msg.includes('SSE stream disconnected') ||
|
||||
msg.includes('stream locked by a reader') ||
|
||||
msg.includes('The operation was aborted')
|
||||
/SSE stream disconnected:.*AbortError/i.test(msg) ||
|
||||
/AbortError: .*aborted/i.test(msg) ||
|
||||
/stream locked by a reader/i.test(msg)
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -614,7 +614,7 @@ class AgenticStore {
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
undefined,
|
||||
conversationId,
|
||||
signal
|
||||
);
|
||||
|
||||
|
||||
@@ -11,9 +11,11 @@
|
||||
* @see ChatService in services/chat.service.ts for API operations
|
||||
*/
|
||||
|
||||
import { SvelteMap } from 'svelte/reactivity';
|
||||
import { SvelteMap, SvelteSet } from 'svelte/reactivity';
|
||||
import { DatabaseService } from '$lib/services/database.service';
|
||||
import { ChatService } from '$lib/services/chat.service';
|
||||
import { streamIdentity } from '$lib/utils/stream-identity';
|
||||
import { getAuthHeaders } from '$lib/utils/api-headers';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { agenticStore } from '$lib/stores/agentic.svelte';
|
||||
@@ -49,10 +51,17 @@ import type {
|
||||
import type {
|
||||
ApiChatMessageData,
|
||||
ApiProcessingState,
|
||||
ApiStreamSession,
|
||||
DatabaseMessage,
|
||||
DatabaseMessageExtra
|
||||
} from '$lib/types';
|
||||
import { ContinueIntentKind, ErrorDialogType, MessageRole, MessageType } from '$lib/enums';
|
||||
import {
|
||||
ContinueIntentKind,
|
||||
ErrorDialogType,
|
||||
MessageRole,
|
||||
MessageType,
|
||||
StreamConnectionState
|
||||
} from '$lib/enums';
|
||||
|
||||
interface ConversationStateEntry {
|
||||
lastAccessed: number;
|
||||
@@ -65,9 +74,25 @@ class ChatStore {
|
||||
isLoading = $state(false);
|
||||
// true while the active conversation streams reasoning content but no visible content yet
|
||||
isReasoning = $state(false);
|
||||
// resumable stream connection state for the active conversation
|
||||
// streaming -> bytes flowing normally, resuming -> waiting on /v1/stream/:id reconnect, lost -> unrecoverable
|
||||
streamConnectionState = $state<StreamConnectionState>(StreamConnectionState.STREAMING);
|
||||
chatLoadingStates = new SvelteMap<string, boolean>();
|
||||
chatReasoningStates = new SvelteMap<string, boolean>();
|
||||
chatStreamingStates = new SvelteMap<string, { response: string; messageId: string }>();
|
||||
chatStreamingStates = new SvelteMap<
|
||||
string,
|
||||
{ response: string; messageId: string; model?: string | null }
|
||||
>();
|
||||
// convs that the backend reports as having a running session, populated by the global sync
|
||||
// at app mount and on visibilitychange. it does not overlap with chatLoadingStates which
|
||||
// tracks inferences driven by this browser, both are unioned to feed the sidebar spinners
|
||||
private remoteRunningConvs = new SvelteSet<string>();
|
||||
// per conv attach lifecycle, used to derive the global streaming flag without flipping it
|
||||
// off when one conv finishes while another is still streaming. mirrors chatLoadingStates
|
||||
// in scope but tracks the attach + tee replay path specifically
|
||||
private attachingConvs = new SvelteSet<string>();
|
||||
// in-flight discoverActiveStream guard, keyed by conv id
|
||||
private discoveringConvs = new SvelteSet<string>();
|
||||
private abortControllers = new SvelteMap<string, AbortController>();
|
||||
private preEncodeAbortController: AbortController | null = null;
|
||||
private processingStates = new SvelteMap<string, ApiProcessingState | null>();
|
||||
@@ -98,6 +123,11 @@ class ChatStore {
|
||||
this.chatLoadingStates.delete(convId);
|
||||
if (convId === conversationsStore.activeConversation?.id) this.isLoading = false;
|
||||
this.setChatReasoning(convId, false);
|
||||
// the local pipe is the authoritative observer of session end: when it finishes (clean
|
||||
// onComplete or explicit Stop), the backend session is finalized too, so we drop the
|
||||
// sidebar hint for this conv right away instead of waiting for the next visibilitychange
|
||||
// snapshot. without this the spinner ghosts until the user toggles the tab
|
||||
this.remoteRunningConvs.delete(convId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -110,9 +140,18 @@ class ChatStore {
|
||||
if (convId === conversationsStore.activeConversation?.id) this.isReasoning = false;
|
||||
}
|
||||
}
|
||||
private setChatStreaming(convId: string, response: string, messageId: string): void {
|
||||
private setChatStreaming(
|
||||
convId: string,
|
||||
response: string,
|
||||
messageId: string,
|
||||
model?: string | null
|
||||
): void {
|
||||
this.touchConversationState(convId);
|
||||
this.chatStreamingStates.set(convId, { response, messageId });
|
||||
this.chatStreamingStates.set(convId, {
|
||||
response,
|
||||
messageId,
|
||||
model: model ?? this.chatStreamingStates.get(convId)?.model
|
||||
});
|
||||
if (convId === conversationsStore.activeConversation?.id) this.currentResponse = response;
|
||||
}
|
||||
private clearChatStreaming(convId: string): void {
|
||||
@@ -137,6 +176,314 @@ class ChatStore {
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Server side stream discovery, split in three pieces:
|
||||
*
|
||||
* probeServerStream(convId) -> hits POST /v1/streams/lookup with the conv id, returns the session to attach
|
||||
* to or null. Pure read, no side effect, no UI lock. Safe to fire in parallel with anything.
|
||||
*
|
||||
* attachServerStream(convId) -> flips the spinner immediately, fetches the replay stream
|
||||
* from byte 0, finds the assistant slot to splice into (creates a placeholder if the conv has
|
||||
* no assistant message yet, for cross device or fresh local DB cases), and pipes the SSE bytes
|
||||
* into the message via handleStreamResponse.
|
||||
*
|
||||
* discoverActiveStream(convId) -> probe + attach in one call. Used by callers that do not need
|
||||
* to overlap the probe with other async work.
|
||||
*
|
||||
* The mount of the chat page in +page.svelte calls probeServerStream in parallel with
|
||||
* loadConversation, then attachServerStream once both have settled. This gives the earliest
|
||||
* possible time to spinner and avoids racing against an empty activeMessages array.
|
||||
*/
|
||||
async probeServerStream(convId: string): Promise<ApiStreamSession | null> {
|
||||
if (!convId) return null;
|
||||
let listResp: Response;
|
||||
try {
|
||||
// POST the one conv id we are probing
|
||||
listResp = await fetch(`./v1/streams/lookup`, {
|
||||
method: 'POST',
|
||||
headers: { ...getAuthHeaders(), 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ conversation_ids: [convId] })
|
||||
});
|
||||
} catch (e) {
|
||||
console.warn('probeServerStream fetch failed:', e);
|
||||
return null;
|
||||
}
|
||||
if (!listResp.ok) {
|
||||
console.warn(`probeServerStream got HTTP ${listResp.status} for conv ${convId}`);
|
||||
return null;
|
||||
}
|
||||
let sessions: ApiStreamSession[];
|
||||
try {
|
||||
sessions = (await listResp.json()) as ApiStreamSession[];
|
||||
} catch (e) {
|
||||
console.warn('probeServerStream JSON parse failed:', e);
|
||||
return null;
|
||||
}
|
||||
return ChatService.selectActiveStream(sessions);
|
||||
}
|
||||
|
||||
async attachServerStream(convId: string, streamId?: string): Promise<void> {
|
||||
if (!convId) return;
|
||||
if (this.chatStreamingStates.has(convId)) return;
|
||||
|
||||
// flip the spinner immediately, the user sees activity as soon as the conv becomes active.
|
||||
// the global isStreamingActive flag is derived from attachingConvs.size, so adding here
|
||||
// turns it on, and removing in unlock only turns it off when this is the last attach
|
||||
this.setChatLoading(convId, true);
|
||||
this.attachingConvs.add(convId);
|
||||
this.setStreamingActive(true);
|
||||
// only set the active processing conv if we are looking at it, otherwise a background
|
||||
// attach would steal the indicator from the conv the user is currently viewing
|
||||
if (convId === conversationsStore.activeConversation?.id) {
|
||||
this.setActiveProcessingConversation(convId);
|
||||
}
|
||||
|
||||
const unlock = () => {
|
||||
this.attachingConvs.delete(convId);
|
||||
// flip the global flag off only when no other conv is still attaching
|
||||
if (this.attachingConvs.size === 0) {
|
||||
this.setStreamingActive(false);
|
||||
}
|
||||
this.setChatLoading(convId, false);
|
||||
this.clearChatStreaming(convId);
|
||||
};
|
||||
|
||||
// fetch the replay stream from byte 0, rebuild the assistant message from scratch.
|
||||
// resolve the server side identity, fall back to streamIdentity when the caller does not
|
||||
// pass a streamId. probeServerStream returns the full id (with ::model suffix when present)
|
||||
const id = streamId || streamIdentity(convId, selectedModelName());
|
||||
let response: Response;
|
||||
try {
|
||||
response = await fetch(`./v1/stream/${encodeURIComponent(id)}?from=0`, {
|
||||
headers: getAuthHeaders()
|
||||
});
|
||||
} catch (e) {
|
||||
console.error('attachServerStream replay fetch failed:', e);
|
||||
unlock();
|
||||
return;
|
||||
}
|
||||
if (!response.ok) {
|
||||
console.warn(`attachServerStream replay got HTTP ${response.status} for conv ${convId}`);
|
||||
unlock();
|
||||
return;
|
||||
}
|
||||
|
||||
// load the target conversation messages by id, not via the active store. when multiple
|
||||
// attaches run in parallel the active store may reflect another conv and writing through
|
||||
// its index mixes content across convs (CoT flicker, message bleed). by going through the
|
||||
// DB we stay isolated, and only mirror into the active store when the attached conv is
|
||||
// the one currently displayed
|
||||
let messages: DatabaseMessage[];
|
||||
try {
|
||||
messages = await DatabaseService.getConversationMessages(convId);
|
||||
} catch (e) {
|
||||
console.error('attachServerStream load messages failed:', e);
|
||||
unlock();
|
||||
return;
|
||||
}
|
||||
|
||||
// locate the slot to splice into, create a placeholder assistant message if there is none.
|
||||
// we use the conv-scoped findLastAssistantIdx helpers, they only depend on the array
|
||||
let targetIdx = this.findLastAssistantIdx(messages);
|
||||
if (targetIdx === -1) {
|
||||
const lastUserIdx = this.findLastUserIdx(messages);
|
||||
if (lastUserIdx === -1) {
|
||||
console.warn(
|
||||
`attachServerStream: conv ${convId} has no user or assistant message, cannot splice`
|
||||
);
|
||||
unlock();
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const placeholder = await DatabaseService.createMessageBranch(
|
||||
{
|
||||
convId,
|
||||
role: MessageRole.ASSISTANT,
|
||||
content: '',
|
||||
type: MessageType.TEXT,
|
||||
timestamp: Date.now(),
|
||||
parent: messages[lastUserIdx].id,
|
||||
children: [],
|
||||
toolCalls: ''
|
||||
} as Omit<DatabaseMessage, 'id'>,
|
||||
messages[lastUserIdx].id
|
||||
);
|
||||
messages = [...messages, placeholder];
|
||||
targetIdx = messages.length - 1;
|
||||
// only push into the active store when this conv is the one displayed right now
|
||||
if (convId === conversationsStore.activeConversation?.id) {
|
||||
conversationsStore.addMessageToActive(placeholder);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('attachServerStream placeholder creation failed:', e);
|
||||
unlock();
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (targetIdx === -1) {
|
||||
unlock();
|
||||
return;
|
||||
}
|
||||
const targetMessage = messages[targetIdx];
|
||||
const targetMessageId = targetMessage.id;
|
||||
// when the assistant slot already has content, the running session is a continue or
|
||||
// another append flow and its buffer holds only the appended deltas. preserve the prefix
|
||||
// and let the replay add to it. when the slot is empty the session buffer holds the whole
|
||||
// message so we wipe and rebuild from byte 0
|
||||
const existingContent = targetMessage.content ?? '';
|
||||
const existingReasoning = targetMessage.reasoningContent ?? '';
|
||||
const isAppendMode = existingContent.length > 0;
|
||||
|
||||
// helper: write to the active store only when the attached conv is currently displayed.
|
||||
// the lookup by message id is robust to reordering of activeMessages, two parallel attaches
|
||||
// can no longer step on each other's indices
|
||||
const writeActive = (updates: Partial<DatabaseMessage>) => {
|
||||
if (convId !== conversationsStore.activeConversation?.id) {
|
||||
return;
|
||||
}
|
||||
const liveIdx = conversationsStore.findMessageIndex(targetMessageId);
|
||||
if (liveIdx === -1) return;
|
||||
conversationsStore.updateMessageAtIndex(liveIdx, updates);
|
||||
};
|
||||
|
||||
if (!isAppendMode) {
|
||||
writeActive({ content: '', reasoningContent: undefined });
|
||||
}
|
||||
|
||||
// extract the model suffix, the resume calls in handleStreamResponse must reuse the model
|
||||
// the session was tagged with, not the live dropdown
|
||||
const sepIdx = id.indexOf('::');
|
||||
const attachedModel: string | null = sepIdx === -1 ? null : id.slice(sepIdx + 2);
|
||||
this.setChatStreaming(convId, existingContent, targetMessageId, attachedModel);
|
||||
const abortController = this.getOrCreateAbortController(convId);
|
||||
|
||||
let streamedContent = '';
|
||||
let streamedReasoningContent = '';
|
||||
|
||||
const cleanup = () => {
|
||||
unlock();
|
||||
this.setProcessingState(convId, null);
|
||||
};
|
||||
|
||||
try {
|
||||
await ChatService.handleStreamResponse(
|
||||
response,
|
||||
(chunk: string) => {
|
||||
streamedContent += chunk;
|
||||
const displayed = isAppendMode ? existingContent + streamedContent : streamedContent;
|
||||
writeActive({ content: displayed });
|
||||
this.setChatStreaming(convId, displayed, targetMessageId);
|
||||
},
|
||||
async (
|
||||
finalContent?: string,
|
||||
reasoningContent?: string,
|
||||
timings?: ChatMessageTimings,
|
||||
toolCalls?: string
|
||||
) => {
|
||||
const streamed = streamedContent || finalContent || '';
|
||||
const streamedR = streamedReasoningContent || reasoningContent || '';
|
||||
const content = isAppendMode ? existingContent + streamed : streamed;
|
||||
const reasoning = isAppendMode ? existingReasoning + streamedR : streamedR;
|
||||
// the DB write is the source of truth, mirror to the active store only when
|
||||
// the conv is currently displayed
|
||||
await DatabaseService.updateMessage(targetMessageId, {
|
||||
content,
|
||||
reasoningContent: reasoning || undefined,
|
||||
toolCalls: toolCalls || '',
|
||||
timings
|
||||
});
|
||||
writeActive({
|
||||
content,
|
||||
reasoningContent: reasoning || undefined,
|
||||
timings
|
||||
});
|
||||
cleanup();
|
||||
},
|
||||
(err: Error) => {
|
||||
console.error('attachServerStream pipe error:', err);
|
||||
cleanup();
|
||||
},
|
||||
(chunk: string) => {
|
||||
streamedReasoningContent += chunk;
|
||||
const displayed = isAppendMode
|
||||
? existingReasoning + streamedReasoningContent
|
||||
: streamedReasoningContent;
|
||||
writeActive({ reasoningContent: displayed });
|
||||
},
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
convId,
|
||||
abortController.signal,
|
||||
(connState: StreamConnectionState) => {
|
||||
if (convId === conversationsStore.activeConversation?.id) {
|
||||
this.streamConnectionState = connState;
|
||||
}
|
||||
},
|
||||
attachedModel
|
||||
);
|
||||
} catch (e) {
|
||||
console.error('attachServerStream pipe crashed:', e);
|
||||
cleanup();
|
||||
}
|
||||
}
|
||||
|
||||
async discoverActiveStream(convId: string): Promise<void> {
|
||||
if (!convId) return;
|
||||
if (this.chatStreamingStates.has(convId)) return;
|
||||
if (this.chatLoadingStates.get(convId)) return;
|
||||
// concurrency guard: another discover may already be running for this conv (typical race
|
||||
// between mount and visibilitychange on tab switch). a second concurrent fetch on the same
|
||||
// /v1/stream/<id> would duplicate every byte into the DB message, this guard bounces it
|
||||
if (this.discoveringConvs.has(convId)) return;
|
||||
this.discoveringConvs.add(convId);
|
||||
|
||||
try {
|
||||
// the model is frozen at POST time, rebuild the exact conv::model identity from the
|
||||
// persisted state so the lookup key matches what the server stored. null means a single
|
||||
// model conv with no ::suffix, only guess from the dropdown with no persisted state
|
||||
const localState = ChatService.getStreamState(convId);
|
||||
const streamId = ChatService.resumeStreamIdentity(convId, localState, selectedModelName());
|
||||
|
||||
// primary path: ask the server which sessions exist for this identity
|
||||
const serverTarget = await this.probeServerStream(streamId);
|
||||
if (serverTarget) {
|
||||
// pass the full server side identity (may carry a ::model suffix) so the GET routes
|
||||
// straight to the owning session, no probe or fan out
|
||||
await this.attachServerStream(convId, serverTarget.conversation_id);
|
||||
return;
|
||||
}
|
||||
|
||||
// fallback: local state remembers an interrupted byte offset for this conv, the server may
|
||||
// still have a live session matching that identity (we just lost the bytes mid stream). retry
|
||||
// with the frozen identity, the server probe inside attachServerStream tells us if it exists
|
||||
if (!localState) {
|
||||
return;
|
||||
}
|
||||
await this.attachServerStream(convId, streamId);
|
||||
// if attachServerStream failed (session gone, TTL expired), clear the local state to avoid retrying forever
|
||||
if (!this.chatStreamingStates.has(convId) && !this.chatLoadingStates.get(convId)) {
|
||||
ChatService.clearStreamState(convId);
|
||||
}
|
||||
} finally {
|
||||
this.discoveringConvs.delete(convId);
|
||||
}
|
||||
}
|
||||
|
||||
private findLastAssistantIdx(messages: DatabaseMessage[]): number {
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
if (messages[i].role === MessageRole.ASSISTANT) return i;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
private findLastUserIdx(messages: DatabaseMessage[]): number {
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
if (messages[i].role === MessageRole.USER) return i;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
clearUIState(): void {
|
||||
this.isLoading = false;
|
||||
@@ -265,13 +612,83 @@ class ChatStore {
|
||||
}
|
||||
|
||||
getAllLoadingChats(): string[] {
|
||||
return Array.from(this.chatLoadingStates.keys());
|
||||
// union of local (this browser is piping) and remote (backend reports a running session
|
||||
// for this conv but no local pipe yet) sources. the sidebar shows one spinner per entry
|
||||
const out = new SvelteSet<string>(this.chatLoadingStates.keys());
|
||||
for (const id of this.remoteRunningConvs) {
|
||||
out.add(id);
|
||||
}
|
||||
return Array.from(out);
|
||||
}
|
||||
|
||||
getAllStreamingChats(): string[] {
|
||||
return Array.from(this.chatStreamingStates.keys());
|
||||
}
|
||||
|
||||
/**
|
||||
* Resync the remote running convs set from the backend. Called by the layout at mount and on
|
||||
* visibilitychange, no polling. A snapshot semantic: the set is replaced wholesale, stale entries
|
||||
* for sessions that finalized while the browser was elsewhere are dropped naturally.
|
||||
*/
|
||||
async syncRemoteRunningStreams(): Promise<void> {
|
||||
// the conversations store loads from IndexedDB asynchronously, the +layout onMount caller
|
||||
// fires before that finishes. read ids straight from the DB so the result does not depend
|
||||
// on the store init race, and the sidebar spinners light up at first paint for every conv
|
||||
// the user owns even if it has not been hydrated into the store yet
|
||||
let ids: string[];
|
||||
try {
|
||||
const all = await DatabaseService.getAllConversations();
|
||||
ids = all.map((c) => c.id).filter((id) => !!id);
|
||||
} catch (e) {
|
||||
console.warn('syncRemoteRunningStreams DB read failed:', e);
|
||||
return;
|
||||
}
|
||||
// only ask about conv ids the user already owns
|
||||
if (ids.length === 0) {
|
||||
for (const id of Array.from(this.remoteRunningConvs)) {
|
||||
this.remoteRunningConvs.delete(id);
|
||||
}
|
||||
return;
|
||||
}
|
||||
// rebuild the frozen conv::model identity per conv so a session started with a model still
|
||||
// matches. the server response is mapped back to the bare id below for the sidebar set
|
||||
const lookupIds = ids.map((id) =>
|
||||
ChatService.resumeStreamIdentity(id, ChatService.getStreamState(id), null)
|
||||
);
|
||||
let sessions: ApiStreamSession[];
|
||||
try {
|
||||
const resp = await fetch('./v1/streams/lookup', {
|
||||
method: 'POST',
|
||||
headers: { ...getAuthHeaders(), 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ conversation_ids: lookupIds })
|
||||
});
|
||||
if (!resp.ok) return;
|
||||
const body = (await resp.json()) as unknown;
|
||||
if (!Array.isArray(body)) return;
|
||||
sessions = body as ApiStreamSession[];
|
||||
} catch (e) {
|
||||
console.warn('syncRemoteRunningStreams fetch failed:', e);
|
||||
return;
|
||||
}
|
||||
const running = new SvelteSet<string>();
|
||||
for (const s of sessions) {
|
||||
if (s && !s.is_done && typeof s.conversation_id === 'string' && s.conversation_id) {
|
||||
// strip the optional ::model suffix, the sidebar set is keyed by the bare conv id
|
||||
const sepIdx = s.conversation_id.indexOf('::');
|
||||
const bareId = sepIdx === -1 ? s.conversation_id : s.conversation_id.slice(0, sepIdx);
|
||||
running.add(bareId);
|
||||
}
|
||||
}
|
||||
for (const id of Array.from(this.remoteRunningConvs)) {
|
||||
if (!running.has(id)) {
|
||||
this.remoteRunningConvs.delete(id);
|
||||
}
|
||||
}
|
||||
for (const id of running) {
|
||||
this.remoteRunningConvs.add(id);
|
||||
}
|
||||
}
|
||||
|
||||
getChatStreamingPublic(convId: string): { response: string; messageId: string } | undefined {
|
||||
return this.getChatStreaming(convId);
|
||||
}
|
||||
@@ -922,6 +1339,11 @@ class ChatStore {
|
||||
onModel: streamCallbacks.onModel,
|
||||
onCompletionId: streamCallbacks.onCompletionId,
|
||||
onTimings: streamCallbacks.onTimings,
|
||||
onConnectionState: (state: StreamConnectionState) => {
|
||||
if (convId === conversationsStore.activeConversation?.id) {
|
||||
this.streamConnectionState = state;
|
||||
}
|
||||
},
|
||||
onComplete: async (
|
||||
finalContent?: string,
|
||||
reasoningContent?: string,
|
||||
@@ -979,6 +1401,12 @@ class ChatStore {
|
||||
async stopGenerationForChat(convId: string): Promise<void> {
|
||||
await this.savePartialResponseIfNeeded(convId);
|
||||
this.setStreamingActive(false);
|
||||
// tell the server to stop the generation, not just drop the HTTP socket. without this the
|
||||
// detached drain keeps producing tokens until eos or max_tokens. use the frozen identity
|
||||
// captured when the session started, not the live dropdown
|
||||
const streamStateForStop = this.chatStreamingStates.get(convId);
|
||||
const modelForStop = streamStateForStop?.model ?? selectedModelName();
|
||||
void ChatService.cancelServerStream(convId, modelForStop);
|
||||
this.abortRequest(convId);
|
||||
this.setChatLoading(convId, false);
|
||||
this.clearChatStreaming(convId);
|
||||
@@ -1393,7 +1821,11 @@ class ChatStore {
|
||||
|
||||
const updateStreamingContent = (fullContent: string) => {
|
||||
this.setChatStreaming(msg.convId, fullContent, msg.id);
|
||||
conversationsStore.updateMessageAtIndex(idx, { content: fullContent });
|
||||
// resolve the row by id on every write, switching to another conv mid continue makes
|
||||
// this a no op instead of writing positionally into the now displayed conversation
|
||||
conversationsStore.updateMessageAtIndex(conversationsStore.findMessageIndex(msg.id), {
|
||||
content: fullContent
|
||||
});
|
||||
};
|
||||
|
||||
const abortController = this.getOrCreateAbortController(msg.convId);
|
||||
@@ -1403,6 +1835,11 @@ class ChatStore {
|
||||
{
|
||||
...this.getApiOptions(),
|
||||
continueFinalMessage: true,
|
||||
onConnectionState: (state: StreamConnectionState) => {
|
||||
if (msg.convId === conversationsStore.activeConversation?.id) {
|
||||
this.streamConnectionState = state;
|
||||
}
|
||||
},
|
||||
onChunk: (chunk: string) => {
|
||||
appendedContent += chunk;
|
||||
hasReceivedContent = true;
|
||||
@@ -1414,7 +1851,7 @@ class ChatStore {
|
||||
hasReceivedContent = true;
|
||||
// mark streaming state so a stop mid-thinking can persist the partial reasoning
|
||||
this.setChatStreaming(msg.convId, originalContent + appendedContent, msg.id);
|
||||
conversationsStore.updateMessageAtIndex(idx, {
|
||||
conversationsStore.updateMessageAtIndex(conversationsStore.findMessageIndex(msg.id), {
|
||||
reasoningContent: originalReasoning + appendedReasoning
|
||||
});
|
||||
this.setChatReasoning(msg.convId, true);
|
||||
@@ -1455,7 +1892,7 @@ class ChatStore {
|
||||
timings
|
||||
});
|
||||
|
||||
conversationsStore.updateMessageAtIndex(idx, {
|
||||
conversationsStore.updateMessageAtIndex(conversationsStore.findMessageIndex(msg.id), {
|
||||
content: fullContent,
|
||||
reasoningContent: fullReasoning,
|
||||
timestamp: Date.now(),
|
||||
@@ -1477,11 +1914,14 @@ class ChatStore {
|
||||
timestamp: Date.now()
|
||||
});
|
||||
|
||||
conversationsStore.updateMessageAtIndex(idx, {
|
||||
content: originalContent + appendedContent,
|
||||
reasoningContent: originalReasoning + appendedReasoning || undefined,
|
||||
timestamp: Date.now()
|
||||
});
|
||||
conversationsStore.updateMessageAtIndex(
|
||||
conversationsStore.findMessageIndex(msg.id),
|
||||
{
|
||||
content: originalContent + appendedContent,
|
||||
reasoningContent: originalReasoning + appendedReasoning || undefined,
|
||||
timestamp: Date.now()
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
this.setChatLoading(msg.convId, false);
|
||||
@@ -1498,7 +1938,7 @@ class ChatStore {
|
||||
reasoningContent: originalReasoning + appendedReasoning || undefined,
|
||||
timestamp: Date.now()
|
||||
});
|
||||
conversationsStore.updateMessageAtIndex(idx, {
|
||||
conversationsStore.updateMessageAtIndex(conversationsStore.findMessageIndex(msg.id), {
|
||||
content: originalContent + appendedContent,
|
||||
reasoningContent: originalReasoning + appendedReasoning || undefined,
|
||||
timestamp: Date.now()
|
||||
|
||||
Vendored
+15
@@ -512,3 +512,18 @@ export interface ApiRouterModelsUnloadResponse {
|
||||
success: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Entry returned by POST /v1/streams/lookup. The client passes the conv ids it owns in the body
|
||||
* and the server returns one entry per matching live or recently completed background streaming
|
||||
* session, keyed by conversation_id. The WebUI uses this at mount and on visibilitychange to
|
||||
* populate sidebar spinners and to reattach to an ongoing inference for the active conversation.
|
||||
* The server never lists ids the client did not ask about, so foreign random UUIDs stay private.
|
||||
*/
|
||||
export interface ApiStreamSession {
|
||||
conversation_id: string;
|
||||
is_done: boolean;
|
||||
total_bytes: number;
|
||||
started_at: number;
|
||||
completed_at: number;
|
||||
}
|
||||
|
||||
@@ -34,7 +34,8 @@ export type {
|
||||
ApiRouterModelsListResponse,
|
||||
ApiRouterModelsUnloadRequest,
|
||||
ApiRouterModelsUnloadResponse,
|
||||
AudioInputFormat
|
||||
AudioInputFormat,
|
||||
ApiStreamSession
|
||||
} from './api';
|
||||
|
||||
// Chat types
|
||||
|
||||
Vendored
+4
-2
@@ -4,9 +4,10 @@ import type { OpenAIToolDefinition } from './mcp';
|
||||
import type { DatabaseMessageExtra } from './database';
|
||||
import type {
|
||||
ParameterSource,
|
||||
ReasoningEffort,
|
||||
SyncableParameterType,
|
||||
SettingsFieldType
|
||||
SettingsFieldType,
|
||||
StreamConnectionState,
|
||||
ReasoningEffort
|
||||
} from '$lib/enums';
|
||||
import type { Icon } from '@lucide/svelte';
|
||||
import type { Component } from 'svelte';
|
||||
@@ -119,6 +120,7 @@ export interface SettingsChatServiceOptions {
|
||||
toolCalls?: string
|
||||
) => void;
|
||||
onError?: (error: Error) => void;
|
||||
onConnectionState?: (state: StreamConnectionState) => void;
|
||||
}
|
||||
|
||||
export type SettingsConfigType = typeof SETTING_CONFIG_DEFAULT & {
|
||||
|
||||
@@ -6,6 +6,17 @@
|
||||
* when needed (e.g., user stops generation, navigates away, etc.).
|
||||
*/
|
||||
|
||||
// the standard DOMException name for a cancelled operation
|
||||
const ABORT_ERROR_NAME = 'AbortError';
|
||||
|
||||
// browser specific TypeError messages emitted when a fetch reader is cut by page unload,
|
||||
// navigation, or a transient network drop. functionally aborts, not actionable errors
|
||||
const ABORT_LIKE_MESSAGE_PATTERNS = [
|
||||
/input stream/i, // Firefox: stream cut at unload
|
||||
/network connection was lost/i, // Safari: transient network drop
|
||||
/load failed/i // Safari: page navigation during fetch
|
||||
];
|
||||
|
||||
/**
|
||||
* Throws an AbortError if the signal is aborted.
|
||||
* Use this at the start of async operations to fail fast.
|
||||
@@ -23,7 +34,7 @@
|
||||
*/
|
||||
export function throwIfAborted(signal?: AbortSignal): void {
|
||||
if (signal?.aborted) {
|
||||
throw new DOMException('Operation was aborted', 'AbortError');
|
||||
throw new DOMException('Operation was aborted', ABORT_ERROR_NAME);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,11 +59,18 @@ export function throwIfAborted(signal?: AbortSignal): void {
|
||||
* ```
|
||||
*/
|
||||
export function isAbortError(error: unknown): boolean {
|
||||
if (error instanceof DOMException && error.name === 'AbortError') {
|
||||
if (error instanceof DOMException && error.name === ABORT_ERROR_NAME) {
|
||||
return true;
|
||||
}
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
return true;
|
||||
if (error instanceof Error) {
|
||||
if (error.name === ABORT_ERROR_NAME) {
|
||||
return true;
|
||||
}
|
||||
// these patterns are functionally aborts, keep them out of the red console
|
||||
if (error instanceof TypeError) {
|
||||
const msg = error.message ?? '';
|
||||
if (ABORT_LIKE_MESSAGE_PATTERNS.some((re) => re.test(msg))) return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -133,7 +151,7 @@ export async function withAbortSignal<T>(promise: Promise<T>, signal?: AbortSign
|
||||
|
||||
return new Promise<T>((resolve, reject) => {
|
||||
const abortHandler = () => {
|
||||
reject(new DOMException('Operation was aborted', 'AbortError'));
|
||||
reject(new DOMException('Operation was aborted', ABORT_ERROR_NAME));
|
||||
};
|
||||
|
||||
signal.addEventListener('abort', abortHandler, { once: true });
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
/**
|
||||
* Build the conversation identity used by the server side replay buffer.
|
||||
*
|
||||
* The server identifies a stream session by a conversation id sent in the
|
||||
* X-Conversation-Id header. When the user has explicitly picked a model the
|
||||
* client appends ::modelName, so a per model session stays distinct and the
|
||||
* router resolves the owning child through its conv_id -> model map.
|
||||
*/
|
||||
export function streamIdentity(conversationId: string, model?: string | null): string {
|
||||
if (!conversationId) return '';
|
||||
if (!model) return conversationId;
|
||||
return `${conversationId}::${model}`;
|
||||
}
|
||||
@@ -4,7 +4,7 @@
|
||||
import { afterNavigate } from '$app/navigation';
|
||||
import { DialogModelNotAvailable } from '$lib/components/app';
|
||||
import { APP_NAME, ROUTES } from '$lib/constants';
|
||||
import { chatStore, isLoading } from '$lib/stores/chat.svelte';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore, activeConversation } from '$lib/stores/conversations.svelte';
|
||||
import { modelsStore, modelOptions } from '$lib/stores/models.svelte';
|
||||
|
||||
@@ -83,7 +83,7 @@
|
||||
|
||||
// Skip loading if this conversation is already active (e.g., just created)
|
||||
if (activeConversation()?.id === chatId) {
|
||||
// Still handle URL params even if conversation is active
|
||||
void chatStore.discoverActiveStream(chatId);
|
||||
if ((qParam !== null || modelParam !== null) && !urlParamsProcessed) {
|
||||
handleUrlParams();
|
||||
}
|
||||
@@ -92,35 +92,33 @@
|
||||
|
||||
(async () => {
|
||||
const success = await conversationsStore.loadConversation(chatId);
|
||||
if (success) {
|
||||
chatStore.syncLoadingStateForChat(chatId);
|
||||
|
||||
// Handle URL params after conversation is loaded
|
||||
if ((qParam !== null || modelParam !== null) && !urlParamsProcessed) {
|
||||
await handleUrlParams();
|
||||
}
|
||||
} else {
|
||||
if (!success) {
|
||||
await goto(ROUTES.START);
|
||||
return;
|
||||
}
|
||||
chatStore.syncLoadingStateForChat(chatId);
|
||||
// server probe (with localStorage fallback) and attach
|
||||
await chatStore.discoverActiveStream(chatId);
|
||||
|
||||
if ((qParam !== null || modelParam !== null) && !urlParamsProcessed) {
|
||||
await handleUrlParams();
|
||||
}
|
||||
})();
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (typeof window !== 'undefined') {
|
||||
const handleBeforeUnload = () => {
|
||||
if (isLoading()) {
|
||||
console.log('Page unload detected while streaming - aborting stream');
|
||||
chatStore.stopGeneration();
|
||||
}
|
||||
};
|
||||
if (typeof window === 'undefined' || typeof document === 'undefined') return;
|
||||
|
||||
window.addEventListener('beforeunload', handleBeforeUnload);
|
||||
|
||||
return () => {
|
||||
window.removeEventListener('beforeunload', handleBeforeUnload);
|
||||
};
|
||||
}
|
||||
// when the tab comes back to the foreground, re-run discovery to catch any race
|
||||
// where the initial mount probe missed an active session
|
||||
const onVisibility = () => {
|
||||
if (document.visibilityState !== 'visible') return;
|
||||
if (!chatId) return;
|
||||
void chatStore.discoverActiveStream(chatId);
|
||||
};
|
||||
document.addEventListener('visibilitychange', onVisibility);
|
||||
return () => document.removeEventListener('visibilitychange', onVisibility);
|
||||
});
|
||||
</script>
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
import { PwaMetaTags, PwaRefreshAlert } from '$lib/components/pwa';
|
||||
import { pwaAssetsHead } from 'virtual:pwa-assets/head';
|
||||
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { isRouterMode, serverStore } from '$lib/stores/server.svelte';
|
||||
@@ -154,8 +155,18 @@
|
||||
|
||||
onMount(() => {
|
||||
updateFavicon();
|
||||
// snapshot of every backend running stream on first load, populates the sidebar spinners
|
||||
// so the user sees each conv that has a live inference, even ones not opened yet
|
||||
void chatStore.syncRemoteRunningStreams();
|
||||
});
|
||||
|
||||
// refresh that snapshot when the tab returns to the foreground, a stream may have advanced
|
||||
// or ended while it was hidden. snapshot only, no polling
|
||||
function handleVisibilityChange() {
|
||||
if (document.visibilityState !== 'visible') return;
|
||||
void chatStore.syncRemoteRunningStreams();
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
void theme.isSystemDark;
|
||||
|
||||
@@ -280,6 +291,7 @@
|
||||
</svelte:head>
|
||||
|
||||
<svelte:window onkeydown={handleKeydown} bind:innerHeight bind:innerWidth />
|
||||
<svelte:document onvisibilitychange={handleVisibilityChange} />
|
||||
|
||||
<Tooltip.Provider delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
<div class="flex flex-col md:flex-row">
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import { isAbortError } from '$lib/utils/abort';
|
||||
|
||||
describe('isAbortError', () => {
|
||||
it('returns false for null, undefined and non-error values', () => {
|
||||
expect(isAbortError(null)).toBe(false);
|
||||
expect(isAbortError(undefined)).toBe(false);
|
||||
expect(isAbortError('string error')).toBe(false);
|
||||
expect(isAbortError({ name: 'AbortError' })).toBe(false);
|
||||
expect(isAbortError(42)).toBe(false);
|
||||
});
|
||||
|
||||
it('returns true for DOMException with AbortError name', () => {
|
||||
const err = new DOMException('Operation was aborted', 'AbortError');
|
||||
expect(isAbortError(err)).toBe(true);
|
||||
});
|
||||
|
||||
it('returns true for plain Error with AbortError name', () => {
|
||||
const err = new Error('aborted');
|
||||
err.name = 'AbortError';
|
||||
expect(isAbortError(err)).toBe(true);
|
||||
});
|
||||
|
||||
it('returns false for unrelated Error instances', () => {
|
||||
expect(isAbortError(new Error('something failed'))).toBe(false);
|
||||
expect(isAbortError(new TypeError('not related'))).toBe(false);
|
||||
expect(isAbortError(new RangeError('out of range'))).toBe(false);
|
||||
});
|
||||
|
||||
it('recognizes Firefox TypeError "Error in input stream" emitted at page unload', () => {
|
||||
expect(isAbortError(new TypeError('Error in input stream'))).toBe(true);
|
||||
expect(isAbortError(new TypeError('TypeError: Error in input stream'))).toBe(true);
|
||||
});
|
||||
|
||||
it('recognizes Safari "The network connection was lost" during transient drop', () => {
|
||||
expect(isAbortError(new TypeError('The network connection was lost.'))).toBe(true);
|
||||
});
|
||||
|
||||
it('recognizes Safari "Load failed" during page navigation', () => {
|
||||
expect(isAbortError(new TypeError('Load failed'))).toBe(true);
|
||||
});
|
||||
|
||||
it('does NOT recognize generic TypeError messages as aborts', () => {
|
||||
// matching too broadly would hide real bugs, the predicate must stay conservative
|
||||
expect(isAbortError(new TypeError('Failed to fetch'))).toBe(false);
|
||||
expect(isAbortError(new TypeError('Cannot read property of undefined'))).toBe(false);
|
||||
expect(isAbortError(new TypeError('NetworkError when attempting to fetch resource'))).toBe(
|
||||
false
|
||||
);
|
||||
});
|
||||
|
||||
it('is case insensitive on the matched substrings', () => {
|
||||
expect(isAbortError(new TypeError('error in INPUT STREAM'))).toBe(true);
|
||||
expect(isAbortError(new TypeError('the network connection WAS LOST'))).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,74 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import { ChatService } from '$lib/services/chat.service';
|
||||
import type { ApiStreamSession } from '$lib/types';
|
||||
|
||||
function makeSession(overrides: Partial<ApiStreamSession>): ApiStreamSession {
|
||||
return {
|
||||
conversation_id: 'conv',
|
||||
is_done: true,
|
||||
total_bytes: 0,
|
||||
started_at: 0,
|
||||
completed_at: 0,
|
||||
...overrides
|
||||
};
|
||||
}
|
||||
|
||||
describe('selectActiveStream', () => {
|
||||
it('returns null on empty input', () => {
|
||||
expect(ChatService.selectActiveStream([])).toBeNull();
|
||||
});
|
||||
|
||||
it('returns null on null or undefined input', () => {
|
||||
expect(ChatService.selectActiveStream(null)).toBeNull();
|
||||
expect(ChatService.selectActiveStream(undefined)).toBeNull();
|
||||
});
|
||||
|
||||
it('returns the single session when it is running', () => {
|
||||
const s = makeSession({ conversation_id: 'only', is_done: false, started_at: 42 });
|
||||
expect(ChatService.selectActiveStream([s])).toBe(s);
|
||||
});
|
||||
|
||||
it('returns null when the single session is finalized', () => {
|
||||
const s = makeSession({ conversation_id: 'only', is_done: true, started_at: 42 });
|
||||
expect(ChatService.selectActiveStream([s])).toBeNull();
|
||||
});
|
||||
|
||||
it('prefers a still running session over a finalized one regardless of started_at', () => {
|
||||
const finalized = makeSession({ conversation_id: 'old', is_done: true, started_at: 1000 });
|
||||
const running = makeSession({ conversation_id: 'new', is_done: false, started_at: 10 });
|
||||
expect(ChatService.selectActiveStream([finalized, running])?.conversation_id).toBe('new');
|
||||
expect(ChatService.selectActiveStream([running, finalized])?.conversation_id).toBe('new');
|
||||
});
|
||||
|
||||
it('among running sessions, picks the most recently started one', () => {
|
||||
const a = makeSession({ conversation_id: 'a', is_done: false, started_at: 100 });
|
||||
const b = makeSession({ conversation_id: 'b', is_done: false, started_at: 200 });
|
||||
const c = makeSession({ conversation_id: 'c', is_done: false, started_at: 150 });
|
||||
expect(ChatService.selectActiveStream([a, b, c])?.conversation_id).toBe('b');
|
||||
expect(ChatService.selectActiveStream([c, a, b])?.conversation_id).toBe('b');
|
||||
});
|
||||
|
||||
it('returns null when all sessions are finalized, the DB already holds the content', () => {
|
||||
const a = makeSession({ conversation_id: 'a', is_done: true, started_at: 10 });
|
||||
const b = makeSession({ conversation_id: 'b', is_done: true, started_at: 30 });
|
||||
const c = makeSession({ conversation_id: 'c', is_done: true, started_at: 20 });
|
||||
expect(ChatService.selectActiveStream([a, b, c])).toBeNull();
|
||||
});
|
||||
|
||||
it('keeps the first match on ties when both are running with identical started_at', () => {
|
||||
// reduce visits left to right, the initial accumulator stays unless a strictly greater value appears
|
||||
const a = makeSession({ conversation_id: 'first', is_done: false, started_at: 50 });
|
||||
const b = makeSession({ conversation_id: 'second', is_done: false, started_at: 50 });
|
||||
expect(ChatService.selectActiveStream([a, b])?.conversation_id).toBe('first');
|
||||
});
|
||||
|
||||
it('handles a typical realistic mix: two finalized old, one freshly running, one freshly finalized', () => {
|
||||
const old1 = makeSession({ conversation_id: 'old1', is_done: true, started_at: 100 });
|
||||
const old2 = makeSession({ conversation_id: 'old2', is_done: true, started_at: 200 });
|
||||
const freshFin = makeSession({ conversation_id: 'freshFin', is_done: true, started_at: 500 });
|
||||
const running = makeSession({ conversation_id: 'running', is_done: false, started_at: 400 });
|
||||
expect(ChatService.selectActiveStream([old1, old2, freshFin, running])?.conversation_id).toBe(
|
||||
'running'
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,128 @@
|
||||
import { afterEach, beforeAll, beforeEach, describe, expect, it } from 'vitest';
|
||||
|
||||
// node env unit project has no DOM, install a minimal localStorage backed by a Map
|
||||
beforeAll(() => {
|
||||
const store = new Map<string, string>();
|
||||
const polyfill: Storage = {
|
||||
get length() {
|
||||
return store.size;
|
||||
},
|
||||
clear: () => store.clear(),
|
||||
getItem: (k) => (store.has(k) ? store.get(k)! : null),
|
||||
key: (i) => Array.from(store.keys())[i] ?? null,
|
||||
removeItem: (k) => {
|
||||
store.delete(k);
|
||||
},
|
||||
setItem: (k, v) => {
|
||||
store.set(k, String(v));
|
||||
}
|
||||
};
|
||||
(globalThis as unknown as { localStorage: Storage }).localStorage = polyfill;
|
||||
});
|
||||
|
||||
import { ChatService } from '$lib/services/chat.service';
|
||||
import { STREAM_RESUME_LOCALSTORAGE_KEY_PREFIX } from '$lib/constants';
|
||||
|
||||
describe('ChatService stream resume', () => {
|
||||
beforeEach(() => {
|
||||
localStorage.clear();
|
||||
});
|
||||
afterEach(() => {
|
||||
localStorage.clear();
|
||||
});
|
||||
|
||||
it('returns null when no state exists for the conversation', () => {
|
||||
expect(ChatService.getStreamState('conv-a')).toBeNull();
|
||||
});
|
||||
|
||||
it('saves and reads back the byte count', () => {
|
||||
ChatService.saveStreamState('conv-a', 4242);
|
||||
const got = ChatService.getStreamState('conv-a');
|
||||
expect(got).not.toBeNull();
|
||||
expect(got!.bytesReceived).toBe(4242);
|
||||
expect(typeof got!.updatedAt).toBe('number');
|
||||
});
|
||||
|
||||
it('overwrites the previous byte count on a new save for the same conversation', () => {
|
||||
ChatService.saveStreamState('conv-a', 100);
|
||||
ChatService.saveStreamState('conv-a', 200);
|
||||
const got = ChatService.getStreamState('conv-a');
|
||||
expect(got!.bytesReceived).toBe(200);
|
||||
});
|
||||
|
||||
it('keeps states for distinct conversations isolated', () => {
|
||||
ChatService.saveStreamState('conv-a', 10);
|
||||
ChatService.saveStreamState('conv-b', 20);
|
||||
expect(ChatService.getStreamState('conv-a')!.bytesReceived).toBe(10);
|
||||
expect(ChatService.getStreamState('conv-b')!.bytesReceived).toBe(20);
|
||||
});
|
||||
|
||||
it('clears the state for a given conversation', () => {
|
||||
ChatService.saveStreamState('conv-a', 10);
|
||||
ChatService.clearStreamState('conv-a');
|
||||
expect(ChatService.getStreamState('conv-a')).toBeNull();
|
||||
});
|
||||
|
||||
it('ignores empty conversation id on save', () => {
|
||||
ChatService.saveStreamState('', 1);
|
||||
expect(ChatService.getStreamState('')).toBeNull();
|
||||
});
|
||||
|
||||
it('returns null on corrupted storage payload', () => {
|
||||
localStorage.setItem(`${STREAM_RESUME_LOCALSTORAGE_KEY_PREFIX}conv-a`, '{not-json');
|
||||
expect(ChatService.getStreamState('conv-a')).toBeNull();
|
||||
});
|
||||
|
||||
it('persists the model alongside the byte count', () => {
|
||||
ChatService.saveStreamState('conv-a', 10, 'model-x');
|
||||
expect(ChatService.getStreamState('conv-a')!.model).toBe('model-x');
|
||||
});
|
||||
|
||||
it('stores a null model when none is provided', () => {
|
||||
ChatService.saveStreamState('conv-a', 10);
|
||||
expect(ChatService.getStreamState('conv-a')!.model).toBeNull();
|
||||
});
|
||||
|
||||
it('overwrites the model on a new save for the same conversation', () => {
|
||||
ChatService.saveStreamState('conv-a', 10, 'model-x');
|
||||
ChatService.saveStreamState('conv-a', 20, 'model-y');
|
||||
expect(ChatService.getStreamState('conv-a')!.model).toBe('model-y');
|
||||
});
|
||||
|
||||
describe('resumeStreamIdentity', () => {
|
||||
it('appends the persisted model so the resume key matches the frozen POST identity', () => {
|
||||
ChatService.saveStreamState('conv-a', 10, 'model-x');
|
||||
expect(
|
||||
ChatService.resumeStreamIdentity('conv-a', ChatService.getStreamState('conv-a'), 'dropdown')
|
||||
).toBe('conv-a::model-x');
|
||||
});
|
||||
|
||||
it('keeps the bare conv id when the persisted model is null', () => {
|
||||
ChatService.saveStreamState('conv-a', 10);
|
||||
expect(
|
||||
ChatService.resumeStreamIdentity('conv-a', ChatService.getStreamState('conv-a'), 'dropdown')
|
||||
).toBe('conv-a');
|
||||
});
|
||||
|
||||
it('falls back to the current model only when no state is persisted', () => {
|
||||
expect(ChatService.resumeStreamIdentity('conv-a', null, 'dropdown')).toBe('conv-a::dropdown');
|
||||
});
|
||||
|
||||
it('ignores the fallback when a state exists, the persisted value is authoritative', () => {
|
||||
ChatService.saveStreamState('conv-a', 10, 'model-x');
|
||||
expect(
|
||||
ChatService.resumeStreamIdentity('conv-a', ChatService.getStreamState('conv-a'), 'dropdown')
|
||||
).toBe('conv-a::model-x');
|
||||
});
|
||||
|
||||
it('falls back when a legacy state has no model field', () => {
|
||||
localStorage.setItem(
|
||||
`${STREAM_RESUME_LOCALSTORAGE_KEY_PREFIX}conv-a`,
|
||||
JSON.stringify({ bytesReceived: 10, updatedAt: 1 })
|
||||
);
|
||||
expect(
|
||||
ChatService.resumeStreamIdentity('conv-a', ChatService.getStreamState('conv-a'), 'dropdown')
|
||||
).toBe('conv-a::dropdown');
|
||||
});
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user