mirror of
https://github.com/mostlygeek/llama-swap.git
synced 2026-06-09 06:46:34 +02:00
Refactor Activity Page (#710)
UI Tests / run-tests (push) Successful in 1m16s
Linux CI / run-tests (push) Successful in 3m59s
Close inactive issues / close-issues (push) Successful in 8s
Build Unified Docker Image / setup (push) Successful in 3s
Build Containers / build-and-push (cpu) (push) Failing after 13s
Build Containers / build-and-push (cuda) (push) Failing after 11s
Build Containers / build-and-push (cuda13) (push) Failing after 13s
Build Containers / build-and-push (intel) (push) Failing after 11s
Build Containers / build-and-push (musa) (push) Failing after 12s
Build Containers / build-and-push (rocm) (push) Failing after 12s
Build Containers / build-and-push (vulkan) (push) Failing after 11s
Build Containers / delete-untagged-containers (push) Has been skipped
Build Unified Docker Image / build (push) Failing after 10s
Windows CI / run-tests (push) Has been cancelled
UI Tests / run-tests (push) Successful in 1m16s
Linux CI / run-tests (push) Successful in 3m59s
Close inactive issues / close-issues (push) Successful in 8s
Build Unified Docker Image / setup (push) Successful in 3s
Build Containers / build-and-push (cpu) (push) Failing after 13s
Build Containers / build-and-push (cuda) (push) Failing after 11s
Build Containers / build-and-push (cuda13) (push) Failing after 13s
Build Containers / build-and-push (intel) (push) Failing after 11s
Build Containers / build-and-push (musa) (push) Failing after 12s
Build Containers / build-and-push (rocm) (push) Failing after 12s
Build Containers / build-and-push (vulkan) (push) Failing after 11s
Build Containers / delete-untagged-containers (push) Has been skipped
Build Unified Docker Image / build (push) Failing after 10s
Windows CI / run-tests (push) Has been cancelled
- inference handles to store an activity record for all inference endpoints - add path, status code, and content type to Activities page - toggle on/off columns no Activities page - add configurable capture level for inference endpoints so large binary blobs are not stored in memory - store captures in compressed binary format
This commit is contained in:
@@ -4,6 +4,7 @@ go 1.26.1
|
||||
|
||||
require (
|
||||
github.com/billziss-gh/golib v0.2.0
|
||||
github.com/fxamacker/cbor/v2 v2.9.1
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/klauspost/compress v1.18.5
|
||||
github.com/stretchr/testify v1.9.0
|
||||
@@ -36,6 +37,7 @@ require (
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
github.com/x448/float16 v0.8.4 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
|
||||
@@ -11,6 +11,8 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ=
|
||||
github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||
@@ -77,6 +79,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
|
||||
Vendored
+102
@@ -0,0 +1,102 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrExceedsMaxSize = errors.New("item exceeds maximum cache size")
|
||||
ErrNotFound = errors.New("item not found")
|
||||
)
|
||||
|
||||
type Cache struct {
|
||||
mu sync.Mutex
|
||||
items map[int][]byte
|
||||
order []int
|
||||
size int
|
||||
maxSize int
|
||||
}
|
||||
|
||||
func New(maxBytes int) *Cache {
|
||||
return &Cache{
|
||||
items: make(map[int][]byte),
|
||||
order: make([]int, 0),
|
||||
maxSize: maxBytes,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache) Add(id int, data []byte) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
dataSize := len(data)
|
||||
if dataSize > c.maxSize {
|
||||
return ErrExceedsMaxSize
|
||||
}
|
||||
|
||||
// If key already exists, remove old entry from size and order
|
||||
if old, exists := c.items[id]; exists {
|
||||
c.size -= len(old)
|
||||
c.removeOrder(id)
|
||||
}
|
||||
|
||||
// Evict oldest (FIFO) until room available
|
||||
for c.size+dataSize > c.maxSize && len(c.order) > 0 {
|
||||
oldestID := c.order[0]
|
||||
c.order = c.order[1:]
|
||||
if evicted, exists := c.items[oldestID]; exists {
|
||||
c.size -= len(evicted)
|
||||
delete(c.items, oldestID)
|
||||
}
|
||||
}
|
||||
|
||||
c.items[id] = data
|
||||
c.order = append(c.order, id)
|
||||
c.size += dataSize
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cache) removeOrder(id int) {
|
||||
for i, v := range c.order {
|
||||
if v == id {
|
||||
c.order = append(c.order[:i], c.order[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache) Get(id int) ([]byte, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
data, exists := c.items[id]
|
||||
if !exists {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (c *Cache) Has(id int) bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
_, exists := c.items[id]
|
||||
return exists
|
||||
}
|
||||
|
||||
func (c *Cache) Size() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
return c.size
|
||||
}
|
||||
|
||||
func (c *Cache) Clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.items = make(map[int][]byte)
|
||||
c.order = c.order[:0]
|
||||
c.size = 0
|
||||
}
|
||||
Vendored
+130
@@ -0,0 +1,130 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCache_Add(t *testing.T) {
|
||||
t.Run("adds and retrieves item", func(t *testing.T) {
|
||||
c := New(1024)
|
||||
data := []byte("hello")
|
||||
require.NoError(t, c.Add(1, data))
|
||||
|
||||
got, err := c.Get(1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, data, got)
|
||||
})
|
||||
|
||||
t.Run("returns error for oversized item", func(t *testing.T) {
|
||||
c := New(10)
|
||||
err := c.Add(1, make([]byte, 20))
|
||||
assert.ErrorIs(t, err, ErrExceedsMaxSize)
|
||||
})
|
||||
|
||||
t.Run("evicts oldest items to make room", func(t *testing.T) {
|
||||
c := New(100)
|
||||
|
||||
require.NoError(t, c.Add(1, make([]byte, 40)))
|
||||
require.NoError(t, c.Add(2, make([]byte, 40)))
|
||||
// Adding item 3 should evict item 1
|
||||
require.NoError(t, c.Add(3, make([]byte, 40)))
|
||||
|
||||
assert.False(t, c.Has(1))
|
||||
assert.True(t, c.Has(2))
|
||||
assert.True(t, c.Has(3))
|
||||
})
|
||||
|
||||
t.Run("overwrites existing key", func(t *testing.T) {
|
||||
c := New(100)
|
||||
require.NoError(t, c.Add(1, []byte("old")))
|
||||
require.NoError(t, c.Add(1, []byte("new")))
|
||||
|
||||
got, err := c.Get(1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("new"), got)
|
||||
assert.Equal(t, 3, c.Size())
|
||||
})
|
||||
}
|
||||
|
||||
func TestCache_Get(t *testing.T) {
|
||||
t.Run("returns ErrNotFound for missing key", func(t *testing.T) {
|
||||
c := New(100)
|
||||
_, err := c.Get(99)
|
||||
assert.ErrorIs(t, err, ErrNotFound)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCache_Has(t *testing.T) {
|
||||
t.Run("returns true for existing key", func(t *testing.T) {
|
||||
c := New(100)
|
||||
require.NoError(t, c.Add(1, []byte("data")))
|
||||
assert.True(t, c.Has(1))
|
||||
})
|
||||
|
||||
t.Run("returns false for missing key", func(t *testing.T) {
|
||||
c := New(100)
|
||||
assert.False(t, c.Has(1))
|
||||
})
|
||||
}
|
||||
|
||||
func TestCache_Size(t *testing.T) {
|
||||
t.Run("tracks byte usage", func(t *testing.T) {
|
||||
c := New(1000)
|
||||
assert.Equal(t, 0, c.Size())
|
||||
|
||||
require.NoError(t, c.Add(1, make([]byte, 100)))
|
||||
assert.Equal(t, 100, c.Size())
|
||||
|
||||
require.NoError(t, c.Add(2, make([]byte, 200)))
|
||||
assert.Equal(t, 300, c.Size())
|
||||
})
|
||||
|
||||
t.Run("updates on eviction", func(t *testing.T) {
|
||||
c := New(150)
|
||||
require.NoError(t, c.Add(1, make([]byte, 100)))
|
||||
require.NoError(t, c.Add(2, make([]byte, 100)))
|
||||
|
||||
// Item 1 should be evicted, size = 100
|
||||
assert.Equal(t, 100, c.Size())
|
||||
})
|
||||
}
|
||||
|
||||
func TestCache_Clear(t *testing.T) {
|
||||
t.Run("removes all items and resets size", func(t *testing.T) {
|
||||
c := New(1000)
|
||||
require.NoError(t, c.Add(1, []byte("a")))
|
||||
require.NoError(t, c.Add(2, []byte("b")))
|
||||
|
||||
c.Clear()
|
||||
|
||||
assert.Equal(t, 0, c.Size())
|
||||
assert.False(t, c.Has(1))
|
||||
assert.False(t, c.Has(2))
|
||||
})
|
||||
}
|
||||
|
||||
func TestCache_Concurrent(t *testing.T) {
|
||||
t.Run("concurrent operations are safe", func(t *testing.T) {
|
||||
c := New(10000)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 100; j++ {
|
||||
key := id*100 + j
|
||||
_ = c.Add(key, []byte("data"))
|
||||
_, _ = c.Get(key)
|
||||
_ = c.Has(key)
|
||||
_ = c.Size()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
+1
-1
@@ -6,7 +6,7 @@ const ProcessStateChangeEventID = 0x01
|
||||
const ChatCompletionStatsEventID = 0x02
|
||||
const ConfigFileChangedEventID = 0x03
|
||||
const LogDataEventID = 0x04
|
||||
const TokenMetricsEventID = 0x05
|
||||
const ActivityLogEventID = 0x05
|
||||
const ModelPreloadedEventID = 0x06
|
||||
const InFlightRequestsEventID = 0x07
|
||||
|
||||
|
||||
+181
-130
@@ -12,9 +12,11 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/mostlygeek/llama-swap/proxy/cache"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
@@ -42,37 +44,53 @@ var zstdDecPool = &sync.Pool{
|
||||
},
|
||||
}
|
||||
|
||||
// compressCapture marshals a ReqRespCapture to JSON and compresses it with zstd.
|
||||
// Returns compressed bytes and the original JSON byte count for logging.
|
||||
// compressCapture marshals a ReqRespCapture to CBOR and compresses it with zstd.
|
||||
// Returns compressed bytes and the original CBOR byte count for logging.
|
||||
func compressCapture(c *ReqRespCapture) ([]byte, int, error) {
|
||||
jsonBytes, err := json.Marshal(c)
|
||||
cborBytes, err := cbor.Marshal(c)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("marshal capture: %w", err)
|
||||
}
|
||||
enc := zstdEncPool.Get().(*zstd.Encoder)
|
||||
defer zstdEncPool.Put(enc)
|
||||
return enc.EncodeAll(jsonBytes, nil), len(jsonBytes), nil
|
||||
zenc := zstdEncPool.Get().(*zstd.Encoder)
|
||||
defer zstdEncPool.Put(zenc)
|
||||
return zenc.EncodeAll(cborBytes, nil), len(cborBytes), nil
|
||||
}
|
||||
|
||||
// decompressCapture decompresses zstd-compressed JSON and returns it.
|
||||
func decompressCapture(data []byte) ([]byte, error) {
|
||||
// decompressCapture decompresses zstd-compressed CBOR and unmarshals it into a ReqRespCapture.
|
||||
func decompressCapture(data []byte) (*ReqRespCapture, error) {
|
||||
dec := zstdDecPool.Get().(*zstd.Decoder)
|
||||
defer zstdDecPool.Put(dec)
|
||||
return dec.DecodeAll(data, nil)
|
||||
cborBytes, err := dec.DecodeAll(data, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decompress capture: %w", err)
|
||||
}
|
||||
var capture ReqRespCapture
|
||||
if err := cbor.Unmarshal(cborBytes, &capture); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal capture: %w", err)
|
||||
}
|
||||
return &capture, nil
|
||||
}
|
||||
|
||||
// TokenMetrics represents parsed token statistics from llama-server logs
|
||||
// TokenMetrics holds token usage and performance metrics
|
||||
type TokenMetrics struct {
|
||||
ID int `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Model string `json:"model"`
|
||||
CachedTokens int `json:"cache_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||
DurationMs int `json:"duration_ms"`
|
||||
HasCapture bool `json:"has_capture"`
|
||||
CachedTokens int `json:"cache_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||
}
|
||||
|
||||
// ActivityLogEntry represents parsed token statistics from llama-server logs
|
||||
type ActivityLogEntry struct {
|
||||
ID int `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Model string `json:"model"`
|
||||
ReqPath string `json:"req_path"`
|
||||
RespContentType string `json:"resp_content_type"`
|
||||
RespStatusCode int `json:"resp_status_code"`
|
||||
Tokens TokenMetrics `json:"tokens"`
|
||||
DurationMs int `json:"duration_ms"`
|
||||
HasCapture bool `json:"has_capture"`
|
||||
}
|
||||
|
||||
type ReqRespCapture struct {
|
||||
@@ -84,48 +102,45 @@ type ReqRespCapture struct {
|
||||
RespBody []byte `json:"resp_body"`
|
||||
}
|
||||
|
||||
// TokenMetricsEvent represents a token metrics event
|
||||
type TokenMetricsEvent struct {
|
||||
Metrics TokenMetrics
|
||||
// ActivityLogEvent represents a token metrics event
|
||||
type ActivityLogEvent struct {
|
||||
Metrics ActivityLogEntry
|
||||
}
|
||||
|
||||
func (e TokenMetricsEvent) Type() uint32 {
|
||||
return TokenMetricsEventID // defined in events.go
|
||||
func (e ActivityLogEvent) Type() uint32 {
|
||||
return ActivityLogEventID // defined in events.go
|
||||
}
|
||||
|
||||
// metricsMonitor parses llama-server output for token statistics
|
||||
type metricsMonitor struct {
|
||||
mu sync.RWMutex
|
||||
metrics []TokenMetrics
|
||||
metrics []ActivityLogEntry
|
||||
maxMetrics int
|
||||
nextID int
|
||||
logger *LogMonitor
|
||||
|
||||
// capture fields
|
||||
enableCaptures bool
|
||||
captures map[int][]byte // zstd-compressed JSON of ReqRespCapture
|
||||
captureOrder []int // track insertion order for FIFO eviction
|
||||
captureSize int // current total compressed size in bytes
|
||||
maxCaptureSize int // max bytes for captures (uncompressed)
|
||||
captureCache *cache.Cache // zstd-compressed CBOR of ReqRespCapture
|
||||
}
|
||||
|
||||
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
|
||||
// capture buffer size in megabytes; 0 disables captures.
|
||||
func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
|
||||
return &metricsMonitor{
|
||||
mm := &metricsMonitor{
|
||||
logger: logger,
|
||||
maxMetrics: maxMetrics,
|
||||
enableCaptures: captureBufferMB > 0,
|
||||
captures: make(map[int][]byte),
|
||||
captureOrder: make([]int, 0),
|
||||
captureSize: 0,
|
||||
maxCaptureSize: captureBufferMB * 1024 * 1024,
|
||||
}
|
||||
if captureBufferMB > 0 {
|
||||
mm.captureCache = cache.New(captureBufferMB * 1024 * 1024)
|
||||
}
|
||||
return mm
|
||||
}
|
||||
|
||||
// addMetrics adds a new metric to the collection and publishes an event.
|
||||
// Returns the assigned metric ID.
|
||||
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int {
|
||||
// queueMetrics adds a new metric to the collection without emitting an event.
|
||||
// Returns the assigned metric ID. Call emitMetric after capture setup.
|
||||
func (mp *metricsMonitor) queueMetrics(metric ActivityLogEntry) int {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
@@ -135,93 +150,75 @@ func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int {
|
||||
if len(mp.metrics) > mp.maxMetrics {
|
||||
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
|
||||
}
|
||||
event.Emit(TokenMetricsEvent{Metrics: metric})
|
||||
return metric.ID
|
||||
}
|
||||
|
||||
// addCapture adds a new capture to the buffer with size-based eviction.
|
||||
// Captures are skipped if enableCaptures is false or if compressed data exceeds maxCaptureSize.
|
||||
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) {
|
||||
// emitMetric publishes an ActivityLogEvent for the given metric.
|
||||
func (mp *metricsMonitor) emitMetric(metric ActivityLogEntry) {
|
||||
event.Emit(ActivityLogEvent{Metrics: metric})
|
||||
}
|
||||
|
||||
// addCapture compresses and stores a capture in the cache.
|
||||
// Returns true if the capture was stored, false otherwise.
|
||||
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) bool {
|
||||
if !mp.enableCaptures {
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
compressed, uncompressedBytes, err := compressCapture(&capture)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("failed to compress capture: %v, skipping", err)
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
captureSize := len(compressed)
|
||||
if captureSize > mp.maxCaptureSize {
|
||||
mp.logger.Warnf("compressed capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize)
|
||||
return
|
||||
if err := mp.captureCache.Add(capture.ID, compressed); err != nil {
|
||||
mp.logger.Warnf("capture %d too large (%d bytes), skipping: %v", capture.ID, len(compressed), err)
|
||||
return false
|
||||
}
|
||||
|
||||
compressionRatio := (1 - float64(captureSize)/float64(uncompressedBytes)) * 100
|
||||
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
// Evict oldest (FIFO) until room available for the compressed data
|
||||
for mp.captureSize+captureSize > mp.maxCaptureSize && len(mp.captureOrder) > 0 {
|
||||
oldestID := mp.captureOrder[0]
|
||||
mp.captureOrder = mp.captureOrder[1:]
|
||||
if evicted, exists := mp.captures[oldestID]; exists {
|
||||
l := len(evicted)
|
||||
mp.captureSize -= l
|
||||
delete(mp.captures, oldestID)
|
||||
mp.logger.Debugf("Capture %d evicted to make space: %d bytes", oldestID, l)
|
||||
}
|
||||
}
|
||||
|
||||
mp.captures[capture.ID] = compressed
|
||||
mp.captureOrder = append(mp.captureOrder, capture.ID)
|
||||
mp.captureSize += captureSize
|
||||
|
||||
compressionRatio := (1 - float64(len(compressed))/float64(uncompressedBytes)) * 100
|
||||
mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio)
|
||||
return true
|
||||
}
|
||||
|
||||
// getCompressedBytes returns the raw compressed bytes for a capture by ID.
|
||||
func (mp *metricsMonitor) getCompressedBytes(id int) ([]byte, bool) {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
data, exists := mp.captures[id]
|
||||
return data, exists
|
||||
if mp.captureCache == nil {
|
||||
return nil, false
|
||||
}
|
||||
data, err := mp.captureCache.Get(id)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return data, true
|
||||
}
|
||||
|
||||
// getCaptureByID returns decompressed capture bytes if found and decompress=true.
|
||||
// If decompress=false, returns the raw zstd-compressed bytes.
|
||||
// Returns nil if the capture is not found.
|
||||
func (mp *metricsMonitor) getCaptureByID(id int, decompress bool) []byte {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
data, exists := mp.captures[id]
|
||||
// getCaptureByID decompresses and unmarshals a capture by ID.
|
||||
// Returns nil if the capture is not found or decompression fails.
|
||||
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture {
|
||||
if mp.captureCache == nil {
|
||||
return nil
|
||||
}
|
||||
data, exists := mp.getCompressedBytes(id)
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !decompress {
|
||||
return data
|
||||
}
|
||||
|
||||
decompressed, err := decompressCapture(data)
|
||||
capture, err := decompressCapture(data)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("failed to decompress capture %d: %v", id, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return decompressed
|
||||
return capture
|
||||
}
|
||||
|
||||
// getMetrics returns a copy of the current metrics
|
||||
func (mp *metricsMonitor) getMetrics() []TokenMetrics {
|
||||
func (mp *metricsMonitor) getMetrics() []ActivityLogEntry {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
result := make([]TokenMetrics, len(mp.metrics))
|
||||
result := make([]ActivityLogEntry, len(mp.metrics))
|
||||
copy(result, mp.metrics)
|
||||
return result
|
||||
}
|
||||
@@ -230,22 +227,52 @@ func (mp *metricsMonitor) getMetrics() []TokenMetrics {
|
||||
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
return json.Marshal(mp.metrics)
|
||||
|
||||
if mp.captureCache == nil {
|
||||
return json.Marshal(mp.metrics)
|
||||
}
|
||||
|
||||
// Make a copy with up-to-date has_capture from cache
|
||||
result := make([]ActivityLogEntry, len(mp.metrics))
|
||||
for i, m := range mp.metrics {
|
||||
m.HasCapture = mp.captureCache.Has(m.ID)
|
||||
result[i] = m
|
||||
}
|
||||
return json.Marshal(result)
|
||||
}
|
||||
|
||||
// wrapHandler wraps the proxy handler to extract token metrics
|
||||
// Capture field flags for controlling what is saved in ReqRespCapture.
|
||||
type captureFields uint
|
||||
|
||||
const (
|
||||
captureNone captureFields = 1 << iota
|
||||
captureReqHeaders
|
||||
captureReqBody
|
||||
captureRespHeaders
|
||||
captureRespBody
|
||||
)
|
||||
|
||||
const (
|
||||
captureReqAll = captureReqHeaders | captureReqBody
|
||||
captureRespAll = captureRespHeaders | captureRespBody
|
||||
captureAll = captureReqAll | captureRespAll
|
||||
)
|
||||
|
||||
// wrapHandler wraps the proxy handler to extract token metrics.
|
||||
// captureFields controls what is saved in the ReqRespCapture using bitwise flags.
|
||||
// if wrapHandler returns an error it is safe to assume that no
|
||||
// data was sent to the client
|
||||
func (mp *metricsMonitor) wrapHandler(
|
||||
modelID string,
|
||||
writer gin.ResponseWriter,
|
||||
request *http.Request,
|
||||
captureFields captureFields,
|
||||
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
|
||||
) error {
|
||||
// Capture request body and headers if captures enabled
|
||||
var reqBody []byte
|
||||
var reqHeaders map[string]string
|
||||
if mp.enableCaptures {
|
||||
if mp.enableCaptures && (captureFields&captureReqBody) != 0 {
|
||||
if request.Body != nil {
|
||||
var err error
|
||||
reqBody, err = io.ReadAll(request.Body)
|
||||
@@ -255,6 +282,8 @@ func (mp *metricsMonitor) wrapHandler(
|
||||
request.Body.Close()
|
||||
request.Body = io.NopCloser(bytes.NewBuffer(reqBody))
|
||||
}
|
||||
}
|
||||
if mp.enableCaptures && (captureFields&captureReqHeaders) != 0 {
|
||||
reqHeaders = make(map[string]string)
|
||||
for key, values := range request.Header {
|
||||
if len(values) > 0 {
|
||||
@@ -278,22 +307,28 @@ func (mp *metricsMonitor) wrapHandler(
|
||||
// after this point we have to assume that data was sent to the client
|
||||
// and we can only log errors but not send them to clients
|
||||
|
||||
if recorder.Status() != http.StatusOK {
|
||||
mp.logger.Warnf("metrics skipped, HTTP status=%d, path=%s", recorder.Status(), request.URL.Path)
|
||||
return nil
|
||||
// Initialize default metrics - recorded for every request
|
||||
tm := ActivityLogEntry{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
ReqPath: request.URL.Path,
|
||||
RespContentType: recorder.Header().Get("Content-Type"),
|
||||
RespStatusCode: recorder.Status(),
|
||||
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
||||
}
|
||||
|
||||
// Initialize default metrics - these will always be recorded
|
||||
tm := TokenMetrics{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
||||
if recorder.Status() != http.StatusOK {
|
||||
mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), request.URL.Path)
|
||||
tm.ID = mp.queueMetrics(tm)
|
||||
mp.emitMetric(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
body := recorder.body.Bytes()
|
||||
if len(body) == 0 {
|
||||
mp.logger.Warn("metrics: empty body, recording minimal metrics")
|
||||
mp.addMetrics(tm)
|
||||
tm.ID = mp.queueMetrics(tm)
|
||||
mp.emitMetric(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -303,7 +338,8 @@ func (mp *metricsMonitor) wrapHandler(
|
||||
body, err = decompressBody(body, encoding)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
mp.addMetrics(tm)
|
||||
tm.ID = mp.queueMetrics(tm)
|
||||
mp.emitMetric(tm)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -311,7 +347,8 @@ func (mp *metricsMonitor) wrapHandler(
|
||||
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
tm = parsed
|
||||
tm.Tokens = parsed.Tokens
|
||||
tm.DurationMs = parsed.DurationMs
|
||||
}
|
||||
} else {
|
||||
if gjson.ValidBytes(body) {
|
||||
@@ -331,7 +368,8 @@ func (mp *metricsMonitor) wrapHandler(
|
||||
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
||||
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
tm = parsedMetrics
|
||||
tm.Tokens = parsedMetrics.Tokens
|
||||
tm.DurationMs = parsedMetrics.DurationMs
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -342,39 +380,50 @@ func (mp *metricsMonitor) wrapHandler(
|
||||
// Build capture if enabled and determine if it will be stored
|
||||
var capture *ReqRespCapture
|
||||
if mp.enableCaptures {
|
||||
respHeaders := make(map[string]string)
|
||||
for key, values := range recorder.Header() {
|
||||
if len(values) > 0 {
|
||||
respHeaders[key] = values[0]
|
||||
var respHeaders map[string]string
|
||||
var respBody []byte
|
||||
if (captureFields & captureRespHeaders) != 0 {
|
||||
respHeaders = make(map[string]string)
|
||||
for key, values := range recorder.Header() {
|
||||
if len(values) > 0 {
|
||||
respHeaders[key] = values[0]
|
||||
}
|
||||
}
|
||||
redactHeaders(respHeaders)
|
||||
delete(respHeaders, "Content-Encoding")
|
||||
}
|
||||
if (captureFields & captureRespBody) != 0 {
|
||||
respBody = body
|
||||
}
|
||||
redactHeaders(respHeaders)
|
||||
delete(respHeaders, "Content-Encoding")
|
||||
capture = &ReqRespCapture{
|
||||
ReqPath: request.URL.Path,
|
||||
ReqHeaders: reqHeaders,
|
||||
ReqBody: reqBody,
|
||||
RespHeaders: respHeaders,
|
||||
RespBody: body,
|
||||
}
|
||||
compressed, _, err := compressCapture(capture)
|
||||
if err == nil && len(compressed) <= mp.maxCaptureSize {
|
||||
tm.HasCapture = true
|
||||
RespBody: respBody,
|
||||
}
|
||||
}
|
||||
|
||||
metricID := mp.addMetrics(tm)
|
||||
metricID := mp.queueMetrics(tm)
|
||||
tm.ID = metricID
|
||||
|
||||
// Store capture if enabled
|
||||
if capture != nil {
|
||||
capture.ID = metricID
|
||||
mp.addCapture(*capture)
|
||||
if mp.addCapture(*capture) {
|
||||
tm.HasCapture = true
|
||||
mp.mu.Lock()
|
||||
mp.metrics[len(mp.metrics)-1].HasCapture = true
|
||||
mp.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
mp.emitMetric(tm)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func processStreamingResponse(modelID string, start time.Time, body []byte) (TokenMetrics, error) {
|
||||
func processStreamingResponse(modelID string, start time.Time, body []byte) (ActivityLogEntry, error) {
|
||||
// Iterate **backwards** through the body looking for the data payload with
|
||||
// usage data. This avoids allocating a slice of all lines via bytes.Split.
|
||||
|
||||
@@ -428,10 +477,10 @@ func processStreamingResponse(modelID string, start time.Time, body []byte) (Tok
|
||||
}
|
||||
}
|
||||
|
||||
return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream")
|
||||
return ActivityLogEntry{}, fmt.Errorf("no valid JSON data found in stream")
|
||||
}
|
||||
|
||||
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (TokenMetrics, error) {
|
||||
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (ActivityLogEntry, error) {
|
||||
wallDurationMs := int(time.Since(start).Milliseconds())
|
||||
|
||||
// default values
|
||||
@@ -481,15 +530,17 @@ func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result)
|
||||
}
|
||||
}
|
||||
|
||||
return TokenMetrics{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
CachedTokens: cachedTokens,
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
PromptPerSecond: promptPerSecond,
|
||||
TokensPerSecond: tokensPerSecond,
|
||||
DurationMs: durationMs,
|
||||
return ActivityLogEntry{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
Tokens: TokenMetrics{
|
||||
CachedTokens: cachedTokens,
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
PromptPerSecond: promptPerSecond,
|
||||
TokensPerSecond: tokensPerSecond,
|
||||
},
|
||||
DurationMs: durationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
+343
-170
@@ -12,8 +12,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/mostlygeek/llama-swap/proxy/cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -22,27 +24,29 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||
t.Run("adds metrics and assigns ID", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
metric := TokenMetrics{
|
||||
Model: "test-model",
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
metric := ActivityLogEntry{
|
||||
Model: "test-model",
|
||||
Tokens: TokenMetrics{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
},
|
||||
}
|
||||
|
||||
mm.addMetrics(metric)
|
||||
mm.queueMetrics(metric)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, 0, metrics[0].ID)
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
|
||||
assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("increments ID for each metric", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
mm.addMetrics(TokenMetrics{Model: "model"})
|
||||
mm.queueMetrics(ActivityLogEntry{Model: "model"})
|
||||
}
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
@@ -57,9 +61,11 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||
|
||||
// Add 5 metrics
|
||||
for i := 0; i < 5; i++ {
|
||||
mm.addMetrics(TokenMetrics{
|
||||
Model: "model",
|
||||
InputTokens: i,
|
||||
mm.queueMetrics(ActivityLogEntry{
|
||||
Model: "model",
|
||||
Tokens: TokenMetrics{
|
||||
InputTokens: i,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -72,29 +78,32 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||
assert.Equal(t, 4, metrics[2].ID)
|
||||
})
|
||||
|
||||
t.Run("emits TokenMetricsEvent", func(t *testing.T) {
|
||||
t.Run("emits ActivityLogEvent", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
receivedEvent := make(chan TokenMetricsEvent, 1)
|
||||
cancel := event.On(func(e TokenMetricsEvent) {
|
||||
receivedEvent := make(chan ActivityLogEvent, 1)
|
||||
cancel := event.On(func(e ActivityLogEvent) {
|
||||
receivedEvent <- e
|
||||
})
|
||||
defer cancel()
|
||||
|
||||
metric := TokenMetrics{
|
||||
Model: "test-model",
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
metric := ActivityLogEntry{
|
||||
Model: "test-model",
|
||||
Tokens: TokenMetrics{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
},
|
||||
}
|
||||
|
||||
mm.addMetrics(metric)
|
||||
mm.queueMetrics(metric)
|
||||
mm.emitMetric(metric)
|
||||
|
||||
select {
|
||||
case evt := <-receivedEvent:
|
||||
assert.Equal(t, 0, evt.Metrics.ID)
|
||||
assert.Equal(t, "test-model", evt.Metrics.Model)
|
||||
assert.Equal(t, 100, evt.Metrics.InputTokens)
|
||||
assert.Equal(t, 50, evt.Metrics.OutputTokens)
|
||||
assert.Equal(t, 100, evt.Metrics.Tokens.InputTokens)
|
||||
assert.Equal(t, 50, evt.Metrics.Tokens.OutputTokens)
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("timeout waiting for event")
|
||||
}
|
||||
@@ -111,8 +120,8 @@ func TestMetricsMonitor_GetMetrics(t *testing.T) {
|
||||
|
||||
t.Run("returns copy of metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
mm.addMetrics(TokenMetrics{Model: "model1"})
|
||||
mm.addMetrics(TokenMetrics{Model: "model2"})
|
||||
mm.queueMetrics(ActivityLogEntry{Model: "model1"})
|
||||
mm.queueMetrics(ActivityLogEntry{Model: "model2"})
|
||||
|
||||
metrics1 := mm.getMetrics()
|
||||
metrics2 := mm.getMetrics()
|
||||
@@ -135,7 +144,7 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, jsonData)
|
||||
|
||||
var metrics []TokenMetrics
|
||||
var metrics []ActivityLogEntry
|
||||
err = json.Unmarshal(jsonData, &metrics)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
@@ -143,23 +152,27 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
|
||||
|
||||
t.Run("returns valid JSON with metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
mm.addMetrics(TokenMetrics{
|
||||
Model: "model1",
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
TokensPerSecond: 25.5,
|
||||
mm.queueMetrics(ActivityLogEntry{
|
||||
Model: "model1",
|
||||
Tokens: TokenMetrics{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
TokensPerSecond: 25.5,
|
||||
},
|
||||
})
|
||||
mm.addMetrics(TokenMetrics{
|
||||
Model: "model2",
|
||||
InputTokens: 200,
|
||||
OutputTokens: 100,
|
||||
TokensPerSecond: 30.0,
|
||||
mm.queueMetrics(ActivityLogEntry{
|
||||
Model: "model2",
|
||||
Tokens: TokenMetrics{
|
||||
InputTokens: 200,
|
||||
OutputTokens: 100,
|
||||
TokensPerSecond: 30.0,
|
||||
},
|
||||
})
|
||||
|
||||
jsonData, err := mm.getMetricsJSON()
|
||||
assert.NoError(t, err)
|
||||
|
||||
var metrics []TokenMetrics
|
||||
var metrics []ActivityLogEntry
|
||||
err = json.Unmarshal(jsonData, &metrics)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(metrics))
|
||||
@@ -190,14 +203,14 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
|
||||
assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("successful request with timings data", func(t *testing.T) {
|
||||
@@ -226,17 +239,17 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 20, metrics[0].CachedTokens)
|
||||
assert.Equal(t, 150.5, metrics[0].PromptPerSecond)
|
||||
assert.Equal(t, 25.5, metrics[0].TokensPerSecond)
|
||||
assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
|
||||
assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
|
||||
assert.Equal(t, 20, metrics[0].Tokens.CachedTokens)
|
||||
assert.Equal(t, 150.5, metrics[0].Tokens.PromptPerSecond)
|
||||
assert.Equal(t, 25.5, metrics[0].Tokens.TokensPerSecond)
|
||||
assert.Equal(t, 2000, metrics[0].DurationMs) // 500 + 1500
|
||||
})
|
||||
|
||||
@@ -265,18 +278,18 @@ data: [DONE]
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
// When timings data is present, it takes precedence
|
||||
assert.Equal(t, 10, metrics[0].InputTokens)
|
||||
assert.Equal(t, 20, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 10, metrics[0].Tokens.InputTokens)
|
||||
assert.Equal(t, 20, metrics[0].Tokens.OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("non-OK status code does not record metrics", func(t *testing.T) {
|
||||
t.Run("non-OK status code records partial metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
@@ -289,11 +302,16 @@ data: [DONE]
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, "/test", metrics[0].ReqPath)
|
||||
assert.Equal(t, http.StatusBadRequest, metrics[0].RespStatusCode)
|
||||
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("empty response body records minimal metrics", func(t *testing.T) {
|
||||
@@ -308,14 +326,14 @@ data: [DONE]
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("invalid JSON records minimal metrics", func(t *testing.T) {
|
||||
@@ -332,14 +350,14 @@ data: [DONE]
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("next handler error is propagated", func(t *testing.T) {
|
||||
@@ -354,7 +372,7 @@ data: [DONE]
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.Equal(t, expectedErr, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
@@ -377,14 +395,14 @@ data: [DONE]
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("infill request extracts timings from last array element", func(t *testing.T) {
|
||||
@@ -416,17 +434,17 @@ data: [DONE]
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 150, metrics[0].InputTokens)
|
||||
assert.Equal(t, 75, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 30, metrics[0].CachedTokens)
|
||||
assert.Equal(t, 200.5, metrics[0].PromptPerSecond)
|
||||
assert.Equal(t, 35.5, metrics[0].TokensPerSecond)
|
||||
assert.Equal(t, 150, metrics[0].Tokens.InputTokens)
|
||||
assert.Equal(t, 75, metrics[0].Tokens.OutputTokens)
|
||||
assert.Equal(t, 30, metrics[0].Tokens.CachedTokens)
|
||||
assert.Equal(t, 200.5, metrics[0].Tokens.PromptPerSecond)
|
||||
assert.Equal(t, 35.5, metrics[0].Tokens.TokensPerSecond)
|
||||
assert.Equal(t, 2400, metrics[0].DurationMs) // 600 + 1800
|
||||
})
|
||||
|
||||
@@ -446,14 +464,14 @@ data: [DONE]
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -507,7 +525,7 @@ func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_Concurrent(t *testing.T) {
|
||||
t.Run("concurrent addMetrics is safe", func(t *testing.T) {
|
||||
t.Run("concurrent queueMetrics is safe", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 1000, 0)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
@@ -519,10 +537,12 @@ func TestMetricsMonitor_Concurrent(t *testing.T) {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < metricsPerGoroutine; j++ {
|
||||
mm.addMetrics(TokenMetrics{
|
||||
Model: "test-model",
|
||||
InputTokens: id*1000 + j,
|
||||
OutputTokens: j,
|
||||
mm.queueMetrics(ActivityLogEntry{
|
||||
Model: "test-model",
|
||||
Tokens: TokenMetrics{
|
||||
InputTokens: id*1000 + j,
|
||||
OutputTokens: j,
|
||||
},
|
||||
})
|
||||
}
|
||||
}(i)
|
||||
@@ -542,7 +562,7 @@ func TestMetricsMonitor_Concurrent(t *testing.T) {
|
||||
// Writer goroutine
|
||||
go func() {
|
||||
for i := 0; i < 50; i++ {
|
||||
mm.addMetrics(TokenMetrics{Model: "test-model"})
|
||||
mm.queueMetrics(ActivityLogEntry{Model: "test-model"})
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
}
|
||||
done <- true
|
||||
@@ -586,10 +606,10 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
||||
|
||||
metrics, err := parseMetrics("test-model", start, usage, timings)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 5, metrics.InputTokens)
|
||||
assert.Equal(t, 1, metrics.OutputTokens)
|
||||
assert.Equal(t, 10.0, metrics.PromptPerSecond)
|
||||
assert.Equal(t, 2.0, metrics.TokensPerSecond)
|
||||
assert.Equal(t, 5, metrics.Tokens.InputTokens)
|
||||
assert.Equal(t, 1, metrics.Tokens.OutputTokens)
|
||||
assert.Equal(t, 10.0, metrics.Tokens.PromptPerSecond)
|
||||
assert.Equal(t, 2.0, metrics.Tokens.TokensPerSecond)
|
||||
assert.GreaterOrEqual(t, metrics.DurationMs, 5000)
|
||||
})
|
||||
|
||||
@@ -623,14 +643,14 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
// Should use timings values, not usage values
|
||||
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
|
||||
assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("handles missing cache_n in timings", func(t *testing.T) {
|
||||
@@ -658,12 +678,12 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, -1, metrics[0].CachedTokens) // Default value when not present
|
||||
assert.Equal(t, -1, metrics[0].Tokens.CachedTokens) // Default value when not present
|
||||
})
|
||||
}
|
||||
|
||||
@@ -693,13 +713,13 @@ data: [DONE]
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
|
||||
assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("handles streaming with no valid JSON records minimal metrics", func(t *testing.T) {
|
||||
@@ -722,14 +742,14 @@ data: [DONE]
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("v1/responses format with nested response.usage", func(t *testing.T) {
|
||||
@@ -751,14 +771,14 @@ data: [DONE]
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 17, metrics[0].InputTokens)
|
||||
assert.Equal(t, 23, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 17, metrics[0].Tokens.InputTokens)
|
||||
assert.Equal(t, 23, metrics[0].Tokens.OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("handles empty streaming response records minimal metrics", func(t *testing.T) {
|
||||
@@ -777,14 +797,14 @@ data: [DONE]
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -792,20 +812,22 @@ data: [DONE]
|
||||
func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
|
||||
mm := newMetricsMonitor(testLogger, 1000, 0)
|
||||
|
||||
metric := TokenMetrics{
|
||||
Model: "test-model",
|
||||
CachedTokens: 100,
|
||||
InputTokens: 500,
|
||||
OutputTokens: 250,
|
||||
PromptPerSecond: 1200.5,
|
||||
TokensPerSecond: 45.8,
|
||||
DurationMs: 5000,
|
||||
Timestamp: time.Now(),
|
||||
metric := ActivityLogEntry{
|
||||
Model: "test-model",
|
||||
Tokens: TokenMetrics{
|
||||
CachedTokens: 100,
|
||||
InputTokens: 500,
|
||||
OutputTokens: 250,
|
||||
PromptPerSecond: 1200.5,
|
||||
TokensPerSecond: 45.8,
|
||||
},
|
||||
DurationMs: 5000,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
mm.addMetrics(metric)
|
||||
mm.queueMetrics(metric)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -813,20 +835,22 @@ func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
|
||||
// Test performance with a smaller buffer where wrapping occurs more frequently
|
||||
mm := newMetricsMonitor(testLogger, 100, 0)
|
||||
|
||||
metric := TokenMetrics{
|
||||
Model: "test-model",
|
||||
CachedTokens: 100,
|
||||
InputTokens: 500,
|
||||
OutputTokens: 250,
|
||||
PromptPerSecond: 1200.5,
|
||||
TokensPerSecond: 45.8,
|
||||
DurationMs: 5000,
|
||||
Timestamp: time.Now(),
|
||||
metric := ActivityLogEntry{
|
||||
Model: "test-model",
|
||||
Tokens: TokenMetrics{
|
||||
CachedTokens: 100,
|
||||
InputTokens: 500,
|
||||
OutputTokens: 250,
|
||||
PromptPerSecond: 1200.5,
|
||||
TokensPerSecond: 45.8,
|
||||
},
|
||||
DurationMs: 5000,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
mm.addMetrics(metric)
|
||||
mm.queueMetrics(metric)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -855,14 +879,14 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
|
||||
assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("deflate encoded response", func(t *testing.T) {
|
||||
@@ -889,14 +913,14 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 200, metrics[0].InputTokens)
|
||||
assert.Equal(t, 75, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 200, metrics[0].Tokens.InputTokens)
|
||||
assert.Equal(t, 75, metrics[0].Tokens.OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("invalid gzip data records minimal metrics", func(t *testing.T) {
|
||||
@@ -917,14 +941,14 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.NoError(t, err) // Should not return error, just log warning
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("unknown encoding treated as uncompressed", func(t *testing.T) {
|
||||
@@ -944,13 +968,13 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, 300, metrics[0].InputTokens)
|
||||
assert.Equal(t, 100, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 300, metrics[0].Tokens.InputTokens)
|
||||
assert.Equal(t, 100, metrics[0].Tokens.OutputTokens)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -989,7 +1013,7 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
|
||||
mm.addCapture(capture)
|
||||
|
||||
// Should not store capture
|
||||
assert.Nil(t, mm.getCaptureByID(0, false))
|
||||
assert.Nil(t, mm.getCaptureByID(0))
|
||||
})
|
||||
|
||||
t.Run("adds capture when enabled", func(t *testing.T) {
|
||||
@@ -1002,22 +1026,18 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
|
||||
}
|
||||
mm.addCapture(capture)
|
||||
|
||||
retrieved := mm.getCaptureByID(0, true)
|
||||
assert.NotNil(t, retrieved)
|
||||
|
||||
var decoded ReqRespCapture
|
||||
err := json.Unmarshal(retrieved, &decoded)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, decoded.ID)
|
||||
assert.Equal(t, []byte("test request"), decoded.ReqBody)
|
||||
assert.Equal(t, []byte("test response"), decoded.RespBody)
|
||||
captured := mm.getCaptureByID(0)
|
||||
assert.NotNil(t, captured)
|
||||
assert.Equal(t, 0, captured.ID)
|
||||
assert.Equal(t, []byte("test request"), captured.ReqBody)
|
||||
assert.Equal(t, []byte("test response"), captured.RespBody)
|
||||
})
|
||||
|
||||
t.Run("evicts oldest when exceeding max size", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
// Each full ReqRespCapture with 80 bytes random data compresses to ~185 bytes.
|
||||
// 2 captures = ~370 bytes, 3 captures = ~555 bytes. Set limit so only 2 fit.
|
||||
mm.maxCaptureSize = 450
|
||||
mm.captureCache = cache.New(450)
|
||||
|
||||
// Use random-looking data that doesn't compress well with zstd
|
||||
rng := rand.New(rand.NewSource(42))
|
||||
@@ -1033,16 +1053,14 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
|
||||
// Adding capture3 should evict capture1
|
||||
mm.addCapture(capture3)
|
||||
|
||||
assert.Nil(t, mm.getCaptureByID(0, true), "capture 0 should be evicted")
|
||||
retrieved := mm.getCaptureByID(1, true)
|
||||
assert.NotNil(t, retrieved, "capture 1 should exist")
|
||||
retrieved = mm.getCaptureByID(2, true)
|
||||
assert.NotNil(t, retrieved, "capture 2 should exist")
|
||||
assert.Nil(t, mm.getCaptureByID(0), "capture 0 should be evicted")
|
||||
assert.NotNil(t, mm.getCaptureByID(1), "capture 1 should exist")
|
||||
assert.NotNil(t, mm.getCaptureByID(2), "capture 2 should exist")
|
||||
})
|
||||
|
||||
t.Run("skips capture larger than max size", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
mm.maxCaptureSize = 100
|
||||
mm.captureCache = cache.New(100)
|
||||
|
||||
// Use random data that doesn't compress well to create an oversized capture
|
||||
rng := rand.New(rand.NewSource(99))
|
||||
@@ -1050,7 +1068,7 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
|
||||
rng.Read(largeCapture.ReqBody)
|
||||
mm.addCapture(largeCapture)
|
||||
|
||||
assert.Nil(t, mm.getCaptureByID(0, false), "oversized capture should not be stored")
|
||||
assert.Nil(t, mm.getCaptureByID(0), "oversized capture should not be stored")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1058,7 +1076,7 @@ func TestMetricsMonitor_GetCaptureByID(t *testing.T) {
|
||||
t.Run("returns nil for non-existent ID", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
|
||||
assert.Nil(t, mm.getCaptureByID(999, false))
|
||||
assert.Nil(t, mm.getCaptureByID(999))
|
||||
})
|
||||
|
||||
t.Run("returns decompressed capture by ID", func(t *testing.T) {
|
||||
@@ -1071,18 +1089,14 @@ func TestMetricsMonitor_GetCaptureByID(t *testing.T) {
|
||||
}
|
||||
mm.addCapture(capture)
|
||||
|
||||
retrieved := mm.getCaptureByID(42, true)
|
||||
assert.NotNil(t, retrieved)
|
||||
|
||||
var decoded ReqRespCapture
|
||||
err := json.Unmarshal(retrieved, &decoded)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 42, decoded.ID)
|
||||
assert.Equal(t, []byte("test request"), decoded.ReqBody)
|
||||
assert.Equal(t, []byte("test response"), decoded.RespBody)
|
||||
captured := mm.getCaptureByID(42)
|
||||
assert.NotNil(t, captured)
|
||||
assert.Equal(t, 42, captured.ID)
|
||||
assert.Equal(t, []byte("test request"), captured.ReqBody)
|
||||
assert.Equal(t, []byte("test response"), captured.RespBody)
|
||||
})
|
||||
|
||||
t.Run("returns compressed bytes when decompress=false", func(t *testing.T) {
|
||||
t.Run("stores data as compressed bytes", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
|
||||
capture := ReqRespCapture{
|
||||
@@ -1092,10 +1106,12 @@ func TestMetricsMonitor_GetCaptureByID(t *testing.T) {
|
||||
}
|
||||
mm.addCapture(capture)
|
||||
|
||||
compressed := mm.getCaptureByID(42, false)
|
||||
compressed, exists := mm.getCompressedBytes(42)
|
||||
assert.True(t, exists)
|
||||
assert.NotNil(t, compressed)
|
||||
// Compressed data should not be valid JSON (it's zstd-compressed)
|
||||
assert.False(t, gjson.ValidBytes(compressed))
|
||||
// Compressed data should not be valid CBOR (it's zstd-compressed)
|
||||
var decoded ReqRespCapture
|
||||
assert.Error(t, cbor.Unmarshal(compressed, &decoded))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1164,7 +1180,7 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check metric was recorded
|
||||
@@ -1173,12 +1189,8 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
|
||||
metricID := metrics[0].ID
|
||||
|
||||
// Check capture was stored with same ID (decompressed)
|
||||
captureData := mm.getCaptureByID(metricID, true)
|
||||
assert.NotNil(t, captureData)
|
||||
|
||||
var capture ReqRespCapture
|
||||
err = json.Unmarshal(captureData, &capture)
|
||||
assert.NoError(t, err)
|
||||
capture := mm.getCaptureByID(metricID)
|
||||
assert.NotNil(t, capture)
|
||||
assert.Equal(t, metricID, capture.ID)
|
||||
assert.Equal(t, []byte(requestBody), capture.ReqBody)
|
||||
assert.Equal(t, []byte(responseBody), capture.RespBody)
|
||||
@@ -1206,7 +1218,7 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Metrics should still be recorded
|
||||
@@ -1214,7 +1226,168 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
|
||||
// But no capture
|
||||
capture := mm.getCaptureByID(metrics[0].ID, false)
|
||||
assert.Nil(t, capture)
|
||||
assert.Nil(t, mm.getCaptureByID(metrics[0].ID))
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_WrapHandler_PartialCaptures(t *testing.T) {
|
||||
requestBody := `{"model": "test"}`
|
||||
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("X-Custom", "header-value")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
t.Run("only request headers", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 100)
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer secret")
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqHeaders, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
|
||||
assert.NotNil(t, capture)
|
||||
assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"])
|
||||
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
|
||||
assert.Nil(t, capture.ReqBody)
|
||||
assert.Nil(t, capture.RespHeaders)
|
||||
assert.Nil(t, capture.RespBody)
|
||||
})
|
||||
|
||||
t.Run("only request body", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 100)
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqBody, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
|
||||
assert.NotNil(t, capture)
|
||||
assert.Nil(t, capture.ReqHeaders)
|
||||
assert.Equal(t, []byte(requestBody), capture.ReqBody)
|
||||
assert.Nil(t, capture.RespHeaders)
|
||||
assert.Nil(t, capture.RespBody)
|
||||
})
|
||||
|
||||
t.Run("only response headers", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 100)
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureRespHeaders, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
|
||||
assert.NotNil(t, capture)
|
||||
assert.Nil(t, capture.ReqHeaders)
|
||||
assert.Nil(t, capture.ReqBody)
|
||||
assert.Equal(t, "application/json", capture.RespHeaders["Content-Type"])
|
||||
assert.Equal(t, "header-value", capture.RespHeaders["X-Custom"])
|
||||
assert.Nil(t, capture.RespBody)
|
||||
})
|
||||
|
||||
t.Run("only response body", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 100)
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureRespBody, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
|
||||
assert.NotNil(t, capture)
|
||||
assert.Nil(t, capture.ReqHeaders)
|
||||
assert.Nil(t, capture.ReqBody)
|
||||
assert.Nil(t, capture.RespHeaders)
|
||||
assert.Equal(t, []byte(responseBody), capture.RespBody)
|
||||
})
|
||||
|
||||
t.Run("captureReqAll", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 100)
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer secret")
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqAll, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
|
||||
assert.NotNil(t, capture)
|
||||
assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"])
|
||||
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
|
||||
assert.Equal(t, []byte(requestBody), capture.ReqBody)
|
||||
assert.Nil(t, capture.RespHeaders)
|
||||
assert.Nil(t, capture.RespBody)
|
||||
})
|
||||
|
||||
t.Run("captureRespAll", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 100)
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureRespAll, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
|
||||
assert.NotNil(t, capture)
|
||||
assert.Nil(t, capture.ReqHeaders)
|
||||
assert.Nil(t, capture.ReqBody)
|
||||
assert.Equal(t, "application/json", capture.RespHeaders["Content-Type"])
|
||||
assert.Equal(t, "header-value", capture.RespHeaders["X-Custom"])
|
||||
assert.Equal(t, []byte(responseBody), capture.RespBody)
|
||||
})
|
||||
|
||||
t.Run("no flags", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 100)
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureFields(0), nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
|
||||
assert.NotNil(t, capture)
|
||||
assert.Nil(t, capture.ReqHeaders)
|
||||
assert.Nil(t, capture.ReqBody)
|
||||
assert.Nil(t, capture.RespHeaders)
|
||||
assert.Nil(t, capture.RespBody)
|
||||
})
|
||||
|
||||
t.Run("mixed flags req headers and resp body", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 100)
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer secret")
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqHeaders|captureRespBody, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
|
||||
assert.NotNil(t, capture)
|
||||
assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"])
|
||||
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
|
||||
assert.Nil(t, capture.ReqBody)
|
||||
assert.Nil(t, capture.RespHeaders)
|
||||
assert.Equal(t, []byte(responseBody), capture.RespBody)
|
||||
})
|
||||
}
|
||||
|
||||
+324
-274
@@ -332,41 +332,77 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
|
||||
// Set up routes using the Gin engine
|
||||
// Protected routes use pm.apiKeyAuth() middleware
|
||||
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
llmHandler := pm.mkProxyJSONHandler(captureAll)
|
||||
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
// Support legacy /v1/completions api, see issue #12
|
||||
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
|
||||
pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
// Support anthropic count_tokens API (Also added in the above PR)
|
||||
pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
|
||||
// Support embeddings and reranking
|
||||
pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
|
||||
// llama-server's /reranking endpoint + aliases
|
||||
pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
|
||||
// llama-server's /infill endpoint for code infilling
|
||||
pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
|
||||
// llama-server's /completion endpoint
|
||||
pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
|
||||
// Support audio/speech endpoint
|
||||
pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST(
|
||||
"/v1/audio/speech",
|
||||
pm.apiKeyAuth(),
|
||||
pm.trackInflight(),
|
||||
pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders),
|
||||
)
|
||||
pm.ginEngine.POST(
|
||||
"/v1/audio/voices",
|
||||
pm.apiKeyAuth(),
|
||||
pm.trackInflight(),
|
||||
pm.mkProxyJSONHandler(captureReqHeaders|captureRespAll),
|
||||
)
|
||||
pm.ginEngine.GET("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
|
||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler)
|
||||
pm.ginEngine.POST("/v1/images/generations", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/images/edits", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler)
|
||||
|
||||
pm.ginEngine.POST(
|
||||
"/v1/audio/transcriptions",
|
||||
pm.apiKeyAuth(),
|
||||
pm.trackInflight(),
|
||||
pm.mkPostFormHandler(captureReqHeaders|captureRespHeaders|captureRespBody),
|
||||
)
|
||||
pm.ginEngine.POST(
|
||||
"/v1/images/generations",
|
||||
pm.apiKeyAuth(),
|
||||
pm.trackInflight(),
|
||||
pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders),
|
||||
)
|
||||
|
||||
pm.ginEngine.POST(
|
||||
"/v1/images/edits",
|
||||
pm.apiKeyAuth(),
|
||||
pm.trackInflight(),
|
||||
pm.mkPostFormHandler(captureReqHeaders|captureRespHeaders),
|
||||
)
|
||||
|
||||
// sd.cpp /sdapi/v1 endpoints
|
||||
pm.ginEngine.POST("/sdapi/v1/txt2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/sdapi/v1/img2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/sdapi/v1/txt2img",
|
||||
pm.apiKeyAuth(),
|
||||
pm.trackInflight(),
|
||||
pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders),
|
||||
)
|
||||
pm.ginEngine.POST("/sdapi/v1/img2img",
|
||||
pm.apiKeyAuth(),
|
||||
pm.trackInflight(),
|
||||
pm.mkProxyJSONHandler(captureReqHeaders|captureRespHeaders),
|
||||
)
|
||||
pm.ginEngine.GET("/sdapi/v1/loras", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
|
||||
|
||||
pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler)
|
||||
@@ -686,7 +722,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
|
||||
// attempt to record metrics if it is a POST request
|
||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, handler); err != nil {
|
||||
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, captureNone, handler); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
|
||||
return
|
||||
@@ -700,280 +736,294 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
||||
return
|
||||
}
|
||||
|
||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||
return
|
||||
}
|
||||
|
||||
// Look for a matching local model first
|
||||
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||
|
||||
modelID, found := pm.config.RealModelName(requestedModel)
|
||||
if found {
|
||||
var localHandler func(string, http.ResponseWriter, *http.Request) error
|
||||
if pm.matrix != nil {
|
||||
localHandler = pm.matrix.ProxyRequest
|
||||
} else {
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
localHandler = processGroup.ProxyRequest
|
||||
func (pm *ProxyManager) mkProxyJSONHandler(cf captureFields) func(*gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
||||
return
|
||||
}
|
||||
|
||||
// issue #69 allow custom model names to be sent to upstream
|
||||
useModelName := pm.config.Models[modelID].UseModelName
|
||||
if useModelName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||
return
|
||||
}
|
||||
|
||||
// issue #174 strip parameters from the JSON body
|
||||
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams()
|
||||
if err != nil { // just log it and continue
|
||||
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error())
|
||||
} else {
|
||||
// Look for a matching local model first
|
||||
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||
|
||||
modelID, found := pm.config.RealModelName(requestedModel)
|
||||
if found {
|
||||
var localHandler func(string, http.ResponseWriter, *http.Request) error
|
||||
if pm.matrix != nil {
|
||||
localHandler = pm.matrix.ProxyRequest
|
||||
} else {
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
localHandler = processGroup.ProxyRequest
|
||||
}
|
||||
|
||||
// issue #69 allow custom model names to be sent to upstream
|
||||
useModelName := pm.config.Models[modelID].UseModelName
|
||||
if useModelName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// issue #174 strip parameters from the JSON body
|
||||
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams()
|
||||
if err != nil { // just log it and continue
|
||||
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error())
|
||||
} else {
|
||||
for _, param := range stripParams {
|
||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param)
|
||||
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// issue #453 set/override parameters in the JSON body
|
||||
setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams()
|
||||
for _, key := range setParamKeys {
|
||||
pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key)
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// setParamsByID: set params based on the requested model ID (runs after setParams, can override it)
|
||||
setParamsByIDParams, setParamsByIDKeys := pm.config.Models[modelID].Filters.SanitizedSetParamsByID(requestedModel)
|
||||
for _, key := range setParamsByIDKeys {
|
||||
pm.proxyLogger.Debugf("<%s> setting param by id: %s", requestedModel, key)
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParamsByIDParams[key])
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||
nextHandler = localHandler
|
||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||
modelID = requestedModel
|
||||
|
||||
// issue #453 apply filters for peer requests
|
||||
peerFilters := pm.peerProxy.GetPeerFilters(requestedModel)
|
||||
|
||||
// Apply stripParams - remove specified parameters from request
|
||||
stripParams := peerFilters.SanitizedStripParams()
|
||||
for _, param := range stripParams {
|
||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param)
|
||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param)
|
||||
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Apply setParams - set/override specified parameters in request
|
||||
setParams, setParamKeys := peerFilters.SanitizedSetParams()
|
||||
for _, key := range setParamKeys {
|
||||
pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key)
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
nextHandler = pm.peerProxy.ProxyRequest
|
||||
}
|
||||
|
||||
if nextHandler == nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// dechunk it as we already have all the body bytes see issue #11
|
||||
c.Request.Header.Del("transfer-encoding")
|
||||
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
||||
c.Request.ContentLength = int64(len(bodyBytes))
|
||||
|
||||
// issue #366 extract values that downstream handlers may need
|
||||
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
|
||||
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
|
||||
ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, cf, nextHandler); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mkPostFormHandler creates a POST form handler for inference backends
|
||||
// with a custom captureFields to filter out large binary requests or responses.
|
||||
func (pm *ProxyManager) mkPostFormHandler(cf captureFields) func(*gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
// Parse multipart form
|
||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Get model parameter from the form
|
||||
requestedModel := c.Request.FormValue("model")
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
|
||||
return
|
||||
}
|
||||
|
||||
// Look for a matching local model first, then check peers
|
||||
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||
var useModelName string
|
||||
|
||||
modelID, found := pm.config.RealModelName(requestedModel)
|
||||
if found {
|
||||
if pm.matrix != nil {
|
||||
nextHandler = pm.matrix.ProxyRequest
|
||||
} else {
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
nextHandler = processGroup.ProxyRequest
|
||||
}
|
||||
|
||||
useModelName = pm.config.Models[modelID].UseModelName
|
||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||
modelID = requestedModel
|
||||
nextHandler = pm.peerProxy.ProxyRequest
|
||||
}
|
||||
|
||||
if nextHandler == nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
// We need to reconstruct the multipart form in any case since the body is consumed
|
||||
// Create a new buffer for the reconstructed request
|
||||
var requestBuffer bytes.Buffer
|
||||
multipartWriter := multipart.NewWriter(&requestBuffer)
|
||||
|
||||
// Copy all form values
|
||||
for key, values := range c.Request.MultipartForm.Value {
|
||||
for _, value := range values {
|
||||
fieldValue := value
|
||||
// If this is the model field and we have a profile, use just the model name
|
||||
if key == "model" {
|
||||
// # issue #69 allow custom model names to be sent to upstream
|
||||
if useModelName != "" {
|
||||
fieldValue = useModelName
|
||||
} else {
|
||||
fieldValue = requestedModel
|
||||
}
|
||||
}
|
||||
field, err := multipartWriter.CreateFormField(key)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
|
||||
return
|
||||
}
|
||||
if _, err = field.Write([]byte(fieldValue)); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// issue #453 set/override parameters in the JSON body
|
||||
setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams()
|
||||
for _, key := range setParamKeys {
|
||||
pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key)
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// setParamsByID: set params based on the requested model ID (runs after setParams, can override it)
|
||||
setParamsByIDParams, setParamsByIDKeys := pm.config.Models[modelID].Filters.SanitizedSetParamsByID(requestedModel)
|
||||
for _, key := range setParamsByIDKeys {
|
||||
pm.proxyLogger.Debugf("<%s> setting param by id: %s", requestedModel, key)
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParamsByIDParams[key])
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||
nextHandler = localHandler
|
||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||
modelID = requestedModel
|
||||
|
||||
// issue #453 apply filters for peer requests
|
||||
peerFilters := pm.peerProxy.GetPeerFilters(requestedModel)
|
||||
|
||||
// Apply stripParams - remove specified parameters from request
|
||||
stripParams := peerFilters.SanitizedStripParams()
|
||||
for _, param := range stripParams {
|
||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param)
|
||||
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Apply setParams - set/override specified parameters in request
|
||||
setParams, setParamKeys := peerFilters.SanitizedSetParams()
|
||||
for _, key := range setParamKeys {
|
||||
pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key)
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
nextHandler = pm.peerProxy.ProxyRequest
|
||||
}
|
||||
|
||||
if nextHandler == nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// dechunk it as we already have all the body bytes see issue #11
|
||||
c.Request.Header.Del("transfer-encoding")
|
||||
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
||||
c.Request.ContentLength = int64(len(bodyBytes))
|
||||
|
||||
// issue #366 extract values that downstream handlers may need
|
||||
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
|
||||
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
|
||||
ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, nextHandler); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
// Parse multipart form
|
||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Get model parameter from the form
|
||||
requestedModel := c.Request.FormValue("model")
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
|
||||
return
|
||||
}
|
||||
|
||||
// Look for a matching local model first, then check peers
|
||||
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||
var useModelName string
|
||||
|
||||
modelID, found := pm.config.RealModelName(requestedModel)
|
||||
if found {
|
||||
if pm.matrix != nil {
|
||||
nextHandler = pm.matrix.ProxyRequest
|
||||
} else {
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
nextHandler = processGroup.ProxyRequest
|
||||
}
|
||||
|
||||
useModelName = pm.config.Models[modelID].UseModelName
|
||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||
modelID = requestedModel
|
||||
nextHandler = pm.peerProxy.ProxyRequest
|
||||
}
|
||||
|
||||
if nextHandler == nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
// We need to reconstruct the multipart form in any case since the body is consumed
|
||||
// Create a new buffer for the reconstructed request
|
||||
var requestBuffer bytes.Buffer
|
||||
multipartWriter := multipart.NewWriter(&requestBuffer)
|
||||
|
||||
// Copy all form values
|
||||
for key, values := range c.Request.MultipartForm.Value {
|
||||
for _, value := range values {
|
||||
fieldValue := value
|
||||
// If this is the model field and we have a profile, use just the model name
|
||||
if key == "model" {
|
||||
// # issue #69 allow custom model names to be sent to upstream
|
||||
if useModelName != "" {
|
||||
fieldValue = useModelName
|
||||
} else {
|
||||
fieldValue = requestedModel
|
||||
// Copy all files from the original request
|
||||
for key, fileHeaders := range c.Request.MultipartForm.File {
|
||||
for _, fileHeader := range fileHeaders {
|
||||
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
|
||||
return
|
||||
}
|
||||
}
|
||||
field, err := multipartWriter.CreateFormField(key)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
|
||||
return
|
||||
}
|
||||
if _, err = field.Write([]byte(fieldValue)); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Copy all files from the original request
|
||||
for key, fileHeaders := range c.Request.MultipartForm.File {
|
||||
for _, fileHeader := range fileHeaders {
|
||||
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
|
||||
return
|
||||
}
|
||||
file, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
|
||||
return
|
||||
}
|
||||
|
||||
file, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = io.Copy(formFile, file); err != nil {
|
||||
if _, err = io.Copy(formFile, file); err != nil {
|
||||
file.Close()
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
|
||||
return
|
||||
}
|
||||
file.Close()
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
|
||||
}
|
||||
}
|
||||
|
||||
// Close the multipart writer to finalize the form
|
||||
if err := multipartWriter.Close(); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
|
||||
return
|
||||
}
|
||||
|
||||
// Create a new request with the reconstructed form data
|
||||
modifiedReq, err := http.NewRequestWithContext(
|
||||
c.Request.Context(),
|
||||
c.Request.Method,
|
||||
c.Request.URL.String(),
|
||||
&requestBuffer,
|
||||
)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
|
||||
return
|
||||
}
|
||||
|
||||
// Copy the headers from the original request
|
||||
modifiedReq.Header = c.Request.Header.Clone()
|
||||
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
||||
|
||||
// set the content length of the body
|
||||
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
|
||||
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
||||
|
||||
// Use the modified request for proxying
|
||||
if pm.metricsMonitor != nil {
|
||||
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, modifiedReq, cf, nextHandler); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := nextHandler(modelID, c.Writer, modifiedReq); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
||||
return
|
||||
}
|
||||
file.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Close the multipart writer to finalize the form
|
||||
if err := multipartWriter.Close(); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
|
||||
return
|
||||
}
|
||||
|
||||
// Create a new request with the reconstructed form data
|
||||
modifiedReq, err := http.NewRequestWithContext(
|
||||
c.Request.Context(),
|
||||
c.Request.Method,
|
||||
c.Request.URL.String(),
|
||||
&requestBuffer,
|
||||
)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
|
||||
return
|
||||
}
|
||||
|
||||
// Copy the headers from the original request
|
||||
modifiedReq.Header = c.Request.Header.Clone()
|
||||
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
||||
|
||||
// set the content length of the body
|
||||
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
|
||||
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
||||
|
||||
// Use the modified request for proxying
|
||||
if err := nextHandler(modelID, c.Writer, modifiedReq); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) {
|
||||
|
||||
+10
-20
@@ -158,7 +158,7 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
sendMetrics := func(metrics []TokenMetrics) {
|
||||
sendMetrics := func(metrics []ActivityLogEntry) {
|
||||
jsonData, err := json.Marshal(metrics)
|
||||
if err == nil {
|
||||
select {
|
||||
@@ -205,8 +205,8 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||
/**
|
||||
* Send Metrics data
|
||||
*/
|
||||
defer event.On(func(e TokenMetricsEvent) {
|
||||
sendMetrics([]TokenMetrics{e.Metrics})
|
||||
defer event.On(func(e ActivityLogEvent) {
|
||||
sendMetrics([]ActivityLogEntry{e.Metrics})
|
||||
})()
|
||||
|
||||
/**
|
||||
@@ -290,26 +290,16 @@ func (pm *ProxyManager) apiGetCapture(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
data, exists := pm.metricsMonitor.getCompressedBytes(id)
|
||||
if !exists {
|
||||
capture := pm.metricsMonitor.getCaptureByID(id)
|
||||
if capture == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Vary", "Accept-Encoding")
|
||||
|
||||
// ¯\_(ツ)_/¯ quality weights are too fancy for us anyway
|
||||
hasZstd := strings.Contains(c.GetHeader("Accept-Encoding"), "zstd")
|
||||
|
||||
if hasZstd {
|
||||
c.Header("Content-Encoding", "zstd")
|
||||
c.Data(http.StatusOK, "application/json", data)
|
||||
} else {
|
||||
decompressed, err := decompressCapture(data)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to decompress capture"})
|
||||
return
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json", decompressed)
|
||||
jsonBytes, err := json.Marshal(capture)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to marshal capture"})
|
||||
return
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json", jsonBytes)
|
||||
}
|
||||
|
||||
@@ -1721,3 +1721,61 @@ models:
|
||||
assert.Contains(t, w.Body.String(), "could not find suitable handler")
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_AudioTranscriptionCapture(t *testing.T) {
|
||||
cfg := testConfigFromYAML(t, `
|
||||
healthCheckTimeout: 15
|
||||
logLevel: error
|
||||
captureBuffer: 5
|
||||
models:
|
||||
TheExpectedModel:
|
||||
cmd: {{RESPONDER}} --port ${PORT} --silent --respond TheExpectedModel
|
||||
`)
|
||||
|
||||
proxy := New(cfg)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
injectTestHandlers(proxy, nil)
|
||||
|
||||
var b bytes.Buffer
|
||||
w := multipart.NewWriter(&b)
|
||||
|
||||
fw, err := w.CreateFormField("model")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte("TheExpectedModel"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte("test audio content"))
|
||||
assert.NoError(t, err)
|
||||
w.Close()
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
req.Header.Set("Authorization", "Bearer mysecret")
|
||||
req.Header.Set("X-Custom-Req", "req-value")
|
||||
rec := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Verify capture exists
|
||||
metrics := proxy.metricsMonitor.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.True(t, metrics[0].HasCapture)
|
||||
|
||||
capture := proxy.metricsMonitor.getCaptureByID(metrics[0].ID)
|
||||
assert.NotNil(t, capture)
|
||||
|
||||
// Should capture request headers (sensitive ones redacted)
|
||||
assert.NotEmpty(t, capture.ReqHeaders)
|
||||
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
|
||||
assert.Equal(t, "req-value", capture.ReqHeaders["X-Custom-Req"])
|
||||
|
||||
// Should capture response headers
|
||||
assert.NotNil(t, capture.RespHeaders)
|
||||
|
||||
// Should NOT capture request bodies but get response bodies (text
|
||||
assert.Nil(t, capture.ReqBody)
|
||||
assert.NotNil(t, capture.RespBody)
|
||||
}
|
||||
|
||||
Generated
+3
-3
@@ -2788,9 +2788,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/postcss": {
|
||||
"version": "8.5.8",
|
||||
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.8.tgz",
|
||||
"integrity": "sha512-OW/rX8O/jXnm82Ey1k44pObPtdblfiuWnrd8X7GJ7emImCOstunGbXUpp7HdBrFQX6rJzn3sPT397Wp5aCwCHg==",
|
||||
"version": "8.5.12",
|
||||
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.12.tgz",
|
||||
"integrity": "sha512-W62t/Se6rA0Az3DfCL0AqJwXuKwBeYg6nOaIgzP+xZ7N5BFCI7DYi1qs6ygUYT6rvfi6t9k65UMLJC+PHZpDAA==",
|
||||
"dev": true,
|
||||
"funding": [
|
||||
{
|
||||
|
||||
@@ -9,13 +9,13 @@
|
||||
|
||||
let stats = $derived.by(() => {
|
||||
const totalRequests = $metrics.length;
|
||||
const totalInputTokens = $metrics.reduce((sum, m) => sum + m.input_tokens, 0);
|
||||
const totalOutputTokens = $metrics.reduce((sum, m) => sum + m.output_tokens, 0);
|
||||
const totalCacheTokens = $metrics.reduce((sum, m) => sum + m.cache_tokens, 0);
|
||||
const totalInputTokens = $metrics.reduce((sum, m) => sum + m.tokens.input_tokens, 0);
|
||||
const totalOutputTokens = $metrics.reduce((sum, m) => sum + m.tokens.output_tokens, 0);
|
||||
const totalCacheTokens = $metrics.reduce((sum, m) => sum + m.tokens.cache_tokens, 0);
|
||||
|
||||
const promptPerSecond = $metrics.filter((m) => m.prompt_per_second > 0).map((m) => m.prompt_per_second);
|
||||
const promptPerSecond = $metrics.filter((m) => m.tokens.prompt_per_second > 0).map((m) => m.tokens.prompt_per_second);
|
||||
|
||||
const tokensPerSecond = $metrics.filter((m) => m.tokens_per_second > 0).map((m) => m.tokens_per_second);
|
||||
const tokensPerSecond = $metrics.filter((m) => m.tokens.tokens_per_second > 0).map((m) => m.tokens.tokens_per_second);
|
||||
|
||||
const promptHistogramData =
|
||||
promptPerSecond.length > 0 ? calculateHistogramData(promptPerSecond) : null;
|
||||
|
||||
@@ -12,15 +12,22 @@ export interface Model {
|
||||
aliases?: string[];
|
||||
}
|
||||
|
||||
export interface Metrics {
|
||||
id: number;
|
||||
timestamp: string;
|
||||
model: string;
|
||||
export interface TokenMetrics {
|
||||
cache_tokens: number;
|
||||
input_tokens: number;
|
||||
output_tokens: number;
|
||||
prompt_per_second: number;
|
||||
tokens_per_second: number;
|
||||
}
|
||||
|
||||
export interface ActivityLogEntry {
|
||||
id: number;
|
||||
timestamp: string;
|
||||
model: string;
|
||||
req_path: string;
|
||||
resp_content_type: string;
|
||||
resp_status_code: number;
|
||||
tokens: TokenMetrics;
|
||||
duration_ms: number;
|
||||
has_capture: boolean;
|
||||
}
|
||||
|
||||
@@ -3,8 +3,87 @@
|
||||
import ActivityStats from "../components/ActivityStats.svelte";
|
||||
import Tooltip from "../components/Tooltip.svelte";
|
||||
import CaptureDialog from "../components/CaptureDialog.svelte";
|
||||
import { persistentStore } from "../stores/persistent";
|
||||
import { onMount } from "svelte";
|
||||
import type { ReqRespCapture } from "../lib/types";
|
||||
|
||||
type ColumnKey =
|
||||
| "id"
|
||||
| "time"
|
||||
| "model"
|
||||
| "req_path"
|
||||
| "resp_status_code"
|
||||
| "resp_content_type"
|
||||
| "cached"
|
||||
| "prompt"
|
||||
| "generated"
|
||||
| "prompt_speed"
|
||||
| "gen_speed"
|
||||
| "duration"
|
||||
| "capture";
|
||||
|
||||
interface ColumnDef {
|
||||
key: ColumnKey;
|
||||
label: string;
|
||||
defaultVisible: boolean;
|
||||
}
|
||||
|
||||
const columns: ColumnDef[] = [
|
||||
{ key: "id", label: "ID", defaultVisible: true },
|
||||
{ key: "time", label: "Time", defaultVisible: true },
|
||||
{ key: "model", label: "Model", defaultVisible: true },
|
||||
{ key: "req_path", label: "Path", defaultVisible: false },
|
||||
{ key: "resp_status_code", label: "Status", defaultVisible: false },
|
||||
{ key: "resp_content_type", label: "Content-Type", defaultVisible: false },
|
||||
{ key: "cached", label: "Cached", defaultVisible: true },
|
||||
{ key: "prompt", label: "Prompt", defaultVisible: true },
|
||||
{ key: "generated", label: "Generated", defaultVisible: true },
|
||||
{ key: "prompt_speed", label: "Prompt Speed", defaultVisible: true },
|
||||
{ key: "gen_speed", label: "Gen Speed", defaultVisible: true },
|
||||
{ key: "duration", label: "Duration", defaultVisible: true },
|
||||
{ key: "capture", label: "Capture", defaultVisible: true },
|
||||
];
|
||||
|
||||
const defaultVisibleKeys = columns.filter((c) => c.defaultVisible).map((c) => c.key);
|
||||
|
||||
const visibleColumns = persistentStore<ColumnKey[]>(
|
||||
"activity-columns",
|
||||
defaultVisibleKeys
|
||||
);
|
||||
|
||||
let columnsMenuOpen = $state(false);
|
||||
let dropdownContainer: HTMLDivElement | null = null;
|
||||
|
||||
onMount(() => {
|
||||
function handleKeydown(e: KeyboardEvent) {
|
||||
if (e.key === "Escape" && columnsMenuOpen) {
|
||||
columnsMenuOpen = false;
|
||||
}
|
||||
}
|
||||
function handleClick(e: MouseEvent) {
|
||||
if (columnsMenuOpen && dropdownContainer && !dropdownContainer.contains(e.target as Node)) {
|
||||
columnsMenuOpen = false;
|
||||
}
|
||||
}
|
||||
document.addEventListener("keydown", handleKeydown);
|
||||
document.addEventListener("click", handleClick);
|
||||
return () => {
|
||||
document.removeEventListener("keydown", handleKeydown);
|
||||
document.removeEventListener("click", handleClick);
|
||||
};
|
||||
});
|
||||
|
||||
function toggleColumn(key: ColumnKey) {
|
||||
const current = $visibleColumns;
|
||||
if (current.includes(key)) {
|
||||
if (current.length > 1) {
|
||||
visibleColumns.set(current.filter((k) => k !== key));
|
||||
}
|
||||
} else {
|
||||
visibleColumns.set([...current, key]);
|
||||
}
|
||||
}
|
||||
|
||||
function formatSpeed(speed: number): string {
|
||||
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
|
||||
}
|
||||
@@ -67,58 +146,150 @@
|
||||
<ActivityStats />
|
||||
</div>
|
||||
|
||||
<div class="card overflow-auto">
|
||||
<div class="card overflow-auto relative min-h-[30rem]">
|
||||
<div class="flex justify-end px-4" bind:this={dropdownContainer}>
|
||||
<div class="relative">
|
||||
<button
|
||||
class="w-8 h-8 flex items-center justify-center rounded hover:bg-secondary-hover transition-colors"
|
||||
onclick={() => (columnsMenuOpen = !columnsMenuOpen)}
|
||||
title="Select columns"
|
||||
>
|
||||
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 6V4m0 2a2 2 0 100 4m0-4a2 2 0 110 4m-6 8a2 2 0 100-4m0 4a2 2 0 110-4m0 4v2m0-6V4m6 6v10m6-2a2 2 0 100-4m0 4a2 2 0 110-4m0 4v2m0-6V4"></path>
|
||||
</svg>
|
||||
</button>
|
||||
{#if columnsMenuOpen}
|
||||
<div class="absolute right-0 top-full mt-1 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-10 py-1 min-w-[16rem]">
|
||||
<div class="px-3 py-2 text-xs font-medium uppercase tracking-wider text-gray-500 dark:text-gray-400 border-b border-gray-200 dark:border-white/10">
|
||||
Columns
|
||||
</div>
|
||||
{#each columns as col (col.key)}
|
||||
<label
|
||||
class="flex items-center gap-2 px-3 py-1.5 text-sm cursor-pointer hover:bg-secondary-hover transition-colors"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={$visibleColumns.includes(col.key)}
|
||||
onchange={() => toggleColumn(col.key)}
|
||||
class="rounded"
|
||||
/>
|
||||
{col.label}
|
||||
</label>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<table class="min-w-full divide-y">
|
||||
<thead class="border-gray-200 dark:border-white/10">
|
||||
<tr class="text-left text-xs uppercase tracking-wider">
|
||||
<th class="px-6 py-3">ID</th>
|
||||
<th class="px-6 py-3">Time</th>
|
||||
<th class="px-6 py-3">Model</th>
|
||||
<th class="px-6 py-3">
|
||||
Cached <Tooltip content="prompt tokens from cache" />
|
||||
</th>
|
||||
<th class="px-6 py-3">
|
||||
Prompt <Tooltip content="new prompt tokens processed" />
|
||||
</th>
|
||||
<th class="px-6 py-3">Generated</th>
|
||||
<th class="px-6 py-3">Prompt Processing</th>
|
||||
<th class="px-6 py-3">Generation Speed</th>
|
||||
<th class="px-6 py-3">Duration</th>
|
||||
<th class="px-6 py-3">Capture</th>
|
||||
{#if $visibleColumns.includes("id")}
|
||||
<th class="px-6 py-3">ID</th>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("time")}
|
||||
<th class="px-6 py-3">Time</th>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("model")}
|
||||
<th class="px-6 py-3">Model</th>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("req_path")}
|
||||
<th class="px-6 py-3">Path</th>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("resp_status_code")}
|
||||
<th class="px-6 py-3">Status</th>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("resp_content_type")}
|
||||
<th class="px-6 py-3">Content-Type</th>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("cached")}
|
||||
<th class="px-6 py-3">
|
||||
Cached <Tooltip content="prompt tokens from cache" />
|
||||
</th>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("prompt")}
|
||||
<th class="px-6 py-3">
|
||||
Prompt <Tooltip content="new prompt tokens processed" />
|
||||
</th>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("generated")}
|
||||
<th class="px-6 py-3">Generated</th>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("prompt_speed")}
|
||||
<th class="px-6 py-3">Prompt Speed</th>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("gen_speed")}
|
||||
<th class="px-6 py-3">Gen Speed</th>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("duration")}
|
||||
<th class="px-6 py-3">Duration</th>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("capture")}
|
||||
<th class="px-6 py-3">Capture</th>
|
||||
{/if}
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody class="divide-y">
|
||||
{#if sortedMetrics.length === 0}
|
||||
<tr>
|
||||
<td colspan="10" class="px-6 py-8 text-center text-sm text-gray-500 dark:text-gray-400">
|
||||
<td colspan={$visibleColumns.length} class="px-6 py-8 text-center text-sm text-gray-500 dark:text-gray-400">
|
||||
No activity recorded
|
||||
</td>
|
||||
</tr>
|
||||
{:else}
|
||||
{#each sortedMetrics as metric (metric.id)}
|
||||
<tr class="whitespace-nowrap text-sm border-gray-200 dark:border-white/10">
|
||||
<td class="px-4 py-4">{metric.id + 1}</td>
|
||||
<td class="px-6 py-4">{formatRelativeTime(metric.timestamp)}</td>
|
||||
<td class="px-6 py-4">{metric.model}</td>
|
||||
<td class="px-6 py-4">{metric.cache_tokens > 0 ? metric.cache_tokens.toLocaleString() : "-"}</td>
|
||||
<td class="px-6 py-4">{metric.input_tokens.toLocaleString()}</td>
|
||||
<td class="px-6 py-4">{metric.output_tokens.toLocaleString()}</td>
|
||||
<td class="px-6 py-4">{formatSpeed(metric.prompt_per_second)}</td>
|
||||
<td class="px-6 py-4">{formatSpeed(metric.tokens_per_second)}</td>
|
||||
<td class="px-6 py-4">{formatDuration(metric.duration_ms)}</td>
|
||||
<td class="px-6 py-4">
|
||||
{#if metric.has_capture}
|
||||
<button
|
||||
onclick={() => viewCapture(metric.id)}
|
||||
disabled={loadingCaptureId === metric.id}
|
||||
class="btn btn--sm"
|
||||
>
|
||||
{loadingCaptureId === metric.id ? "..." : "View"}
|
||||
</button>
|
||||
{:else}
|
||||
<span class="text-txtsecondary">-</span>
|
||||
{/if}
|
||||
</td>
|
||||
{#if $visibleColumns.includes("id")}
|
||||
<td class="px-4 py-4">{metric.id + 1}</td>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("time")}
|
||||
<td class="px-6 py-4">{formatRelativeTime(metric.timestamp)}</td>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("model")}
|
||||
<td class="px-6 py-4">{metric.model}</td>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("req_path")}
|
||||
<td class="px-6 py-4">{metric.req_path || "-"}</td>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("resp_status_code")}
|
||||
<td class="px-6 py-4">{metric.resp_status_code || "-"}</td>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("resp_content_type")}
|
||||
<td class="px-6 py-4">{metric.resp_content_type || "-"}</td>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("cached")}
|
||||
<td class="px-6 py-4">{metric.tokens.cache_tokens > 0 ? metric.tokens.cache_tokens.toLocaleString() : "-"}</td>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("prompt")}
|
||||
<td class="px-6 py-4">{metric.tokens.input_tokens.toLocaleString()}</td>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("generated")}
|
||||
<td class="px-6 py-4">{metric.tokens.output_tokens.toLocaleString()}</td>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("prompt_speed")}
|
||||
<td class="px-6 py-4">{formatSpeed(metric.tokens.prompt_per_second)}</td>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("gen_speed")}
|
||||
<td class="px-6 py-4">{formatSpeed(metric.tokens.tokens_per_second)}</td>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("duration")}
|
||||
<td class="px-6 py-4">{formatDuration(metric.duration_ms)}</td>
|
||||
{/if}
|
||||
{#if $visibleColumns.includes("capture")}
|
||||
<td class="px-6 py-4">
|
||||
{#if metric.has_capture}
|
||||
<button
|
||||
onclick={() => viewCapture(metric.id)}
|
||||
disabled={loadingCaptureId === metric.id}
|
||||
class="btn btn--sm"
|
||||
>
|
||||
{loadingCaptureId === metric.id ? "..." : "View"}
|
||||
</button>
|
||||
{:else}
|
||||
<span class="text-txtsecondary">-</span>
|
||||
{/if}
|
||||
</td>
|
||||
{/if}
|
||||
</tr>
|
||||
{/each}
|
||||
{/if}
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
const viewModeStore = persistentStore<ViewMode>("logviewer-view-mode", "panels");
|
||||
|
||||
let direction = $derived<"horizontal" | "vertical">(
|
||||
$screenWidth === "xs" || $screenWidth === "sm" ? "vertical" : "horizontal"
|
||||
$screenWidth === "xs" || $screenWidth === "sm" ? "vertical" : "horizontal",
|
||||
);
|
||||
</script>
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
class:bg-primary={$viewModeStore === "proxy"}
|
||||
class:text-btn-primary-text={$viewModeStore === "proxy"}
|
||||
>
|
||||
Panel
|
||||
Proxy
|
||||
</button>
|
||||
<button
|
||||
onclick={() => viewModeStore.set("upstream")}
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
import { writable } from "svelte/store";
|
||||
import type { Model, Metrics, VersionInfo, LogData, APIEventEnvelope, ReqRespCapture, InFlightStats } from "../lib/types";
|
||||
import type {
|
||||
Model,
|
||||
ActivityLogEntry,
|
||||
VersionInfo,
|
||||
LogData,
|
||||
APIEventEnvelope,
|
||||
ReqRespCapture,
|
||||
InFlightStats,
|
||||
} from "../lib/types";
|
||||
import { connectionState } from "./theme";
|
||||
|
||||
const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
|
||||
@@ -8,7 +16,7 @@ const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
|
||||
export const models = writable<Model[]>([]);
|
||||
export const proxyLogs = writable<string>("");
|
||||
export const upstreamLogs = writable<string>("");
|
||||
export const metrics = writable<Metrics[]>([]);
|
||||
export const metrics = writable<ActivityLogEntry[]>([]);
|
||||
export const inFlightRequests = writable<number>(0);
|
||||
export const versionInfo = writable<VersionInfo>({
|
||||
build_date: "unknown",
|
||||
@@ -62,7 +70,7 @@ export function enableAPIEvents(enabled: boolean): void {
|
||||
const newModels = JSON.parse(message.data) as Model[];
|
||||
// Sort models by name and id
|
||||
newModels.sort((a, b) => {
|
||||
return (a.name + a.id).localeCompare(b.name + b.id, undefined, { numeric : true} );
|
||||
return (a.name + a.id).localeCompare(b.name + b.id, undefined, { numeric: true });
|
||||
});
|
||||
models.set(newModels);
|
||||
break;
|
||||
@@ -82,7 +90,7 @@ export function enableAPIEvents(enabled: boolean): void {
|
||||
}
|
||||
|
||||
case "metrics": {
|
||||
const newMetrics = JSON.parse(message.data) as Metrics[];
|
||||
const newMetrics = JSON.parse(message.data) as ActivityLogEntry[];
|
||||
metrics.update((prevMetrics) => [...newMetrics, ...prevMetrics]);
|
||||
break;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user