Refactor Activity Page (#710)
UI Tests / run-tests (push) Successful in 1m16s
Linux CI / run-tests (push) Successful in 3m59s
Close inactive issues / close-issues (push) Successful in 8s
Build Unified Docker Image / setup (push) Successful in 3s
Build Containers / build-and-push (cpu) (push) Failing after 13s
Build Containers / build-and-push (cuda) (push) Failing after 11s
Build Containers / build-and-push (cuda13) (push) Failing after 13s
Build Containers / build-and-push (intel) (push) Failing after 11s
Build Containers / build-and-push (musa) (push) Failing after 12s
Build Containers / build-and-push (rocm) (push) Failing after 12s
Build Containers / build-and-push (vulkan) (push) Failing after 11s
Build Containers / delete-untagged-containers (push) Has been skipped
Build Unified Docker Image / build (push) Failing after 10s
Windows CI / run-tests (push) Has been cancelled

- inference handles to store an activity record for all inference endpoints
- add path, status code, and content type to Activities page
- toggle on/off columns no Activities page 
- add configurable capture level for inference endpoints so large binary blobs are not stored in memory
- store captures in compressed binary format
This commit is contained in:
Benson Wong
2026-04-28 20:33:03 -07:00
committed by GitHub
parent a846c4f18c
commit fd3c28ffc5
16 changed files with 1397 additions and 651 deletions
+2
View File
@@ -4,6 +4,7 @@ go 1.26.1
require (
github.com/billziss-gh/golib v0.2.0
github.com/fxamacker/cbor/v2 v2.9.1
github.com/gin-gonic/gin v1.10.0
github.com/klauspost/compress v1.18.5
github.com/stretchr/testify v1.9.0
@@ -36,6 +37,7 @@ require (
github.com/tidwall/pretty v1.2.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/x448/float16 v0.8.4 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.45.0 // indirect
golang.org/x/net v0.47.0 // indirect
+4
View File
@@ -11,6 +11,8 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ=
github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
@@ -77,6 +79,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
+102
View File
@@ -0,0 +1,102 @@
package cache
import (
"errors"
"sync"
)
var (
ErrExceedsMaxSize = errors.New("item exceeds maximum cache size")
ErrNotFound = errors.New("item not found")
)
type Cache struct {
mu sync.Mutex
items map[int][]byte
order []int
size int
maxSize int
}
func New(maxBytes int) *Cache {
return &Cache{
items: make(map[int][]byte),
order: make([]int, 0),
maxSize: maxBytes,
}
}
func (c *Cache) Add(id int, data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
dataSize := len(data)
if dataSize > c.maxSize {
return ErrExceedsMaxSize
}
// If key already exists, remove old entry from size and order
if old, exists := c.items[id]; exists {
c.size -= len(old)
c.removeOrder(id)
}
// Evict oldest (FIFO) until room available
for c.size+dataSize > c.maxSize && len(c.order) > 0 {
oldestID := c.order[0]
c.order = c.order[1:]
if evicted, exists := c.items[oldestID]; exists {
c.size -= len(evicted)
delete(c.items, oldestID)
}
}
c.items[id] = data
c.order = append(c.order, id)
c.size += dataSize
return nil
}
func (c *Cache) removeOrder(id int) {
for i, v := range c.order {
if v == id {
c.order = append(c.order[:i], c.order[i+1:]...)
return
}
}
}
func (c *Cache) Get(id int) ([]byte, error) {
c.mu.Lock()
defer c.mu.Unlock()
data, exists := c.items[id]
if !exists {
return nil, ErrNotFound
}
return data, nil
}
func (c *Cache) Has(id int) bool {
c.mu.Lock()
defer c.mu.Unlock()
_, exists := c.items[id]
return exists
}
func (c *Cache) Size() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.size
}
func (c *Cache) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
c.items = make(map[int][]byte)
c.order = c.order[:0]
c.size = 0
}
+130
View File
@@ -0,0 +1,130 @@
package cache
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCache_Add(t *testing.T) {
t.Run("adds and retrieves item", func(t *testing.T) {
c := New(1024)
data := []byte("hello")
require.NoError(t, c.Add(1, data))
got, err := c.Get(1)
require.NoError(t, err)
assert.Equal(t, data, got)
})
t.Run("returns error for oversized item", func(t *testing.T) {
c := New(10)
err := c.Add(1, make([]byte, 20))
assert.ErrorIs(t, err, ErrExceedsMaxSize)
})
t.Run("evicts oldest items to make room", func(t *testing.T) {
c := New(100)
require.NoError(t, c.Add(1, make([]byte, 40)))
require.NoError(t, c.Add(2, make([]byte, 40)))
// Adding item 3 should evict item 1
require.NoError(t, c.Add(3, make([]byte, 40)))
assert.False(t, c.Has(1))
assert.True(t, c.Has(2))
assert.True(t, c.Has(3))
})
t.Run("overwrites existing key", func(t *testing.T) {
c := New(100)
require.NoError(t, c.Add(1, []byte("old")))
require.NoError(t, c.Add(1, []byte("new")))
got, err := c.Get(1)
require.NoError(t, err)
assert.Equal(t, []byte("new"), got)
assert.Equal(t, 3, c.Size())
})
}
func TestCache_Get(t *testing.T) {
t.Run("returns ErrNotFound for missing key", func(t *testing.T) {
c := New(100)
_, err := c.Get(99)
assert.ErrorIs(t, err, ErrNotFound)
})
}
func TestCache_Has(t *testing.T) {
t.Run("returns true for existing key", func(t *testing.T) {
c := New(100)
require.NoError(t, c.Add(1, []byte("data")))
assert.True(t, c.Has(1))
})
t.Run("returns false for missing key", func(t *testing.T) {
c := New(100)
assert.False(t, c.Has(1))
})
}
func TestCache_Size(t *testing.T) {
t.Run("tracks byte usage", func(t *testing.T) {
c := New(1000)
assert.Equal(t, 0, c.Size())
require.NoError(t, c.Add(1, make([]byte, 100)))
assert.Equal(t, 100, c.Size())
require.NoError(t, c.Add(2, make([]byte, 200)))
assert.Equal(t, 300, c.Size())
})
t.Run("updates on eviction", func(t *testing.T) {
c := New(150)
require.NoError(t, c.Add(1, make([]byte, 100)))
require.NoError(t, c.Add(2, make([]byte, 100)))
// Item 1 should be evicted, size = 100
assert.Equal(t, 100, c.Size())
})
}
func TestCache_Clear(t *testing.T) {
t.Run("removes all items and resets size", func(t *testing.T) {
c := New(1000)
require.NoError(t, c.Add(1, []byte("a")))
require.NoError(t, c.Add(2, []byte("b")))
c.Clear()
assert.Equal(t, 0, c.Size())
assert.False(t, c.Has(1))
assert.False(t, c.Has(2))
})
}
func TestCache_Concurrent(t *testing.T) {
t.Run("concurrent operations are safe", func(t *testing.T) {
c := New(10000)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < 100; j++ {
key := id*100 + j
_ = c.Add(key, []byte("data"))
_, _ = c.Get(key)
_ = c.Has(key)
_ = c.Size()
}
}(i)
}
wg.Wait()
})
}
+1 -1
View File
@@ -6,7 +6,7 @@ const ProcessStateChangeEventID = 0x01
const ChatCompletionStatsEventID = 0x02
const ConfigFileChangedEventID = 0x03
const LogDataEventID = 0x04
const TokenMetricsEventID = 0x05
const ActivityLogEventID = 0x05
const ModelPreloadedEventID = 0x06
const InFlightRequestsEventID = 0x07
+181 -130
View File
@@ -12,9 +12,11 @@ import (
"sync"
"time"
"github.com/fxamacker/cbor/v2"
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
"github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy/cache"
"github.com/tidwall/gjson"
)
@@ -42,37 +44,53 @@ var zstdDecPool = &sync.Pool{
},
}
// compressCapture marshals a ReqRespCapture to JSON and compresses it with zstd.
// Returns compressed bytes and the original JSON byte count for logging.
// compressCapture marshals a ReqRespCapture to CBOR and compresses it with zstd.
// Returns compressed bytes and the original CBOR byte count for logging.
func compressCapture(c *ReqRespCapture) ([]byte, int, error) {
jsonBytes, err := json.Marshal(c)
cborBytes, err := cbor.Marshal(c)
if err != nil {
return nil, 0, fmt.Errorf("marshal capture: %w", err)
}
enc := zstdEncPool.Get().(*zstd.Encoder)
defer zstdEncPool.Put(enc)
return enc.EncodeAll(jsonBytes, nil), len(jsonBytes), nil
zenc := zstdEncPool.Get().(*zstd.Encoder)
defer zstdEncPool.Put(zenc)
return zenc.EncodeAll(cborBytes, nil), len(cborBytes), nil
}
// decompressCapture decompresses zstd-compressed JSON and returns it.
func decompressCapture(data []byte) ([]byte, error) {
// decompressCapture decompresses zstd-compressed CBOR and unmarshals it into a ReqRespCapture.
func decompressCapture(data []byte) (*ReqRespCapture, error) {
dec := zstdDecPool.Get().(*zstd.Decoder)
defer zstdDecPool.Put(dec)
return dec.DecodeAll(data, nil)
cborBytes, err := dec.DecodeAll(data, nil)
if err != nil {
return nil, fmt.Errorf("decompress capture: %w", err)
}
var capture ReqRespCapture
if err := cbor.Unmarshal(cborBytes, &capture); err != nil {
return nil, fmt.Errorf("unmarshal capture: %w", err)
}
return &capture, nil
}
// TokenMetrics represents parsed token statistics from llama-server logs
// TokenMetrics holds token usage and performance metrics
type TokenMetrics struct {
ID int `json:"id"`
Timestamp time.Time `json:"timestamp"`
Model string `json:"model"`
CachedTokens int `json:"cache_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
PromptPerSecond float64 `json:"prompt_per_second"`
TokensPerSecond float64 `json:"tokens_per_second"`
DurationMs int `json:"duration_ms"`
HasCapture bool `json:"has_capture"`
CachedTokens int `json:"cache_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
PromptPerSecond float64 `json:"prompt_per_second"`
TokensPerSecond float64 `json:"tokens_per_second"`
}
// ActivityLogEntry represents parsed token statistics from llama-server logs
type ActivityLogEntry struct {
ID int `json:"id"`
Timestamp time.Time `json:"timestamp"`
Model string `json:"model"`
ReqPath string `json:"req_path"`
RespContentType string `json:"resp_content_type"`
RespStatusCode int `json:"resp_status_code"`
Tokens TokenMetrics `json:"tokens"`
DurationMs int `json:"duration_ms"`
HasCapture bool `json:"has_capture"`
}
type ReqRespCapture struct {
@@ -84,48 +102,45 @@ type ReqRespCapture struct {
RespBody []byte `json:"resp_body"`
}
// TokenMetricsEvent represents a token metrics event
type TokenMetricsEvent struct {
Metrics TokenMetrics
// ActivityLogEvent represents a token metrics event
type ActivityLogEvent struct {
Metrics ActivityLogEntry
}
func (e TokenMetricsEvent) Type() uint32 {
return TokenMetricsEventID // defined in events.go
func (e ActivityLogEvent) Type() uint32 {
return ActivityLogEventID // defined in events.go
}
// metricsMonitor parses llama-server output for token statistics
type metricsMonitor struct {
mu sync.RWMutex
metrics []TokenMetrics
metrics []ActivityLogEntry
maxMetrics int
nextID int
logger *LogMonitor
// capture fields
enableCaptures bool
captures map[int][]byte // zstd-compressed JSON of ReqRespCapture
captureOrder []int // track insertion order for FIFO eviction
captureSize int // current total compressed size in bytes
maxCaptureSize int // max bytes for captures (uncompressed)
captureCache *cache.Cache // zstd-compressed CBOR of ReqRespCapture
}
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
// capture buffer size in megabytes; 0 disables captures.
func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
return &metricsMonitor{
mm := &metricsMonitor{
logger: logger,
maxMetrics: maxMetrics,
enableCaptures: captureBufferMB > 0,
captures: make(map[int][]byte),
captureOrder: make([]int, 0),
captureSize: 0,
maxCaptureSize: captureBufferMB * 1024 * 1024,
}
if captureBufferMB > 0 {
mm.captureCache = cache.New(captureBufferMB * 1024 * 1024)
}
return mm
}
// addMetrics adds a new metric to the collection and publishes an event.
// Returns the assigned metric ID.
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int {
// queueMetrics adds a new metric to the collection without emitting an event.
// Returns the assigned metric ID. Call emitMetric after capture setup.
func (mp *metricsMonitor) queueMetrics(metric ActivityLogEntry) int {
mp.mu.Lock()
defer mp.mu.Unlock()
@@ -135,93 +150,75 @@ func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int {
if len(mp.metrics) > mp.maxMetrics {
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
}
event.Emit(TokenMetricsEvent{Metrics: metric})
return metric.ID
}
// addCapture adds a new capture to the buffer with size-based eviction.
// Captures are skipped if enableCaptures is false or if compressed data exceeds maxCaptureSize.
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) {
// emitMetric publishes an ActivityLogEvent for the given metric.
func (mp *metricsMonitor) emitMetric(metric ActivityLogEntry) {
event.Emit(ActivityLogEvent{Metrics: metric})
}
// addCapture compresses and stores a capture in the cache.
// Returns true if the capture was stored, false otherwise.
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) bool {
if !mp.enableCaptures {
return
return false
}
compressed, uncompressedBytes, err := compressCapture(&capture)
if err != nil {
mp.logger.Warnf("failed to compress capture: %v, skipping", err)
return
return false
}
captureSize := len(compressed)
if captureSize > mp.maxCaptureSize {
mp.logger.Warnf("compressed capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize)
return
if err := mp.captureCache.Add(capture.ID, compressed); err != nil {
mp.logger.Warnf("capture %d too large (%d bytes), skipping: %v", capture.ID, len(compressed), err)
return false
}
compressionRatio := (1 - float64(captureSize)/float64(uncompressedBytes)) * 100
mp.mu.Lock()
defer mp.mu.Unlock()
// Evict oldest (FIFO) until room available for the compressed data
for mp.captureSize+captureSize > mp.maxCaptureSize && len(mp.captureOrder) > 0 {
oldestID := mp.captureOrder[0]
mp.captureOrder = mp.captureOrder[1:]
if evicted, exists := mp.captures[oldestID]; exists {
l := len(evicted)
mp.captureSize -= l
delete(mp.captures, oldestID)
mp.logger.Debugf("Capture %d evicted to make space: %d bytes", oldestID, l)
}
}
mp.captures[capture.ID] = compressed
mp.captureOrder = append(mp.captureOrder, capture.ID)
mp.captureSize += captureSize
compressionRatio := (1 - float64(len(compressed))/float64(uncompressedBytes)) * 100
mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio)
return true
}
// getCompressedBytes returns the raw compressed bytes for a capture by ID.
func (mp *metricsMonitor) getCompressedBytes(id int) ([]byte, bool) {
mp.mu.RLock()
defer mp.mu.RUnlock()
data, exists := mp.captures[id]
return data, exists
if mp.captureCache == nil {
return nil, false
}
data, err := mp.captureCache.Get(id)
if err != nil {
return nil, false
}
return data, true
}
// getCaptureByID returns decompressed capture bytes if found and decompress=true.
// If decompress=false, returns the raw zstd-compressed bytes.
// Returns nil if the capture is not found.
func (mp *metricsMonitor) getCaptureByID(id int, decompress bool) []byte {
mp.mu.RLock()
defer mp.mu.RUnlock()
data, exists := mp.captures[id]
// getCaptureByID decompresses and unmarshals a capture by ID.
// Returns nil if the capture is not found or decompression fails.
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture {
if mp.captureCache == nil {
return nil
}
data, exists := mp.getCompressedBytes(id)
if !exists {
return nil
}
if !decompress {
return data
}
decompressed, err := decompressCapture(data)
capture, err := decompressCapture(data)
if err != nil {
mp.logger.Warnf("failed to decompress capture %d: %v", id, err)
return nil
}
return decompressed
return capture
}
// getMetrics returns a copy of the current metrics
func (mp *metricsMonitor) getMetrics() []TokenMetrics {
func (mp *metricsMonitor) getMetrics() []ActivityLogEntry {
mp.mu.RLock()
defer mp.mu.RUnlock()
result := make([]TokenMetrics, len(mp.metrics))
result := make([]ActivityLogEntry, len(mp.metrics))
copy(result, mp.metrics)
return result
}
@@ -230,22 +227,52 @@ func (mp *metricsMonitor) getMetrics() []TokenMetrics {
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
mp.mu.RLock()
defer mp.mu.RUnlock()
return json.Marshal(mp.metrics)
if mp.captureCache == nil {
return json.Marshal(mp.metrics)
}
// Make a copy with up-to-date has_capture from cache
result := make([]ActivityLogEntry, len(mp.metrics))
for i, m := range mp.metrics {
m.HasCapture = mp.captureCache.Has(m.ID)
result[i] = m
}
return json.Marshal(result)
}
// wrapHandler wraps the proxy handler to extract token metrics
// Capture field flags for controlling what is saved in ReqRespCapture.
type captureFields uint
const (
captureNone captureFields = 1 << iota
captureReqHeaders
captureReqBody
captureRespHeaders
captureRespBody
)
const (
captureReqAll = captureReqHeaders | captureReqBody
captureRespAll = captureRespHeaders | captureRespBody
captureAll = captureReqAll | captureRespAll
)
// wrapHandler wraps the proxy handler to extract token metrics.
// captureFields controls what is saved in the ReqRespCapture using bitwise flags.
// if wrapHandler returns an error it is safe to assume that no
// data was sent to the client
func (mp *metricsMonitor) wrapHandler(
modelID string,
writer gin.ResponseWriter,
request *http.Request,
captureFields captureFields,
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
) error {
// Capture request body and headers if captures enabled
var reqBody []byte
var reqHeaders map[string]string
if mp.enableCaptures {
if mp.enableCaptures && (captureFields&captureReqBody) != 0 {
if request.Body != nil {
var err error
reqBody, err = io.ReadAll(request.Body)
@@ -255,6 +282,8 @@ func (mp *metricsMonitor) wrapHandler(
request.Body.Close()
request.Body = io.NopCloser(bytes.NewBuffer(reqBody))
}
}
if mp.enableCaptures && (captureFields&captureReqHeaders) != 0 {
reqHeaders = make(map[string]string)
for key, values := range request.Header {
if len(values) > 0 {
@@ -278,22 +307,28 @@ func (mp *metricsMonitor) wrapHandler(
// after this point we have to assume that data was sent to the client
// and we can only log errors but not send them to clients
if recorder.Status() != http.StatusOK {
mp.logger.Warnf("metrics skipped, HTTP status=%d, path=%s", recorder.Status(), request.URL.Path)
return nil
// Initialize default metrics - recorded for every request
tm := ActivityLogEntry{
Timestamp: time.Now(),
Model: modelID,
ReqPath: request.URL.Path,
RespContentType: recorder.Header().Get("Content-Type"),
RespStatusCode: recorder.Status(),
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
}
// Initialize default metrics - these will always be recorded
tm := TokenMetrics{
Timestamp: time.Now(),
Model: modelID,
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
if recorder.Status() != http.StatusOK {
mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), request.URL.Path)
tm.ID = mp.queueMetrics(tm)
mp.emitMetric(tm)
return nil
}
body := recorder.body.Bytes()
if len(body) == 0 {
mp.logger.Warn("metrics: empty body, recording minimal metrics")
mp.addMetrics(tm)
tm.ID = mp.queueMetrics(tm)
mp.emitMetric(tm)
return nil
}
@@ -303,7 +338,8 @@ func (mp *metricsMonitor) wrapHandler(
body, err = decompressBody(body, encoding)
if err != nil {
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
mp.addMetrics(tm)
tm.ID = mp.queueMetrics(tm)
mp.emitMetric(tm)
return nil
}
}
@@ -311,7 +347,8 @@ func (mp *metricsMonitor) wrapHandler(
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
} else {
tm = parsed
tm.Tokens = parsed.Tokens
tm.DurationMs = parsed.DurationMs
}
} else {
if gjson.ValidBytes(body) {
@@ -331,7 +368,8 @@ func (mp *metricsMonitor) wrapHandler(
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
} else {
tm = parsedMetrics
tm.Tokens = parsedMetrics.Tokens
tm.DurationMs = parsedMetrics.DurationMs
}
}
} else {
@@ -342,39 +380,50 @@ func (mp *metricsMonitor) wrapHandler(
// Build capture if enabled and determine if it will be stored
var capture *ReqRespCapture
if mp.enableCaptures {
respHeaders := make(map[string]string)
for key, values := range recorder.Header() {
if len(values) > 0 {
respHeaders[key] = values[0]
var respHeaders map[string]string
var respBody []byte
if (captureFields & captureRespHeaders) != 0 {
respHeaders = make(map[string]string)
for key, values := range recorder.Header() {
if len(values) > 0 {
respHeaders[key] = values[0]
}
}
redactHeaders(respHeaders)
delete(respHeaders, "Content-Encoding")
}
if (captureFields & captureRespBody) != 0 {
respBody = body
}
redactHeaders(respHeaders)
delete(respHeaders, "Content-Encoding")
capture = &ReqRespCapture{
ReqPath: request.URL.Path,
ReqHeaders: reqHeaders,
ReqBody: reqBody,
RespHeaders: respHeaders,
RespBody: body,
}
compressed, _, err := compressCapture(capture)
if err == nil && len(compressed) <= mp.maxCaptureSize {
tm.HasCapture = true
RespBody: respBody,
}
}
metricID := mp.addMetrics(tm)
metricID := mp.queueMetrics(tm)
tm.ID = metricID
// Store capture if enabled
if capture != nil {
capture.ID = metricID
mp.addCapture(*capture)
if mp.addCapture(*capture) {
tm.HasCapture = true
mp.mu.Lock()
mp.metrics[len(mp.metrics)-1].HasCapture = true
mp.mu.Unlock()
}
}
mp.emitMetric(tm)
return nil
}
func processStreamingResponse(modelID string, start time.Time, body []byte) (TokenMetrics, error) {
func processStreamingResponse(modelID string, start time.Time, body []byte) (ActivityLogEntry, error) {
// Iterate **backwards** through the body looking for the data payload with
// usage data. This avoids allocating a slice of all lines via bytes.Split.
@@ -428,10 +477,10 @@ func processStreamingResponse(modelID string, start time.Time, body []byte) (Tok
}
}
return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream")
return ActivityLogEntry{}, fmt.Errorf("no valid JSON data found in stream")
}
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (TokenMetrics, error) {
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (ActivityLogEntry, error) {
wallDurationMs := int(time.Since(start).Milliseconds())
// default values
@@ -481,15 +530,17 @@ func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result)
}
}
return TokenMetrics{
Timestamp: time.Now(),
Model: modelID,
CachedTokens: cachedTokens,
InputTokens: inputTokens,
OutputTokens: outputTokens,
PromptPerSecond: promptPerSecond,
TokensPerSecond: tokensPerSecond,
DurationMs: durationMs,
return ActivityLogEntry{
Timestamp: time.Now(),
Model: modelID,
Tokens: TokenMetrics{
CachedTokens: cachedTokens,
InputTokens: inputTokens,
OutputTokens: outputTokens,
PromptPerSecond: promptPerSecond,
TokensPerSecond: tokensPerSecond,
},
DurationMs: durationMs,
}, nil
}
+343 -170
View File
@@ -12,8 +12,10 @@ import (
"testing"
"time"
"github.com/fxamacker/cbor/v2"
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy/cache"
"github.com/stretchr/testify/assert"
"github.com/tidwall/gjson"
)
@@ -22,27 +24,29 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
t.Run("adds metrics and assigns ID", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)
metric := TokenMetrics{
Model: "test-model",
InputTokens: 100,
OutputTokens: 50,
metric := ActivityLogEntry{
Model: "test-model",
Tokens: TokenMetrics{
InputTokens: 100,
OutputTokens: 50,
},
}
mm.addMetrics(metric)
mm.queueMetrics(metric)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, 0, metrics[0].ID)
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
})
t.Run("increments ID for each metric", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)
for i := 0; i < 5; i++ {
mm.addMetrics(TokenMetrics{Model: "model"})
mm.queueMetrics(ActivityLogEntry{Model: "model"})
}
metrics := mm.getMetrics()
@@ -57,9 +61,11 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
// Add 5 metrics
for i := 0; i < 5; i++ {
mm.addMetrics(TokenMetrics{
Model: "model",
InputTokens: i,
mm.queueMetrics(ActivityLogEntry{
Model: "model",
Tokens: TokenMetrics{
InputTokens: i,
},
})
}
@@ -72,29 +78,32 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
assert.Equal(t, 4, metrics[2].ID)
})
t.Run("emits TokenMetricsEvent", func(t *testing.T) {
t.Run("emits ActivityLogEvent", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)
receivedEvent := make(chan TokenMetricsEvent, 1)
cancel := event.On(func(e TokenMetricsEvent) {
receivedEvent := make(chan ActivityLogEvent, 1)
cancel := event.On(func(e ActivityLogEvent) {
receivedEvent <- e
})
defer cancel()
metric := TokenMetrics{
Model: "test-model",
InputTokens: 100,
OutputTokens: 50,
metric := ActivityLogEntry{
Model: "test-model",
Tokens: TokenMetrics{
InputTokens: 100,
OutputTokens: 50,
},
}
mm.addMetrics(metric)
mm.queueMetrics(metric)
mm.emitMetric(metric)
select {
case evt := <-receivedEvent:
assert.Equal(t, 0, evt.Metrics.ID)
assert.Equal(t, "test-model", evt.Metrics.Model)
assert.Equal(t, 100, evt.Metrics.InputTokens)
assert.Equal(t, 50, evt.Metrics.OutputTokens)
assert.Equal(t, 100, evt.Metrics.Tokens.InputTokens)
assert.Equal(t, 50, evt.Metrics.Tokens.OutputTokens)
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for event")
}
@@ -111,8 +120,8 @@ func TestMetricsMonitor_GetMetrics(t *testing.T) {
t.Run("returns copy of metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)
mm.addMetrics(TokenMetrics{Model: "model1"})
mm.addMetrics(TokenMetrics{Model: "model2"})
mm.queueMetrics(ActivityLogEntry{Model: "model1"})
mm.queueMetrics(ActivityLogEntry{Model: "model2"})
metrics1 := mm.getMetrics()
metrics2 := mm.getMetrics()
@@ -135,7 +144,7 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, jsonData)
var metrics []TokenMetrics
var metrics []ActivityLogEntry
err = json.Unmarshal(jsonData, &metrics)
assert.NoError(t, err)
assert.Equal(t, 0, len(metrics))
@@ -143,23 +152,27 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
t.Run("returns valid JSON with metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)
mm.addMetrics(TokenMetrics{
Model: "model1",
InputTokens: 100,
OutputTokens: 50,
TokensPerSecond: 25.5,
mm.queueMetrics(ActivityLogEntry{
Model: "model1",
Tokens: TokenMetrics{
InputTokens: 100,
OutputTokens: 50,
TokensPerSecond: 25.5,
},
})
mm.addMetrics(TokenMetrics{
Model: "model2",
InputTokens: 200,
OutputTokens: 100,
TokensPerSecond: 30.0,
mm.queueMetrics(ActivityLogEntry{
Model: "model2",
Tokens: TokenMetrics{
InputTokens: 200,
OutputTokens: 100,
TokensPerSecond: 30.0,
},
})
jsonData, err := mm.getMetricsJSON()
assert.NoError(t, err)
var metrics []TokenMetrics
var metrics []ActivityLogEntry
err = json.Unmarshal(jsonData, &metrics)
assert.NoError(t, err)
assert.Equal(t, 2, len(metrics))
@@ -190,14 +203,14 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
})
t.Run("successful request with timings data", func(t *testing.T) {
@@ -226,17 +239,17 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
assert.Equal(t, 20, metrics[0].CachedTokens)
assert.Equal(t, 150.5, metrics[0].PromptPerSecond)
assert.Equal(t, 25.5, metrics[0].TokensPerSecond)
assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
assert.Equal(t, 20, metrics[0].Tokens.CachedTokens)
assert.Equal(t, 150.5, metrics[0].Tokens.PromptPerSecond)
assert.Equal(t, 25.5, metrics[0].Tokens.TokensPerSecond)
assert.Equal(t, 2000, metrics[0].DurationMs) // 500 + 1500
})
@@ -265,18 +278,18 @@ data: [DONE]
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
// When timings data is present, it takes precedence
assert.Equal(t, 10, metrics[0].InputTokens)
assert.Equal(t, 20, metrics[0].OutputTokens)
assert.Equal(t, 10, metrics[0].Tokens.InputTokens)
assert.Equal(t, 20, metrics[0].Tokens.OutputTokens)
})
t.Run("non-OK status code does not record metrics", func(t *testing.T) {
t.Run("non-OK status code records partial metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
@@ -289,11 +302,16 @@ data: [DONE]
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, "/test", metrics[0].ReqPath)
assert.Equal(t, http.StatusBadRequest, metrics[0].RespStatusCode)
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
})
t.Run("empty response body records minimal metrics", func(t *testing.T) {
@@ -308,14 +326,14 @@ data: [DONE]
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
})
t.Run("invalid JSON records minimal metrics", func(t *testing.T) {
@@ -332,14 +350,14 @@ data: [DONE]
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) // Errors after response is sent are logged, not returned
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
})
t.Run("next handler error is propagated", func(t *testing.T) {
@@ -354,7 +372,7 @@ data: [DONE]
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.Equal(t, expectedErr, err)
metrics := mm.getMetrics()
@@ -377,14 +395,14 @@ data: [DONE]
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
})
t.Run("infill request extracts timings from last array element", func(t *testing.T) {
@@ -416,17 +434,17 @@ data: [DONE]
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 150, metrics[0].InputTokens)
assert.Equal(t, 75, metrics[0].OutputTokens)
assert.Equal(t, 30, metrics[0].CachedTokens)
assert.Equal(t, 200.5, metrics[0].PromptPerSecond)
assert.Equal(t, 35.5, metrics[0].TokensPerSecond)
assert.Equal(t, 150, metrics[0].Tokens.InputTokens)
assert.Equal(t, 75, metrics[0].Tokens.OutputTokens)
assert.Equal(t, 30, metrics[0].Tokens.CachedTokens)
assert.Equal(t, 200.5, metrics[0].Tokens.PromptPerSecond)
assert.Equal(t, 35.5, metrics[0].Tokens.TokensPerSecond)
assert.Equal(t, 2400, metrics[0].DurationMs) // 600 + 1800
})
@@ -446,14 +464,14 @@ data: [DONE]
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
})
}
@@ -507,7 +525,7 @@ func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) {
}
func TestMetricsMonitor_Concurrent(t *testing.T) {
t.Run("concurrent addMetrics is safe", func(t *testing.T) {
t.Run("concurrent queueMetrics is safe", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 1000, 0)
var wg sync.WaitGroup
@@ -519,10 +537,12 @@ func TestMetricsMonitor_Concurrent(t *testing.T) {
go func(id int) {
defer wg.Done()
for j := 0; j < metricsPerGoroutine; j++ {
mm.addMetrics(TokenMetrics{
Model: "test-model",
InputTokens: id*1000 + j,
OutputTokens: j,
mm.queueMetrics(ActivityLogEntry{
Model: "test-model",
Tokens: TokenMetrics{
InputTokens: id*1000 + j,
OutputTokens: j,
},
})
}
}(i)
@@ -542,7 +562,7 @@ func TestMetricsMonitor_Concurrent(t *testing.T) {
// Writer goroutine
go func() {
for i := 0; i < 50; i++ {
mm.addMetrics(TokenMetrics{Model: "test-model"})
mm.queueMetrics(ActivityLogEntry{Model: "test-model"})
time.Sleep(1 * time.Millisecond)
}
done <- true
@@ -586,10 +606,10 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
metrics, err := parseMetrics("test-model", start, usage, timings)
assert.NoError(t, err)
assert.Equal(t, 5, metrics.InputTokens)
assert.Equal(t, 1, metrics.OutputTokens)
assert.Equal(t, 10.0, metrics.PromptPerSecond)
assert.Equal(t, 2.0, metrics.TokensPerSecond)
assert.Equal(t, 5, metrics.Tokens.InputTokens)
assert.Equal(t, 1, metrics.Tokens.OutputTokens)
assert.Equal(t, 10.0, metrics.Tokens.PromptPerSecond)
assert.Equal(t, 2.0, metrics.Tokens.TokensPerSecond)
assert.GreaterOrEqual(t, metrics.DurationMs, 5000)
})
@@ -623,14 +643,14 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
// Should use timings values, not usage values
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
})
t.Run("handles missing cache_n in timings", func(t *testing.T) {
@@ -658,12 +678,12 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, -1, metrics[0].CachedTokens) // Default value when not present
assert.Equal(t, -1, metrics[0].Tokens.CachedTokens) // Default value when not present
})
}
@@ -693,13 +713,13 @@ data: [DONE]
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
})
t.Run("handles streaming with no valid JSON records minimal metrics", func(t *testing.T) {
@@ -722,14 +742,14 @@ data: [DONE]
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
})
t.Run("v1/responses format with nested response.usage", func(t *testing.T) {
@@ -751,14 +771,14 @@ data: [DONE]
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 17, metrics[0].InputTokens)
assert.Equal(t, 23, metrics[0].OutputTokens)
assert.Equal(t, 17, metrics[0].Tokens.InputTokens)
assert.Equal(t, 23, metrics[0].Tokens.OutputTokens)
})
t.Run("handles empty streaming response records minimal metrics", func(t *testing.T) {
@@ -777,14 +797,14 @@ data: [DONE]
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
})
}
@@ -792,20 +812,22 @@ data: [DONE]
func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
mm := newMetricsMonitor(testLogger, 1000, 0)
metric := TokenMetrics{
Model: "test-model",
CachedTokens: 100,
InputTokens: 500,
OutputTokens: 250,
PromptPerSecond: 1200.5,
TokensPerSecond: 45.8,
DurationMs: 5000,
Timestamp: time.Now(),
metric := ActivityLogEntry{
Model: "test-model",
Tokens: TokenMetrics{
CachedTokens: 100,
InputTokens: 500,
OutputTokens: 250,
PromptPerSecond: 1200.5,
TokensPerSecond: 45.8,
},
DurationMs: 5000,
Timestamp: time.Now(),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
mm.addMetrics(metric)
mm.queueMetrics(metric)
}
}
@@ -813,20 +835,22 @@ func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
// Test performance with a smaller buffer where wrapping occurs more frequently
mm := newMetricsMonitor(testLogger, 100, 0)
metric := TokenMetrics{
Model: "test-model",
CachedTokens: 100,
InputTokens: 500,
OutputTokens: 250,
PromptPerSecond: 1200.5,
TokensPerSecond: 45.8,
DurationMs: 5000,
Timestamp: time.Now(),
metric := ActivityLogEntry{
Model: "test-model",
Tokens: TokenMetrics{
CachedTokens: 100,
InputTokens: 500,
OutputTokens: 250,
PromptPerSecond: 1200.5,
TokensPerSecond: 45.8,
},
DurationMs: 5000,
Timestamp: time.Now(),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
mm.addMetrics(metric)
mm.queueMetrics(metric)
}
}
@@ -855,14 +879,14 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
})
t.Run("deflate encoded response", func(t *testing.T) {
@@ -889,14 +913,14 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 200, metrics[0].InputTokens)
assert.Equal(t, 75, metrics[0].OutputTokens)
assert.Equal(t, 200, metrics[0].Tokens.InputTokens)
assert.Equal(t, 75, metrics[0].Tokens.OutputTokens)
})
t.Run("invalid gzip data records minimal metrics", func(t *testing.T) {
@@ -917,14 +941,14 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) // Should not return error, just log warning
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
})
t.Run("unknown encoding treated as uncompressed", func(t *testing.T) {
@@ -944,13 +968,13 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, 300, metrics[0].InputTokens)
assert.Equal(t, 100, metrics[0].OutputTokens)
assert.Equal(t, 300, metrics[0].Tokens.InputTokens)
assert.Equal(t, 100, metrics[0].Tokens.OutputTokens)
})
}
@@ -989,7 +1013,7 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
mm.addCapture(capture)
// Should not store capture
assert.Nil(t, mm.getCaptureByID(0, false))
assert.Nil(t, mm.getCaptureByID(0))
})
t.Run("adds capture when enabled", func(t *testing.T) {
@@ -1002,22 +1026,18 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
}
mm.addCapture(capture)
retrieved := mm.getCaptureByID(0, true)
assert.NotNil(t, retrieved)
var decoded ReqRespCapture
err := json.Unmarshal(retrieved, &decoded)
assert.NoError(t, err)
assert.Equal(t, 0, decoded.ID)
assert.Equal(t, []byte("test request"), decoded.ReqBody)
assert.Equal(t, []byte("test response"), decoded.RespBody)
captured := mm.getCaptureByID(0)
assert.NotNil(t, captured)
assert.Equal(t, 0, captured.ID)
assert.Equal(t, []byte("test request"), captured.ReqBody)
assert.Equal(t, []byte("test response"), captured.RespBody)
})
t.Run("evicts oldest when exceeding max size", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5)
// Each full ReqRespCapture with 80 bytes random data compresses to ~185 bytes.
// 2 captures = ~370 bytes, 3 captures = ~555 bytes. Set limit so only 2 fit.
mm.maxCaptureSize = 450
mm.captureCache = cache.New(450)
// Use random-looking data that doesn't compress well with zstd
rng := rand.New(rand.NewSource(42))
@@ -1033,16 +1053,14 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
// Adding capture3 should evict capture1
mm.addCapture(capture3)
assert.Nil(t, mm.getCaptureByID(0, true), "capture 0 should be evicted")
retrieved := mm.getCaptureByID(1, true)
assert.NotNil(t, retrieved, "capture 1 should exist")
retrieved = mm.getCaptureByID(2, true)
assert.NotNil(t, retrieved, "capture 2 should exist")
assert.Nil(t, mm.getCaptureByID(0), "capture 0 should be evicted")
assert.NotNil(t, mm.getCaptureByID(1), "capture 1 should exist")
assert.NotNil(t, mm.getCaptureByID(2), "capture 2 should exist")
})
t.Run("skips capture larger than max size", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5)
mm.maxCaptureSize = 100
mm.captureCache = cache.New(100)
// Use random data that doesn't compress well to create an oversized capture
rng := rand.New(rand.NewSource(99))
@@ -1050,7 +1068,7 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
rng.Read(largeCapture.ReqBody)
mm.addCapture(largeCapture)
assert.Nil(t, mm.getCaptureByID(0, false), "oversized capture should not be stored")
assert.Nil(t, mm.getCaptureByID(0), "oversized capture should not be stored")
})
}
@@ -1058,7 +1076,7 @@ func TestMetricsMonitor_GetCaptureByID(t *testing.T) {
t.Run("returns nil for non-existent ID", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5)
assert.Nil(t, mm.getCaptureByID(999, false))
assert.Nil(t, mm.getCaptureByID(999))
})
t.Run("returns decompressed capture by ID", func(t *testing.T) {
@@ -1071,18 +1089,14 @@ func TestMetricsMonitor_GetCaptureByID(t *testing.T) {
}
mm.addCapture(capture)
retrieved := mm.getCaptureByID(42, true)
assert.NotNil(t, retrieved)
var decoded ReqRespCapture
err := json.Unmarshal(retrieved, &decoded)
assert.NoError(t, err)
assert.Equal(t, 42, decoded.ID)
assert.Equal(t, []byte("test request"), decoded.ReqBody)
assert.Equal(t, []byte("test response"), decoded.RespBody)
captured := mm.getCaptureByID(42)
assert.NotNil(t, captured)
assert.Equal(t, 42, captured.ID)
assert.Equal(t, []byte("test request"), captured.ReqBody)
assert.Equal(t, []byte("test response"), captured.RespBody)
})
t.Run("returns compressed bytes when decompress=false", func(t *testing.T) {
t.Run("stores data as compressed bytes", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5)
capture := ReqRespCapture{
@@ -1092,10 +1106,12 @@ func TestMetricsMonitor_GetCaptureByID(t *testing.T) {
}
mm.addCapture(capture)
compressed := mm.getCaptureByID(42, false)
compressed, exists := mm.getCompressedBytes(42)
assert.True(t, exists)
assert.NotNil(t, compressed)
// Compressed data should not be valid JSON (it's zstd-compressed)
assert.False(t, gjson.ValidBytes(compressed))
// Compressed data should not be valid CBOR (it's zstd-compressed)
var decoded ReqRespCapture
assert.Error(t, cbor.Unmarshal(compressed, &decoded))
})
}
@@ -1164,7 +1180,7 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err)
// Check metric was recorded
@@ -1173,12 +1189,8 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
metricID := metrics[0].ID
// Check capture was stored with same ID (decompressed)
captureData := mm.getCaptureByID(metricID, true)
assert.NotNil(t, captureData)
var capture ReqRespCapture
err = json.Unmarshal(captureData, &capture)
assert.NoError(t, err)
capture := mm.getCaptureByID(metricID)
assert.NotNil(t, capture)
assert.Equal(t, metricID, capture.ID)
assert.Equal(t, []byte(requestBody), capture.ReqBody)
assert.Equal(t, []byte(responseBody), capture.RespBody)
@@ -1206,7 +1218,7 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err)
// Metrics should still be recorded
@@ -1214,7 +1226,168 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
assert.Equal(t, 1, len(metrics))
// But no capture
capture := mm.getCaptureByID(metrics[0].ID, false)
assert.Nil(t, capture)
assert.Nil(t, mm.getCaptureByID(metrics[0].ID))
})
}
func TestMetricsMonitor_WrapHandler_PartialCaptures(t *testing.T) {
requestBody := `{"model": "test"}`
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Custom", "header-value")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
t.Run("only request headers", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer secret")
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqHeaders, nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"])
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
assert.Nil(t, capture.ReqBody)
assert.Nil(t, capture.RespHeaders)
assert.Nil(t, capture.RespBody)
})
t.Run("only request body", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqBody, nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Nil(t, capture.ReqHeaders)
assert.Equal(t, []byte(requestBody), capture.ReqBody)
assert.Nil(t, capture.RespHeaders)
assert.Nil(t, capture.RespBody)
})
t.Run("only response headers", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureRespHeaders, nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Nil(t, capture.ReqHeaders)
assert.Nil(t, capture.ReqBody)
assert.Equal(t, "application/json", capture.RespHeaders["Content-Type"])
assert.Equal(t, "header-value", capture.RespHeaders["X-Custom"])
assert.Nil(t, capture.RespBody)
})
t.Run("only response body", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureRespBody, nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Nil(t, capture.ReqHeaders)
assert.Nil(t, capture.ReqBody)
assert.Nil(t, capture.RespHeaders)
assert.Equal(t, []byte(responseBody), capture.RespBody)
})
t.Run("captureReqAll", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer secret")
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqAll, nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"])
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
assert.Equal(t, []byte(requestBody), capture.ReqBody)
assert.Nil(t, capture.RespHeaders)
assert.Nil(t, capture.RespBody)
})
t.Run("captureRespAll", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureRespAll, nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Nil(t, capture.ReqHeaders)
assert.Nil(t, capture.ReqBody)
assert.Equal(t, "application/json", capture.RespHeaders["Content-Type"])
assert.Equal(t, "header-value", capture.RespHeaders["X-Custom"])
assert.Equal(t, []byte(responseBody), capture.RespBody)
})
t.Run("no flags", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureFields(0), nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Nil(t, capture.ReqHeaders)
assert.Nil(t, capture.ReqBody)
assert.Nil(t, capture.RespHeaders)
assert.Nil(t, capture.RespBody)
})
t.Run("mixed flags req headers and resp body", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer secret")
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqHeaders|captureRespBody, nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"])
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
assert.Nil(t, capture.ReqBody)
assert.Nil(t, capture.RespHeaders)
assert.Equal(t, []byte(responseBody), capture.RespBody)
})
}
+324 -274
View File
@@ -332,41 +332,77 @@ func (pm *ProxyManager) setupGinEngine() {
// Set up routes using the Gin engine
// Protected routes use pm.apiKeyAuth() middleware
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
llmHandler := pm.mkProxyJSONHandler(captureAll)
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// Support legacy /v1/completions api, see issue #12
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// Support anthropic count_tokens API (Also added in the above PR)
pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// Support embeddings and reranking
pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// llama-server's /reranking endpoint + aliases
pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// llama-server's /infill endpoint for code infilling
pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// llama-server's /completion endpoint
pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST(
"/v1/audio/speech",
pm.apiKeyAuth(),
pm.trackInflight(),
pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders),
)
pm.ginEngine.POST(
"/v1/audio/voices",
pm.apiKeyAuth(),
pm.trackInflight(),
pm.mkProxyJSONHandler(captureReqHeaders|captureRespAll),
)
pm.ginEngine.GET("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler)
pm.ginEngine.POST("/v1/images/generations", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/images/edits", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler)
pm.ginEngine.POST(
"/v1/audio/transcriptions",
pm.apiKeyAuth(),
pm.trackInflight(),
pm.mkPostFormHandler(captureReqHeaders|captureRespHeaders|captureRespBody),
)
pm.ginEngine.POST(
"/v1/images/generations",
pm.apiKeyAuth(),
pm.trackInflight(),
pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders),
)
pm.ginEngine.POST(
"/v1/images/edits",
pm.apiKeyAuth(),
pm.trackInflight(),
pm.mkPostFormHandler(captureReqHeaders|captureRespHeaders),
)
// sd.cpp /sdapi/v1 endpoints
pm.ginEngine.POST("/sdapi/v1/txt2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/sdapi/v1/img2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/sdapi/v1/txt2img",
pm.apiKeyAuth(),
pm.trackInflight(),
pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders),
)
pm.ginEngine.POST("/sdapi/v1/img2img",
pm.apiKeyAuth(),
pm.trackInflight(),
pm.mkProxyJSONHandler(captureReqHeaders|captureRespHeaders),
)
pm.ginEngine.GET("/sdapi/v1/loras", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler)
@@ -686,7 +722,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
// attempt to record metrics if it is a POST request
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, handler); err != nil {
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, captureNone, handler); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
return
@@ -700,280 +736,294 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
}
}
func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
return
}
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
return
}
// Look for a matching local model first
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
modelID, found := pm.config.RealModelName(requestedModel)
if found {
var localHandler func(string, http.ResponseWriter, *http.Request) error
if pm.matrix != nil {
localHandler = pm.matrix.ProxyRequest
} else {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
localHandler = processGroup.ProxyRequest
func (pm *ProxyManager) mkProxyJSONHandler(cf captureFields) func(*gin.Context) {
return func(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
return
}
// issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[modelID].UseModelName
if useModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
return
}
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
return
}
// issue #174 strip parameters from the JSON body
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams()
if err != nil { // just log it and continue
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error())
} else {
// Look for a matching local model first
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
modelID, found := pm.config.RealModelName(requestedModel)
if found {
var localHandler func(string, http.ResponseWriter, *http.Request) error
if pm.matrix != nil {
localHandler = pm.matrix.ProxyRequest
} else {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
localHandler = processGroup.ProxyRequest
}
// issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[modelID].UseModelName
if useModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
return
}
}
// issue #174 strip parameters from the JSON body
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams()
if err != nil { // just log it and continue
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error())
} else {
for _, param := range stripParams {
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param)
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
return
}
}
}
// issue #453 set/override parameters in the JSON body
setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams()
for _, key := range setParamKeys {
pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
return
}
}
// setParamsByID: set params based on the requested model ID (runs after setParams, can override it)
setParamsByIDParams, setParamsByIDKeys := pm.config.Models[modelID].Filters.SanitizedSetParamsByID(requestedModel)
for _, key := range setParamsByIDKeys {
pm.proxyLogger.Debugf("<%s> setting param by id: %s", requestedModel, key)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParamsByIDParams[key])
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
return
}
}
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
nextHandler = localHandler
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
// issue #453 apply filters for peer requests
peerFilters := pm.peerProxy.GetPeerFilters(requestedModel)
// Apply stripParams - remove specified parameters from request
stripParams := peerFilters.SanitizedStripParams()
for _, param := range stripParams {
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param)
pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param)
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param))
return
}
}
// Apply setParams - set/override specified parameters in request
setParams, setParamKeys := peerFilters.SanitizedSetParams()
for _, key := range setParamKeys {
pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
return
}
}
nextHandler = pm.peerProxy.ProxyRequest
}
if nextHandler == nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
return
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
c.Request.ContentLength = int64(len(bodyBytes))
// issue #366 extract values that downstream handlers may need
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
c.Request = c.Request.WithContext(ctx)
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, cf, nextHandler); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
return
}
} else {
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
return
}
}
}
}
// mkPostFormHandler creates a POST form handler for inference backends
// with a custom captureFields to filter out large binary requests or responses.
func (pm *ProxyManager) mkPostFormHandler(cf captureFields) func(*gin.Context) {
return func(c *gin.Context) {
// Parse multipart form
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
return
}
// Get model parameter from the form
requestedModel := c.Request.FormValue("model")
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
return
}
// Look for a matching local model first, then check peers
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
var useModelName string
modelID, found := pm.config.RealModelName(requestedModel)
if found {
if pm.matrix != nil {
nextHandler = pm.matrix.ProxyRequest
} else {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
nextHandler = processGroup.ProxyRequest
}
useModelName = pm.config.Models[modelID].UseModelName
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
nextHandler = pm.peerProxy.ProxyRequest
}
if nextHandler == nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
return
}
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Copy all form values
for key, values := range c.Request.MultipartForm.Value {
for _, value := range values {
fieldValue := value
// If this is the model field and we have a profile, use just the model name
if key == "model" {
// # issue #69 allow custom model names to be sent to upstream
if useModelName != "" {
fieldValue = useModelName
} else {
fieldValue = requestedModel
}
}
field, err := multipartWriter.CreateFormField(key)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
return
}
if _, err = field.Write([]byte(fieldValue)); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
return
}
}
}
// issue #453 set/override parameters in the JSON body
setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams()
for _, key := range setParamKeys {
pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
return
}
}
// setParamsByID: set params based on the requested model ID (runs after setParams, can override it)
setParamsByIDParams, setParamsByIDKeys := pm.config.Models[modelID].Filters.SanitizedSetParamsByID(requestedModel)
for _, key := range setParamsByIDKeys {
pm.proxyLogger.Debugf("<%s> setting param by id: %s", requestedModel, key)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParamsByIDParams[key])
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
return
}
}
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
nextHandler = localHandler
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
// issue #453 apply filters for peer requests
peerFilters := pm.peerProxy.GetPeerFilters(requestedModel)
// Apply stripParams - remove specified parameters from request
stripParams := peerFilters.SanitizedStripParams()
for _, param := range stripParams {
pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param)
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param))
return
}
}
// Apply setParams - set/override specified parameters in request
setParams, setParamKeys := peerFilters.SanitizedSetParams()
for _, key := range setParamKeys {
pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
return
}
}
nextHandler = pm.peerProxy.ProxyRequest
}
if nextHandler == nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
return
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
c.Request.ContentLength = int64(len(bodyBytes))
// issue #366 extract values that downstream handlers may need
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
c.Request = c.Request.WithContext(ctx)
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, nextHandler); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
return
}
} else {
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
return
}
}
}
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// Parse multipart form
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
return
}
// Get model parameter from the form
requestedModel := c.Request.FormValue("model")
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
return
}
// Look for a matching local model first, then check peers
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
var useModelName string
modelID, found := pm.config.RealModelName(requestedModel)
if found {
if pm.matrix != nil {
nextHandler = pm.matrix.ProxyRequest
} else {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
nextHandler = processGroup.ProxyRequest
}
useModelName = pm.config.Models[modelID].UseModelName
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
nextHandler = pm.peerProxy.ProxyRequest
}
if nextHandler == nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
return
}
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Copy all form values
for key, values := range c.Request.MultipartForm.Value {
for _, value := range values {
fieldValue := value
// If this is the model field and we have a profile, use just the model name
if key == "model" {
// # issue #69 allow custom model names to be sent to upstream
if useModelName != "" {
fieldValue = useModelName
} else {
fieldValue = requestedModel
// Copy all files from the original request
for key, fileHeaders := range c.Request.MultipartForm.File {
for _, fileHeader := range fileHeaders {
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
return
}
}
field, err := multipartWriter.CreateFormField(key)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
return
}
if _, err = field.Write([]byte(fieldValue)); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
return
}
}
}
// Copy all files from the original request
for key, fileHeaders := range c.Request.MultipartForm.File {
for _, fileHeader := range fileHeaders {
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
return
}
file, err := fileHeader.Open()
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
return
}
file, err := fileHeader.Open()
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
return
}
if _, err = io.Copy(formFile, file); err != nil {
if _, err = io.Copy(formFile, file); err != nil {
file.Close()
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
return
}
file.Close()
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
}
}
// Close the multipart writer to finalize the form
if err := multipartWriter.Close(); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
return
}
// Create a new request with the reconstructed form data
modifiedReq, err := http.NewRequestWithContext(
c.Request.Context(),
c.Request.Method,
c.Request.URL.String(),
&requestBuffer,
)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
return
}
// Copy the headers from the original request
modifiedReq.Header = c.Request.Header.Clone()
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
// set the content length of the body
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
modifiedReq.ContentLength = int64(requestBuffer.Len())
// Use the modified request for proxying
if pm.metricsMonitor != nil {
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, modifiedReq, cf, nextHandler); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
return
}
} else {
if err := nextHandler(modelID, c.Writer, modifiedReq); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
return
}
file.Close()
}
}
// Close the multipart writer to finalize the form
if err := multipartWriter.Close(); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
return
}
// Create a new request with the reconstructed form data
modifiedReq, err := http.NewRequestWithContext(
c.Request.Context(),
c.Request.Method,
c.Request.URL.String(),
&requestBuffer,
)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
return
}
// Copy the headers from the original request
modifiedReq.Header = c.Request.Header.Clone()
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
// set the content length of the body
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
modifiedReq.ContentLength = int64(requestBuffer.Len())
// Use the modified request for proxying
if err := nextHandler(modelID, c.Writer, modifiedReq); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
return
}
}
func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) {
+10 -20
View File
@@ -158,7 +158,7 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
}
}
sendMetrics := func(metrics []TokenMetrics) {
sendMetrics := func(metrics []ActivityLogEntry) {
jsonData, err := json.Marshal(metrics)
if err == nil {
select {
@@ -205,8 +205,8 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
/**
* Send Metrics data
*/
defer event.On(func(e TokenMetricsEvent) {
sendMetrics([]TokenMetrics{e.Metrics})
defer event.On(func(e ActivityLogEvent) {
sendMetrics([]ActivityLogEntry{e.Metrics})
})()
/**
@@ -290,26 +290,16 @@ func (pm *ProxyManager) apiGetCapture(c *gin.Context) {
return
}
data, exists := pm.metricsMonitor.getCompressedBytes(id)
if !exists {
capture := pm.metricsMonitor.getCaptureByID(id)
if capture == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"})
return
}
c.Header("Vary", "Accept-Encoding")
// ¯\_(ツ)_/¯ quality weights are too fancy for us anyway
hasZstd := strings.Contains(c.GetHeader("Accept-Encoding"), "zstd")
if hasZstd {
c.Header("Content-Encoding", "zstd")
c.Data(http.StatusOK, "application/json", data)
} else {
decompressed, err := decompressCapture(data)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to decompress capture"})
return
}
c.Data(http.StatusOK, "application/json", decompressed)
jsonBytes, err := json.Marshal(capture)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to marshal capture"})
return
}
c.Data(http.StatusOK, "application/json", jsonBytes)
}
+58
View File
@@ -1721,3 +1721,61 @@ models:
assert.Contains(t, w.Body.String(), "could not find suitable handler")
})
}
func TestProxyManager_AudioTranscriptionCapture(t *testing.T) {
cfg := testConfigFromYAML(t, `
healthCheckTimeout: 15
logLevel: error
captureBuffer: 5
models:
TheExpectedModel:
cmd: {{RESPONDER}} --port ${PORT} --silent --respond TheExpectedModel
`)
proxy := New(cfg)
defer proxy.StopProcesses(StopWaitForInflightRequest)
injectTestHandlers(proxy, nil)
var b bytes.Buffer
w := multipart.NewWriter(&b)
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte("TheExpectedModel"))
assert.NoError(t, err)
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
_, err = fw.Write([]byte("test audio content"))
assert.NoError(t, err)
w.Close()
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
req.Header.Set("Authorization", "Bearer mysecret")
req.Header.Set("X-Custom-Req", "req-value")
rec := CreateTestResponseRecorder()
proxy.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// Verify capture exists
metrics := proxy.metricsMonitor.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.True(t, metrics[0].HasCapture)
capture := proxy.metricsMonitor.getCaptureByID(metrics[0].ID)
assert.NotNil(t, capture)
// Should capture request headers (sensitive ones redacted)
assert.NotEmpty(t, capture.ReqHeaders)
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
assert.Equal(t, "req-value", capture.ReqHeaders["X-Custom-Req"])
// Should capture response headers
assert.NotNil(t, capture.RespHeaders)
// Should NOT capture request bodies but get response bodies (text
assert.Nil(t, capture.ReqBody)
assert.NotNil(t, capture.RespBody)
}
+3 -3
View File
@@ -2788,9 +2788,9 @@
}
},
"node_modules/postcss": {
"version": "8.5.8",
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.8.tgz",
"integrity": "sha512-OW/rX8O/jXnm82Ey1k44pObPtdblfiuWnrd8X7GJ7emImCOstunGbXUpp7HdBrFQX6rJzn3sPT397Wp5aCwCHg==",
"version": "8.5.12",
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.12.tgz",
"integrity": "sha512-W62t/Se6rA0Az3DfCL0AqJwXuKwBeYg6nOaIgzP+xZ7N5BFCI7DYi1qs6ygUYT6rvfi6t9k65UMLJC+PHZpDAA==",
"dev": true,
"funding": [
{
@@ -9,13 +9,13 @@
let stats = $derived.by(() => {
const totalRequests = $metrics.length;
const totalInputTokens = $metrics.reduce((sum, m) => sum + m.input_tokens, 0);
const totalOutputTokens = $metrics.reduce((sum, m) => sum + m.output_tokens, 0);
const totalCacheTokens = $metrics.reduce((sum, m) => sum + m.cache_tokens, 0);
const totalInputTokens = $metrics.reduce((sum, m) => sum + m.tokens.input_tokens, 0);
const totalOutputTokens = $metrics.reduce((sum, m) => sum + m.tokens.output_tokens, 0);
const totalCacheTokens = $metrics.reduce((sum, m) => sum + m.tokens.cache_tokens, 0);
const promptPerSecond = $metrics.filter((m) => m.prompt_per_second > 0).map((m) => m.prompt_per_second);
const promptPerSecond = $metrics.filter((m) => m.tokens.prompt_per_second > 0).map((m) => m.tokens.prompt_per_second);
const tokensPerSecond = $metrics.filter((m) => m.tokens_per_second > 0).map((m) => m.tokens_per_second);
const tokensPerSecond = $metrics.filter((m) => m.tokens.tokens_per_second > 0).map((m) => m.tokens.tokens_per_second);
const promptHistogramData =
promptPerSecond.length > 0 ? calculateHistogramData(promptPerSecond) : null;
+11 -4
View File
@@ -12,15 +12,22 @@ export interface Model {
aliases?: string[];
}
export interface Metrics {
id: number;
timestamp: string;
model: string;
export interface TokenMetrics {
cache_tokens: number;
input_tokens: number;
output_tokens: number;
prompt_per_second: number;
tokens_per_second: number;
}
export interface ActivityLogEntry {
id: number;
timestamp: string;
model: string;
req_path: string;
resp_content_type: string;
resp_status_code: number;
tokens: TokenMetrics;
duration_ms: number;
has_capture: boolean;
}
+209 -38
View File
@@ -3,8 +3,87 @@
import ActivityStats from "../components/ActivityStats.svelte";
import Tooltip from "../components/Tooltip.svelte";
import CaptureDialog from "../components/CaptureDialog.svelte";
import { persistentStore } from "../stores/persistent";
import { onMount } from "svelte";
import type { ReqRespCapture } from "../lib/types";
type ColumnKey =
| "id"
| "time"
| "model"
| "req_path"
| "resp_status_code"
| "resp_content_type"
| "cached"
| "prompt"
| "generated"
| "prompt_speed"
| "gen_speed"
| "duration"
| "capture";
interface ColumnDef {
key: ColumnKey;
label: string;
defaultVisible: boolean;
}
const columns: ColumnDef[] = [
{ key: "id", label: "ID", defaultVisible: true },
{ key: "time", label: "Time", defaultVisible: true },
{ key: "model", label: "Model", defaultVisible: true },
{ key: "req_path", label: "Path", defaultVisible: false },
{ key: "resp_status_code", label: "Status", defaultVisible: false },
{ key: "resp_content_type", label: "Content-Type", defaultVisible: false },
{ key: "cached", label: "Cached", defaultVisible: true },
{ key: "prompt", label: "Prompt", defaultVisible: true },
{ key: "generated", label: "Generated", defaultVisible: true },
{ key: "prompt_speed", label: "Prompt Speed", defaultVisible: true },
{ key: "gen_speed", label: "Gen Speed", defaultVisible: true },
{ key: "duration", label: "Duration", defaultVisible: true },
{ key: "capture", label: "Capture", defaultVisible: true },
];
const defaultVisibleKeys = columns.filter((c) => c.defaultVisible).map((c) => c.key);
const visibleColumns = persistentStore<ColumnKey[]>(
"activity-columns",
defaultVisibleKeys
);
let columnsMenuOpen = $state(false);
let dropdownContainer: HTMLDivElement | null = null;
onMount(() => {
function handleKeydown(e: KeyboardEvent) {
if (e.key === "Escape" && columnsMenuOpen) {
columnsMenuOpen = false;
}
}
function handleClick(e: MouseEvent) {
if (columnsMenuOpen && dropdownContainer && !dropdownContainer.contains(e.target as Node)) {
columnsMenuOpen = false;
}
}
document.addEventListener("keydown", handleKeydown);
document.addEventListener("click", handleClick);
return () => {
document.removeEventListener("keydown", handleKeydown);
document.removeEventListener("click", handleClick);
};
});
function toggleColumn(key: ColumnKey) {
const current = $visibleColumns;
if (current.includes(key)) {
if (current.length > 1) {
visibleColumns.set(current.filter((k) => k !== key));
}
} else {
visibleColumns.set([...current, key]);
}
}
function formatSpeed(speed: number): string {
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
}
@@ -67,58 +146,150 @@
<ActivityStats />
</div>
<div class="card overflow-auto">
<div class="card overflow-auto relative min-h-[30rem]">
<div class="flex justify-end px-4" bind:this={dropdownContainer}>
<div class="relative">
<button
class="w-8 h-8 flex items-center justify-center rounded hover:bg-secondary-hover transition-colors"
onclick={() => (columnsMenuOpen = !columnsMenuOpen)}
title="Select columns"
>
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 6V4m0 2a2 2 0 100 4m0-4a2 2 0 110 4m-6 8a2 2 0 100-4m0 4a2 2 0 110-4m0 4v2m0-6V4m6 6v10m6-2a2 2 0 100-4m0 4a2 2 0 110-4m0 4v2m0-6V4"></path>
</svg>
</button>
{#if columnsMenuOpen}
<div class="absolute right-0 top-full mt-1 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-10 py-1 min-w-[16rem]">
<div class="px-3 py-2 text-xs font-medium uppercase tracking-wider text-gray-500 dark:text-gray-400 border-b border-gray-200 dark:border-white/10">
Columns
</div>
{#each columns as col (col.key)}
<label
class="flex items-center gap-2 px-3 py-1.5 text-sm cursor-pointer hover:bg-secondary-hover transition-colors"
>
<input
type="checkbox"
checked={$visibleColumns.includes(col.key)}
onchange={() => toggleColumn(col.key)}
class="rounded"
/>
{col.label}
</label>
{/each}
</div>
{/if}
</div>
</div>
<table class="min-w-full divide-y">
<thead class="border-gray-200 dark:border-white/10">
<tr class="text-left text-xs uppercase tracking-wider">
<th class="px-6 py-3">ID</th>
<th class="px-6 py-3">Time</th>
<th class="px-6 py-3">Model</th>
<th class="px-6 py-3">
Cached <Tooltip content="prompt tokens from cache" />
</th>
<th class="px-6 py-3">
Prompt <Tooltip content="new prompt tokens processed" />
</th>
<th class="px-6 py-3">Generated</th>
<th class="px-6 py-3">Prompt Processing</th>
<th class="px-6 py-3">Generation Speed</th>
<th class="px-6 py-3">Duration</th>
<th class="px-6 py-3">Capture</th>
{#if $visibleColumns.includes("id")}
<th class="px-6 py-3">ID</th>
{/if}
{#if $visibleColumns.includes("time")}
<th class="px-6 py-3">Time</th>
{/if}
{#if $visibleColumns.includes("model")}
<th class="px-6 py-3">Model</th>
{/if}
{#if $visibleColumns.includes("req_path")}
<th class="px-6 py-3">Path</th>
{/if}
{#if $visibleColumns.includes("resp_status_code")}
<th class="px-6 py-3">Status</th>
{/if}
{#if $visibleColumns.includes("resp_content_type")}
<th class="px-6 py-3">Content-Type</th>
{/if}
{#if $visibleColumns.includes("cached")}
<th class="px-6 py-3">
Cached <Tooltip content="prompt tokens from cache" />
</th>
{/if}
{#if $visibleColumns.includes("prompt")}
<th class="px-6 py-3">
Prompt <Tooltip content="new prompt tokens processed" />
</th>
{/if}
{#if $visibleColumns.includes("generated")}
<th class="px-6 py-3">Generated</th>
{/if}
{#if $visibleColumns.includes("prompt_speed")}
<th class="px-6 py-3">Prompt Speed</th>
{/if}
{#if $visibleColumns.includes("gen_speed")}
<th class="px-6 py-3">Gen Speed</th>
{/if}
{#if $visibleColumns.includes("duration")}
<th class="px-6 py-3">Duration</th>
{/if}
{#if $visibleColumns.includes("capture")}
<th class="px-6 py-3">Capture</th>
{/if}
</tr>
</thead>
<tbody class="divide-y">
{#if sortedMetrics.length === 0}
<tr>
<td colspan="10" class="px-6 py-8 text-center text-sm text-gray-500 dark:text-gray-400">
<td colspan={$visibleColumns.length} class="px-6 py-8 text-center text-sm text-gray-500 dark:text-gray-400">
No activity recorded
</td>
</tr>
{:else}
{#each sortedMetrics as metric (metric.id)}
<tr class="whitespace-nowrap text-sm border-gray-200 dark:border-white/10">
<td class="px-4 py-4">{metric.id + 1}</td>
<td class="px-6 py-4">{formatRelativeTime(metric.timestamp)}</td>
<td class="px-6 py-4">{metric.model}</td>
<td class="px-6 py-4">{metric.cache_tokens > 0 ? metric.cache_tokens.toLocaleString() : "-"}</td>
<td class="px-6 py-4">{metric.input_tokens.toLocaleString()}</td>
<td class="px-6 py-4">{metric.output_tokens.toLocaleString()}</td>
<td class="px-6 py-4">{formatSpeed(metric.prompt_per_second)}</td>
<td class="px-6 py-4">{formatSpeed(metric.tokens_per_second)}</td>
<td class="px-6 py-4">{formatDuration(metric.duration_ms)}</td>
<td class="px-6 py-4">
{#if metric.has_capture}
<button
onclick={() => viewCapture(metric.id)}
disabled={loadingCaptureId === metric.id}
class="btn btn--sm"
>
{loadingCaptureId === metric.id ? "..." : "View"}
</button>
{:else}
<span class="text-txtsecondary">-</span>
{/if}
</td>
{#if $visibleColumns.includes("id")}
<td class="px-4 py-4">{metric.id + 1}</td>
{/if}
{#if $visibleColumns.includes("time")}
<td class="px-6 py-4">{formatRelativeTime(metric.timestamp)}</td>
{/if}
{#if $visibleColumns.includes("model")}
<td class="px-6 py-4">{metric.model}</td>
{/if}
{#if $visibleColumns.includes("req_path")}
<td class="px-6 py-4">{metric.req_path || "-"}</td>
{/if}
{#if $visibleColumns.includes("resp_status_code")}
<td class="px-6 py-4">{metric.resp_status_code || "-"}</td>
{/if}
{#if $visibleColumns.includes("resp_content_type")}
<td class="px-6 py-4">{metric.resp_content_type || "-"}</td>
{/if}
{#if $visibleColumns.includes("cached")}
<td class="px-6 py-4">{metric.tokens.cache_tokens > 0 ? metric.tokens.cache_tokens.toLocaleString() : "-"}</td>
{/if}
{#if $visibleColumns.includes("prompt")}
<td class="px-6 py-4">{metric.tokens.input_tokens.toLocaleString()}</td>
{/if}
{#if $visibleColumns.includes("generated")}
<td class="px-6 py-4">{metric.tokens.output_tokens.toLocaleString()}</td>
{/if}
{#if $visibleColumns.includes("prompt_speed")}
<td class="px-6 py-4">{formatSpeed(metric.tokens.prompt_per_second)}</td>
{/if}
{#if $visibleColumns.includes("gen_speed")}
<td class="px-6 py-4">{formatSpeed(metric.tokens.tokens_per_second)}</td>
{/if}
{#if $visibleColumns.includes("duration")}
<td class="px-6 py-4">{formatDuration(metric.duration_ms)}</td>
{/if}
{#if $visibleColumns.includes("capture")}
<td class="px-6 py-4">
{#if metric.has_capture}
<button
onclick={() => viewCapture(metric.id)}
disabled={loadingCaptureId === metric.id}
class="btn btn--sm"
>
{loadingCaptureId === metric.id ? "..." : "View"}
</button>
{:else}
<span class="text-txtsecondary">-</span>
{/if}
</td>
{/if}
</tr>
{/each}
{/if}
+2 -2
View File
@@ -10,7 +10,7 @@
const viewModeStore = persistentStore<ViewMode>("logviewer-view-mode", "panels");
let direction = $derived<"horizontal" | "vertical">(
$screenWidth === "xs" || $screenWidth === "sm" ? "vertical" : "horizontal"
$screenWidth === "xs" || $screenWidth === "sm" ? "vertical" : "horizontal",
);
</script>
@@ -30,7 +30,7 @@
class:bg-primary={$viewModeStore === "proxy"}
class:text-btn-primary-text={$viewModeStore === "proxy"}
>
Panel
Proxy
</button>
<button
onclick={() => viewModeStore.set("upstream")}
+12 -4
View File
@@ -1,5 +1,13 @@
import { writable } from "svelte/store";
import type { Model, Metrics, VersionInfo, LogData, APIEventEnvelope, ReqRespCapture, InFlightStats } from "../lib/types";
import type {
Model,
ActivityLogEntry,
VersionInfo,
LogData,
APIEventEnvelope,
ReqRespCapture,
InFlightStats,
} from "../lib/types";
import { connectionState } from "./theme";
const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
@@ -8,7 +16,7 @@ const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
export const models = writable<Model[]>([]);
export const proxyLogs = writable<string>("");
export const upstreamLogs = writable<string>("");
export const metrics = writable<Metrics[]>([]);
export const metrics = writable<ActivityLogEntry[]>([]);
export const inFlightRequests = writable<number>(0);
export const versionInfo = writable<VersionInfo>({
build_date: "unknown",
@@ -62,7 +70,7 @@ export function enableAPIEvents(enabled: boolean): void {
const newModels = JSON.parse(message.data) as Model[];
// Sort models by name and id
newModels.sort((a, b) => {
return (a.name + a.id).localeCompare(b.name + b.id, undefined, { numeric : true} );
return (a.name + a.id).localeCompare(b.name + b.id, undefined, { numeric: true });
});
models.set(newModels);
break;
@@ -82,7 +90,7 @@ export function enableAPIEvents(enabled: boolean): void {
}
case "metrics": {
const newMetrics = JSON.parse(message.data) as Metrics[];
const newMetrics = JSON.parse(message.data) as ActivityLogEntry[];
metrics.update((prevMetrics) => [...newMetrics, ...prevMetrics]);
break;
}