mirror of
https://github.com/mostlygeek/llama-swap.git
synced 2026-06-09 06:46:34 +02:00
@@ -1,249 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||
"github.com/mostlygeek/llama-swap/internal/watcher"
|
||||
"github.com/mostlygeek/llama-swap/proxy"
|
||||
)
|
||||
|
||||
var (
|
||||
version string = "0"
|
||||
commit string = "abcd1234"
|
||||
date string = "unknown"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Define a command-line flag for the port
|
||||
configPath := flag.String("config", "config.yaml", "config file name")
|
||||
listenStr := flag.String("listen", "", "listen ip/port")
|
||||
certFile := flag.String("tls-cert-file", "", "TLS certificate file")
|
||||
keyFile := flag.String("tls-key-file", "", "TLS key file")
|
||||
showVersion := flag.Bool("version", false, "show version of build")
|
||||
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
|
||||
mainLogger := logmon.New()
|
||||
|
||||
flag.Parse() // Parse the command-line flags
|
||||
|
||||
if *showVersion {
|
||||
fmt.Printf("version: %s (%s), built at %s", version, commit, date)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
conf, err := config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
mainLogger.Errorf("Error loading config: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(conf.Profiles) > 0 {
|
||||
mainLogger.Warn("Profile functionality has been removed in favor of Groups. See the README for more information.")
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(conf.LogLevel)) {
|
||||
case "debug":
|
||||
mainLogger.SetLogLevel(logmon.LevelDebug)
|
||||
case "info":
|
||||
mainLogger.SetLogLevel(logmon.LevelInfo)
|
||||
case "warn":
|
||||
mainLogger.SetLogLevel(logmon.LevelWarn)
|
||||
case "error":
|
||||
mainLogger.SetLogLevel(logmon.LevelError)
|
||||
default:
|
||||
mainLogger.SetLogLevel(logmon.LevelInfo)
|
||||
}
|
||||
|
||||
mainLogger.Debugf("PID: %d", os.Getpid())
|
||||
|
||||
if mode := os.Getenv("GIN_MODE"); mode != "" {
|
||||
gin.SetMode(mode)
|
||||
} else {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
|
||||
// Validate TLS flags.
|
||||
var useTLS = (*certFile != "" && *keyFile != "")
|
||||
if (*certFile != "" && *keyFile == "") ||
|
||||
(*certFile == "" && *keyFile != "") {
|
||||
fmt.Println("Error: Both --tls-cert-file and --tls-key-file must be provided for TLS.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Set default ports.
|
||||
if *listenStr == "" {
|
||||
defaultPort := ":8080"
|
||||
if useTLS {
|
||||
defaultPort = ":8443"
|
||||
}
|
||||
listenStr = &defaultPort
|
||||
}
|
||||
|
||||
var mon *perf.Monitor
|
||||
if !conf.Performance.Disabled {
|
||||
mon, err = perf.New(conf.Performance, mainLogger)
|
||||
if err != nil {
|
||||
mainLogger.Errorf("failed to create monitor: %s", err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
mon.Start()
|
||||
} else {
|
||||
mainLogger.Info("performance monitoring is disabled")
|
||||
}
|
||||
|
||||
// Setup channels for server management
|
||||
exitChan := make(chan struct{})
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
|
||||
|
||||
// Context that bounds the lifetime of background watcher goroutines.
|
||||
watcherCtx, watcherCancel := context.WithCancel(context.Background())
|
||||
|
||||
// Create server with initial handlergit
|
||||
srv := &http.Server{
|
||||
Addr: *listenStr,
|
||||
}
|
||||
|
||||
// Support for watching config and reloading when it changes
|
||||
reloading := false
|
||||
var reloadMutex sync.Mutex
|
||||
reloadProxyManager := func() {
|
||||
reloadMutex.Lock()
|
||||
if reloading {
|
||||
reloadMutex.Unlock()
|
||||
return
|
||||
}
|
||||
reloading = true
|
||||
reloadMutex.Unlock()
|
||||
defer func() {
|
||||
reloadMutex.Lock()
|
||||
reloading = false
|
||||
reloadMutex.Unlock()
|
||||
}()
|
||||
|
||||
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||
mainLogger.Info("Reloading Configuration")
|
||||
conf, err = config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
mainLogger.Warnf("Unable to reload configuration: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
mainLogger.Debug("Configuration Changed")
|
||||
currentPM.Shutdown()
|
||||
if mon != nil {
|
||||
mon.UpdateConfig(conf.Performance)
|
||||
}
|
||||
newPM := proxy.New(conf)
|
||||
newPM.SetVersion(date, commit, version)
|
||||
newPM.SetPerfMonitor(mon)
|
||||
srv.Handler = newPM
|
||||
mainLogger.Debug("Configuration Reloaded")
|
||||
|
||||
// wait a few seconds and tell any UI to reload
|
||||
time.AfterFunc(3*time.Second, func() {
|
||||
event.Emit(proxy.ConfigFileChangedEvent{
|
||||
ReloadingState: proxy.ReloadingStateEnd,
|
||||
})
|
||||
})
|
||||
} else {
|
||||
conf, err = config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
mainLogger.Errorf("Unable to load configuration: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
newPM := proxy.New(conf)
|
||||
newPM.SetVersion(date, commit, version)
|
||||
newPM.SetPerfMonitor(mon)
|
||||
srv.Handler = newPM
|
||||
}
|
||||
}
|
||||
|
||||
// load the initial proxy manager
|
||||
reloadProxyManager()
|
||||
|
||||
if *watchConfig {
|
||||
go func() {
|
||||
absConfigPath, err := filepath.Abs(*configPath)
|
||||
if err != nil {
|
||||
mainLogger.Errorf("watch-config unable to determine absolute path for watching config file: %v", err)
|
||||
return
|
||||
}
|
||||
mainLogger.Info("Watching configuration for changes (poll-based, 2s interval)")
|
||||
(&configwatcher.Watcher{
|
||||
Path: absConfigPath,
|
||||
Interval: configwatcher.DefaultInterval,
|
||||
OnChange: func() {
|
||||
reloadProxyManager()
|
||||
},
|
||||
}).Run(watcherCtx)
|
||||
}()
|
||||
}
|
||||
|
||||
// Signal handling
|
||||
go func() {
|
||||
for {
|
||||
sig := <-sigChan
|
||||
switch sig {
|
||||
case syscall.SIGHUP:
|
||||
mainLogger.Debug("Received SIGHUP")
|
||||
reloadProxyManager()
|
||||
case syscall.SIGINT, syscall.SIGTERM:
|
||||
mainLogger.Debugf("Received signal %v, shutting down...", sig)
|
||||
if mon != nil {
|
||||
mon.Stop()
|
||||
}
|
||||
watcherCancel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
if pm, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||
pm.Shutdown()
|
||||
} else {
|
||||
mainLogger.Error("srv.Handler is not of type *proxy.ProxyManager")
|
||||
}
|
||||
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
mainLogger.Errorf("Server shutdown: %v", err)
|
||||
}
|
||||
close(exitChan)
|
||||
return
|
||||
default:
|
||||
// do nothing on other signals
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Start server
|
||||
go func() {
|
||||
var err error
|
||||
if useTLS {
|
||||
mainLogger.Infof("llama-swap listening with TLS on https://%s", *listenStr)
|
||||
err = srv.ListenAndServeTLS(*certFile, *keyFile)
|
||||
} else {
|
||||
mainLogger.Infof("llama-swap listening on http://%s", *listenStr)
|
||||
err = srv.ListenAndServe()
|
||||
}
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
mainLogger.Errorf("Fatal server error: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for exit signal
|
||||
<-exitChan
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
ui_dist/*
|
||||
@@ -1,27 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Custom discard writer that implements http.ResponseWriter but just discards everything
|
||||
type DiscardWriter struct {
|
||||
header http.Header
|
||||
status int
|
||||
}
|
||||
|
||||
func (w *DiscardWriter) Header() http.Header {
|
||||
if w.header == nil {
|
||||
w.header = make(http.Header)
|
||||
}
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *DiscardWriter) Write(data []byte) (int, error) {
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *DiscardWriter) WriteHeader(code int) {
|
||||
w.status = code
|
||||
}
|
||||
|
||||
// Satisfy the http.Flusher interface for streaming responses
|
||||
func (w *DiscardWriter) Flush() {}
|
||||
@@ -1,60 +0,0 @@
|
||||
package proxy
|
||||
|
||||
// package level registry of the different event types
|
||||
|
||||
const ProcessStateChangeEventID = 0x01
|
||||
const ChatCompletionStatsEventID = 0x02
|
||||
const ConfigFileChangedEventID = 0x03
|
||||
const ActivityLogEventID = 0x05
|
||||
const ModelPreloadedEventID = 0x06
|
||||
const InFlightRequestsEventID = 0x07
|
||||
|
||||
type ProcessStateChangeEvent struct {
|
||||
ProcessName string
|
||||
NewState ProcessState
|
||||
OldState ProcessState
|
||||
}
|
||||
|
||||
func (e ProcessStateChangeEvent) Type() uint32 {
|
||||
return ProcessStateChangeEventID
|
||||
}
|
||||
|
||||
type ChatCompletionStats struct {
|
||||
TokensGenerated int
|
||||
}
|
||||
|
||||
func (e ChatCompletionStats) Type() uint32 {
|
||||
return ChatCompletionStatsEventID
|
||||
}
|
||||
|
||||
type ReloadingState int
|
||||
|
||||
const (
|
||||
ReloadingStateStart ReloadingState = iota
|
||||
ReloadingStateEnd
|
||||
)
|
||||
|
||||
type ConfigFileChangedEvent struct {
|
||||
ReloadingState ReloadingState
|
||||
}
|
||||
|
||||
func (e ConfigFileChangedEvent) Type() uint32 {
|
||||
return ConfigFileChangedEventID
|
||||
}
|
||||
|
||||
type ModelPreloadedEvent struct {
|
||||
ModelName string
|
||||
Success bool
|
||||
}
|
||||
|
||||
func (e ModelPreloadedEvent) Type() uint32 {
|
||||
return ModelPreloadedEventID
|
||||
}
|
||||
|
||||
type InFlightRequestsEvent struct {
|
||||
Total int
|
||||
}
|
||||
|
||||
func (e InFlightRequestsEvent) Type() uint32 {
|
||||
return InFlightRequestsEventID
|
||||
}
|
||||
@@ -1,304 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
nextTestPort int = 12000
|
||||
portMutex sync.Mutex
|
||||
testLogger = logmon.NewWriter(os.Stdout)
|
||||
simpleResponderPath = getSimpleResponderPath()
|
||||
)
|
||||
|
||||
// Check if the binary exists
|
||||
func TestMain(m *testing.M) {
|
||||
binaryPath := getSimpleResponderPath()
|
||||
if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
|
||||
fmt.Printf("simple-responder not found at %s, did you `make simple-responder`?\n", binaryPath)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
switch os.Getenv("LOG_LEVEL") {
|
||||
case "debug":
|
||||
testLogger.SetLogLevel(logmon.LevelDebug)
|
||||
case "warn":
|
||||
testLogger.SetLogLevel(logmon.LevelWarn)
|
||||
case "info":
|
||||
testLogger.SetLogLevel(logmon.LevelInfo)
|
||||
default:
|
||||
testLogger.SetLogLevel(logmon.LevelWarn)
|
||||
}
|
||||
|
||||
m.Run()
|
||||
}
|
||||
|
||||
// Helper function to get the binary path
|
||||
func getSimpleResponderPath() string {
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
|
||||
if goos == "windows" {
|
||||
return filepath.Join("..", "build", "simple-responder.exe")
|
||||
} else {
|
||||
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||
}
|
||||
}
|
||||
|
||||
func getTestPort() int {
|
||||
portMutex.Lock()
|
||||
defer portMutex.Unlock()
|
||||
|
||||
port := nextTestPort
|
||||
nextTestPort++
|
||||
|
||||
return port
|
||||
}
|
||||
|
||||
// testConfigFromYAML substitutes {{RESPONDER}} with the simple-responder path and
|
||||
// loads through the real config pipeline (env vars, macros, port assignment, etc.)
|
||||
func testConfigFromYAML(t *testing.T, yamlTmpl string) config.Config {
|
||||
t.Helper()
|
||||
yamlStr := strings.ReplaceAll(yamlTmpl, "{{RESPONDER}}", filepath.ToSlash(simpleResponderPath))
|
||||
cfg, err := config.LoadConfigFromReader(strings.NewReader(yamlStr))
|
||||
require.NoError(t, err)
|
||||
return cfg
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
|
||||
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig {
|
||||
// Convert path to forward slashes for cross-platform compatibility
|
||||
// Windows handles forward slashes in paths correctly
|
||||
cmdPath := filepath.ToSlash(simpleResponderPath)
|
||||
|
||||
// Create a YAML string with just the values we want to set
|
||||
yamlStr := fmt.Sprintf(`
|
||||
cmd: '%s --port %d --silent --respond %s'
|
||||
proxy: "http://127.0.0.1:%d"
|
||||
`, cmdPath, port, expectedMessage, port)
|
||||
|
||||
var cfg config.ModelConfig
|
||||
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
|
||||
panic(fmt.Sprintf("failed to unmarshal test config: %v in [%s]", err, yamlStr))
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// injectTestHandlers sets a testHandler on every Process in every ProcessGroup
|
||||
// of the given ProxyManager, bypassing subprocess launches. modelResponses maps
|
||||
// model IDs to their respond strings; if a model ID is not in the map, the model
|
||||
// ID itself is used.
|
||||
func injectTestHandlers(pm *ProxyManager, modelResponses map[string]string) {
|
||||
for _, pg := range pm.processGroups {
|
||||
for modelID, process := range pg.processes {
|
||||
respond := modelID
|
||||
if r, ok := modelResponses[modelID]; ok {
|
||||
respond = r
|
||||
}
|
||||
process.testHandler = newTestHandler(respond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newTestHandler returns an http.Handler that mimics simple-responder's API.
|
||||
// It supports the endpoints that routing tests depend on, without launching
|
||||
// any subprocess or binding any port.
|
||||
func respondJSON(w http.ResponseWriter, respond string, bodyBytes []byte) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"responseMessage": respond,
|
||||
"h_content_length": strconv.Itoa(len(bodyBytes)),
|
||||
"request_body": string(bodyBytes),
|
||||
"usage": map[string]any{
|
||||
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||
},
|
||||
"timings": map[string]any{
|
||||
"prompt_n": 25, "prompt_ms": 13, "predicted_n": 10,
|
||||
"predicted_ms": 17, "predicted_per_second": 10,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func newTestHandler(respond string) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
isStreaming := r.URL.Query().Get("stream") == "true"
|
||||
|
||||
if wait := r.URL.Query().Get("wait"); wait != "" {
|
||||
if d, err := time.ParseDuration(wait); err == nil {
|
||||
time.Sleep(d)
|
||||
}
|
||||
}
|
||||
|
||||
if isStreaming {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
flusher := w.(http.Flusher)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
data, _ := json.Marshal(map[string]any{
|
||||
"created": time.Now().Unix(),
|
||||
"choices": []map[string]any{
|
||||
{"index": 0, "delta": map[string]any{"content": "asdf"}, "finish_reason": nil},
|
||||
},
|
||||
})
|
||||
fmt.Fprintf(w, "event: message\ndata: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
finalData, _ := json.Marshal(map[string]any{
|
||||
"usage": map[string]any{
|
||||
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||
},
|
||||
"timings": map[string]any{
|
||||
"prompt_n": 25, "prompt_ms": 13, "predicted_n": 10,
|
||||
"predicted_ms": 17, "predicted_per_second": 10,
|
||||
},
|
||||
})
|
||||
fmt.Fprintf(w, "event: message\ndata: %s\n\n", finalData)
|
||||
flusher.Flush()
|
||||
|
||||
fmt.Fprintf(w, "event: message\ndata: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
} else {
|
||||
respondJSON(w, respond, bodyBytes)
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
if modelName != respond {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, respond)})
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]string{"message": "ok"})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
respondJSON(w, respond, bodyBytes)
|
||||
})
|
||||
|
||||
for _, path := range []string{
|
||||
"/chat/completions", "/completions",
|
||||
"/responses", "/messages", "/messages/count_tokens",
|
||||
"/embeddings", "/rerank", "/reranking",
|
||||
} {
|
||||
mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
respondJSON(w, respond, bodyBytes)
|
||||
})
|
||||
}
|
||||
|
||||
mux.HandleFunc("/completion", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"responseMessage": respond,
|
||||
"usage": map[string]any{
|
||||
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/audio/transcriptions", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseMultipartForm(10 << 20); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
|
||||
return
|
||||
}
|
||||
model := r.FormValue("model")
|
||||
if model == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Missing model parameter"})
|
||||
return
|
||||
}
|
||||
file, _, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error getting file: %s", err)})
|
||||
return
|
||||
}
|
||||
fileBytes, _ := io.ReadAll(file)
|
||||
file.Close()
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"text": fmt.Sprintf("The length of the file is %d bytes", len(fileBytes)),
|
||||
"model": model,
|
||||
"h_content_type": r.Header.Get("Content-Type"),
|
||||
"h_content_length": r.Header.Get("Content-Length"),
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/audio/voices", func(w http.ResponseWriter, r *http.Request) {
|
||||
model := r.URL.Query().Get("model")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"voices": []string{"voice1"}, "model": model,
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
fmt.Fprint(w, respond)
|
||||
})
|
||||
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
fmt.Fprintf(w, "%s %s", r.Method, r.URL.Path)
|
||||
})
|
||||
|
||||
mux.HandleFunc("/sdapi/v1/txt2img", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"model": modelName, "images": []string{},
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/sdapi/v1/img2img", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"model": modelName, "images": []string{},
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/sdapi/v1/loras", func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"loras": []string{},
|
||||
})
|
||||
})
|
||||
|
||||
return mux
|
||||
}
|
||||
-330
@@ -1,330 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
// MatrixSolver contains pure swap-decision logic with no Process dependencies.
|
||||
// It is safe for concurrent reads after construction.
|
||||
type MatrixSolver struct {
|
||||
expandedSets []config.ExpandedSet // all valid model combinations
|
||||
evictCosts map[string]int // real model name -> eviction cost (default 1)
|
||||
modelToSets map[string][]int // model name -> indices into expandedSets
|
||||
}
|
||||
|
||||
// NewMatrixSolver builds a solver from expanded sets and eviction costs.
|
||||
func NewMatrixSolver(expandedSets []config.ExpandedSet, evictCosts map[string]int) *MatrixSolver {
|
||||
modelToSets := make(map[string][]int)
|
||||
for i, es := range expandedSets {
|
||||
for _, model := range es.Models {
|
||||
modelToSets[model] = append(modelToSets[model], i)
|
||||
}
|
||||
}
|
||||
|
||||
return &MatrixSolver{
|
||||
expandedSets: expandedSets,
|
||||
evictCosts: evictCosts,
|
||||
modelToSets: modelToSets,
|
||||
}
|
||||
}
|
||||
|
||||
// SolveResult describes what the solver decided.
|
||||
type SolveResult struct {
|
||||
Evict []string // running models that must be stopped
|
||||
TargetSet []string // the chosen set of models (for informational purposes)
|
||||
SetName string // name of the chosen set
|
||||
DSL string // original DSL expression for the chosen set
|
||||
TotalCost int // total eviction cost
|
||||
}
|
||||
|
||||
// Solve determines which models to evict when a model is requested.
|
||||
//
|
||||
// Algorithm:
|
||||
// 1. If requestedModel is already running, no eviction needed.
|
||||
// 2. Find all sets containing requestedModel.
|
||||
// 3. If no sets found, the model runs alone; evict all running models.
|
||||
// 4. For each candidate set, compute cost = sum of evict_costs for running
|
||||
// models NOT in that set.
|
||||
// 5. Pick lowest cost. Ties broken by definition order (index in expandedSets).
|
||||
// 6. Return models to evict and the chosen set.
|
||||
func (s *MatrixSolver) Solve(requestedModel string, runningModels []string) (SolveResult, error) {
|
||||
// If already running, nothing to do (but fill in set info for logging)
|
||||
if slices.Contains(runningModels, requestedModel) {
|
||||
setName, dsl := s.findMatchingSet(requestedModel, runningModels)
|
||||
return SolveResult{
|
||||
TargetSet: runningModels,
|
||||
SetName: setName,
|
||||
DSL: dsl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
candidateIndices := s.modelToSets[requestedModel]
|
||||
|
||||
// Model not in any set: runs alone, evict everything
|
||||
if len(candidateIndices) == 0 {
|
||||
evict := make([]string, len(runningModels))
|
||||
copy(evict, runningModels)
|
||||
return SolveResult{
|
||||
Evict: evict,
|
||||
TargetSet: []string{requestedModel},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Find the cheapest candidate set
|
||||
bestCost := -1
|
||||
bestIdx := -1
|
||||
|
||||
for _, idx := range candidateIndices {
|
||||
setModels := s.expandedSets[idx].Models
|
||||
cost := 0
|
||||
for _, running := range runningModels {
|
||||
if !slices.Contains(setModels, running) {
|
||||
cost += s.evictCost(running)
|
||||
}
|
||||
}
|
||||
|
||||
if bestCost < 0 || cost < bestCost || (cost == bestCost && idx < bestIdx) {
|
||||
bestCost = cost
|
||||
bestIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
// Determine which running models to evict
|
||||
chosen := s.expandedSets[bestIdx]
|
||||
var evict []string
|
||||
for _, running := range runningModels {
|
||||
if !slices.Contains(chosen.Models, running) {
|
||||
evict = append(evict, running)
|
||||
}
|
||||
}
|
||||
|
||||
return SolveResult{
|
||||
Evict: evict,
|
||||
TargetSet: chosen.Models,
|
||||
SetName: chosen.SetName,
|
||||
DSL: chosen.DSL,
|
||||
TotalCost: bestCost,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// findMatchingSet finds the expanded set that contains all running models.
|
||||
// Returns the set name and DSL, or empty strings if no match.
|
||||
func (s *MatrixSolver) findMatchingSet(requestedModel string, runningModels []string) (string, string) {
|
||||
for _, idx := range s.modelToSets[requestedModel] {
|
||||
set := s.expandedSets[idx]
|
||||
allInSet := true
|
||||
for _, m := range runningModels {
|
||||
if !slices.Contains(set.Models, m) {
|
||||
allInSet = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allInSet {
|
||||
return set.SetName, set.DSL
|
||||
}
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func (s *MatrixSolver) evictCost(model string) int {
|
||||
if cost, ok := s.evictCosts[model]; ok {
|
||||
return cost
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// Matrix manages processes using solver-based swap logic.
|
||||
type Matrix struct {
|
||||
sync.Mutex
|
||||
solver *MatrixSolver
|
||||
processes map[string]*Process // all processes keyed by real model name
|
||||
config config.Config
|
||||
proxyLogger *logmon.Monitor
|
||||
upstreamLogger *logmon.Monitor
|
||||
|
||||
// inflight tracks ProxyRequest calls that have released m.Lock but may
|
||||
// not yet have incremented Process.inFlightRequests. A concurrent
|
||||
// request that needs to evict models waits for inflight to drain under
|
||||
// m.Lock before stopping anything. Without this, a request that
|
||||
// released m.Lock but has not yet reached Process.inFlightRequests.Add(1)
|
||||
// races with Stop()'s Wait() and can be killed mid-request.
|
||||
inflight sync.WaitGroup
|
||||
|
||||
// testDelayFastPath is a test-only hook invoked in the no-eviction path
|
||||
// after m.Lock is released but before the request is dispatched to
|
||||
// Process.ProxyRequest. Tests use it to park a request at the exact
|
||||
// race window to deterministically reproduce the race.
|
||||
testDelayFastPath func()
|
||||
}
|
||||
|
||||
// NewMatrix creates a Matrix from config. It creates a Process for every
|
||||
// model defined in the config (any model can run alone even if not in a set).
|
||||
func NewMatrix(cfg config.Config, proxyLogger, upstreamLogger *logmon.Monitor) *Matrix {
|
||||
processes := make(map[string]*Process)
|
||||
for modelID, modelConfig := range cfg.Models {
|
||||
processLogger := logmon.NewWriter(upstreamLogger)
|
||||
process := NewProcess(modelID, cfg.HealthCheckTimeout, modelConfig, processLogger, proxyLogger)
|
||||
processes[modelID] = process
|
||||
}
|
||||
|
||||
evictCosts := cfg.Matrix.ResolvedEvictCosts()
|
||||
|
||||
return &Matrix{
|
||||
solver: NewMatrixSolver(cfg.ExpandedSets, evictCosts),
|
||||
processes: processes,
|
||||
config: cfg,
|
||||
proxyLogger: proxyLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
}
|
||||
}
|
||||
|
||||
// ProxyRequest handles the swap logic and proxies the request to the model.
|
||||
func (m *Matrix) ProxyRequest(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
process, ok := m.processes[modelID]
|
||||
if !ok {
|
||||
return fmt.Errorf("model %s not found in matrix", modelID)
|
||||
}
|
||||
|
||||
m.Lock()
|
||||
running := m.runningModels()
|
||||
result, err := m.solver.Solve(modelID, running)
|
||||
if err != nil {
|
||||
m.Unlock()
|
||||
return fmt.Errorf("matrix solver error: %w", err)
|
||||
}
|
||||
|
||||
// Log solver decision
|
||||
if len(result.Evict) > 0 {
|
||||
m.proxyLogger.Infof("Matrix: model=%s set=%s dsl=%q evict=%v target=%v cost=%d",
|
||||
modelID, result.SetName, result.DSL, result.Evict, result.TargetSet, result.TotalCost)
|
||||
} else if len(running) == 0 {
|
||||
m.proxyLogger.Infof("Matrix: model=%s starting (no models running)", modelID)
|
||||
} else {
|
||||
m.proxyLogger.Debugf("Matrix: model=%s already running in set=%s dsl=%q", modelID, result.SetName, result.DSL)
|
||||
}
|
||||
|
||||
// Evict models that need to be stopped
|
||||
if len(result.Evict) > 0 {
|
||||
// Wait for any in-flight ProxyRequest calls to register on their
|
||||
// Process before stopping anything. Without this, a request that
|
||||
// released m.Lock but has not yet incremented
|
||||
// Process.inFlightRequests races with Stop() and can be killed
|
||||
// mid-request.
|
||||
m.inflight.Wait()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, evictModel := range result.Evict {
|
||||
if p, exists := m.processes[evictModel]; exists {
|
||||
wg.Add(1)
|
||||
go func(p *Process) {
|
||||
defer wg.Done()
|
||||
p.Stop()
|
||||
}(p)
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Register this request in inflight before releasing m.Lock so a
|
||||
// concurrent eviction will wait for it to complete.
|
||||
m.inflight.Add(1)
|
||||
defer m.inflight.Done()
|
||||
isFastPath := len(result.Evict) == 0
|
||||
m.Unlock()
|
||||
|
||||
if isFastPath && m.testDelayFastPath != nil {
|
||||
m.testDelayFastPath()
|
||||
}
|
||||
|
||||
// Proxy the request (Process handles on-demand start)
|
||||
process.ProxyRequest(w, r)
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopProcesses stops all running processes.
|
||||
func (m *Matrix) StopProcesses(strategy StopStrategy) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range m.processes {
|
||||
wg.Add(1)
|
||||
go func(p *Process) {
|
||||
defer wg.Done()
|
||||
switch strategy {
|
||||
case StopImmediately:
|
||||
p.StopImmediately()
|
||||
default:
|
||||
p.Stop()
|
||||
}
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// StopProcess stops a single process by model ID.
|
||||
func (m *Matrix) StopProcess(modelID string, strategy StopStrategy) error {
|
||||
process, ok := m.processes[modelID]
|
||||
if !ok {
|
||||
return fmt.Errorf("process not found for %s", modelID)
|
||||
}
|
||||
|
||||
switch strategy {
|
||||
case StopImmediately:
|
||||
process.StopImmediately()
|
||||
default:
|
||||
process.Stop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown shuts down all processes.
|
||||
func (m *Matrix) Shutdown() {
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range m.processes {
|
||||
wg.Add(1)
|
||||
go func(p *Process) {
|
||||
defer wg.Done()
|
||||
p.Shutdown()
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// RunningModels returns model names currently in an active (non-stopped) state.
|
||||
func (m *Matrix) RunningModels() []string {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return m.runningModels()
|
||||
}
|
||||
|
||||
// runningModels returns running model names (caller must hold lock).
|
||||
func (m *Matrix) runningModels() []string {
|
||||
var running []string
|
||||
for id, process := range m.processes {
|
||||
if process.CurrentState() != StateStopped && process.CurrentState() != StateShutdown {
|
||||
running = append(running, id)
|
||||
}
|
||||
}
|
||||
sort.Strings(running)
|
||||
return running
|
||||
}
|
||||
|
||||
// GetProcess returns the Process for a model.
|
||||
func (m *Matrix) GetProcess(modelID string) (*Process, bool) {
|
||||
p, ok := m.processes[modelID]
|
||||
return p, ok
|
||||
}
|
||||
|
||||
// HasModel returns true if the model is managed by this matrix.
|
||||
func (m *Matrix) HasModel(modelID string) bool {
|
||||
_, ok := m.processes[modelID]
|
||||
return ok
|
||||
}
|
||||
@@ -1,349 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Helper to build expanded sets for solver tests
|
||||
func makeExpandedSets(sets ...struct {
|
||||
name string
|
||||
models []string
|
||||
}) []config.ExpandedSet {
|
||||
var result []config.ExpandedSet
|
||||
for _, s := range sets {
|
||||
result = append(result, config.ExpandedSet{
|
||||
SetName: s.name,
|
||||
Models: s.models,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func es(name string, models ...string) struct {
|
||||
name string
|
||||
models []string
|
||||
} {
|
||||
return struct {
|
||||
name string
|
||||
models []string
|
||||
}{name, models}
|
||||
}
|
||||
|
||||
func TestMatrixSolver_AlreadyRunning(t *testing.T) {
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(es("s1", "a", "b")),
|
||||
nil,
|
||||
)
|
||||
|
||||
result, err := solver.Solve("a", []string{"a"})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"a"}, result.TargetSet)
|
||||
assert.Equal(t, "s1", result.SetName)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_NotInAnySet_RunsAlone(t *testing.T) {
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(es("s1", "a", "b")),
|
||||
nil,
|
||||
)
|
||||
|
||||
// Model "c" not in any set
|
||||
result, err := solver.Solve("c", []string{"a", "b"})
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"a", "b"}, result.Evict)
|
||||
assert.Equal(t, []string{"c"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_NotInAnySet_NothingRunning(t *testing.T) {
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(es("s1", "a", "b")),
|
||||
nil,
|
||||
)
|
||||
|
||||
result, err := solver.Solve("c", []string{})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"c"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_SingleSet_EvictsNonMembers(t *testing.T) {
|
||||
// Set: [a, b]. Request a when b and c are running.
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(es("s1", "a", "b")),
|
||||
nil,
|
||||
)
|
||||
|
||||
result, err := solver.Solve("a", []string{"b", "c"})
|
||||
require.NoError(t, err)
|
||||
// c is not in the set, so it gets evicted. b is in the set, so it stays.
|
||||
assert.Equal(t, []string{"c"}, result.Evict)
|
||||
assert.Equal(t, []string{"a", "b"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_PicksLowestCost(t *testing.T) {
|
||||
// Two sets containing model "a":
|
||||
// s1: [a, v] — if v is running, cost=0; if L is running, cost=30
|
||||
// s2: [a, L] — if L is running, cost=0; if v is running, cost=50
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(
|
||||
es("s1", "a", "v"),
|
||||
es("s2", "a", "L"),
|
||||
),
|
||||
map[string]int{"v": 50, "L": 30},
|
||||
)
|
||||
|
||||
// v is running. Switching to a:
|
||||
// s1 cost: v is in s1, so 0
|
||||
// s2 cost: v is NOT in s2, so 50
|
||||
// => pick s1
|
||||
result, err := solver.Solve("a", []string{"v"})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"a", "v"}, result.TargetSet)
|
||||
|
||||
// L is running. Switching to a:
|
||||
// s1 cost: L is NOT in s1, so 30
|
||||
// s2 cost: L is in s2, so 0
|
||||
// => pick s2
|
||||
result, err = solver.Solve("a", []string{"L"})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"a", "L"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_TieBreakingByDefinitionOrder(t *testing.T) {
|
||||
// Two sets with identical cost. Definition order should win.
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(
|
||||
es("s1", "a", "x"),
|
||||
es("s2", "a", "y"),
|
||||
),
|
||||
nil,
|
||||
)
|
||||
|
||||
// Nothing running, both sets cost 0. s1 is first.
|
||||
result, err := solver.Solve("a", []string{})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"a", "x"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_EvictCostPreservesExpensive(t *testing.T) {
|
||||
// Model "v" costs 50 to evict, "m" costs 1 (default).
|
||||
// Sets: [g,v], [g,m]
|
||||
// Running: v, m. Request g.
|
||||
// s1=[g,v]: evict m (cost 1), keep v
|
||||
// s2=[g,m]: evict v (cost 50), keep m
|
||||
// => pick s1
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(
|
||||
es("s1", "g", "v"),
|
||||
es("s2", "g", "m"),
|
||||
),
|
||||
map[string]int{"v": 50},
|
||||
)
|
||||
|
||||
result, err := solver.Solve("g", []string{"v", "m"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"m"}, result.Evict)
|
||||
assert.Equal(t, []string{"g", "v"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_NothingRunning(t *testing.T) {
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(
|
||||
es("s1", "g", "v"),
|
||||
es("s2", "q", "v"),
|
||||
),
|
||||
nil,
|
||||
)
|
||||
|
||||
result, err := solver.Solve("g", []string{})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"g", "v"}, result.TargetSet)
|
||||
}
|
||||
|
||||
// TestMatrix_ProxyRequestSwapRaceAgainstFastPath verifies that an eviction
|
||||
// cannot stop a process while an in-flight ProxyRequest for that process is
|
||||
// still in the [m.Unlock, Process.inFlightRequests.Add(1)] window. Without
|
||||
// matrix-level inflight tracking, the eviction's Stop() races with the
|
||||
// pending request and kills it mid-start.
|
||||
func TestMatrix_ProxyRequestSwapRaceAgainstFastPath(t *testing.T) {
|
||||
cfg := config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
ExpandedSets: []config.ExpandedSet{
|
||||
{SetName: "s1", Models: []string{"model1"}},
|
||||
{SetName: "s2", Models: []string{"model2"}},
|
||||
},
|
||||
Matrix: &config.MatrixConfig{},
|
||||
}
|
||||
|
||||
m := NewMatrix(cfg, testLogger, testLogger)
|
||||
defer m.StopProcesses(StopImmediately)
|
||||
|
||||
// Bypass real subprocesses so the test is fast and deterministic.
|
||||
m.processes["model1"].testHandler = newTestHandler("model1")
|
||||
m.processes["model2"].testHandler = newTestHandler("model2")
|
||||
|
||||
// Prime: run a request through model1 so it reaches StateReady and
|
||||
// subsequent requests take the no-eviction path.
|
||||
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
primeW := httptest.NewRecorder()
|
||||
require.NoError(t, m.ProxyRequest("model1", primeW, primeReq))
|
||||
require.Equal(t, http.StatusOK, primeW.Code)
|
||||
require.Equal(t, StateReady, m.processes["model1"].CurrentState())
|
||||
require.Equal(t, StateStopped, m.processes["model2"].CurrentState())
|
||||
|
||||
// Install fast-path hook that signals arrival and waits for release.
|
||||
// This parks R2 at the race window — after m.Lock is released but
|
||||
// before Process.inFlightRequests.Add(1).
|
||||
r2Reached := make(chan struct{})
|
||||
r2Release := make(chan struct{})
|
||||
m.testDelayFastPath = func() {
|
||||
close(r2Reached)
|
||||
<-r2Release
|
||||
}
|
||||
|
||||
// R2: no-eviction request for model1. Will pause at the hook.
|
||||
r2Done := make(chan struct{})
|
||||
w2 := httptest.NewRecorder()
|
||||
go func() {
|
||||
defer close(r2Done)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
assert.NoError(t, m.ProxyRequest("model1", w2, req))
|
||||
}()
|
||||
|
||||
// Deterministically wait for R2 to reach the race window.
|
||||
<-r2Reached
|
||||
|
||||
// R3: request for model2 which requires evicting model1. Must wait for
|
||||
// R2 to finish before touching model1.
|
||||
r3Done := make(chan struct{})
|
||||
w3 := httptest.NewRecorder()
|
||||
go func() {
|
||||
defer close(r3Done)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
assert.NoError(t, m.ProxyRequest("model2", w3, req))
|
||||
}()
|
||||
|
||||
// Spin until R3 has acquired m.Lock and entered the eviction path. In
|
||||
// the fixed code, R3 then blocks on m.inflight.Wait() while still
|
||||
// holding the lock, so TryLock keeps failing.
|
||||
for m.TryLock() {
|
||||
m.Unlock()
|
||||
runtime.Gosched()
|
||||
}
|
||||
|
||||
// Bounded poll: give R3 a chance to demonstrate the bug by mutating
|
||||
// state. In the fixed code R3 is blocked and nothing changes; in the
|
||||
// buggy code R3 will Stop() model1 and start model2 within microseconds.
|
||||
deadline := time.Now().Add(100 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if m.processes["model1"].CurrentState() != StateReady ||
|
||||
m.processes["model2"].CurrentState() != StateStopped {
|
||||
break
|
||||
}
|
||||
done := false
|
||||
select {
|
||||
case <-r3Done:
|
||||
done = true
|
||||
default:
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
runtime.Gosched()
|
||||
}
|
||||
|
||||
// Invariant: R3 must be blocked while R2 is still in flight.
|
||||
select {
|
||||
case <-r3Done:
|
||||
t.Fatal("eviction completed while in-flight request was still pending — race not prevented")
|
||||
default:
|
||||
}
|
||||
assert.Equal(t, StateReady, m.processes["model1"].CurrentState(),
|
||||
"model1 must stay Ready while an in-flight request is pending")
|
||||
assert.Equal(t, StateStopped, m.processes["model2"].CurrentState(),
|
||||
"model2 must not be started until R2 finishes and model1 is evicted")
|
||||
|
||||
// Release R2 and let both requests finish.
|
||||
close(r2Release)
|
||||
<-r2Done
|
||||
<-r3Done
|
||||
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Contains(t, w2.Body.String(), "model1")
|
||||
assert.Equal(t, http.StatusOK, w3.Code)
|
||||
assert.Contains(t, w3.Body.String(), "model2")
|
||||
}
|
||||
|
||||
func TestMatrixSolver_FullScenario(t *testing.T) {
|
||||
// Simulates the example config:
|
||||
// standard: [g,v], [q,v], [m,v]
|
||||
// with_rerank: [g,v,e], [q,v,e]
|
||||
// creative: [g,sd], [q,sd]
|
||||
// full: [L]
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(
|
||||
es("standard", "g", "v"),
|
||||
es("standard", "q", "v"),
|
||||
es("standard", "m", "v"),
|
||||
es("with_rerank", "e", "g", "v"),
|
||||
es("with_rerank", "e", "q", "v"),
|
||||
es("creative", "g", "sd"),
|
||||
es("creative", "q", "sd"),
|
||||
es("full", "L"),
|
||||
),
|
||||
map[string]int{"v": 50, "L": 30, "whisper": 10},
|
||||
)
|
||||
|
||||
// Running: g, v. Request q.
|
||||
// standard[q,v]: evict g (cost 1), keep v. Total: 1.
|
||||
// with_rerank[q,v,e]: evict g (cost 1), keep v. Total: 1.
|
||||
// => tie, pick first by definition order = standard[q,v]
|
||||
result, err := solver.Solve("q", []string{"g", "v"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"g"}, result.Evict)
|
||||
assert.Equal(t, []string{"q", "v"}, result.TargetSet)
|
||||
|
||||
// Running: g, v. Request L.
|
||||
// full[L]: evict g (cost 1) + v (cost 50). Total: 51.
|
||||
// Only one set contains L, so pick it.
|
||||
result, err = solver.Solve("L", []string{"g", "v"})
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"g", "v"}, result.Evict)
|
||||
assert.Equal(t, []string{"L"}, result.TargetSet)
|
||||
|
||||
// Running: g, v. Request sd.
|
||||
// creative[g,sd]: evict v (cost 50). Total: 50.
|
||||
// creative[q,sd]: evict g (cost 1) + v (cost 50). Total: 51.
|
||||
// => pick creative[g,sd]
|
||||
result, err = solver.Solve("sd", []string{"g", "v"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"v"}, result.Evict)
|
||||
assert.Equal(t, []string{"g", "sd"}, result.TargetSet)
|
||||
|
||||
// Running: q, v, e. Request g.
|
||||
// standard[g,v]: evict q (1) + e (1). Total: 2.
|
||||
// with_rerank[g,v,e]: evict q (1). Total: 1.
|
||||
// creative[g,sd]: evict q (1) + v (50) + e (1). Total: 52.
|
||||
// => pick with_rerank[g,v,e]
|
||||
result, err = solver.Solve("g", []string{"e", "q", "v"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"q"}, result.Evict)
|
||||
assert.Equal(t, []string{"e", "g", "v"}, result.TargetSet)
|
||||
}
|
||||
@@ -1,689 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/mostlygeek/llama-swap/internal/cache"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/ring"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// zstdEncOptions are the shared zstd encoder options for maximum compression.
|
||||
var zstdEncOptions = []zstd.EOption{
|
||||
zstd.WithEncoderLevel(zstd.SpeedBetterCompression),
|
||||
}
|
||||
|
||||
// zstdDecOptions are the shared zstd decoder options.
|
||||
var zstdDecOptions = []zstd.DOption{}
|
||||
|
||||
// zstdEncPool pools zstd.Encoder instances to reduce allocations.
|
||||
var zstdEncPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
enc, _ := zstd.NewWriter(nil, zstdEncOptions...)
|
||||
return enc
|
||||
},
|
||||
}
|
||||
|
||||
// zstdDecPool pools zstd.Decoder instances to reduce allocations.
|
||||
var zstdDecPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
dec, _ := zstd.NewReader(nil, zstdDecOptions...)
|
||||
return dec
|
||||
},
|
||||
}
|
||||
|
||||
// compressCapture marshals a ReqRespCapture to CBOR and compresses it with zstd.
|
||||
// Returns compressed bytes and the original CBOR byte count for logging.
|
||||
func compressCapture(c *ReqRespCapture) ([]byte, int, error) {
|
||||
cborBytes, err := cbor.Marshal(c)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("marshal capture: %w", err)
|
||||
}
|
||||
zenc := zstdEncPool.Get().(*zstd.Encoder)
|
||||
defer zstdEncPool.Put(zenc)
|
||||
return zenc.EncodeAll(cborBytes, nil), len(cborBytes), nil
|
||||
}
|
||||
|
||||
// decompressCapture decompresses zstd-compressed CBOR and unmarshals it into a ReqRespCapture.
|
||||
func decompressCapture(data []byte) (*ReqRespCapture, error) {
|
||||
dec := zstdDecPool.Get().(*zstd.Decoder)
|
||||
defer zstdDecPool.Put(dec)
|
||||
cborBytes, err := dec.DecodeAll(data, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decompress capture: %w", err)
|
||||
}
|
||||
var capture ReqRespCapture
|
||||
if err := cbor.Unmarshal(cborBytes, &capture); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal capture: %w", err)
|
||||
}
|
||||
return &capture, nil
|
||||
}
|
||||
|
||||
// TokenMetrics holds token usage and performance metrics
|
||||
type TokenMetrics struct {
|
||||
CachedTokens int `json:"cache_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||
}
|
||||
|
||||
// ActivityLogEntry represents parsed token statistics from llama-server logs
|
||||
type ActivityLogEntry struct {
|
||||
ID int `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Model string `json:"model"`
|
||||
ReqPath string `json:"req_path"`
|
||||
RespContentType string `json:"resp_content_type"`
|
||||
RespStatusCode int `json:"resp_status_code"`
|
||||
Tokens TokenMetrics `json:"tokens"`
|
||||
DurationMs int `json:"duration_ms"`
|
||||
HasCapture bool `json:"has_capture"`
|
||||
}
|
||||
|
||||
type ReqRespCapture struct {
|
||||
ID int `json:"id"`
|
||||
ReqPath string `json:"req_path"`
|
||||
ReqHeaders map[string]string `json:"req_headers"`
|
||||
ReqBody []byte `json:"req_body"`
|
||||
RespHeaders map[string]string `json:"resp_headers"`
|
||||
RespBody []byte `json:"resp_body"`
|
||||
}
|
||||
|
||||
// ActivityLogEvent represents a token metrics event
|
||||
type ActivityLogEvent struct {
|
||||
Metrics ActivityLogEntry
|
||||
}
|
||||
|
||||
func (e ActivityLogEvent) Type() uint32 {
|
||||
return ActivityLogEventID // defined in events.go
|
||||
}
|
||||
|
||||
// metricsMonitor parses llama-server output for token statistics
|
||||
type metricsMonitor struct {
|
||||
mu sync.RWMutex
|
||||
metrics ring.Buffer[ActivityLogEntry]
|
||||
nextID int
|
||||
logger *logmon.Monitor
|
||||
|
||||
// capture fields
|
||||
enableCaptures bool
|
||||
captureCache *cache.Cache // zstd-compressed CBOR of ReqRespCapture
|
||||
}
|
||||
|
||||
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
|
||||
// capture buffer size in megabytes; 0 disables captures.
|
||||
func newMetricsMonitor(logger *logmon.Monitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
|
||||
mm := &metricsMonitor{
|
||||
logger: logger,
|
||||
metrics: ring.NewBuffer[ActivityLogEntry](maxMetrics),
|
||||
enableCaptures: captureBufferMB > 0,
|
||||
}
|
||||
if captureBufferMB > 0 {
|
||||
mm.captureCache = cache.New(captureBufferMB * 1024 * 1024)
|
||||
}
|
||||
return mm
|
||||
}
|
||||
|
||||
// queueMetrics adds a new metric to the collection without emitting an event.
|
||||
// Returns the assigned metric ID. Call emitMetric after capture setup.
|
||||
func (mp *metricsMonitor) queueMetrics(metric ActivityLogEntry) int {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
metric.ID = mp.nextID
|
||||
mp.nextID++
|
||||
mp.metrics.Push(metric)
|
||||
return metric.ID
|
||||
}
|
||||
|
||||
// emitMetric publishes an ActivityLogEvent for the given metric.
|
||||
func (mp *metricsMonitor) emitMetric(metric ActivityLogEntry) {
|
||||
event.Emit(ActivityLogEvent{Metrics: metric})
|
||||
}
|
||||
|
||||
// addCapture compresses and stores a capture in the cache.
|
||||
// Returns true if the capture was stored, false otherwise.
|
||||
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) bool {
|
||||
if !mp.enableCaptures {
|
||||
return false
|
||||
}
|
||||
|
||||
compressed, uncompressedBytes, err := compressCapture(&capture)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("failed to compress capture: %v, skipping", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := mp.captureCache.Add(capture.ID, compressed); err != nil {
|
||||
mp.logger.Warnf("capture %d too large (%d bytes), skipping: %v", capture.ID, len(compressed), err)
|
||||
return false
|
||||
}
|
||||
|
||||
compressionRatio := (1 - float64(len(compressed))/float64(uncompressedBytes)) * 100
|
||||
mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio)
|
||||
return true
|
||||
}
|
||||
|
||||
// getCompressedBytes returns the raw compressed bytes for a capture by ID.
|
||||
func (mp *metricsMonitor) getCompressedBytes(id int) ([]byte, bool) {
|
||||
if mp.captureCache == nil {
|
||||
return nil, false
|
||||
}
|
||||
data, err := mp.captureCache.Get(id)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return data, true
|
||||
}
|
||||
|
||||
// getCaptureByID decompresses and unmarshals a capture by ID.
|
||||
// Returns nil if the capture is not found or decompression fails.
|
||||
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture {
|
||||
if mp.captureCache == nil {
|
||||
return nil
|
||||
}
|
||||
data, exists := mp.getCompressedBytes(id)
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
capture, err := decompressCapture(data)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("failed to decompress capture %d: %v", id, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return capture
|
||||
}
|
||||
|
||||
// getMetrics returns a copy of the current metrics with HasCapture resolved from cache.
|
||||
func (mp *metricsMonitor) getMetrics() []ActivityLogEntry {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
result := mp.metrics.Slice()
|
||||
if result == nil {
|
||||
return []ActivityLogEntry{}
|
||||
}
|
||||
if mp.captureCache != nil {
|
||||
for i := range result {
|
||||
result[i].HasCapture = mp.captureCache.Has(result[i].ID)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// getMetricsJSON returns metrics as JSON with HasCapture resolved from cache.
|
||||
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
result := mp.metrics.Slice()
|
||||
if result == nil {
|
||||
return json.Marshal([]ActivityLogEntry{})
|
||||
}
|
||||
if mp.captureCache != nil {
|
||||
for i := range result {
|
||||
result[i].HasCapture = mp.captureCache.Has(result[i].ID)
|
||||
}
|
||||
}
|
||||
return json.Marshal(result)
|
||||
}
|
||||
|
||||
// Capture field flags for controlling what is saved in ReqRespCapture.
|
||||
type captureFields uint
|
||||
|
||||
const (
|
||||
captureNone captureFields = 1 << iota
|
||||
captureReqHeaders
|
||||
captureReqBody
|
||||
captureRespHeaders
|
||||
captureRespBody
|
||||
)
|
||||
|
||||
const (
|
||||
captureReqAll = captureReqHeaders | captureReqBody
|
||||
captureRespAll = captureRespHeaders | captureRespBody
|
||||
captureAll = captureReqAll | captureRespAll
|
||||
)
|
||||
|
||||
// wrapHandler wraps the proxy handler to extract token metrics.
|
||||
// captureFields controls what is saved in the ReqRespCapture using bitwise flags.
|
||||
// if wrapHandler returns an error it is safe to assume that no
|
||||
// data was sent to the client
|
||||
func (mp *metricsMonitor) wrapHandler(
|
||||
modelID string,
|
||||
writer gin.ResponseWriter,
|
||||
request *http.Request,
|
||||
captureFields captureFields,
|
||||
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
|
||||
) error {
|
||||
// Capture request body and headers if captures enabled
|
||||
var reqBody []byte
|
||||
var reqHeaders map[string]string
|
||||
if mp.enableCaptures && (captureFields&captureReqBody) != 0 {
|
||||
if request.Body != nil {
|
||||
var err error
|
||||
reqBody, err = io.ReadAll(request.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read request body for capture: %w", err)
|
||||
}
|
||||
request.Body.Close()
|
||||
request.Body = io.NopCloser(bytes.NewBuffer(reqBody))
|
||||
}
|
||||
}
|
||||
if mp.enableCaptures && (captureFields&captureReqHeaders) != 0 {
|
||||
reqHeaders = make(map[string]string)
|
||||
for key, values := range request.Header {
|
||||
if len(values) > 0 {
|
||||
reqHeaders[key] = values[0]
|
||||
}
|
||||
}
|
||||
redactHeaders(reqHeaders)
|
||||
}
|
||||
|
||||
recorder := newBodyCopier(writer)
|
||||
|
||||
// Filter Accept-Encoding to only include encodings we can decompress for metrics
|
||||
if ae := request.Header.Get("Accept-Encoding"); ae != "" {
|
||||
request.Header.Set("Accept-Encoding", filterAcceptEncoding(ae))
|
||||
}
|
||||
|
||||
if err := next(modelID, recorder, request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// after this point we have to assume that data was sent to the client
|
||||
// and we can only log errors but not send them to clients
|
||||
|
||||
// Initialize default metrics - recorded for every request
|
||||
tm := ActivityLogEntry{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
ReqPath: request.URL.Path,
|
||||
RespContentType: recorder.Header().Get("Content-Type"),
|
||||
RespStatusCode: recorder.Status(),
|
||||
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
||||
}
|
||||
|
||||
if recorder.Status() != http.StatusOK {
|
||||
mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), request.URL.Path)
|
||||
tm.ID = mp.queueMetrics(tm)
|
||||
mp.emitMetric(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
body := recorder.body.Bytes()
|
||||
if len(body) == 0 {
|
||||
mp.logger.Warn("metrics: empty body, recording minimal metrics")
|
||||
tm.ID = mp.queueMetrics(tm)
|
||||
mp.emitMetric(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decompress if needed
|
||||
if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" {
|
||||
var err error
|
||||
body, err = decompressBody(body, encoding)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
tm.ID = mp.queueMetrics(tm)
|
||||
mp.emitMetric(tm)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
tm.Tokens = parsed.Tokens
|
||||
tm.DurationMs = parsed.DurationMs
|
||||
}
|
||||
} else {
|
||||
if gjson.ValidBytes(body) {
|
||||
parsed := gjson.ParseBytes(body)
|
||||
usage := parsed.Get("usage")
|
||||
timings := parsed.Get("timings")
|
||||
|
||||
// extract timings for infill - response is an array, timings are in the last element
|
||||
// see #463
|
||||
if strings.HasPrefix(request.URL.Path, "/infill") {
|
||||
if arr := parsed.Array(); len(arr) > 0 {
|
||||
timings = arr[len(arr)-1].Get("timings")
|
||||
}
|
||||
}
|
||||
|
||||
if usage.Exists() || timings.Exists() {
|
||||
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
||||
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
tm.Tokens = parsedMetrics.Tokens
|
||||
tm.DurationMs = parsedMetrics.DurationMs
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", request.URL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
// Build capture if enabled and determine if it will be stored
|
||||
var capture *ReqRespCapture
|
||||
if mp.enableCaptures {
|
||||
var respHeaders map[string]string
|
||||
var respBody []byte
|
||||
if (captureFields & captureRespHeaders) != 0 {
|
||||
respHeaders = make(map[string]string)
|
||||
for key, values := range recorder.Header() {
|
||||
if len(values) > 0 {
|
||||
respHeaders[key] = values[0]
|
||||
}
|
||||
}
|
||||
redactHeaders(respHeaders)
|
||||
delete(respHeaders, "Content-Encoding")
|
||||
}
|
||||
if (captureFields & captureRespBody) != 0 {
|
||||
respBody = body
|
||||
}
|
||||
capture = &ReqRespCapture{
|
||||
ReqPath: request.URL.Path,
|
||||
ReqHeaders: reqHeaders,
|
||||
ReqBody: reqBody,
|
||||
RespHeaders: respHeaders,
|
||||
RespBody: respBody,
|
||||
}
|
||||
}
|
||||
|
||||
metricID := mp.queueMetrics(tm)
|
||||
tm.ID = metricID
|
||||
|
||||
// Store capture if enabled
|
||||
if capture != nil {
|
||||
capture.ID = metricID
|
||||
if mp.addCapture(*capture) {
|
||||
tm.HasCapture = true
|
||||
}
|
||||
}
|
||||
|
||||
mp.emitMetric(tm)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// usagePaths lists the JSON paths where a per-event usage object can live.
|
||||
// v1/chat/completions puts it at top-level "usage"; v1/responses nests under
|
||||
// "response.usage"; v1/messages emits it at "message.usage" on message_start
|
||||
// and at "usage" on message_delta.
|
||||
var usagePaths = []string{"usage", "response.usage", "message.usage"}
|
||||
|
||||
// extractUsageTokens reads input/output/cached token counts from a usage
|
||||
// gjson.Result, handling the field-name differences across endpoints.
|
||||
// cached returns -1 when the field is absent. ok is true when at least one
|
||||
// field was present.
|
||||
func extractUsageTokens(usage gjson.Result) (input, output, cached int64, ok bool) {
|
||||
cached = -1
|
||||
if !usage.Exists() {
|
||||
return
|
||||
}
|
||||
|
||||
if v := usage.Get("prompt_tokens"); v.Exists() {
|
||||
// v1/chat/completions
|
||||
input = v.Int()
|
||||
ok = true
|
||||
} else if v := usage.Get("input_tokens"); v.Exists() {
|
||||
// v1/messages, v1/responses
|
||||
input = v.Int()
|
||||
ok = true
|
||||
}
|
||||
|
||||
if v := usage.Get("completion_tokens"); v.Exists() {
|
||||
// v1/chat/completions
|
||||
output = v.Int()
|
||||
ok = true
|
||||
} else if v := usage.Get("output_tokens"); v.Exists() {
|
||||
// v1/messages, v1/responses
|
||||
output = v.Int()
|
||||
ok = true
|
||||
}
|
||||
|
||||
if v := usage.Get("cache_read_input_tokens"); v.Exists() {
|
||||
// v1/messages (Anthropic)
|
||||
cached = v.Int()
|
||||
ok = true
|
||||
} else if v := usage.Get("input_tokens_details.cached_tokens"); v.Exists() {
|
||||
// v1/responses (OpenAI Responses API)
|
||||
cached = v.Int()
|
||||
ok = true
|
||||
} else if v := usage.Get("prompt_tokens_details.cached_tokens"); v.Exists() {
|
||||
// v1/chat/completions (OpenAI cache hits)
|
||||
cached = v.Int()
|
||||
ok = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func processStreamingResponse(modelID string, start time.Time, body []byte) (ActivityLogEntry, error) {
|
||||
// Walk SSE "data:" lines forward, merging usage info from every event.
|
||||
// Different endpoints split usage across events:
|
||||
// - v1/chat/completions: usage on the final chunk before [DONE]
|
||||
// - v1/responses: usage on response.completed (response.usage)
|
||||
// - v1/messages: input + cache on message_start (message.usage),
|
||||
// output_tokens on message_delta (usage)
|
||||
// We take the latest informative value per field so all three are covered.
|
||||
|
||||
var (
|
||||
inputTokens, outputTokens int64
|
||||
cachedTokens int64 = -1
|
||||
hasAny bool
|
||||
timings gjson.Result
|
||||
)
|
||||
|
||||
prefix := []byte("data:")
|
||||
for offset := 0; offset < len(body); {
|
||||
nl := bytes.IndexByte(body[offset:], '\n')
|
||||
var line []byte
|
||||
if nl == -1 {
|
||||
line = body[offset:]
|
||||
offset = len(body)
|
||||
} else {
|
||||
line = body[offset : offset+nl]
|
||||
offset += nl + 1
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, prefix) {
|
||||
continue
|
||||
}
|
||||
data := bytes.TrimSpace(line[len(prefix):])
|
||||
if len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
if !gjson.ValidBytes(data) {
|
||||
continue
|
||||
}
|
||||
parsed := gjson.ParseBytes(data)
|
||||
|
||||
for _, path := range usagePaths {
|
||||
u := parsed.Get(path)
|
||||
if !u.Exists() {
|
||||
continue
|
||||
}
|
||||
i, o, c, ok := extractUsageTokens(u)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
hasAny = true
|
||||
// Take the latest non-zero value so message_start's input_tokens
|
||||
// is preserved when message_delta's usage omits it, and vice versa
|
||||
// for output_tokens.
|
||||
if i > 0 {
|
||||
inputTokens = i
|
||||
}
|
||||
if o > 0 {
|
||||
outputTokens = o
|
||||
}
|
||||
if c >= 0 {
|
||||
cachedTokens = c
|
||||
}
|
||||
}
|
||||
if t := parsed.Get("timings"); t.Exists() {
|
||||
timings = t
|
||||
hasAny = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasAny {
|
||||
return ActivityLogEntry{}, fmt.Errorf("no valid JSON data found in stream")
|
||||
}
|
||||
|
||||
return buildMetrics(modelID, start, inputTokens, outputTokens, cachedTokens, timings), nil
|
||||
}
|
||||
|
||||
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (ActivityLogEntry, error) {
|
||||
input, output, cached, _ := extractUsageTokens(usage)
|
||||
return buildMetrics(modelID, start, input, output, cached, timings), nil
|
||||
}
|
||||
|
||||
// buildMetrics composes an ActivityLogEntry from accumulated token counts and
|
||||
// optional llama-server timings (which override input/output and provide rates).
|
||||
func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, cachedTokens int64, timings gjson.Result) ActivityLogEntry {
|
||||
wallDurationMs := int(time.Since(start).Milliseconds())
|
||||
durationMs := wallDurationMs
|
||||
tokensPerSecond := -1.0
|
||||
promptPerSecond := -1.0
|
||||
|
||||
if timings.Exists() {
|
||||
inputTokens = timings.Get("prompt_n").Int()
|
||||
outputTokens = timings.Get("predicted_n").Int()
|
||||
promptPerSecond = timings.Get("prompt_per_second").Float()
|
||||
tokensPerSecond = timings.Get("predicted_per_second").Float()
|
||||
timingsDurationMs := int(timings.Get("prompt_ms").Float() + timings.Get("predicted_ms").Float())
|
||||
if timingsDurationMs > durationMs {
|
||||
durationMs = timingsDurationMs
|
||||
}
|
||||
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
|
||||
cachedTokens = cachedValue.Int()
|
||||
}
|
||||
}
|
||||
|
||||
return ActivityLogEntry{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
Tokens: TokenMetrics{
|
||||
CachedTokens: int(cachedTokens),
|
||||
InputTokens: int(inputTokens),
|
||||
OutputTokens: int(outputTokens),
|
||||
PromptPerSecond: promptPerSecond,
|
||||
TokensPerSecond: tokensPerSecond,
|
||||
},
|
||||
DurationMs: durationMs,
|
||||
}
|
||||
}
|
||||
|
||||
// decompressBody decompresses the body based on Content-Encoding header
|
||||
func decompressBody(body []byte, encoding string) ([]byte, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(encoding)) {
|
||||
case "gzip":
|
||||
reader, err := gzip.NewReader(bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
case "deflate":
|
||||
reader := flate.NewReader(bytes.NewReader(body))
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
default:
|
||||
return body, nil // Return as-is for unknown/no encoding
|
||||
}
|
||||
}
|
||||
|
||||
// responseBodyCopier records the response body and writes to the original response writer
|
||||
// while also capturing it in a buffer for later processing
|
||||
type responseBodyCopier struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
tee io.Writer
|
||||
start time.Time
|
||||
}
|
||||
|
||||
func newBodyCopier(w gin.ResponseWriter) *responseBodyCopier {
|
||||
bodyBuffer := &bytes.Buffer{}
|
||||
return &responseBodyCopier{
|
||||
ResponseWriter: w,
|
||||
body: bodyBuffer,
|
||||
tee: io.MultiWriter(w, bodyBuffer),
|
||||
start: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) Write(b []byte) (int, error) {
|
||||
return w.tee.Write(b)
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) WriteHeader(statusCode int) {
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) Header() http.Header {
|
||||
return w.ResponseWriter.Header()
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) StartTime() time.Time {
|
||||
return w.start
|
||||
}
|
||||
|
||||
// sensitiveHeaders lists headers that should be redacted in captures
|
||||
var sensitiveHeaders = map[string]bool{
|
||||
"authorization": true,
|
||||
"proxy-authorization": true,
|
||||
"cookie": true,
|
||||
"set-cookie": true,
|
||||
"x-api-key": true,
|
||||
}
|
||||
|
||||
// redactHeaders replaces sensitive header values in-place with "[REDACTED]"
|
||||
func redactHeaders(headers map[string]string) {
|
||||
for key := range headers {
|
||||
if sensitiveHeaders[strings.ToLower(key)] {
|
||||
headers[key] = "[REDACTED]"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// filterAcceptEncoding filters the Accept-Encoding header to only include
|
||||
// encodings we can decompress (gzip, deflate). This respects the client's
|
||||
// preferences while ensuring we can parse response bodies for metrics.
|
||||
func filterAcceptEncoding(acceptEncoding string) string {
|
||||
if acceptEncoding == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
supported := map[string]bool{"gzip": true, "deflate": true}
|
||||
var filtered []string
|
||||
|
||||
for part := range strings.SplitSeq(acceptEncoding, ",") {
|
||||
// Parse encoding and optional quality value (e.g., "gzip;q=1.0")
|
||||
encoding, _, _ := strings.Cut(strings.TrimSpace(part), ";")
|
||||
if supported[strings.ToLower(encoding)] {
|
||||
filtered = append(filtered, strings.TrimSpace(part))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(filtered, ", ")
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,144 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
type peerProxyMember struct {
|
||||
peerID string
|
||||
reverseProxy *httputil.ReverseProxy
|
||||
apiKey string
|
||||
}
|
||||
|
||||
type PeerProxy struct {
|
||||
peers config.PeerDictionaryConfig
|
||||
proxyMap map[string]*peerProxyMember
|
||||
}
|
||||
|
||||
func NewPeerProxy(peers config.PeerDictionaryConfig, proxyLogger *logmon.Monitor) (*PeerProxy, error) {
|
||||
proxyMap := make(map[string]*peerProxyMember)
|
||||
|
||||
// Sort peer IDs for consistent iteration order
|
||||
peerIDs := make([]string, 0, len(peers))
|
||||
for peerID := range peers {
|
||||
peerIDs = append(peerIDs, peerID)
|
||||
}
|
||||
sort.Strings(peerIDs)
|
||||
|
||||
for _, peerID := range peerIDs {
|
||||
peer := peers[peerID]
|
||||
|
||||
// Create a transport with per-peer timeout configuration
|
||||
peerTransport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: time.Duration(peer.Timeouts.Connect) * time.Second,
|
||||
KeepAlive: time.Duration(peer.Timeouts.KeepAlive) * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: time.Duration(peer.Timeouts.TLSHandshake) * time.Second,
|
||||
ResponseHeaderTimeout: time.Duration(peer.Timeouts.ResponseHeader) * time.Second,
|
||||
ExpectContinueTimeout: time.Duration(peer.Timeouts.ExpectContinue) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second,
|
||||
}
|
||||
|
||||
// Create reverse proxy for this peer
|
||||
reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL)
|
||||
reverseProxy.Transport = peerTransport
|
||||
|
||||
// Wrap Director to set Host header for remote hosts (not localhost)
|
||||
originalDirector := reverseProxy.Director
|
||||
reverseProxy.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
// Ensure Host header matches target URL for remote proxying
|
||||
req.Host = req.URL.Host
|
||||
}
|
||||
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
resp.Header.Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
proxyLogger.Warnf("peer %s: proxy error: %v", peerID, err)
|
||||
errMsg := fmt.Sprintf("peer proxy error: %v", err)
|
||||
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
|
||||
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
|
||||
}
|
||||
http.Error(w, errMsg, http.StatusBadGateway)
|
||||
}
|
||||
|
||||
pp := &peerProxyMember{
|
||||
peerID: peerID,
|
||||
reverseProxy: reverseProxy,
|
||||
apiKey: peer.ApiKey,
|
||||
}
|
||||
|
||||
// Map each model to this peer's proxy
|
||||
for _, modelID := range peer.Models {
|
||||
if _, found := proxyMap[modelID]; found {
|
||||
proxyLogger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
|
||||
continue
|
||||
}
|
||||
proxyMap[modelID] = pp
|
||||
}
|
||||
}
|
||||
|
||||
return &PeerProxy{
|
||||
peers: peers,
|
||||
proxyMap: proxyMap,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PeerProxy) HasPeerModel(modelID string) bool {
|
||||
_, found := p.proxyMap[modelID]
|
||||
return found
|
||||
}
|
||||
|
||||
// GetPeerFilters returns the filters for a peer model, or empty filters if not found
|
||||
func (p *PeerProxy) GetPeerFilters(modelID string) config.Filters {
|
||||
pp, found := p.proxyMap[modelID]
|
||||
if !found {
|
||||
return config.Filters{}
|
||||
}
|
||||
// Get the peer config using the peerID
|
||||
peer, found := p.peers[pp.peerID]
|
||||
if !found {
|
||||
return config.Filters{}
|
||||
}
|
||||
return peer.Filters
|
||||
}
|
||||
|
||||
func (p *PeerProxy) ListPeers() config.PeerDictionaryConfig {
|
||||
return p.peers
|
||||
}
|
||||
|
||||
func (p *PeerProxy) ProxyRequest(model_id string, writer http.ResponseWriter, request *http.Request) error {
|
||||
pp, found := p.proxyMap[model_id]
|
||||
if !found {
|
||||
return fmt.Errorf("no peer proxy found for model %s", model_id)
|
||||
}
|
||||
|
||||
// Inject API key if configured for this peer
|
||||
if pp.apiKey != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+pp.apiKey)
|
||||
request.Header.Set("x-api-key", pp.apiKey)
|
||||
}
|
||||
|
||||
pp.reverseProxy.ServeHTTP(writer, request)
|
||||
return nil
|
||||
}
|
||||
@@ -1,311 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewPeerProxy_EmptyPeers(t *testing.T) {
|
||||
peers := config.PeerDictionaryConfig{}
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, pm)
|
||||
assert.Empty(t, pm.proxyMap)
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_SinglePeer(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "test-key",
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pm.proxyMap, 2)
|
||||
assert.True(t, pm.HasPeerModel("model-a"))
|
||||
assert.True(t, pm.HasPeerModel("model-b"))
|
||||
assert.False(t, pm.HasPeerModel("model-c"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_MultiplePeers(t *testing.T) {
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
"peer2": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"model-c", "model-d"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pm.proxyMap, 4)
|
||||
assert.True(t, pm.HasPeerModel("model-a"))
|
||||
assert.True(t, pm.HasPeerModel("model-b"))
|
||||
assert.True(t, pm.HasPeerModel("model-c"))
|
||||
assert.True(t, pm.HasPeerModel("model-d"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_DuplicateModelWarning(t *testing.T) {
|
||||
// When the same model is in multiple peers, only the first (lexicographically by peer ID)
|
||||
// should be mapped, and a warning should be logged
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"alpha-peer": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
"beta-peer": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
// Should only have one entry for the duplicate model
|
||||
assert.Len(t, pm.proxyMap, 1)
|
||||
assert.True(t, pm.HasPeerModel("duplicate-model"))
|
||||
}
|
||||
|
||||
func TestHasPeerModel(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"existing-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, pm.HasPeerModel("existing-model"))
|
||||
assert.False(t, pm.HasPeerModel("non-existing-model"))
|
||||
}
|
||||
|
||||
func TestProxyRequest_ModelNotFound(t *testing.T) {
|
||||
peers := config.PeerDictionaryConfig{}
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("non-existing-model", w, req)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no peer proxy found for model non-existing-model")
|
||||
}
|
||||
|
||||
func TestProxyRequest_Success(t *testing.T) {
|
||||
// Create a test server to act as the peer
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("response from peer"))
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "response from peer", w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyRequest_ApiKeyInjection(t *testing.T) {
|
||||
// Create a test server that checks for the Authorization header
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "secret-api-key",
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Bearer secret-api-key", receivedAuthHeader)
|
||||
}
|
||||
|
||||
func TestProxyRequest_NoApiKey(t *testing.T) {
|
||||
// Create a test server that checks for the Authorization header
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "", // No API key
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, receivedAuthHeader)
|
||||
}
|
||||
|
||||
func TestProxyRequest_HostHeaderSet(t *testing.T) {
|
||||
// Create a test server that checks the Host header
|
||||
var receivedHost string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedHost = r.Host
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
// The Host header should be set to the target URL's host
|
||||
assert.True(t, strings.HasPrefix(receivedHost, "127.0.0.1:"))
|
||||
}
|
||||
|
||||
func TestProxyRequest_SSEHeaderModification(t *testing.T) {
|
||||
// Create a test server that returns SSE content type
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
// The X-Accel-Buffering header should be set to "no" for SSE
|
||||
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_CustomTimeouts(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://localhost:8080")
|
||||
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"test-peer": config.PeerConfig{
|
||||
Proxy: "http://localhost:8080",
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"model1"},
|
||||
Timeouts: config.TimeoutsConfig{
|
||||
Connect: 45,
|
||||
ResponseHeader: 300,
|
||||
TLSHandshake: 15,
|
||||
ExpectContinue: 2,
|
||||
IdleConn: 120,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
peerProxy, err := NewPeerProxy(peers, testLogger)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, peerProxy)
|
||||
assert.True(t, peerProxy.HasPeerModel("model1"))
|
||||
|
||||
// Verify the timeout values are actually applied to the transport
|
||||
member, found := peerProxy.proxyMap["model1"]
|
||||
require.True(t, found, "model1 should exist in proxyMap")
|
||||
assert.NotNil(t, member.reverseProxy)
|
||||
assert.NotNil(t, member.reverseProxy.Transport)
|
||||
|
||||
transport, ok := member.reverseProxy.Transport.(*http.Transport)
|
||||
require.True(t, ok, "Transport should be *http.Transport")
|
||||
|
||||
// Verify all timeout values are correctly applied
|
||||
assert.Equal(t, 300*time.Second, transport.ResponseHeaderTimeout)
|
||||
assert.Equal(t, 15*time.Second, transport.TLSHandshakeTimeout)
|
||||
assert.Equal(t, 2*time.Second, transport.ExpectContinueTimeout)
|
||||
assert.Equal(t, 120*time.Second, transport.IdleConnTimeout)
|
||||
// ForceAttemptHTTP2 should be enabled
|
||||
assert.True(t, transport.ForceAttemptHTTP2)
|
||||
}
|
||||
@@ -1,956 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
type ProcessState string
|
||||
|
||||
const (
|
||||
StateStopped ProcessState = ProcessState("stopped")
|
||||
StateStarting ProcessState = ProcessState("starting")
|
||||
StateReady ProcessState = ProcessState("ready")
|
||||
StateStopping ProcessState = ProcessState("stopping")
|
||||
|
||||
// process is shutdown and will not be restarted
|
||||
StateShutdown ProcessState = ProcessState("shutdown")
|
||||
)
|
||||
|
||||
type StopStrategy int
|
||||
|
||||
const (
|
||||
StopImmediately StopStrategy = iota
|
||||
StopWaitForInflightRequest
|
||||
)
|
||||
|
||||
type Process struct {
|
||||
ID string
|
||||
config config.ModelConfig
|
||||
cmd *exec.Cmd
|
||||
reverseProxy *httputil.ReverseProxy
|
||||
|
||||
// PR #155 called to cancel the upstream process
|
||||
cmdMutex sync.RWMutex
|
||||
cancelUpstream context.CancelFunc
|
||||
|
||||
// closed when command exits
|
||||
cmdWaitChan chan struct{}
|
||||
|
||||
processLogger *logmon.Monitor
|
||||
proxyLogger *logmon.Monitor
|
||||
|
||||
healthCheckTimeout int
|
||||
healthCheckLoopInterval time.Duration
|
||||
|
||||
lastRequestHandledMutex sync.RWMutex
|
||||
lastRequestHandled time.Time
|
||||
|
||||
stateMutex sync.RWMutex
|
||||
state ProcessState
|
||||
|
||||
inFlightRequests sync.WaitGroup
|
||||
inFlightRequestsCount atomic.Int32
|
||||
|
||||
// used to block on multiple start() calls
|
||||
waitStarting sync.WaitGroup
|
||||
|
||||
// for managing concurrency limits
|
||||
concurrencyLimitSemaphore chan struct{}
|
||||
|
||||
// used for testing to override the default value
|
||||
gracefulStopTimeout time.Duration
|
||||
|
||||
// used for testing to bypass subprocess and reverse proxy
|
||||
testHandler http.Handler
|
||||
|
||||
// track the number of failed starts
|
||||
failedStartCount int
|
||||
}
|
||||
|
||||
func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, processLogger *logmon.Monitor, proxyLogger *logmon.Monitor) *Process {
|
||||
concurrentLimit := 10
|
||||
if config.ConcurrencyLimit > 0 {
|
||||
concurrentLimit = config.ConcurrencyLimit
|
||||
}
|
||||
|
||||
// Setup the reverse proxy.
|
||||
proxyURL, err := url.Parse(config.Proxy)
|
||||
if err != nil {
|
||||
proxyLogger.Errorf("<%s> invalid proxy URL %q: %v", ID, config.Proxy, err)
|
||||
}
|
||||
|
||||
var reverseProxy *httputil.ReverseProxy
|
||||
if proxyURL != nil {
|
||||
reverseProxy = httputil.NewSingleHostReverseProxy(proxyURL)
|
||||
|
||||
// Create custom transport with configured timeouts
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: time.Duration(config.Timeouts.Connect) * time.Second,
|
||||
KeepAlive: time.Duration(config.Timeouts.KeepAlive) * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: time.Duration(config.Timeouts.TLSHandshake) * time.Second,
|
||||
ResponseHeaderTimeout: time.Duration(config.Timeouts.ResponseHeader) * time.Second,
|
||||
ExpectContinueTimeout: time.Duration(config.Timeouts.ExpectContinue) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: time.Duration(config.Timeouts.IdleConn) * time.Second,
|
||||
}
|
||||
reverseProxy.Transport = transport
|
||||
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
// prevent nginx from buffering streaming responses (e.g., SSE)
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
resp.Header.Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return &Process{
|
||||
ID: ID,
|
||||
config: config,
|
||||
cmd: nil,
|
||||
reverseProxy: reverseProxy,
|
||||
cancelUpstream: nil,
|
||||
processLogger: processLogger,
|
||||
proxyLogger: proxyLogger,
|
||||
healthCheckTimeout: healthCheckTimeout,
|
||||
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
|
||||
state: StateStopped,
|
||||
|
||||
// concurrency limit
|
||||
concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit),
|
||||
|
||||
// To be removed when migration over exec.CommandContext is complete
|
||||
// stop timeout
|
||||
gracefulStopTimeout: 10 * time.Second,
|
||||
cmdWaitChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// LogMonitor returns the log monitor associated with the process.
|
||||
func (p *Process) LogMonitor() *logmon.Monitor {
|
||||
return p.processLogger
|
||||
}
|
||||
|
||||
// setLastRequestHandled sets the last request handled time in a thread-safe manner.
|
||||
func (p *Process) setLastRequestHandled(t time.Time) {
|
||||
p.lastRequestHandledMutex.Lock()
|
||||
defer p.lastRequestHandledMutex.Unlock()
|
||||
p.lastRequestHandled = t
|
||||
}
|
||||
|
||||
// getLastRequestHandled gets the last request handled time in a thread-safe manner.
|
||||
func (p *Process) getLastRequestHandled() time.Time {
|
||||
p.lastRequestHandledMutex.RLock()
|
||||
defer p.lastRequestHandledMutex.RUnlock()
|
||||
return p.lastRequestHandled
|
||||
}
|
||||
|
||||
// custom error types for swapping state
|
||||
var (
|
||||
ErrExpectedStateMismatch = errors.New("expected state mismatch")
|
||||
ErrInvalidStateTransition = errors.New("invalid state transition")
|
||||
)
|
||||
|
||||
// swapState performs a compare and swap of the state atomically. It returns the current state
|
||||
// and an error if the swap failed.
|
||||
func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, error) {
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
|
||||
if p.state != expectedState {
|
||||
p.proxyLogger.Warnf("<%s> swapState() Unexpected current state %s, expected %s", p.ID, p.state, expectedState)
|
||||
return p.state, ErrExpectedStateMismatch
|
||||
}
|
||||
|
||||
if !isValidTransition(p.state, newState) {
|
||||
p.proxyLogger.Warnf("<%s> swapState() Invalid state transition from %s to %s", p.ID, p.state, newState)
|
||||
return p.state, ErrInvalidStateTransition
|
||||
}
|
||||
|
||||
p.state = newState
|
||||
|
||||
// Atomically increment waitStarting when entering StateStarting
|
||||
// This ensures any thread that sees StateStarting will also see the WaitGroup counter incremented
|
||||
if newState == StateStarting {
|
||||
p.waitStarting.Add(1)
|
||||
}
|
||||
|
||||
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
||||
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
|
||||
return p.state, nil
|
||||
}
|
||||
|
||||
// Helper function to encapsulate transition rules
|
||||
func isValidTransition(from, to ProcessState) bool {
|
||||
switch from {
|
||||
case StateStopped:
|
||||
return to == StateStarting
|
||||
case StateStarting:
|
||||
return to == StateReady || to == StateStopping || to == StateStopped
|
||||
case StateReady:
|
||||
return to == StateStopping
|
||||
case StateStopping:
|
||||
return to == StateStopped || to == StateShutdown
|
||||
case StateShutdown:
|
||||
return false // No transitions allowed from these states
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Process) CurrentState() ProcessState {
|
||||
p.stateMutex.RLock()
|
||||
defer p.stateMutex.RUnlock()
|
||||
return p.state
|
||||
}
|
||||
|
||||
// forceState forces the process state to the new state with mutex protection.
|
||||
// This should only be used in exceptional cases where the normal state transition
|
||||
// validation via swapState() cannot be used.
|
||||
func (p *Process) forceState(newState ProcessState) {
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
p.state = newState
|
||||
}
|
||||
|
||||
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
|
||||
// it is a private method because starting is automatic but stopping can be called
|
||||
// at any time.
|
||||
func (p *Process) start() error {
|
||||
|
||||
// test-only fast path: skip subprocess, health check, and TTL goroutine
|
||||
if p.testHandler != nil {
|
||||
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
|
||||
if err == ErrExpectedStateMismatch {
|
||||
if curState == StateStarting {
|
||||
p.waitStarting.Wait()
|
||||
curState = p.CurrentState()
|
||||
if curState == StateReady {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("process was already starting but wound up in state %v", curState)
|
||||
}
|
||||
return fmt.Errorf("process was in state %v when start() was called", curState)
|
||||
}
|
||||
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
|
||||
}
|
||||
defer p.waitStarting.Done()
|
||||
|
||||
// Mimic the real stop path: cancelUpstream transitions
|
||||
// StateStopping -> StateStopped and closes cmdWaitChan,
|
||||
// matching what waitForCmd does for real subprocesses.
|
||||
ch := make(chan struct{})
|
||||
p.cmdMutex.Lock()
|
||||
p.cancelUpstream = func() {
|
||||
if curState := p.CurrentState(); curState == StateStopping {
|
||||
if _, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||
p.forceState(StateStopped)
|
||||
}
|
||||
} else {
|
||||
p.forceState(StateStopped)
|
||||
}
|
||||
close(ch)
|
||||
}
|
||||
p.cmdWaitChan = ch
|
||||
p.cmdMutex.Unlock()
|
||||
|
||||
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
|
||||
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
|
||||
}
|
||||
p.failedStartCount = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
if p.config.Proxy == "" {
|
||||
return fmt.Errorf("can not start(), upstream proxy missing")
|
||||
}
|
||||
|
||||
args, err := p.config.SanitizedCommand()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get sanitized command: %v", err)
|
||||
}
|
||||
|
||||
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
|
||||
if err == ErrExpectedStateMismatch {
|
||||
// already starting, just wait for it to complete and expect
|
||||
// it to be be in the Ready start after. If not, return an error
|
||||
if curState == StateStarting {
|
||||
p.waitStarting.Wait()
|
||||
if state := p.CurrentState(); state == StateReady {
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("process was already starting but wound up in state %v", state)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("process was in state %v when start() was called", curState)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
|
||||
}
|
||||
}
|
||||
|
||||
// waitStarting.Add(1) is now called atomically in swapState() when transitioning to StateStarting
|
||||
defer p.waitStarting.Done()
|
||||
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())
|
||||
|
||||
p.cmd = exec.CommandContext(cmdContext, args[0], args[1:]...)
|
||||
p.cmd.Stdout = p.processLogger
|
||||
p.cmd.Stderr = p.processLogger
|
||||
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
|
||||
p.cmd.Cancel = p.cmdStopUpstreamProcess
|
||||
p.cmd.WaitDelay = p.gracefulStopTimeout
|
||||
setProcAttributes(p.cmd)
|
||||
|
||||
p.cmdMutex.Lock()
|
||||
p.cancelUpstream = ctxCancelUpstream
|
||||
p.cmdWaitChan = make(chan struct{})
|
||||
p.cmdMutex.Unlock()
|
||||
|
||||
p.failedStartCount++ // this will be reset to zero when the process has successfully started
|
||||
|
||||
p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.ID, strings.Join(args, " "), strings.Join(p.config.Env, ", "))
|
||||
err = p.cmd.Start()
|
||||
|
||||
// Set process state to failed
|
||||
if err != nil {
|
||||
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
|
||||
p.forceState(StateStopped) // force it into a stopped state
|
||||
return fmt.Errorf(
|
||||
"failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
||||
strings.Join(args, " "), err, curState, swapErr,
|
||||
)
|
||||
}
|
||||
return fmt.Errorf("start() failed for command '%s': %v", strings.Join(args, " "), err)
|
||||
}
|
||||
|
||||
// Capture the exit error for later signalling
|
||||
go p.waitForCmd()
|
||||
|
||||
// One of three things can happen at this stage:
|
||||
// 1. The command exits unexpectedly
|
||||
// 2. The health check fails
|
||||
// 3. The health check passes
|
||||
//
|
||||
// only in the third case will the process be considered Ready to accept
|
||||
<-time.After(250 * time.Millisecond) // give process a bit of time to start
|
||||
|
||||
checkStartTime := time.Now()
|
||||
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
|
||||
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
||||
|
||||
// a "none" means don't check for health ... I could have picked a better word :facepalm:
|
||||
if checkEndpoint != "none" {
|
||||
proxyTo := p.config.Proxy
|
||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
|
||||
}
|
||||
|
||||
// Ready Check loop
|
||||
for {
|
||||
currentState := p.CurrentState()
|
||||
if currentState != StateStarting {
|
||||
if currentState == StateStopped {
|
||||
return fmt.Errorf("upstream command exited prematurely but successfully")
|
||||
}
|
||||
return errors.New("health check interrupted due to shutdown")
|
||||
}
|
||||
|
||||
if time.Since(checkStartTime) > maxDuration {
|
||||
p.stopCommand()
|
||||
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
|
||||
}
|
||||
|
||||
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
||||
p.proxyLogger.Infof("<%s> Health check passed on %s", p.ID, healthURL)
|
||||
break
|
||||
} else {
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
ttl := time.Until(checkStartTime.Add(maxDuration))
|
||||
p.proxyLogger.Debugf("<%s> Connection refused on %s, giving up in %.0fs (normal during startup)", p.ID, healthURL, ttl.Seconds())
|
||||
} else {
|
||||
p.proxyLogger.Debugf("<%s> Health check error on %s, %v (normal during startup)", p.ID, healthURL, err)
|
||||
}
|
||||
}
|
||||
<-time.After(p.healthCheckLoopInterval)
|
||||
}
|
||||
}
|
||||
|
||||
if p.config.UnloadAfter > 0 {
|
||||
// start a goroutine to check every second if
|
||||
// the process should be stopped
|
||||
go func() {
|
||||
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
||||
|
||||
for range time.Tick(time.Second) {
|
||||
if p.CurrentState() != StateReady {
|
||||
return
|
||||
}
|
||||
|
||||
// skip the TTL check if there are inflight requests
|
||||
if p.inFlightRequestsCount.Load() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if time.Since(p.getLastRequestHandled()) > maxDuration {
|
||||
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
|
||||
p.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
|
||||
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
|
||||
} else {
|
||||
p.failedStartCount = 0
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Stop will wait for inflight requests to complete before stopping the process.
|
||||
func (p *Process) Stop() {
|
||||
|
||||
// guard to prevent multiple goroutines from stopping
|
||||
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||
p.proxyLogger.Debugf("<%s> Stop() suppressing invalid transition from %s to StateStopping", p.ID, p.CurrentState())
|
||||
return
|
||||
}
|
||||
|
||||
// wait for any inflight requests before proceeding
|
||||
p.proxyLogger.Debugf("<%s> Stop(): Waiting for inflight requests to complete", p.ID)
|
||||
p.inFlightRequests.Wait()
|
||||
p.StopImmediately()
|
||||
}
|
||||
|
||||
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
|
||||
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
|
||||
func (p *Process) StopImmediately() {
|
||||
|
||||
// guard to prevent multiple goroutines from stopping the process
|
||||
enterState := p.CurrentState()
|
||||
if !isValidTransition(enterState, StateStopping) {
|
||||
p.proxyLogger.Debugf("<%s> StopImmediate() suppressing invalid transition from %s to StateStopping", p.ID, p.CurrentState())
|
||||
return
|
||||
}
|
||||
|
||||
p.proxyLogger.Debugf("<%s> Stopping process, enter state: %s", p.ID, enterState)
|
||||
if curState, err := p.swapState(enterState, StateStopping); err != nil {
|
||||
p.proxyLogger.Infof("<%s> Stop() %s -> StateStopping err: %v, current state: %v", p.ID, enterState, err, curState)
|
||||
return
|
||||
}
|
||||
|
||||
p.stopCommand()
|
||||
}
|
||||
|
||||
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
||||
// of time for any inflight requests to complete before shutting down. If the Process
|
||||
// is in the state of starting, it will cancel it and shut it down. Once a process is in
|
||||
// the StateShutdown state, it can not be started again.
|
||||
func (p *Process) Shutdown() {
|
||||
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||
return
|
||||
}
|
||||
|
||||
p.stopCommand()
|
||||
// just force it to this state since there is no recovery from shutdown
|
||||
p.forceState(StateShutdown)
|
||||
}
|
||||
|
||||
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
||||
// If it does not exit within 5 seconds, it will send a SIGKILL.
|
||||
func (p *Process) stopCommand() {
|
||||
stopStartTime := time.Now()
|
||||
defer func() {
|
||||
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
||||
|
||||
// free the buffer in processLogger so the memory can be recovered
|
||||
p.processLogger.Clear()
|
||||
}()
|
||||
|
||||
p.cmdMutex.RLock()
|
||||
cancelUpstream := p.cancelUpstream
|
||||
cmdWaitChan := p.cmdWaitChan
|
||||
p.cmdMutex.RUnlock()
|
||||
|
||||
if cancelUpstream == nil {
|
||||
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
|
||||
return
|
||||
}
|
||||
|
||||
cancelUpstream()
|
||||
<-cmdWaitChan
|
||||
}
|
||||
|
||||
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||
|
||||
client := &http.Client{
|
||||
// wait a short time for a tcp connection to be established
|
||||
Transport: &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 500 * time.Millisecond,
|
||||
}).DialContext,
|
||||
},
|
||||
|
||||
// give a long time to respond to the health check endpoint
|
||||
// after the connection is established. See issue: 276
|
||||
Timeout: 5000 * time.Millisecond,
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", healthURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// got a response but it was not an OK
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if p.reverseProxy == nil {
|
||||
http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
requestBeginTime := time.Now()
|
||||
var startDuration time.Duration
|
||||
|
||||
// prevent new requests from being made while stopping or irrecoverable
|
||||
currentState := p.CurrentState()
|
||||
if currentState == StateShutdown || currentState == StateStopping {
|
||||
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case p.concurrencyLimitSemaphore <- struct{}{}:
|
||||
defer func() { <-p.concurrencyLimitSemaphore }()
|
||||
default:
|
||||
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
p.inFlightRequests.Add(1)
|
||||
p.inFlightRequestsCount.Add(1)
|
||||
defer func() {
|
||||
p.setLastRequestHandled(time.Now())
|
||||
p.inFlightRequestsCount.Add(-1)
|
||||
p.inFlightRequests.Done()
|
||||
}()
|
||||
|
||||
// for #366
|
||||
// - extract streaming param from request context, should have been set by proxymanager
|
||||
var srw *statusResponseWriter
|
||||
swapCtx, cancelLoadCtx := context.WithCancel(r.Context())
|
||||
// start the process on demand
|
||||
if p.CurrentState() != StateReady {
|
||||
// start a goroutine to stream loading status messages into the response writer
|
||||
// add a sync so the streaming client only runs when the goroutine has exited
|
||||
|
||||
isStreaming, _ := r.Context().Value(proxyCtxKey("streaming")).(bool)
|
||||
|
||||
// PR #417 (no support for anthropic v1/messages yet)
|
||||
isChatCompletions := strings.HasPrefix(r.URL.Path, "/v1/chat/completions")
|
||||
if p.config.SendLoadingState != nil && *p.config.SendLoadingState && isStreaming && isChatCompletions {
|
||||
srw = newStatusResponseWriter(p, w)
|
||||
go srw.statusUpdates(swapCtx)
|
||||
} else {
|
||||
p.proxyLogger.Debugf("<%s> SendLoadingState is nil or false, not streaming loading state", p.ID)
|
||||
}
|
||||
|
||||
beginStartTime := time.Now()
|
||||
if err := p.start(); err != nil {
|
||||
errstr := fmt.Sprintf("unable to start process: %s", err)
|
||||
cancelLoadCtx()
|
||||
if srw != nil {
|
||||
srw.sendData(fmt.Sprintf("Unable to swap model err: %s\n", errstr))
|
||||
// Wait for statusUpdates goroutine to finish writing its deferred "Done!" messages
|
||||
// before closing the connection. Without this, the connection would close before
|
||||
// the goroutine can write its cleanup messages, causing incomplete SSE output.
|
||||
srw.waitForCompletion(100 * time.Millisecond)
|
||||
} else {
|
||||
http.Error(w, errstr, http.StatusBadGateway)
|
||||
}
|
||||
return
|
||||
}
|
||||
startDuration = time.Since(beginStartTime)
|
||||
}
|
||||
|
||||
// should trigger srw to stop sending loading events ...
|
||||
cancelLoadCtx()
|
||||
|
||||
// recover from http.ErrAbortHandler panics that can occur when the client
|
||||
// disconnects before the response is sent
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if r == http.ErrAbortHandler {
|
||||
p.proxyLogger.Infof("<%s> recovered from client disconnection during streaming", p.ID)
|
||||
} else {
|
||||
p.proxyLogger.Infof("<%s> recovered from panic: %v", p.ID, r)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if srw != nil {
|
||||
// Wait for the goroutine to finish writing its final messages
|
||||
const completionTimeout = 1 * time.Second
|
||||
if !srw.waitForCompletion(completionTimeout) {
|
||||
p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
if p.testHandler != nil {
|
||||
p.testHandler.ServeHTTP(w, r)
|
||||
} else if srw != nil {
|
||||
p.reverseProxy.ServeHTTP(srw, r)
|
||||
} else {
|
||||
p.reverseProxy.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
totalTime := time.Since(requestBeginTime)
|
||||
p.proxyLogger.Debugf("<%s> request %s - start: %v, total: %v",
|
||||
p.ID, r.RequestURI, startDuration, totalTime)
|
||||
}
|
||||
|
||||
// waitForCmd waits for the command to exit and handles exit conditions depending on current state
|
||||
func (p *Process) waitForCmd() {
|
||||
exitErr := p.cmd.Wait()
|
||||
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
|
||||
|
||||
if exitErr != nil {
|
||||
if errno, ok := exitErr.(syscall.Errno); ok {
|
||||
p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno)
|
||||
} else if exitError, ok := exitErr.(*exec.ExitError); ok {
|
||||
if strings.Contains(exitError.String(), "signal: terminated") {
|
||||
p.proxyLogger.Debugf("<%s> Process stopped OK", p.ID)
|
||||
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
||||
p.proxyLogger.Debugf("<%s> Process interrupted OK", p.ID)
|
||||
} else {
|
||||
p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
|
||||
}
|
||||
} else {
|
||||
if exitErr.Error() != "context canceled" /* this is normal */ {
|
||||
p.proxyLogger.Errorf("<%s> Process exited >> %v", p.ID, exitErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
currentState := p.CurrentState()
|
||||
switch currentState {
|
||||
case StateStopping:
|
||||
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err)
|
||||
p.forceState(StateStopped)
|
||||
}
|
||||
default:
|
||||
p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState)
|
||||
p.forceState(StateStopped) // force it to be in this state
|
||||
}
|
||||
|
||||
p.cmdMutex.Lock()
|
||||
close(p.cmdWaitChan)
|
||||
p.cmdMutex.Unlock()
|
||||
}
|
||||
|
||||
// cmdStopUpstreamProcess attemps to stop the upstream process gracefully
|
||||
func (p *Process) cmdStopUpstreamProcess() error {
|
||||
p.processLogger.Debugf("<%s> cmdStopUpstreamProcess() initiating graceful stop of upstream process", p.ID)
|
||||
|
||||
// this should never happen ...
|
||||
if p.cmd == nil || p.cmd.Process == nil {
|
||||
p.proxyLogger.Debugf("<%s> cmd or cmd.Process is nil (normal during config reload)", p.ID)
|
||||
return fmt.Errorf("<%s> process is nil or cmd is nil, skipping graceful stop", p.ID)
|
||||
}
|
||||
|
||||
if p.config.CmdStop != "" {
|
||||
// replace ${PID} with the pid of the process
|
||||
stopArgs, err := config.SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
|
||||
if err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
p.proxyLogger.Debugf("<%s> Executing stop command: %s", p.ID, strings.Join(stopArgs, " "))
|
||||
|
||||
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
|
||||
stopCmd.Stdout = p.processLogger
|
||||
stopCmd.Stderr = p.processLogger
|
||||
setProcAttributes(stopCmd)
|
||||
stopCmd.Env = p.cmd.Env
|
||||
|
||||
if err := stopCmd.Run(); err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Failed to exec stop command: %v", p.ID, err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := p.cmd.Process.Signal(syscall.SIGTERM); err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Failed to send SIGTERM to process: %v", p.ID, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Logger returns the logger for this process.
|
||||
func (p *Process) Logger() *logmon.Monitor {
|
||||
return p.processLogger
|
||||
}
|
||||
|
||||
var loadingRemarks = []string{
|
||||
"Still faster than your last standup meeting...",
|
||||
"Reticulating splines...",
|
||||
"Waking up the hamsters...",
|
||||
"Teaching the model manners...",
|
||||
"Convincing the GPU to participate...",
|
||||
"Loading weights (they're heavy)...",
|
||||
"Herding electrons...",
|
||||
"Compiling excuses for the delay...",
|
||||
"Downloading more RAM...",
|
||||
"Asking the model nicely to boot up...",
|
||||
"Bribing CUDA with cookies...",
|
||||
"Still loading (blame VRAM)...",
|
||||
"The model is fashionably late...",
|
||||
"Warming up those tensors...",
|
||||
"Making the neural net do push-ups...",
|
||||
"Your patience is appreciated (really)...",
|
||||
"Almost there (probably)...",
|
||||
"Loading like it's 1999...",
|
||||
"The model forgot where it put its keys...",
|
||||
"Quantum tunneling through layers...",
|
||||
"Negotiating with the PCIe bus...",
|
||||
"Defrosting frozen parameters...",
|
||||
"Teaching attention heads to focus...",
|
||||
"Running the matrix (slowly)...",
|
||||
"Untangling transformer blocks...",
|
||||
"Calibrating the flux capacitor...",
|
||||
"Spinning up the probability wheels...",
|
||||
"Waiting for the GPU to wake from its nap...",
|
||||
"Converting caffeine to compute...",
|
||||
"Allocating virtual patience...",
|
||||
"Performing arcane CUDA rituals...",
|
||||
"The model is stuck in traffic...",
|
||||
"Inflating embeddings...",
|
||||
"Summoning computational demons...",
|
||||
"Pleading with the OOM killer...",
|
||||
"Calculating the meaning of life (still at 42)...",
|
||||
"Training the training wheels...",
|
||||
"Optimizing the optimizer...",
|
||||
"Bootstrapping the bootstrapper...",
|
||||
"Loading loading screen...",
|
||||
"Processing processing logs...",
|
||||
"Buffering buffer overflow jokes...",
|
||||
"The model hit snooze...",
|
||||
"Debugging the debugger...",
|
||||
"Compiling the compiler...",
|
||||
"Parsing the parser (meta)...",
|
||||
"Tokenizing tokens...",
|
||||
"Encoding the encoder...",
|
||||
"Hashing hash browns...",
|
||||
"Forking spoons (not forks)...",
|
||||
"The model is contemplating existence...",
|
||||
"Transcending dimensional barriers...",
|
||||
"Invoking elder tensor gods...",
|
||||
"Unfurling probability clouds...",
|
||||
"Synchronizing parallel universes...",
|
||||
"The GPU is having second thoughts...",
|
||||
"Recalibrating reality matrices...",
|
||||
"Time is an illusion, loading doubly so...",
|
||||
"Convincing bits to flip themselves...",
|
||||
"The model is reading its own documentation...",
|
||||
}
|
||||
|
||||
type statusResponseWriter struct {
|
||||
hasWritten bool
|
||||
writer http.ResponseWriter
|
||||
process *Process
|
||||
wg sync.WaitGroup // Track goroutine completion
|
||||
start time.Time
|
||||
}
|
||||
|
||||
func newStatusResponseWriter(p *Process, w http.ResponseWriter) *statusResponseWriter {
|
||||
s := &statusResponseWriter{
|
||||
writer: w,
|
||||
process: p,
|
||||
start: time.Now(),
|
||||
}
|
||||
|
||||
s.Header().Set("Content-Type", "text/event-stream") // SSE
|
||||
s.Header().Set("Cache-Control", "no-cache") // no-cache
|
||||
s.Header().Set("Connection", "keep-alive") // keep-alive
|
||||
s.WriteHeader(http.StatusOK) // send status code 200
|
||||
s.sendLine("━━━━━")
|
||||
s.sendLine(fmt.Sprintf("llama-swap loading model: %s", p.ID))
|
||||
return s
|
||||
}
|
||||
|
||||
// statusUpdates sends status updates to the client while the model is loading
|
||||
func (s *statusResponseWriter) statusUpdates(ctx context.Context) {
|
||||
s.wg.Add(1)
|
||||
defer s.wg.Done()
|
||||
|
||||
// Recover from panics caused by client disconnection
|
||||
// Note: recover() only works within the same goroutine, so we need it here
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
s.process.proxyLogger.Debugf("<%s> statusUpdates recovered from panic (likely client disconnect): %v", s.process.ID, r)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
duration := time.Since(s.start)
|
||||
s.sendLine(fmt.Sprintf("\nDone! (%.2fs)", duration.Seconds()))
|
||||
s.sendLine("━━━━━")
|
||||
s.sendLine(" ")
|
||||
}()
|
||||
|
||||
// Create a shuffled copy of loadingRemarks
|
||||
remarks := make([]string, len(loadingRemarks))
|
||||
copy(remarks, loadingRemarks)
|
||||
rand.Shuffle(len(remarks), func(i, j int) {
|
||||
remarks[i], remarks[j] = remarks[j], remarks[i]
|
||||
})
|
||||
ri := 0
|
||||
|
||||
// Pick a random duration to send a remark
|
||||
nextRemarkIn := time.Duration(2+rand.Intn(4)) * time.Second
|
||||
lastRemarkTime := time.Now()
|
||||
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop() // Ensure ticker is stopped to prevent resource leak
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if s.process.CurrentState() == StateReady {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if it's time for a snarky remark
|
||||
if time.Since(lastRemarkTime) >= nextRemarkIn {
|
||||
remark := remarks[ri%len(remarks)]
|
||||
ri++
|
||||
s.sendLine(fmt.Sprintf("\n%s", remark))
|
||||
lastRemarkTime = time.Now()
|
||||
// Pick a new random duration for the next remark
|
||||
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
|
||||
} else {
|
||||
s.sendData(".")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// waitForCompletion waits for the statusUpdates goroutine to finish
|
||||
func (s *statusResponseWriter) waitForCompletion(timeout time.Duration) bool {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return true
|
||||
case <-time.After(timeout):
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) sendLine(line string) {
|
||||
s.sendData(line + "\n")
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) sendData(data string) {
|
||||
// Create the proper SSE JSON structure
|
||||
type Delta struct {
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
}
|
||||
type Choice struct {
|
||||
Delta Delta `json:"delta"`
|
||||
}
|
||||
type SSEMessage struct {
|
||||
Choices []Choice `json:"choices"`
|
||||
}
|
||||
|
||||
msg := SSEMessage{
|
||||
Choices: []Choice{
|
||||
{
|
||||
Delta: Delta{
|
||||
ReasoningContent: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
s.process.proxyLogger.Errorf("<%s> Failed to marshal SSE message: %v", s.process.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Write SSE formatted data, panic if not able to write
|
||||
_, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("<%s> Failed to write SSE data: %v", s.process.ID, err))
|
||||
}
|
||||
s.Flush()
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) Header() http.Header {
|
||||
return s.writer.Header()
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) Write(data []byte) (int, error) {
|
||||
return s.writer.Write(data)
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) WriteHeader(statusCode int) {
|
||||
if s.hasWritten {
|
||||
return
|
||||
}
|
||||
s.hasWritten = true
|
||||
s.writer.WriteHeader(statusCode)
|
||||
s.Flush()
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) Flush() {
|
||||
if flusher, ok := s.writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
@@ -1,609 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
debugLogger = logmon.NewWriter(os.Stdout)
|
||||
)
|
||||
|
||||
func init() {
|
||||
// flip to help with debugging tests
|
||||
if false {
|
||||
debugLogger.SetLogLevel(logmon.LevelDebug)
|
||||
} else {
|
||||
debugLogger.SetLogLevel(logmon.LevelError)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
||||
|
||||
expectedMessage := "testing91931"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// Create a process
|
||||
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// process is automatically started
|
||||
assert.Equal(t, StateStopped, process.CurrentState())
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), expectedMessage)
|
||||
|
||||
// Stop the process
|
||||
process.Stop()
|
||||
|
||||
req = httptest.NewRequest("GET", "/", nil)
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
// Proxy the request
|
||||
process.ProxyRequest(w, req)
|
||||
|
||||
// should have automatically started the process again
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcess_WaitOnMultipleStarts tests that multiple concurrent requests
|
||||
// are all handled successfully, even though they all may ask for the process to .start()
|
||||
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
|
||||
|
||||
expectedMessage := "testing91931"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func(reqID int) {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Worker %d got wrong HTTP code", reqID)
|
||||
assert.Contains(t, w.Body.String(), expectedMessage, "Worker %d got wrong message", reqID)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
}
|
||||
|
||||
// test that the automatic start returns the expected error type
|
||||
func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||
// Create a process configuration
|
||||
config := config.ModelConfig{
|
||||
Cmd: "nonexistent-command",
|
||||
Proxy: "http://127.0.0.1:9913",
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("broken", 1, config, debugLogger, debugLogger)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "unable to start process")
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "start() failed for command 'nonexistent-command':")
|
||||
}
|
||||
|
||||
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long auto unload TTL test")
|
||||
}
|
||||
|
||||
expectedMessage := "I_sense_imminent_danger"
|
||||
conf := getTestSimpleResponderConfig(expectedMessage)
|
||||
assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter)
|
||||
conf.UnloadAfter = 3 // seconds
|
||||
assert.Equal(t, 3, conf.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl_test", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// this should take 4 seconds
|
||||
req1 := httptest.NewRequest("GET", "/slow-respond?echo=1234&delay=1000ms", nil)
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Proxy the request (auto start) with a slow response that takes longer than config.UnloadAfter
|
||||
process.ProxyRequest(w, req1)
|
||||
|
||||
t.Log("sending slow first request (4 seconds)")
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "1234")
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
// ensure the TTL timeout does not race slow requests (see issue #25)
|
||||
t.Log("sending second request (1 second)")
|
||||
time.Sleep(time.Second)
|
||||
w = httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req2)
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), expectedMessage)
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
// wait 5 seconds
|
||||
t.Log("sleep 5 seconds and check if unloaded")
|
||||
time.Sleep(5 * time.Second)
|
||||
assert.Equal(t, StateStopped, process.CurrentState())
|
||||
}
|
||||
|
||||
func TestProcess_LowTTLValue(t *testing.T) {
|
||||
if true { // change this code to run this ...
|
||||
t.Skip("skipping test, edit process_test.go to run it ")
|
||||
}
|
||||
|
||||
conf := getTestSimpleResponderConfig("fast_ttl")
|
||||
assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter)
|
||||
conf.UnloadAfter = 1 // second
|
||||
assert.Equal(t, 1, conf.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
t.Logf("Waiting before sending request %d", i)
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
expected := fmt.Sprintf("echo=test_%d", i)
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=50ms", expected), nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), expected)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// issue #19
|
||||
// This test makes sure using Process.Stop() does not affect pending HTTP
|
||||
// requests. All HTTP requests in this test should complete successfully.
|
||||
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
expectedMessage := "12345"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
process := NewProcess("t", 10, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
results := map[string]string{
|
||||
"12345": "",
|
||||
"abcde": "",
|
||||
"fghij": "",
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
|
||||
for key := range results {
|
||||
wg.Add(1)
|
||||
go func(key string) {
|
||||
defer wg.Done()
|
||||
// send a request where simple-responder is will wait 300ms before responding
|
||||
// this will simulate an in-progress request.
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=300ms", key), nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
process.ProxyRequest(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
results[key] = w.Body.String()
|
||||
mu.Unlock()
|
||||
|
||||
}(key)
|
||||
}
|
||||
|
||||
// Stop the process while requests are still being processed
|
||||
go func() {
|
||||
<-time.After(150 * time.Millisecond)
|
||||
process.Stop()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
for key, result := range results {
|
||||
assert.Equal(t, key, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcess_SwapState(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
currentState ProcessState
|
||||
expectedState ProcessState
|
||||
newState ProcessState
|
||||
expectedError error
|
||||
expectedResult ProcessState
|
||||
}{
|
||||
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
|
||||
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
|
||||
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
|
||||
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, nil, StateStopped},
|
||||
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
|
||||
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
|
||||
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
|
||||
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
|
||||
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
|
||||
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
|
||||
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
|
||||
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
|
||||
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger)
|
||||
p.state = test.currentState
|
||||
|
||||
resultState, err := p.swapState(test.expectedState, test.newState)
|
||||
if err != nil && test.expectedError == nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
} else if err == nil && test.expectedError != nil {
|
||||
t.Errorf("Expected error: %v, but got none", test.expectedError)
|
||||
} else if err != nil && test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("Expected error: %v, got: %v", test.expectedError, err)
|
||||
}
|
||||
}
|
||||
|
||||
if resultState != test.expectedResult {
|
||||
t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long shutdown test")
|
||||
}
|
||||
|
||||
expectedMessage := "testing91931"
|
||||
|
||||
// make a config where the healthcheck will always fail because port is wrong
|
||||
config := getTestSimpleResponderConfigPort(expectedMessage, 9999)
|
||||
config.Proxy = "http://localhost:9998/test"
|
||||
|
||||
healthCheckTTLSeconds := 30
|
||||
process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger)
|
||||
|
||||
// make it a lot faster
|
||||
process.healthCheckLoopInterval = time.Second
|
||||
|
||||
// start a goroutine to simulate a shutdown
|
||||
var wg sync.WaitGroup
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-time.After(time.Millisecond * 500)
|
||||
process.Shutdown()
|
||||
}()
|
||||
wg.Add(1)
|
||||
|
||||
// start the process, this is a blocking call
|
||||
err := process.start()
|
||||
|
||||
wg.Wait()
|
||||
assert.ErrorContains(t, err, "health check interrupted due to shutdown")
|
||||
assert.Equal(t, StateShutdown, process.CurrentState())
|
||||
}
|
||||
|
||||
func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping Exit Interrupts Health Check test")
|
||||
}
|
||||
|
||||
// should run and exit but interrupt the long checkHealthTimeout
|
||||
checkHealthTimeout := 5
|
||||
config := config.ModelConfig{
|
||||
Cmd: "sleep 1",
|
||||
Proxy: "http://127.0.0.1:9913",
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger)
|
||||
process.healthCheckLoopInterval = time.Second // make it faster
|
||||
err := process.start()
|
||||
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
func TestProcess_ConcurrencyLimit(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long concurrency limit test")
|
||||
}
|
||||
|
||||
expectedMessage := "concurrency_limit_test"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// only allow 1 concurrent request at a time
|
||||
config.ConcurrencyLimit = 1
|
||||
|
||||
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
|
||||
assert.Equal(t, 1, cap(process.concurrencyLimitSemaphore))
|
||||
defer process.Stop()
|
||||
|
||||
// launch a goroutine first to take up the semaphore
|
||||
go func() {
|
||||
req1 := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=75ms", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req1)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}()
|
||||
|
||||
// let the goroutine start
|
||||
<-time.After(time.Millisecond * 25)
|
||||
|
||||
denied := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, denied)
|
||||
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||
}
|
||||
|
||||
func TestProcess_StopImmediately(t *testing.T) {
|
||||
expectedMessage := "test_stop_immediate"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, process.CurrentState(), StateReady)
|
||||
go func() {
|
||||
// slow, but will get killed by StopImmediate
|
||||
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=1s", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
}()
|
||||
<-time.After(time.Millisecond)
|
||||
process.StopImmediately()
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
||||
// the upstream command
|
||||
func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping SIGTERM test on Windows ")
|
||||
}
|
||||
|
||||
expectedMessage := "test_sigkill"
|
||||
binaryPath := getSimpleResponderPath()
|
||||
port := getTestPort()
|
||||
|
||||
conf := config.ModelConfig{
|
||||
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
|
||||
// to force the process to exit
|
||||
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("stop_immediate", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// reduce to make testing go faster
|
||||
process.gracefulStopTimeout = time.Second
|
||||
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, process.CurrentState(), StateReady)
|
||||
|
||||
waitChan := make(chan struct{})
|
||||
go func() {
|
||||
// slow, but will get killed by StopImmediate
|
||||
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=2s", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
|
||||
// StatusOK because that was already sent before the kill
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// unexpected EOF because the kill happened, the "1" is sent before the kill
|
||||
// then the unexpected EOF is sent after the kill
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
|
||||
} else {
|
||||
// Upstream may be killed mid-response.
|
||||
// Assert an incomplete or partial response.
|
||||
assert.NotEqual(t, "12345", w.Body.String())
|
||||
}
|
||||
|
||||
close(waitChan)
|
||||
}()
|
||||
|
||||
<-time.After(time.Millisecond)
|
||||
process.StopImmediately()
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
|
||||
// the request should have been interrupted by SIGKILL
|
||||
<-waitChan
|
||||
}
|
||||
|
||||
func TestProcess_StopCmd(t *testing.T) {
|
||||
conf := getTestSimpleResponderConfig("test_stop_cmd")
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
conf.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
} else {
|
||||
conf.CmdStop = "kill -TERM ${PID}"
|
||||
}
|
||||
|
||||
process := NewProcess("testStopCmd", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, process.CurrentState(), StateReady)
|
||||
process.StopImmediately()
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
|
||||
expectedMessage := "test_env_not_emptied"
|
||||
conf := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// ensure that the the default config does not blank out the inherited environment
|
||||
configWEnv := conf
|
||||
|
||||
// ensure the additiona variables are appended to the process' environment
|
||||
configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2")
|
||||
|
||||
process1 := NewProcess("env_test", 2, conf, debugLogger, debugLogger)
|
||||
process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger)
|
||||
|
||||
process1.start()
|
||||
defer process1.Stop()
|
||||
process2.start()
|
||||
defer process2.Stop()
|
||||
|
||||
assert.NotZero(t, len(process1.cmd.Environ()))
|
||||
assert.NotZero(t, len(process2.cmd.Environ()))
|
||||
assert.Equal(t, len(process1.cmd.Environ())+2, len(process2.cmd.Environ()), "process2 should have 2 more environment variables than process1")
|
||||
|
||||
}
|
||||
|
||||
// TestProcess_ReverseProxyPanicIsHandled tests that panics from
|
||||
// httputil.ReverseProxy in Process.ProxyRequest(w, r) do not bubble up and are
|
||||
// handled appropriately.
|
||||
//
|
||||
// httputil.ReverseProxy will panic with http.ErrAbortHandler when it has sent headers
|
||||
// can't copy the body. This can be caused by a client disconnecting before the full
|
||||
// response is sent from some reason.
|
||||
//
|
||||
// bug: https://github.com/mostlygeek/llama-swap/issues/362
|
||||
// see: https://github.com/golang/go/issues/23643 (where panic was added to httputil.ReverseProxy)
|
||||
func TestProcess_ReverseProxyPanicIsHandled(t *testing.T) {
|
||||
// Add defer/recover to catch any panics that aren't handled by ProxyRequest
|
||||
// If this recover() is hit, it means ProxyRequest didn't handle the panic properly
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("ProxyRequest should handle panics from reverseProxy.ServeHTTP, but panic was not caught: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
expectedMessage := "panic_test"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
process := NewProcess("panic-test", 5, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// Start the process
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
// Create a custom ResponseWriter that simulates a client disconnect
|
||||
// by panicking when Write is called after headers are sent
|
||||
panicWriter := &panicOnWriteResponseWriter{
|
||||
ResponseRecorder: httptest.NewRecorder(),
|
||||
shouldPanic: true,
|
||||
}
|
||||
|
||||
// Make a request that will trigger the panic
|
||||
req := httptest.NewRequest("GET", "/slow-respond?echo=test&delay=100ms", nil)
|
||||
|
||||
// This should panic inside reverseProxy.ServeHTTP when the panicWriter.Write() is called.
|
||||
// ProxyRequest should catch and handle this panic gracefully.
|
||||
process.ProxyRequest(panicWriter, req)
|
||||
|
||||
// If we get here, the panic was properly recovered in ProxyRequest
|
||||
// The process should still be in a ready state
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
}
|
||||
|
||||
// panicOnWriteResponseWriter is a ResponseWriter that panics on Write
|
||||
// to simulate a client disconnect after headers are sent
|
||||
// used by: TestProcess_ReverseProxyPanicIsHandled
|
||||
type panicOnWriteResponseWriter struct {
|
||||
*httptest.ResponseRecorder
|
||||
shouldPanic bool
|
||||
headerWritten bool
|
||||
}
|
||||
|
||||
func (w *panicOnWriteResponseWriter) WriteHeader(statusCode int) {
|
||||
w.headerWritten = true
|
||||
w.ResponseRecorder.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *panicOnWriteResponseWriter) Write(b []byte) (int, error) {
|
||||
if w.shouldPanic && w.headerWritten {
|
||||
// Simulate the panic that httputil.ReverseProxy throws
|
||||
panic(http.ErrAbortHandler)
|
||||
}
|
||||
return w.ResponseRecorder.Write(b)
|
||||
}
|
||||
|
||||
func TestProcess_CustomTimeouts(t *testing.T) {
|
||||
modelConfig := config.ModelConfig{
|
||||
Cmd: "echo test",
|
||||
Proxy: "http://localhost:8080",
|
||||
CheckEndpoint: "/health",
|
||||
Timeouts: config.TimeoutsConfig{
|
||||
Connect: 45,
|
||||
ResponseHeader: 120,
|
||||
TLSHandshake: 15,
|
||||
ExpectContinue: 2,
|
||||
IdleConn: 120,
|
||||
},
|
||||
}
|
||||
|
||||
debugLogger := logmon.NewWriter(io.Discard)
|
||||
process := NewProcess("test-model", 30, modelConfig, debugLogger, debugLogger)
|
||||
|
||||
// Verify the process was created successfully
|
||||
assert.NotNil(t, process)
|
||||
assert.Equal(t, "test-model", process.ID)
|
||||
assert.NotNil(t, process.reverseProxy)
|
||||
assert.NotNil(t, process.reverseProxy.Transport)
|
||||
|
||||
// Verify it's using http.Transport (not some other type)
|
||||
transport, ok := process.reverseProxy.Transport.(*http.Transport)
|
||||
assert.True(t, ok, "Transport should be *http.Transport")
|
||||
assert.NotNil(t, transport)
|
||||
|
||||
// Verify the timeouts are correctly applied
|
||||
assert.Equal(t, 120*time.Second, transport.ResponseHeaderTimeout)
|
||||
assert.Equal(t, 15*time.Second, transport.TLSHandshakeTimeout)
|
||||
assert.Equal(t, 2*time.Second, transport.ExpectContinueTimeout)
|
||||
assert.Equal(t, 120*time.Second, transport.IdleConnTimeout)
|
||||
assert.True(t, transport.ForceAttemptHTTP2)
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
// No-op on Unix systems
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
HideWindow: true,
|
||||
CreationFlags: 0x08000000, // CREATE_NO_WINDOW
|
||||
}
|
||||
}
|
||||
@@ -1,194 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
type ProcessGroup struct {
|
||||
sync.Mutex
|
||||
|
||||
config config.Config
|
||||
id string
|
||||
swap bool
|
||||
exclusive bool
|
||||
persistent bool
|
||||
|
||||
proxyLogger *logmon.Monitor
|
||||
upstreamLogger *logmon.Monitor
|
||||
|
||||
// map of current processes
|
||||
processes map[string]*Process
|
||||
lastUsedProcess string
|
||||
|
||||
// inflight tracks fast-path requests (requests for the already-selected
|
||||
// model in a swap group). Fast-path requests Add(1) while holding pg.Lock
|
||||
// and Done() on completion; a concurrent swap request calls inflight.Wait()
|
||||
// under pg.Lock before stopping the current process. Without this tracking,
|
||||
// a fast-path request that has released pg.Lock but has not yet called
|
||||
// Process.inFlightRequests.Add(1) races with Stop()'s Wait() and can be
|
||||
// killed mid-request.
|
||||
inflight sync.WaitGroup
|
||||
|
||||
// testDelayFastPath is a test-only hook that, when non-nil, is invoked in
|
||||
// the fast path after pg.Lock is released but before the request is
|
||||
// dispatched to Process.ProxyRequest. Tests use it to park a fast-path
|
||||
// request at the exact race window to deterministically reproduce the
|
||||
// fast-path vs swap race.
|
||||
testDelayFastPath func()
|
||||
}
|
||||
|
||||
func NewProcessGroup(id string, config config.Config, proxyLogger *logmon.Monitor, upstreamLogger *logmon.Monitor) *ProcessGroup {
|
||||
groupConfig, ok := config.Groups[id]
|
||||
if !ok {
|
||||
panic("Unable to find configuration for group id: " + id)
|
||||
}
|
||||
|
||||
pg := &ProcessGroup{
|
||||
id: id,
|
||||
config: config,
|
||||
swap: groupConfig.Swap,
|
||||
exclusive: groupConfig.Exclusive,
|
||||
persistent: groupConfig.Persistent,
|
||||
proxyLogger: proxyLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
processes: make(map[string]*Process),
|
||||
}
|
||||
|
||||
// Create a Process for each member in the group
|
||||
for _, modelID := range groupConfig.Members {
|
||||
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
|
||||
processLogger := logmon.NewWriter(upstreamLogger)
|
||||
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, processLogger, pg.proxyLogger)
|
||||
pg.processes[modelID] = process
|
||||
}
|
||||
|
||||
return pg
|
||||
}
|
||||
|
||||
// ProxyRequest proxies a request to the specified model
|
||||
func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter, request *http.Request) error {
|
||||
if !pg.HasMember(modelID) {
|
||||
return fmt.Errorf("model %s not part of group %s", modelID, pg.id)
|
||||
}
|
||||
|
||||
if pg.swap {
|
||||
pg.Lock()
|
||||
if pg.lastUsedProcess != modelID {
|
||||
|
||||
// Wait for in-flight fast-path requests to drain before stopping
|
||||
// the previous process. Without this, a fast-path request that has
|
||||
// released pg.Lock but has not yet incremented
|
||||
// Process.inFlightRequests races with Stop() and can be killed
|
||||
// mid-request.
|
||||
pg.inflight.Wait()
|
||||
|
||||
// is there something already running?
|
||||
if pg.lastUsedProcess != "" {
|
||||
pg.processes[pg.lastUsedProcess].Stop()
|
||||
}
|
||||
|
||||
// wait for the request to the new model to be fully handled
|
||||
// and prevent race conditions see issue #277
|
||||
pg.processes[modelID].ProxyRequest(writer, request)
|
||||
pg.lastUsedProcess = modelID
|
||||
|
||||
// short circuit and exit
|
||||
pg.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fast path: register this request in inflight before releasing
|
||||
// pg.Lock so a concurrent swap will wait for it to complete.
|
||||
pg.inflight.Add(1)
|
||||
defer pg.inflight.Done()
|
||||
pg.Unlock()
|
||||
|
||||
if pg.testDelayFastPath != nil {
|
||||
pg.testDelayFastPath()
|
||||
}
|
||||
}
|
||||
|
||||
pg.processes[modelID].ProxyRequest(writer, request)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) HasMember(modelName string) bool {
|
||||
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) GetMember(modelName string) (*Process, bool) {
|
||||
if pg.HasMember(modelName) {
|
||||
return pg.processes[modelName], true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error {
|
||||
pg.Lock()
|
||||
|
||||
process, exists := pg.processes[modelID]
|
||||
if !exists {
|
||||
pg.Unlock()
|
||||
return fmt.Errorf("process not found for %s", modelID)
|
||||
}
|
||||
|
||||
if pg.lastUsedProcess == modelID {
|
||||
pg.lastUsedProcess = ""
|
||||
}
|
||||
pg.Unlock()
|
||||
|
||||
switch strategy {
|
||||
case StopImmediately:
|
||||
process.StopImmediately()
|
||||
default:
|
||||
process.Stop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
|
||||
pg.Lock()
|
||||
defer pg.Unlock()
|
||||
|
||||
if strategy != StopImmediately {
|
||||
pg.inflight.Wait()
|
||||
}
|
||||
|
||||
if len(pg.processes) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// stop Processes in parallel
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pg.processes {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
defer wg.Done()
|
||||
switch strategy {
|
||||
case StopImmediately:
|
||||
process.StopImmediately()
|
||||
default:
|
||||
process.Stop()
|
||||
}
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) Shutdown() {
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pg.processes {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
defer wg.Done()
|
||||
process.Shutdown()
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -1,345 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
"model4": getTestSimpleResponderConfig("model4"),
|
||||
"model5": getTestSimpleResponderConfig("model5"),
|
||||
},
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model2"},
|
||||
},
|
||||
"G2": {
|
||||
Swap: false,
|
||||
Exclusive: true,
|
||||
Members: []string{"model3", "model4"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
|
||||
pg := NewProcessGroup(config.DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
|
||||
assert.True(t, pg.HasMember("model5"))
|
||||
}
|
||||
|
||||
func TestProcessGroup_HasMember(t *testing.T) {
|
||||
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||
assert.True(t, pg.HasMember("model1"))
|
||||
assert.True(t, pg.HasMember("model2"))
|
||||
assert.False(t, pg.HasMember("model3"))
|
||||
}
|
||||
|
||||
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
|
||||
// and multiple requests are made in parallel, only one process is running at a time.
|
||||
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
// use the same listening so if a model is already running, it will fail
|
||||
// this is a way to test that swap isolation is working
|
||||
// properly when there are parallel requests made at the
|
||||
// same time.
|
||||
"model1": getTestSimpleResponderConfigPort("model1", 9832),
|
||||
"model2": getTestSimpleResponderConfigPort("model2", 9832),
|
||||
"model3": getTestSimpleResponderConfigPort("model3", 9832),
|
||||
"model4": getTestSimpleResponderConfigPort("model4", 9832),
|
||||
"model5": getTestSimpleResponderConfigPort("model5", 9832),
|
||||
},
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Members: []string{"model1", "model2", "model3", "model4", "model5"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
tests := []string{"model1", "model2", "model3", "model4", "model5"}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(len(tests))
|
||||
for _, modelName := range tests {
|
||||
go func(modelName string) {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelName)
|
||||
}(modelName)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestProcessGroup_ProxyRequestSwapRaceAgainstFastPath verifies that a swap
|
||||
// request cannot stop the current process while a fast-path request (for the
|
||||
// already-selected model) is in flight. Without ProcessGroup-level inflight
|
||||
// tracking, a fast-path request that has released pg.Lock but has not yet
|
||||
// incremented Process.inFlightRequests races with Stop()'s Wait() and the
|
||||
// process is killed mid-request.
|
||||
func TestProcessGroup_ProxyRequestSwapRaceAgainstFastPath(t *testing.T) {
|
||||
cfg := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Members: []string{"model1", "model2"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
pg := NewProcessGroup("G1", cfg, testLogger, testLogger)
|
||||
defer pg.StopProcesses(StopImmediately)
|
||||
|
||||
// Bypass real subprocesses so the test is fast and deterministic.
|
||||
pg.processes["model1"].testHandler = newTestHandler("model1")
|
||||
pg.processes["model2"].testHandler = newTestHandler("model2")
|
||||
|
||||
// Prime: run a request through model1 via the swap path so that
|
||||
// lastUsedProcess == "model1" and subsequent model1 requests take the
|
||||
// fast path.
|
||||
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
primeW := httptest.NewRecorder()
|
||||
require.NoError(t, pg.ProxyRequest("model1", primeW, primeReq))
|
||||
require.Equal(t, http.StatusOK, primeW.Code)
|
||||
require.Equal(t, StateReady, pg.processes["model1"].CurrentState())
|
||||
require.Equal(t, StateStopped, pg.processes["model2"].CurrentState())
|
||||
|
||||
// Fast-path hook: signal arrival at the race window, then wait for
|
||||
// release. This parks R2 deterministically at the point where pg.Lock
|
||||
// has been released but Process.inFlightRequests has not yet been
|
||||
// incremented — the exact window the race exploits.
|
||||
r2Reached := make(chan struct{})
|
||||
r2Release := make(chan struct{})
|
||||
pg.testDelayFastPath = func() {
|
||||
close(r2Reached)
|
||||
<-r2Release
|
||||
}
|
||||
|
||||
// R2: fast-path request for model1. Will pause at the test hook.
|
||||
r2Done := make(chan struct{})
|
||||
w2 := httptest.NewRecorder()
|
||||
go func() {
|
||||
defer close(r2Done)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
assert.NoError(t, pg.ProxyRequest("model1", w2, req))
|
||||
}()
|
||||
|
||||
// Deterministically wait for R2 to reach the race window.
|
||||
<-r2Reached
|
||||
|
||||
// R3: swap request for model2. Must wait for R2 to finish before touching
|
||||
// model1, otherwise model1 gets killed mid-request.
|
||||
r3Done := make(chan struct{})
|
||||
w3 := httptest.NewRecorder()
|
||||
go func() {
|
||||
defer close(r3Done)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
assert.NoError(t, pg.ProxyRequest("model2", w3, req))
|
||||
}()
|
||||
|
||||
// Spin until R3 has acquired pg.Lock and entered the swap critical
|
||||
// section. In the fixed code, R3 then blocks on pg.inflight.Wait() while
|
||||
// still holding the lock, so TryLock keeps failing.
|
||||
for pg.TryLock() {
|
||||
pg.Unlock()
|
||||
runtime.Gosched()
|
||||
}
|
||||
|
||||
// Bounded poll: give R3 a chance to demonstrate the bug by mutating
|
||||
// state. In the fixed code, R3 is blocked on pg.inflight.Wait() and
|
||||
// nothing changes, so we wait the full window. In the buggy code, R3
|
||||
// will Stop() model1 and start serving via model2 within microseconds —
|
||||
// we exit early once the mutation is observable.
|
||||
deadline := time.Now().Add(100 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if pg.processes["model1"].CurrentState() != StateReady ||
|
||||
pg.processes["model2"].CurrentState() != StateStopped {
|
||||
break
|
||||
}
|
||||
done := false
|
||||
select {
|
||||
case <-r3Done:
|
||||
done = true
|
||||
default:
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
runtime.Gosched()
|
||||
}
|
||||
|
||||
// Invariant: R3 must be blocked while R2 is still in flight.
|
||||
select {
|
||||
case <-r3Done:
|
||||
t.Fatal("swap completed while fast-path request was still in flight — race not prevented")
|
||||
default:
|
||||
}
|
||||
assert.Equal(t, StateReady, pg.processes["model1"].CurrentState(),
|
||||
"model1 must stay Ready while a fast-path request is in flight")
|
||||
assert.Equal(t, StateStopped, pg.processes["model2"].CurrentState(),
|
||||
"model2 must not be started until R2 finishes and model1 is swapped out")
|
||||
|
||||
// Release R2 and let both requests finish.
|
||||
close(r2Release)
|
||||
<-r2Done
|
||||
<-r3Done
|
||||
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Contains(t, w2.Body.String(), "model1")
|
||||
assert.Equal(t, http.StatusOK, w3.Code)
|
||||
assert.Contains(t, w3.Body.String(), "model2")
|
||||
}
|
||||
|
||||
// TestProcessGroup_StopProcessesWaitsForInflight verifies that StopProcesses
|
||||
// (called externally, e.g. from ProxyManager.swapProcessGroup) cannot stop a
|
||||
// process while a fast-path ProxyRequest is in the [pg.Unlock,
|
||||
// Process.inFlightRequests.Add(1)] window. Without pg.inflight.Wait() in
|
||||
// StopProcesses, the external caller bypasses the inflight guard and kills the
|
||||
// process mid-request.
|
||||
func TestProcessGroup_StopProcessesWaitsForInflight(t *testing.T) {
|
||||
cfg := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Members: []string{"model1", "model2"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
pg := NewProcessGroup("G1", cfg, testLogger, testLogger)
|
||||
defer pg.StopProcesses(StopImmediately)
|
||||
|
||||
pg.processes["model1"].testHandler = newTestHandler("model1")
|
||||
pg.processes["model2"].testHandler = newTestHandler("model2")
|
||||
|
||||
// Prime: model1 is active so subsequent model1 requests take the fast path.
|
||||
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
primeW := httptest.NewRecorder()
|
||||
require.NoError(t, pg.ProxyRequest("model1", primeW, primeReq))
|
||||
require.Equal(t, http.StatusOK, primeW.Code)
|
||||
require.Equal(t, StateReady, pg.processes["model1"].CurrentState())
|
||||
|
||||
// Park a fast-path request at the race window.
|
||||
r2Reached := make(chan struct{})
|
||||
r2Release := make(chan struct{})
|
||||
pg.testDelayFastPath = func() {
|
||||
close(r2Reached)
|
||||
<-r2Release
|
||||
}
|
||||
|
||||
r2Done := make(chan struct{})
|
||||
w2 := httptest.NewRecorder()
|
||||
go func() {
|
||||
defer close(r2Done)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
assert.NoError(t, pg.ProxyRequest("model1", w2, req))
|
||||
}()
|
||||
|
||||
<-r2Reached
|
||||
|
||||
// Simulate an external caller (e.g. ProxyManager.swapProcessGroup) stopping
|
||||
// the group while a fast-path request is in flight.
|
||||
r3Done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(r3Done)
|
||||
pg.StopProcesses(StopWaitForInflightRequest)
|
||||
}()
|
||||
|
||||
// Spin until StopProcesses has acquired pg.Lock.
|
||||
for pg.TryLock() {
|
||||
pg.Unlock()
|
||||
runtime.Gosched()
|
||||
}
|
||||
|
||||
// Bounded poll: in the fixed code StopProcesses blocks on pg.inflight.Wait()
|
||||
// and model1 stays Ready. In the buggy code it proceeds immediately and
|
||||
// kills model1.
|
||||
deadline := time.Now().Add(100 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if pg.processes["model1"].CurrentState() != StateReady {
|
||||
break
|
||||
}
|
||||
select {
|
||||
case <-r3Done:
|
||||
goto done
|
||||
default:
|
||||
}
|
||||
runtime.Gosched()
|
||||
}
|
||||
done:
|
||||
|
||||
select {
|
||||
case <-r3Done:
|
||||
t.Fatal("StopProcesses completed while a fast-path request was still in flight — race not prevented")
|
||||
default:
|
||||
}
|
||||
assert.Equal(t, StateReady, pg.processes["model1"].CurrentState(),
|
||||
"model1 must stay Ready while a fast-path request is in flight")
|
||||
|
||||
close(r2Release)
|
||||
<-r2Done
|
||||
<-r3Done
|
||||
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Contains(t, w2.Body.String(), "model1")
|
||||
}
|
||||
|
||||
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
||||
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
|
||||
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
tests := []string{"model3", "model4"}
|
||||
|
||||
for _, modelName := range tests {
|
||||
t.Run(modelName, func(t *testing.T) {
|
||||
reqBody := `{"x", "y"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelName)
|
||||
})
|
||||
}
|
||||
|
||||
// make sure all the processes are running
|
||||
for _, process := range pg.processes {
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,358 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
Unlisted bool `json:"unlisted"`
|
||||
PeerID string `json:"peerID"`
|
||||
Aliases []string `json:"aliases,omitempty"`
|
||||
}
|
||||
|
||||
func addApiHandlers(pm *ProxyManager) {
|
||||
// Add API endpoints for React to consume
|
||||
// Protected with API key authentication
|
||||
apiGroup := pm.ginEngine.Group("/api", pm.apiKeyAuth())
|
||||
{
|
||||
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
||||
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
|
||||
apiGroup.GET("/events", pm.apiSendEvents)
|
||||
apiGroup.GET("/metrics", pm.apiGetMetrics)
|
||||
apiGroup.GET("/performance", pm.apiGetPerformance)
|
||||
apiGroup.GET("/version", pm.apiGetVersion)
|
||||
apiGroup.GET("/captures/:id", pm.apiGetCapture)
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiUnloadAllModels(c *gin.Context) {
|
||||
pm.StopProcesses(StopImmediately)
|
||||
c.JSON(http.StatusOK, gin.H{"msg": "ok"})
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) getModelStatus() []Model {
|
||||
// Extract keys and sort them
|
||||
models := []Model{}
|
||||
|
||||
modelIDs := make([]string, 0, len(pm.config.Models))
|
||||
for modelID := range pm.config.Models {
|
||||
modelIDs = append(modelIDs, modelID)
|
||||
}
|
||||
sort.Strings(modelIDs)
|
||||
|
||||
// Iterate over sorted keys
|
||||
for _, modelID := range modelIDs {
|
||||
// Get process state
|
||||
state := "unknown"
|
||||
var process *Process
|
||||
if pm.matrix != nil {
|
||||
process, _ = pm.matrix.GetProcess(modelID)
|
||||
} else {
|
||||
processGroup := pm.findGroupByModelName(modelID)
|
||||
if processGroup != nil {
|
||||
process = processGroup.processes[modelID]
|
||||
}
|
||||
}
|
||||
if process != nil {
|
||||
switch process.CurrentState() {
|
||||
case StateReady:
|
||||
state = "ready"
|
||||
case StateStarting:
|
||||
state = "starting"
|
||||
case StateStopping:
|
||||
state = "stopping"
|
||||
case StateShutdown:
|
||||
state = "shutdown"
|
||||
case StateStopped:
|
||||
state = "stopped"
|
||||
}
|
||||
}
|
||||
models = append(models, Model{
|
||||
Id: modelID,
|
||||
Name: pm.config.Models[modelID].Name,
|
||||
Description: pm.config.Models[modelID].Description,
|
||||
State: state,
|
||||
Unlisted: pm.config.Models[modelID].Unlisted,
|
||||
Aliases: pm.config.Models[modelID].Aliases,
|
||||
})
|
||||
}
|
||||
|
||||
// Iterate over the peer models
|
||||
if pm.peerProxy != nil {
|
||||
for peerID, peer := range pm.peerProxy.ListPeers() {
|
||||
for _, modelID := range peer.Models {
|
||||
models = append(models, Model{
|
||||
Id: modelID,
|
||||
PeerID: peerID,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
type messageType string
|
||||
|
||||
const (
|
||||
msgTypeModelStatus messageType = "modelStatus"
|
||||
msgTypeLogData messageType = "logData"
|
||||
msgTypeMetrics messageType = "metrics"
|
||||
msgTypeInFlight messageType = "inflight"
|
||||
)
|
||||
|
||||
type messageEnvelope struct {
|
||||
Type messageType `json:"type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// sends a stream of different message types that happen on the server
|
||||
func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
// prevent nginx from buffering SSE
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
sendBuffer := make(chan messageEnvelope, 25)
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
sendModels := func() {
|
||||
data, err := json.Marshal(pm.getModelStatus())
|
||||
if err == nil {
|
||||
msg := messageEnvelope{Type: msgTypeModelStatus, Data: string(data)}
|
||||
select {
|
||||
case sendBuffer <- msg:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
sendLogData := func(source string, data []byte) {
|
||||
data, err := json.Marshal(gin.H{
|
||||
"source": source,
|
||||
"data": string(data),
|
||||
})
|
||||
if err == nil {
|
||||
select {
|
||||
case sendBuffer <- messageEnvelope{Type: msgTypeLogData, Data: string(data)}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sendMetrics := func(metrics []ActivityLogEntry) {
|
||||
jsonData, err := json.Marshal(metrics)
|
||||
if err == nil {
|
||||
select {
|
||||
case sendBuffer <- messageEnvelope{Type: msgTypeMetrics, Data: string(jsonData)}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sendInFlight := func(total int) {
|
||||
jsonData, err := json.Marshal(gin.H{"total": total})
|
||||
if err == nil {
|
||||
select {
|
||||
case sendBuffer <- messageEnvelope{Type: msgTypeInFlight, Data: string(jsonData)}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send updated models list
|
||||
*/
|
||||
defer event.On(func(e ProcessStateChangeEvent) {
|
||||
sendModels()
|
||||
})()
|
||||
defer event.On(func(e ConfigFileChangedEvent) {
|
||||
sendModels()
|
||||
})()
|
||||
|
||||
/**
|
||||
* Send Log data
|
||||
*/
|
||||
defer pm.proxyLogger.OnLogData(func(data []byte) {
|
||||
sendLogData("proxy", data)
|
||||
})()
|
||||
defer pm.upstreamLogger.OnLogData(func(data []byte) {
|
||||
sendLogData("upstream", data)
|
||||
})()
|
||||
|
||||
/**
|
||||
* Send Metrics data
|
||||
*/
|
||||
defer event.On(func(e ActivityLogEvent) {
|
||||
sendMetrics([]ActivityLogEntry{e.Metrics})
|
||||
})()
|
||||
|
||||
/**
|
||||
* Send in-flight request stats related to token stats "Waiting: N" count.
|
||||
*/
|
||||
defer event.On(func(e InFlightRequestsEvent) {
|
||||
sendInFlight(e.Total)
|
||||
})()
|
||||
|
||||
// send initial batch of data
|
||||
sendLogData("proxy", pm.proxyLogger.GetHistory())
|
||||
sendLogData("upstream", pm.upstreamLogger.GetHistory())
|
||||
sendModels()
|
||||
sendMetrics(pm.metricsMonitor.getMetrics())
|
||||
sendInFlight(pm.inFlightCounter.Current())
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel()
|
||||
return
|
||||
case <-pm.shutdownCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case msg := <-sendBuffer:
|
||||
c.SSEvent("message", msg)
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
|
||||
jsonData, err := pm.metricsMonitor.getMetricsJSON()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"})
|
||||
return
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json", jsonData)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) prometheusMetricsHandler(c *gin.Context) {
|
||||
if pm.perfMonitor == nil {
|
||||
c.String(http.StatusServiceUnavailable, "# performance monitor not available\n")
|
||||
return
|
||||
}
|
||||
pm.perfMonitor.MetricsHandler().ServeHTTP(c.Writer, c.Request)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiGetPerformance(c *gin.Context) {
|
||||
if pm.perfMonitor == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "performance monitor not available"})
|
||||
return
|
||||
}
|
||||
|
||||
sysStats, gpuStats := pm.perfMonitor.Current()
|
||||
|
||||
var after time.Time
|
||||
if afterStr := c.Query("after"); afterStr != "" {
|
||||
ts, err := time.Parse(time.RFC3339, afterStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid 'after' timestamp, use RFC3339 format"})
|
||||
return
|
||||
}
|
||||
after = ts
|
||||
}
|
||||
|
||||
if !after.IsZero() {
|
||||
filtered := make([]perf.SysStat, 0, len(sysStats))
|
||||
for _, s := range sysStats {
|
||||
if s.Timestamp.After(after) {
|
||||
filtered = append(filtered, s)
|
||||
}
|
||||
}
|
||||
sysStats = filtered
|
||||
|
||||
filteredGpu := make([]perf.GpuStat, 0, len(gpuStats))
|
||||
for _, g := range gpuStats {
|
||||
if g.Timestamp.After(after) {
|
||||
filteredGpu = append(filteredGpu, g)
|
||||
}
|
||||
}
|
||||
gpuStats = filteredGpu
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"sys_stats": sysStats,
|
||||
"gpu_stats": gpuStats,
|
||||
})
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) {
|
||||
requestedModel := strings.TrimPrefix(c.Param("model"), "/")
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, "Model not found")
|
||||
return
|
||||
}
|
||||
|
||||
var stopErr error
|
||||
if pm.matrix != nil {
|
||||
stopErr = pm.matrix.StopProcess(realModelName, StopImmediately)
|
||||
} else {
|
||||
processGroup := pm.findGroupByModelName(realModelName)
|
||||
if processGroup == nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
|
||||
return
|
||||
}
|
||||
stopErr = processGroup.StopProcess(realModelName, StopImmediately)
|
||||
}
|
||||
|
||||
if stopErr != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", stopErr.Error()))
|
||||
return
|
||||
}
|
||||
c.String(http.StatusOK, "OK")
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiGetVersion(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, map[string]string{
|
||||
"version": pm.version,
|
||||
"commit": pm.commit,
|
||||
"build_date": pm.buildDate,
|
||||
})
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiGetCapture(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid capture ID"})
|
||||
return
|
||||
}
|
||||
|
||||
capture := pm.metricsMonitor.getCaptureByID(id)
|
||||
if capture == nil || (capture.ReqPath == "" && capture.ReqHeaders == nil && capture.ReqBody == nil && capture.RespHeaders == nil && capture.RespBody == nil) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"})
|
||||
return
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(capture)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to marshal capture"})
|
||||
return
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json", jsonBytes)
|
||||
}
|
||||
@@ -1,121 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
||||
accept := c.GetHeader("Accept")
|
||||
if strings.Contains(accept, "text/html") {
|
||||
c.Redirect(http.StatusFound, "/ui/")
|
||||
} else {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
history := pm.muxLogger.GetHistory()
|
||||
_, err := c.Writer.Write(history)
|
||||
if err != nil {
|
||||
c.AbortWithError(http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.Header("Transfer-Encoding", "chunked")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
// prevent nginx from buffering streamed logs
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
logMonitorId := strings.TrimPrefix(c.Param("logMonitorID"), "/")
|
||||
|
||||
// Handle case where query string might be included in the parameter
|
||||
// (can happen with catch-all routes on some versions/setups)
|
||||
if idx := strings.Index(logMonitorId, "?"); idx != -1 {
|
||||
logMonitorId = logMonitorId[:idx]
|
||||
}
|
||||
|
||||
logger, err := pm.getLogger(logMonitorId)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
|
||||
return
|
||||
}
|
||||
|
||||
_, skipHistory := c.GetQuery("no-history")
|
||||
// Send history first if not skipped
|
||||
|
||||
if !skipHistory {
|
||||
history := logger.GetHistory()
|
||||
if len(history) != 0 {
|
||||
c.Writer.Write(history)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
sendChan := make(chan []byte, 10)
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer logger.OnLogData(func(data []byte) {
|
||||
select {
|
||||
case sendChan <- data:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
})()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel()
|
||||
return
|
||||
case <-pm.shutdownCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case data := <-sendChan:
|
||||
c.Writer.Write(data)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getLogger searches for the appropriate logger based on the logMonitorId
|
||||
func (pm *ProxyManager) getLogger(logMonitorId string) (*logmon.Monitor, error) {
|
||||
switch logMonitorId {
|
||||
case "":
|
||||
// maintain the default
|
||||
return pm.muxLogger, nil
|
||||
case "proxy":
|
||||
return pm.proxyLogger, nil
|
||||
case "upstream":
|
||||
return pm.upstreamLogger, nil
|
||||
default:
|
||||
// search for a models specific logger using findModelInPath
|
||||
// to handle model names with slashes (e.g., "author/model")
|
||||
if _, name, _, found := pm.findModelInPath("/" + logMonitorId); found {
|
||||
for _, group := range pm.processGroups {
|
||||
if process, found := group.GetMember(name); found {
|
||||
return process.Logger(), nil
|
||||
}
|
||||
}
|
||||
// also check the matrix when processGroups doesn't contain the model
|
||||
if pm.matrix != nil {
|
||||
if process, found := pm.matrix.GetProcess(name); found {
|
||||
return process.Logger(), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID")
|
||||
}
|
||||
}
|
||||
@@ -1,173 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLogMonitorIdQueryParameterStripping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "upstream without query param",
|
||||
input: "upstream",
|
||||
expected: "upstream",
|
||||
},
|
||||
{
|
||||
name: "upstream with query param",
|
||||
input: "upstream?no-history",
|
||||
expected: "upstream",
|
||||
},
|
||||
{
|
||||
name: "proxy with multiple query params",
|
||||
input: "proxy?no-history&foo=bar",
|
||||
expected: "proxy",
|
||||
},
|
||||
{
|
||||
name: "model with slash and query param",
|
||||
input: "author/model?no-history",
|
||||
expected: "author/model",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Simulate the query parameter stripping logic
|
||||
logMonitorId := tt.input
|
||||
if idx := strings.Index(logMonitorId, "?"); idx != -1 {
|
||||
logMonitorId = logMonitorId[:idx]
|
||||
}
|
||||
|
||||
if logMonitorId != tt.expected {
|
||||
t.Errorf("Query parameter stripping failed: got %q, want %q", logMonitorId, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyManager_GetLogger_ProcessGroups verifies getLogger resolves the
|
||||
// well-known "proxy"/"upstream" loggers and a model ID managed by processGroups.
|
||||
func TestProxyManager_GetLogger_ProcessGroups(t *testing.T) {
|
||||
cfg := testConfigFromYAML(t, `
|
||||
healthCheckTimeout: 15
|
||||
logLevel: error
|
||||
models:
|
||||
model1:
|
||||
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
||||
`)
|
||||
pm := New(cfg)
|
||||
defer pm.StopProcesses(StopImmediately)
|
||||
|
||||
tests := []struct {
|
||||
id string
|
||||
wantErr bool
|
||||
}{
|
||||
{"proxy", false},
|
||||
{"upstream", false},
|
||||
{"model1", false},
|
||||
{"does-not-exist", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.id, func(t *testing.T) {
|
||||
logger, err := pm.getLogger(tt.id)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid logger")
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, logger)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyManager_GetLogger_Matrix verifies that getLogger can resolve a model
|
||||
// ID when the proxy is configured with a swap matrix (pm.processGroups is empty
|
||||
// for matrix-managed models).
|
||||
func TestProxyManager_GetLogger_Matrix(t *testing.T) {
|
||||
cfg := config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
ExpandedSets: []config.ExpandedSet{
|
||||
{SetName: "s1", Models: []string{"model1", "model2"}},
|
||||
},
|
||||
Matrix: &config.MatrixConfig{},
|
||||
}
|
||||
|
||||
pm := New(cfg)
|
||||
defer pm.StopProcesses(StopImmediately)
|
||||
|
||||
tests := []struct {
|
||||
id string
|
||||
wantErr bool
|
||||
}{
|
||||
{"proxy", false},
|
||||
{"upstream", false},
|
||||
{"model1", false},
|
||||
{"model2", false},
|
||||
{"does-not-exist", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.id, func(t *testing.T) {
|
||||
logger, err := pm.getLogger(tt.id)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid logger")
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, logger)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyManager_StreamLogs_Matrix verifies that /logs/stream/<modelID>
|
||||
// returns 200 (not 400) for a model managed by the swap matrix.
|
||||
func TestProxyManager_StreamLogs_Matrix(t *testing.T) {
|
||||
cfg := config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"matrix-model": getTestSimpleResponderConfig("matrix-model"),
|
||||
},
|
||||
ExpandedSets: []config.ExpandedSet{
|
||||
{SetName: "s1", Models: []string{"matrix-model"}},
|
||||
},
|
||||
Matrix: &config.MatrixConfig{},
|
||||
}
|
||||
|
||||
pm := New(cfg)
|
||||
defer pm.StopProcesses(StopImmediately)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
req := httptest.NewRequest("GET", "/logs/stream/matrix-model", nil)
|
||||
req = req.WithContext(ctx)
|
||||
rec := CreateTestResponseRecorder()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
pm.ServeHTTP(rec, req)
|
||||
}()
|
||||
|
||||
<-ctx.Done()
|
||||
<-done
|
||||
|
||||
assert.Equal(t, 200, rec.Code)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,43 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func isTokenChar(r rune) bool {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
case r >= 'A' && r <= 'Z':
|
||||
case r >= '0' && r <= '9':
|
||||
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func SanitizeAccessControlRequestHeaderValues(headerValues string) string {
|
||||
parts := strings.Split(headerValues, ",")
|
||||
valid := make([]string, 0, len(parts))
|
||||
|
||||
for _, p := range parts {
|
||||
v := strings.TrimSpace(p)
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
validPart := true
|
||||
for _, c := range v {
|
||||
if !isTokenChar(c) {
|
||||
validPart = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if validPart {
|
||||
valid = append(valid, v)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(valid, ", ")
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "whitespace only",
|
||||
input: " ",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "single valid value",
|
||||
input: "content-type",
|
||||
expected: "content-type",
|
||||
},
|
||||
{
|
||||
name: "multiple valid values",
|
||||
input: "content-type, authorization, x-requested-with",
|
||||
expected: "content-type, authorization, x-requested-with",
|
||||
},
|
||||
{
|
||||
name: "values with extra spaces",
|
||||
input: " content-type , authorization ",
|
||||
expected: "content-type, authorization",
|
||||
},
|
||||
{
|
||||
name: "values with tabs",
|
||||
input: "content-type,\tauthorization",
|
||||
expected: "content-type, authorization",
|
||||
},
|
||||
{
|
||||
name: "values with invalid characters",
|
||||
input: "content-type, auth\n, x-requested-with\r",
|
||||
expected: "content-type, auth, x-requested-with",
|
||||
},
|
||||
{
|
||||
name: "empty values in list",
|
||||
input: "content-type,,authorization",
|
||||
expected: "content-type, authorization",
|
||||
},
|
||||
{
|
||||
name: "leading and trailing commas",
|
||||
input: ",content-type,authorization,",
|
||||
expected: "content-type, authorization",
|
||||
},
|
||||
{
|
||||
name: "mixed valid and invalid values",
|
||||
input: "content-type, \x00invalid, x-requested-with",
|
||||
expected: "content-type, x-requested-with",
|
||||
},
|
||||
{
|
||||
name: "mixed case values",
|
||||
input: "Content-Type, my-Valid-Header, Another-hEader",
|
||||
expected: "Content-Type, my-Valid-Header, Another-hEader",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SanitizeAccessControlRequestHeaderValues(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q",
|
||||
tt.input, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,81 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// selectEncoding chooses the best encoding based on Accept-Encoding header
|
||||
// Returns the encoding ("br", "gzip", or "") and the corresponding file extension
|
||||
func selectEncoding(acceptEncoding string) (encoding, ext string) {
|
||||
if acceptEncoding == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||
enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0])
|
||||
if enc == "br" {
|
||||
return "br", ".br"
|
||||
}
|
||||
}
|
||||
|
||||
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||
enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0])
|
||||
if enc == "gzip" {
|
||||
return "gzip", ".gz"
|
||||
}
|
||||
}
|
||||
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// ServeCompressedFile serves a file with compression support.
|
||||
// It checks for pre-compressed versions and serves them with proper headers.
|
||||
func ServeCompressedFile(fs http.FileSystem, w http.ResponseWriter, r *http.Request, name string) {
|
||||
encoding, ext := selectEncoding(r.Header.Get("Accept-Encoding"))
|
||||
|
||||
// Try to serve compressed version if client supports it
|
||||
if encoding != "" {
|
||||
if cf, err := fs.Open(name + ext); err == nil {
|
||||
defer cf.Close()
|
||||
|
||||
// Verify it's a regular file (not a directory)
|
||||
if stat, err := cf.Stat(); err == nil && !stat.IsDir() {
|
||||
// Set the content encoding header
|
||||
w.Header().Set("Content-Encoding", encoding)
|
||||
w.Header().Add("Vary", "Accept-Encoding")
|
||||
|
||||
// Get original file info for content type detection
|
||||
origFile, err := fs.Open(name)
|
||||
if err == nil {
|
||||
origFile.Close()
|
||||
}
|
||||
|
||||
// Serve the compressed file
|
||||
http.ServeContent(w, r, name, stat.ModTime(), cf)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to serving the uncompressed file
|
||||
file, err := fs.Open(name)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if stat.IsDir() {
|
||||
http.Error(w, "is a directory", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
http.ServeContent(w, r, name, stat.ModTime(), file)
|
||||
}
|
||||
@@ -1,283 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestServeCompressedFile_Brotli(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is test content that should be compressed with brotli")
|
||||
brContent := []byte("fake-brotli-compressed-data")
|
||||
|
||||
// Create a test filesystem
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
"test.js.br": {Data: brContent, ModTime: time.Now()},
|
||||
"test.js.gz": {Data: []byte("fake-gzip-data"), ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
req.Header.Set("Accept-Encoding", "br, gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check that brotli is used (preferred over gzip)
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" {
|
||||
t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding)
|
||||
}
|
||||
|
||||
if vary := resp.Header.Get("Vary"); vary != "Accept-Encoding" {
|
||||
t.Errorf("Expected Vary 'Accept-Encoding', got '%s'", vary)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, brContent) {
|
||||
t.Errorf("Expected brotli content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_Gzip(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is test content that should be compressed with gzip")
|
||||
gzContent := []byte("fake-gzip-compressed-data")
|
||||
|
||||
// Create a test filesystem without brotli
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
"test.js.gz": {Data: gzContent, ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" {
|
||||
t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, gzContent) {
|
||||
t.Errorf("Expected gzip content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_UncompressedFallback(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is uncompressed test content")
|
||||
|
||||
// Create a test filesystem without compressed versions
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
req.Header.Set("Accept-Encoding", "br, gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Should not have Content-Encoding header since we're serving uncompressed
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
|
||||
t.Errorf("Expected no Content-Encoding, got '%s'", encoding)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, content) {
|
||||
t.Errorf("Expected original content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_NoAcceptEncoding(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is test content")
|
||||
|
||||
// Create a test filesystem with compressed versions
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
"test.js.br": {Data: []byte("brotli"), ModTime: time.Now()},
|
||||
"test.js.gz": {Data: []byte("gzip"), ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
// No Accept-Encoding header
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Should serve uncompressed content
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
|
||||
t.Errorf("Expected no Content-Encoding, got '%s'", encoding)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, content) {
|
||||
t.Errorf("Expected original content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_NotFound(t *testing.T) {
|
||||
mapFS := fstest.MapFS{}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/nonexistent.js", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "nonexistent.js")
|
||||
|
||||
resp := w.Result()
|
||||
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("Expected status 404, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectEncoding(t *testing.T) {
|
||||
tests := []struct {
|
||||
acceptEncoding string
|
||||
wantEncoding string
|
||||
wantExt string
|
||||
}{
|
||||
{"br, gzip", "br", ".br"},
|
||||
{"gzip, deflate", "gzip", ".gz"},
|
||||
{"gzip", "gzip", ".gz"},
|
||||
{"br", "br", ".br"},
|
||||
{"", "", ""},
|
||||
{"deflate", "", ""},
|
||||
{"br;q=1.0, gzip;q=0.5", "br", ".br"},
|
||||
{"gzip;q=1.0, br;q=0.5", "br", ".br"},
|
||||
{"browser", "", ""},
|
||||
{"compress, deflate", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
gotEncoding, gotExt := selectEncoding(tt.acceptEncoding)
|
||||
if gotEncoding != tt.wantEncoding || gotExt != tt.wantExt {
|
||||
t.Errorf("selectEncoding(%q) = (%q, %q), want (%q, %q)",
|
||||
tt.acceptEncoding, gotEncoding, gotExt, tt.wantEncoding, tt.wantExt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test with actual pre-compressed files from ui_dist
|
||||
func TestServeCompressedFile_RealFiles(t *testing.T) {
|
||||
// Check if ui_dist exists
|
||||
if _, err := os.Stat("./ui_dist"); os.IsNotExist(err) {
|
||||
t.Skip("ui_dist not found, skipping real file test")
|
||||
}
|
||||
|
||||
// Find a .js or .css file that has compressed versions
|
||||
entries, err := os.ReadDir("./ui_dist/assets")
|
||||
if err != nil {
|
||||
t.Skipf("Could not read ui_dist/assets: %v", err)
|
||||
}
|
||||
|
||||
var testFile string
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if strings.HasSuffix(name, ".js") && !strings.HasSuffix(name, ".js.gz") && !strings.HasSuffix(name, ".js.br") {
|
||||
// Check if compressed versions exist
|
||||
base := strings.TrimSuffix(name, ".js")
|
||||
if _, err := os.Stat(filepath.Join("./ui_dist/assets", base+".js.gz")); err == nil {
|
||||
testFile = "assets/" + name
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if testFile == "" {
|
||||
t.Skip("No suitable test file found with compressed versions")
|
||||
}
|
||||
|
||||
fs := http.FS(os.DirFS("./ui_dist"))
|
||||
|
||||
// Test brotli
|
||||
t.Run("brotli", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil)
|
||||
req.Header.Set("Accept-Encoding", "br")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, testFile)
|
||||
|
||||
resp := w.Result()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" {
|
||||
t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding)
|
||||
}
|
||||
})
|
||||
|
||||
// Test gzip
|
||||
t.Run("gzip", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, testFile)
|
||||
|
||||
resp := w.Result()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" {
|
||||
t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding)
|
||||
}
|
||||
|
||||
// Verify it's valid gzip
|
||||
reader, err := gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid gzip content: %v", err)
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Just read to verify it's valid
|
||||
_, err = io.Copy(io.Discard, reader)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decompress gzip: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
//go:embed ui_dist
|
||||
var reactStaticFS embed.FS
|
||||
|
||||
// GetReactFS returns the embedded React filesystem
|
||||
func GetReactFS() (http.FileSystem, error) {
|
||||
subFS, err := fs.Sub(reactStaticFS, "ui_dist")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return http.FS(subFS), nil
|
||||
}
|
||||
|
||||
// GetReactIndexHTML returns the main index.html for the React app
|
||||
func GetReactIndexHTML() ([]byte, error) {
|
||||
return reactStaticFS.ReadFile("ui_dist/index.html")
|
||||
}
|
||||
Reference in New Issue
Block a user