diff --git a/cmd/legacy/llama-swap.go b/cmd/legacy/llama-swap.go deleted file mode 100644 index 2b2d26f..0000000 --- a/cmd/legacy/llama-swap.go +++ /dev/null @@ -1,249 +0,0 @@ -package main - -import ( - "context" - "flag" - "fmt" - "net/http" - "os" - "os/signal" - "path/filepath" - "strings" - "sync" - "syscall" - "time" - - "github.com/gin-gonic/gin" - "github.com/mostlygeek/llama-swap/internal/config" - "github.com/mostlygeek/llama-swap/internal/event" - "github.com/mostlygeek/llama-swap/internal/logmon" - "github.com/mostlygeek/llama-swap/internal/perf" - "github.com/mostlygeek/llama-swap/internal/watcher" - "github.com/mostlygeek/llama-swap/proxy" -) - -var ( - version string = "0" - commit string = "abcd1234" - date string = "unknown" -) - -func main() { - // Define a command-line flag for the port - configPath := flag.String("config", "config.yaml", "config file name") - listenStr := flag.String("listen", "", "listen ip/port") - certFile := flag.String("tls-cert-file", "", "TLS certificate file") - keyFile := flag.String("tls-key-file", "", "TLS key file") - showVersion := flag.Bool("version", false, "show version of build") - watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change") - mainLogger := logmon.New() - - flag.Parse() // Parse the command-line flags - - if *showVersion { - fmt.Printf("version: %s (%s), built at %s", version, commit, date) - os.Exit(0) - } - - conf, err := config.LoadConfig(*configPath) - if err != nil { - mainLogger.Errorf("Error loading config: %v", err) - os.Exit(1) - } - - if len(conf.Profiles) > 0 { - mainLogger.Warn("Profile functionality has been removed in favor of Groups. See the README for more information.") - } - - switch strings.ToLower(strings.TrimSpace(conf.LogLevel)) { - case "debug": - mainLogger.SetLogLevel(logmon.LevelDebug) - case "info": - mainLogger.SetLogLevel(logmon.LevelInfo) - case "warn": - mainLogger.SetLogLevel(logmon.LevelWarn) - case "error": - mainLogger.SetLogLevel(logmon.LevelError) - default: - mainLogger.SetLogLevel(logmon.LevelInfo) - } - - mainLogger.Debugf("PID: %d", os.Getpid()) - - if mode := os.Getenv("GIN_MODE"); mode != "" { - gin.SetMode(mode) - } else { - gin.SetMode(gin.ReleaseMode) - } - - // Validate TLS flags. - var useTLS = (*certFile != "" && *keyFile != "") - if (*certFile != "" && *keyFile == "") || - (*certFile == "" && *keyFile != "") { - fmt.Println("Error: Both --tls-cert-file and --tls-key-file must be provided for TLS.") - os.Exit(1) - } - - // Set default ports. - if *listenStr == "" { - defaultPort := ":8080" - if useTLS { - defaultPort = ":8443" - } - listenStr = &defaultPort - } - - var mon *perf.Monitor - if !conf.Performance.Disabled { - mon, err = perf.New(conf.Performance, mainLogger) - if err != nil { - mainLogger.Errorf("failed to create monitor: %s", err.Error()) - os.Exit(1) - } - mon.Start() - } else { - mainLogger.Info("performance monitoring is disabled") - } - - // Setup channels for server management - exitChan := make(chan struct{}) - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) - - // Context that bounds the lifetime of background watcher goroutines. - watcherCtx, watcherCancel := context.WithCancel(context.Background()) - - // Create server with initial handlergit - srv := &http.Server{ - Addr: *listenStr, - } - - // Support for watching config and reloading when it changes - reloading := false - var reloadMutex sync.Mutex - reloadProxyManager := func() { - reloadMutex.Lock() - if reloading { - reloadMutex.Unlock() - return - } - reloading = true - reloadMutex.Unlock() - defer func() { - reloadMutex.Lock() - reloading = false - reloadMutex.Unlock() - }() - - if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok { - mainLogger.Info("Reloading Configuration") - conf, err = config.LoadConfig(*configPath) - if err != nil { - mainLogger.Warnf("Unable to reload configuration: %v", err) - return - } - - mainLogger.Debug("Configuration Changed") - currentPM.Shutdown() - if mon != nil { - mon.UpdateConfig(conf.Performance) - } - newPM := proxy.New(conf) - newPM.SetVersion(date, commit, version) - newPM.SetPerfMonitor(mon) - srv.Handler = newPM - mainLogger.Debug("Configuration Reloaded") - - // wait a few seconds and tell any UI to reload - time.AfterFunc(3*time.Second, func() { - event.Emit(proxy.ConfigFileChangedEvent{ - ReloadingState: proxy.ReloadingStateEnd, - }) - }) - } else { - conf, err = config.LoadConfig(*configPath) - if err != nil { - mainLogger.Errorf("Unable to load configuration: %v", err) - os.Exit(1) - } - newPM := proxy.New(conf) - newPM.SetVersion(date, commit, version) - newPM.SetPerfMonitor(mon) - srv.Handler = newPM - } - } - - // load the initial proxy manager - reloadProxyManager() - - if *watchConfig { - go func() { - absConfigPath, err := filepath.Abs(*configPath) - if err != nil { - mainLogger.Errorf("watch-config unable to determine absolute path for watching config file: %v", err) - return - } - mainLogger.Info("Watching configuration for changes (poll-based, 2s interval)") - (&configwatcher.Watcher{ - Path: absConfigPath, - Interval: configwatcher.DefaultInterval, - OnChange: func() { - reloadProxyManager() - }, - }).Run(watcherCtx) - }() - } - - // Signal handling - go func() { - for { - sig := <-sigChan - switch sig { - case syscall.SIGHUP: - mainLogger.Debug("Received SIGHUP") - reloadProxyManager() - case syscall.SIGINT, syscall.SIGTERM: - mainLogger.Debugf("Received signal %v, shutting down...", sig) - if mon != nil { - mon.Stop() - } - watcherCancel() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - if pm, ok := srv.Handler.(*proxy.ProxyManager); ok { - pm.Shutdown() - } else { - mainLogger.Error("srv.Handler is not of type *proxy.ProxyManager") - } - - if err := srv.Shutdown(ctx); err != nil { - mainLogger.Errorf("Server shutdown: %v", err) - } - close(exitChan) - return - default: - // do nothing on other signals - } - } - }() - - // Start server - go func() { - var err error - if useTLS { - mainLogger.Infof("llama-swap listening with TLS on https://%s", *listenStr) - err = srv.ListenAndServeTLS(*certFile, *keyFile) - } else { - mainLogger.Infof("llama-swap listening on http://%s", *listenStr) - err = srv.ListenAndServe() - } - if err != nil && err != http.ErrServerClosed { - mainLogger.Errorf("Fatal server error: %v", err) - os.Exit(1) - } - }() - - // Wait for exit signal - <-exitChan -} diff --git a/proxy/.gitignore b/proxy/.gitignore deleted file mode 100644 index 94bc667..0000000 --- a/proxy/.gitignore +++ /dev/null @@ -1 +0,0 @@ -ui_dist/* \ No newline at end of file diff --git a/proxy/discardWriter.go b/proxy/discardWriter.go deleted file mode 100644 index 8af8c04..0000000 --- a/proxy/discardWriter.go +++ /dev/null @@ -1,27 +0,0 @@ -package proxy - -import "net/http" - -// Custom discard writer that implements http.ResponseWriter but just discards everything -type DiscardWriter struct { - header http.Header - status int -} - -func (w *DiscardWriter) Header() http.Header { - if w.header == nil { - w.header = make(http.Header) - } - return w.header -} - -func (w *DiscardWriter) Write(data []byte) (int, error) { - return len(data), nil -} - -func (w *DiscardWriter) WriteHeader(code int) { - w.status = code -} - -// Satisfy the http.Flusher interface for streaming responses -func (w *DiscardWriter) Flush() {} diff --git a/proxy/events.go b/proxy/events.go deleted file mode 100644 index 3b0b2a4..0000000 --- a/proxy/events.go +++ /dev/null @@ -1,60 +0,0 @@ -package proxy - -// package level registry of the different event types - -const ProcessStateChangeEventID = 0x01 -const ChatCompletionStatsEventID = 0x02 -const ConfigFileChangedEventID = 0x03 -const ActivityLogEventID = 0x05 -const ModelPreloadedEventID = 0x06 -const InFlightRequestsEventID = 0x07 - -type ProcessStateChangeEvent struct { - ProcessName string - NewState ProcessState - OldState ProcessState -} - -func (e ProcessStateChangeEvent) Type() uint32 { - return ProcessStateChangeEventID -} - -type ChatCompletionStats struct { - TokensGenerated int -} - -func (e ChatCompletionStats) Type() uint32 { - return ChatCompletionStatsEventID -} - -type ReloadingState int - -const ( - ReloadingStateStart ReloadingState = iota - ReloadingStateEnd -) - -type ConfigFileChangedEvent struct { - ReloadingState ReloadingState -} - -func (e ConfigFileChangedEvent) Type() uint32 { - return ConfigFileChangedEventID -} - -type ModelPreloadedEvent struct { - ModelName string - Success bool -} - -func (e ModelPreloadedEvent) Type() uint32 { - return ModelPreloadedEventID -} - -type InFlightRequestsEvent struct { - Total int -} - -func (e InFlightRequestsEvent) Type() uint32 { - return InFlightRequestsEventID -} diff --git a/proxy/helpers_test.go b/proxy/helpers_test.go deleted file mode 100644 index 185cb5b..0000000 --- a/proxy/helpers_test.go +++ /dev/null @@ -1,304 +0,0 @@ -package proxy - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "runtime" - "strconv" - "strings" - "sync" - "testing" - "time" - - "github.com/gin-gonic/gin" - "github.com/mostlygeek/llama-swap/internal/config" - "github.com/mostlygeek/llama-swap/internal/logmon" - "github.com/stretchr/testify/require" - "github.com/tidwall/gjson" - "gopkg.in/yaml.v3" -) - -var ( - nextTestPort int = 12000 - portMutex sync.Mutex - testLogger = logmon.NewWriter(os.Stdout) - simpleResponderPath = getSimpleResponderPath() -) - -// Check if the binary exists -func TestMain(m *testing.M) { - binaryPath := getSimpleResponderPath() - if _, err := os.Stat(binaryPath); os.IsNotExist(err) { - fmt.Printf("simple-responder not found at %s, did you `make simple-responder`?\n", binaryPath) - os.Exit(1) - } - - gin.SetMode(gin.TestMode) - - switch os.Getenv("LOG_LEVEL") { - case "debug": - testLogger.SetLogLevel(logmon.LevelDebug) - case "warn": - testLogger.SetLogLevel(logmon.LevelWarn) - case "info": - testLogger.SetLogLevel(logmon.LevelInfo) - default: - testLogger.SetLogLevel(logmon.LevelWarn) - } - - m.Run() -} - -// Helper function to get the binary path -func getSimpleResponderPath() string { - goos := runtime.GOOS - goarch := runtime.GOARCH - - if goos == "windows" { - return filepath.Join("..", "build", "simple-responder.exe") - } else { - return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch)) - } -} - -func getTestPort() int { - portMutex.Lock() - defer portMutex.Unlock() - - port := nextTestPort - nextTestPort++ - - return port -} - -// testConfigFromYAML substitutes {{RESPONDER}} with the simple-responder path and -// loads through the real config pipeline (env vars, macros, port assignment, etc.) -func testConfigFromYAML(t *testing.T, yamlTmpl string) config.Config { - t.Helper() - yamlStr := strings.ReplaceAll(yamlTmpl, "{{RESPONDER}}", filepath.ToSlash(simpleResponderPath)) - cfg, err := config.LoadConfigFromReader(strings.NewReader(yamlStr)) - require.NoError(t, err) - return cfg -} - -func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig { - return getTestSimpleResponderConfigPort(expectedMessage, getTestPort()) -} - -func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig { - // Convert path to forward slashes for cross-platform compatibility - // Windows handles forward slashes in paths correctly - cmdPath := filepath.ToSlash(simpleResponderPath) - - // Create a YAML string with just the values we want to set - yamlStr := fmt.Sprintf(` -cmd: '%s --port %d --silent --respond %s' -proxy: "http://127.0.0.1:%d" -`, cmdPath, port, expectedMessage, port) - - var cfg config.ModelConfig - if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil { - panic(fmt.Sprintf("failed to unmarshal test config: %v in [%s]", err, yamlStr)) - } - - return cfg -} - -// injectTestHandlers sets a testHandler on every Process in every ProcessGroup -// of the given ProxyManager, bypassing subprocess launches. modelResponses maps -// model IDs to their respond strings; if a model ID is not in the map, the model -// ID itself is used. -func injectTestHandlers(pm *ProxyManager, modelResponses map[string]string) { - for _, pg := range pm.processGroups { - for modelID, process := range pg.processes { - respond := modelID - if r, ok := modelResponses[modelID]; ok { - respond = r - } - process.testHandler = newTestHandler(respond) - } - } -} - -// newTestHandler returns an http.Handler that mimics simple-responder's API. -// It supports the endpoints that routing tests depend on, without launching -// any subprocess or binding any port. -func respondJSON(w http.ResponseWriter, respond string, bodyBytes []byte) { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{ - "responseMessage": respond, - "h_content_length": strconv.Itoa(len(bodyBytes)), - "request_body": string(bodyBytes), - "usage": map[string]any{ - "completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35, - }, - "timings": map[string]any{ - "prompt_n": 25, "prompt_ms": 13, "predicted_n": 10, - "predicted_ms": 17, "predicted_per_second": 10, - }, - }) -} - -func newTestHandler(respond string) http.Handler { - mux := http.NewServeMux() - - mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { - bodyBytes, _ := io.ReadAll(r.Body) - isStreaming := r.URL.Query().Get("stream") == "true" - - if wait := r.URL.Query().Get("wait"); wait != "" { - if d, err := time.ParseDuration(wait); err == nil { - time.Sleep(d) - } - } - - if isStreaming { - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - flusher := w.(http.Flusher) - - for i := 0; i < 10; i++ { - data, _ := json.Marshal(map[string]any{ - "created": time.Now().Unix(), - "choices": []map[string]any{ - {"index": 0, "delta": map[string]any{"content": "asdf"}, "finish_reason": nil}, - }, - }) - fmt.Fprintf(w, "event: message\ndata: %s\n\n", data) - flusher.Flush() - } - - finalData, _ := json.Marshal(map[string]any{ - "usage": map[string]any{ - "completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35, - }, - "timings": map[string]any{ - "prompt_n": 25, "prompt_ms": 13, "predicted_n": 10, - "predicted_ms": 17, "predicted_per_second": 10, - }, - }) - fmt.Fprintf(w, "event: message\ndata: %s\n\n", finalData) - flusher.Flush() - - fmt.Fprintf(w, "event: message\ndata: [DONE]\n\n") - flusher.Flush() - } else { - respondJSON(w, respond, bodyBytes) - } - }) - - mux.HandleFunc("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) { - body, _ := io.ReadAll(r.Body) - modelName := gjson.GetBytes(body, "model").String() - if modelName != respond { - w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, respond)}) - return - } - json.NewEncoder(w).Encode(map[string]string{"message": "ok"}) - }) - - mux.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) { - bodyBytes, _ := io.ReadAll(r.Body) - respondJSON(w, respond, bodyBytes) - }) - - for _, path := range []string{ - "/chat/completions", "/completions", - "/responses", "/messages", "/messages/count_tokens", - "/embeddings", "/rerank", "/reranking", - } { - mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { - bodyBytes, _ := io.ReadAll(r.Body) - respondJSON(w, respond, bodyBytes) - }) - } - - mux.HandleFunc("/completion", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{ - "responseMessage": respond, - "usage": map[string]any{ - "completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35, - }, - }) - }) - - mux.HandleFunc("/v1/audio/transcriptions", func(w http.ResponseWriter, r *http.Request) { - if err := r.ParseMultipartForm(10 << 20); err != nil { - w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error parsing multipart form: %s", err)}) - return - } - model := r.FormValue("model") - if model == "" { - w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(map[string]string{"error": "Missing model parameter"}) - return - } - file, _, err := r.FormFile("file") - if err != nil { - w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error getting file: %s", err)}) - return - } - fileBytes, _ := io.ReadAll(file) - file.Close() - json.NewEncoder(w).Encode(map[string]any{ - "text": fmt.Sprintf("The length of the file is %d bytes", len(fileBytes)), - "model": model, - "h_content_type": r.Header.Get("Content-Type"), - "h_content_length": r.Header.Get("Content-Length"), - }) - }) - - mux.HandleFunc("/v1/audio/voices", func(w http.ResponseWriter, r *http.Request) { - model := r.URL.Query().Get("model") - json.NewEncoder(w).Encode(map[string]any{ - "voices": []string{"voice1"}, "model": model, - }) - }) - - mux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain") - fmt.Fprint(w, respond) - }) - - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/" { - http.NotFound(w, r) - return - } - w.Header().Set("Content-Type", "text/plain") - fmt.Fprintf(w, "%s %s", r.Method, r.URL.Path) - }) - - mux.HandleFunc("/sdapi/v1/txt2img", func(w http.ResponseWriter, r *http.Request) { - body, _ := io.ReadAll(r.Body) - modelName := gjson.GetBytes(body, "model").String() - json.NewEncoder(w).Encode(map[string]any{ - "model": modelName, "images": []string{}, - }) - }) - - mux.HandleFunc("/sdapi/v1/img2img", func(w http.ResponseWriter, r *http.Request) { - body, _ := io.ReadAll(r.Body) - modelName := gjson.GetBytes(body, "model").String() - json.NewEncoder(w).Encode(map[string]any{ - "model": modelName, "images": []string{}, - }) - }) - - mux.HandleFunc("/sdapi/v1/loras", func(w http.ResponseWriter, r *http.Request) { - json.NewEncoder(w).Encode(map[string]any{ - "loras": []string{}, - }) - }) - - return mux -} diff --git a/proxy/matrix.go b/proxy/matrix.go deleted file mode 100644 index f699436..0000000 --- a/proxy/matrix.go +++ /dev/null @@ -1,330 +0,0 @@ -package proxy - -import ( - "fmt" - "net/http" - "slices" - "sort" - "sync" - - "github.com/mostlygeek/llama-swap/internal/config" - "github.com/mostlygeek/llama-swap/internal/logmon" -) - -// MatrixSolver contains pure swap-decision logic with no Process dependencies. -// It is safe for concurrent reads after construction. -type MatrixSolver struct { - expandedSets []config.ExpandedSet // all valid model combinations - evictCosts map[string]int // real model name -> eviction cost (default 1) - modelToSets map[string][]int // model name -> indices into expandedSets -} - -// NewMatrixSolver builds a solver from expanded sets and eviction costs. -func NewMatrixSolver(expandedSets []config.ExpandedSet, evictCosts map[string]int) *MatrixSolver { - modelToSets := make(map[string][]int) - for i, es := range expandedSets { - for _, model := range es.Models { - modelToSets[model] = append(modelToSets[model], i) - } - } - - return &MatrixSolver{ - expandedSets: expandedSets, - evictCosts: evictCosts, - modelToSets: modelToSets, - } -} - -// SolveResult describes what the solver decided. -type SolveResult struct { - Evict []string // running models that must be stopped - TargetSet []string // the chosen set of models (for informational purposes) - SetName string // name of the chosen set - DSL string // original DSL expression for the chosen set - TotalCost int // total eviction cost -} - -// Solve determines which models to evict when a model is requested. -// -// Algorithm: -// 1. If requestedModel is already running, no eviction needed. -// 2. Find all sets containing requestedModel. -// 3. If no sets found, the model runs alone; evict all running models. -// 4. For each candidate set, compute cost = sum of evict_costs for running -// models NOT in that set. -// 5. Pick lowest cost. Ties broken by definition order (index in expandedSets). -// 6. Return models to evict and the chosen set. -func (s *MatrixSolver) Solve(requestedModel string, runningModels []string) (SolveResult, error) { - // If already running, nothing to do (but fill in set info for logging) - if slices.Contains(runningModels, requestedModel) { - setName, dsl := s.findMatchingSet(requestedModel, runningModels) - return SolveResult{ - TargetSet: runningModels, - SetName: setName, - DSL: dsl, - }, nil - } - - candidateIndices := s.modelToSets[requestedModel] - - // Model not in any set: runs alone, evict everything - if len(candidateIndices) == 0 { - evict := make([]string, len(runningModels)) - copy(evict, runningModels) - return SolveResult{ - Evict: evict, - TargetSet: []string{requestedModel}, - }, nil - } - - // Find the cheapest candidate set - bestCost := -1 - bestIdx := -1 - - for _, idx := range candidateIndices { - setModels := s.expandedSets[idx].Models - cost := 0 - for _, running := range runningModels { - if !slices.Contains(setModels, running) { - cost += s.evictCost(running) - } - } - - if bestCost < 0 || cost < bestCost || (cost == bestCost && idx < bestIdx) { - bestCost = cost - bestIdx = idx - } - } - - // Determine which running models to evict - chosen := s.expandedSets[bestIdx] - var evict []string - for _, running := range runningModels { - if !slices.Contains(chosen.Models, running) { - evict = append(evict, running) - } - } - - return SolveResult{ - Evict: evict, - TargetSet: chosen.Models, - SetName: chosen.SetName, - DSL: chosen.DSL, - TotalCost: bestCost, - }, nil -} - -// findMatchingSet finds the expanded set that contains all running models. -// Returns the set name and DSL, or empty strings if no match. -func (s *MatrixSolver) findMatchingSet(requestedModel string, runningModels []string) (string, string) { - for _, idx := range s.modelToSets[requestedModel] { - set := s.expandedSets[idx] - allInSet := true - for _, m := range runningModels { - if !slices.Contains(set.Models, m) { - allInSet = false - break - } - } - if allInSet { - return set.SetName, set.DSL - } - } - return "", "" -} - -func (s *MatrixSolver) evictCost(model string) int { - if cost, ok := s.evictCosts[model]; ok { - return cost - } - return 1 -} - -// Matrix manages processes using solver-based swap logic. -type Matrix struct { - sync.Mutex - solver *MatrixSolver - processes map[string]*Process // all processes keyed by real model name - config config.Config - proxyLogger *logmon.Monitor - upstreamLogger *logmon.Monitor - - // inflight tracks ProxyRequest calls that have released m.Lock but may - // not yet have incremented Process.inFlightRequests. A concurrent - // request that needs to evict models waits for inflight to drain under - // m.Lock before stopping anything. Without this, a request that - // released m.Lock but has not yet reached Process.inFlightRequests.Add(1) - // races with Stop()'s Wait() and can be killed mid-request. - inflight sync.WaitGroup - - // testDelayFastPath is a test-only hook invoked in the no-eviction path - // after m.Lock is released but before the request is dispatched to - // Process.ProxyRequest. Tests use it to park a request at the exact - // race window to deterministically reproduce the race. - testDelayFastPath func() -} - -// NewMatrix creates a Matrix from config. It creates a Process for every -// model defined in the config (any model can run alone even if not in a set). -func NewMatrix(cfg config.Config, proxyLogger, upstreamLogger *logmon.Monitor) *Matrix { - processes := make(map[string]*Process) - for modelID, modelConfig := range cfg.Models { - processLogger := logmon.NewWriter(upstreamLogger) - process := NewProcess(modelID, cfg.HealthCheckTimeout, modelConfig, processLogger, proxyLogger) - processes[modelID] = process - } - - evictCosts := cfg.Matrix.ResolvedEvictCosts() - - return &Matrix{ - solver: NewMatrixSolver(cfg.ExpandedSets, evictCosts), - processes: processes, - config: cfg, - proxyLogger: proxyLogger, - upstreamLogger: upstreamLogger, - } -} - -// ProxyRequest handles the swap logic and proxies the request to the model. -func (m *Matrix) ProxyRequest(modelID string, w http.ResponseWriter, r *http.Request) error { - process, ok := m.processes[modelID] - if !ok { - return fmt.Errorf("model %s not found in matrix", modelID) - } - - m.Lock() - running := m.runningModels() - result, err := m.solver.Solve(modelID, running) - if err != nil { - m.Unlock() - return fmt.Errorf("matrix solver error: %w", err) - } - - // Log solver decision - if len(result.Evict) > 0 { - m.proxyLogger.Infof("Matrix: model=%s set=%s dsl=%q evict=%v target=%v cost=%d", - modelID, result.SetName, result.DSL, result.Evict, result.TargetSet, result.TotalCost) - } else if len(running) == 0 { - m.proxyLogger.Infof("Matrix: model=%s starting (no models running)", modelID) - } else { - m.proxyLogger.Debugf("Matrix: model=%s already running in set=%s dsl=%q", modelID, result.SetName, result.DSL) - } - - // Evict models that need to be stopped - if len(result.Evict) > 0 { - // Wait for any in-flight ProxyRequest calls to register on their - // Process before stopping anything. Without this, a request that - // released m.Lock but has not yet incremented - // Process.inFlightRequests races with Stop() and can be killed - // mid-request. - m.inflight.Wait() - - var wg sync.WaitGroup - for _, evictModel := range result.Evict { - if p, exists := m.processes[evictModel]; exists { - wg.Add(1) - go func(p *Process) { - defer wg.Done() - p.Stop() - }(p) - } - } - wg.Wait() - } - - // Register this request in inflight before releasing m.Lock so a - // concurrent eviction will wait for it to complete. - m.inflight.Add(1) - defer m.inflight.Done() - isFastPath := len(result.Evict) == 0 - m.Unlock() - - if isFastPath && m.testDelayFastPath != nil { - m.testDelayFastPath() - } - - // Proxy the request (Process handles on-demand start) - process.ProxyRequest(w, r) - return nil -} - -// StopProcesses stops all running processes. -func (m *Matrix) StopProcesses(strategy StopStrategy) { - m.Lock() - defer m.Unlock() - - var wg sync.WaitGroup - for _, process := range m.processes { - wg.Add(1) - go func(p *Process) { - defer wg.Done() - switch strategy { - case StopImmediately: - p.StopImmediately() - default: - p.Stop() - } - }(process) - } - wg.Wait() -} - -// StopProcess stops a single process by model ID. -func (m *Matrix) StopProcess(modelID string, strategy StopStrategy) error { - process, ok := m.processes[modelID] - if !ok { - return fmt.Errorf("process not found for %s", modelID) - } - - switch strategy { - case StopImmediately: - process.StopImmediately() - default: - process.Stop() - } - return nil -} - -// Shutdown shuts down all processes. -func (m *Matrix) Shutdown() { - var wg sync.WaitGroup - for _, process := range m.processes { - wg.Add(1) - go func(p *Process) { - defer wg.Done() - p.Shutdown() - }(process) - } - wg.Wait() -} - -// RunningModels returns model names currently in an active (non-stopped) state. -func (m *Matrix) RunningModels() []string { - m.Lock() - defer m.Unlock() - return m.runningModels() -} - -// runningModels returns running model names (caller must hold lock). -func (m *Matrix) runningModels() []string { - var running []string - for id, process := range m.processes { - if process.CurrentState() != StateStopped && process.CurrentState() != StateShutdown { - running = append(running, id) - } - } - sort.Strings(running) - return running -} - -// GetProcess returns the Process for a model. -func (m *Matrix) GetProcess(modelID string) (*Process, bool) { - p, ok := m.processes[modelID] - return p, ok -} - -// HasModel returns true if the model is managed by this matrix. -func (m *Matrix) HasModel(modelID string) bool { - _, ok := m.processes[modelID] - return ok -} diff --git a/proxy/matrix_test.go b/proxy/matrix_test.go deleted file mode 100644 index 01c3141..0000000 --- a/proxy/matrix_test.go +++ /dev/null @@ -1,349 +0,0 @@ -package proxy - -import ( - "net/http" - "net/http/httptest" - "runtime" - "testing" - "time" - - "github.com/mostlygeek/llama-swap/internal/config" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// Helper to build expanded sets for solver tests -func makeExpandedSets(sets ...struct { - name string - models []string -}) []config.ExpandedSet { - var result []config.ExpandedSet - for _, s := range sets { - result = append(result, config.ExpandedSet{ - SetName: s.name, - Models: s.models, - }) - } - return result -} - -func es(name string, models ...string) struct { - name string - models []string -} { - return struct { - name string - models []string - }{name, models} -} - -func TestMatrixSolver_AlreadyRunning(t *testing.T) { - solver := NewMatrixSolver( - makeExpandedSets(es("s1", "a", "b")), - nil, - ) - - result, err := solver.Solve("a", []string{"a"}) - require.NoError(t, err) - assert.Empty(t, result.Evict) - assert.Equal(t, []string{"a"}, result.TargetSet) - assert.Equal(t, "s1", result.SetName) -} - -func TestMatrixSolver_NotInAnySet_RunsAlone(t *testing.T) { - solver := NewMatrixSolver( - makeExpandedSets(es("s1", "a", "b")), - nil, - ) - - // Model "c" not in any set - result, err := solver.Solve("c", []string{"a", "b"}) - require.NoError(t, err) - assert.ElementsMatch(t, []string{"a", "b"}, result.Evict) - assert.Equal(t, []string{"c"}, result.TargetSet) -} - -func TestMatrixSolver_NotInAnySet_NothingRunning(t *testing.T) { - solver := NewMatrixSolver( - makeExpandedSets(es("s1", "a", "b")), - nil, - ) - - result, err := solver.Solve("c", []string{}) - require.NoError(t, err) - assert.Empty(t, result.Evict) - assert.Equal(t, []string{"c"}, result.TargetSet) -} - -func TestMatrixSolver_SingleSet_EvictsNonMembers(t *testing.T) { - // Set: [a, b]. Request a when b and c are running. - solver := NewMatrixSolver( - makeExpandedSets(es("s1", "a", "b")), - nil, - ) - - result, err := solver.Solve("a", []string{"b", "c"}) - require.NoError(t, err) - // c is not in the set, so it gets evicted. b is in the set, so it stays. - assert.Equal(t, []string{"c"}, result.Evict) - assert.Equal(t, []string{"a", "b"}, result.TargetSet) -} - -func TestMatrixSolver_PicksLowestCost(t *testing.T) { - // Two sets containing model "a": - // s1: [a, v] — if v is running, cost=0; if L is running, cost=30 - // s2: [a, L] — if L is running, cost=0; if v is running, cost=50 - solver := NewMatrixSolver( - makeExpandedSets( - es("s1", "a", "v"), - es("s2", "a", "L"), - ), - map[string]int{"v": 50, "L": 30}, - ) - - // v is running. Switching to a: - // s1 cost: v is in s1, so 0 - // s2 cost: v is NOT in s2, so 50 - // => pick s1 - result, err := solver.Solve("a", []string{"v"}) - require.NoError(t, err) - assert.Empty(t, result.Evict) - assert.Equal(t, []string{"a", "v"}, result.TargetSet) - - // L is running. Switching to a: - // s1 cost: L is NOT in s1, so 30 - // s2 cost: L is in s2, so 0 - // => pick s2 - result, err = solver.Solve("a", []string{"L"}) - require.NoError(t, err) - assert.Empty(t, result.Evict) - assert.Equal(t, []string{"a", "L"}, result.TargetSet) -} - -func TestMatrixSolver_TieBreakingByDefinitionOrder(t *testing.T) { - // Two sets with identical cost. Definition order should win. - solver := NewMatrixSolver( - makeExpandedSets( - es("s1", "a", "x"), - es("s2", "a", "y"), - ), - nil, - ) - - // Nothing running, both sets cost 0. s1 is first. - result, err := solver.Solve("a", []string{}) - require.NoError(t, err) - assert.Empty(t, result.Evict) - assert.Equal(t, []string{"a", "x"}, result.TargetSet) -} - -func TestMatrixSolver_EvictCostPreservesExpensive(t *testing.T) { - // Model "v" costs 50 to evict, "m" costs 1 (default). - // Sets: [g,v], [g,m] - // Running: v, m. Request g. - // s1=[g,v]: evict m (cost 1), keep v - // s2=[g,m]: evict v (cost 50), keep m - // => pick s1 - solver := NewMatrixSolver( - makeExpandedSets( - es("s1", "g", "v"), - es("s2", "g", "m"), - ), - map[string]int{"v": 50}, - ) - - result, err := solver.Solve("g", []string{"v", "m"}) - require.NoError(t, err) - assert.Equal(t, []string{"m"}, result.Evict) - assert.Equal(t, []string{"g", "v"}, result.TargetSet) -} - -func TestMatrixSolver_NothingRunning(t *testing.T) { - solver := NewMatrixSolver( - makeExpandedSets( - es("s1", "g", "v"), - es("s2", "q", "v"), - ), - nil, - ) - - result, err := solver.Solve("g", []string{}) - require.NoError(t, err) - assert.Empty(t, result.Evict) - assert.Equal(t, []string{"g", "v"}, result.TargetSet) -} - -// TestMatrix_ProxyRequestSwapRaceAgainstFastPath verifies that an eviction -// cannot stop a process while an in-flight ProxyRequest for that process is -// still in the [m.Unlock, Process.inFlightRequests.Add(1)] window. Without -// matrix-level inflight tracking, the eviction's Stop() races with the -// pending request and kills it mid-start. -func TestMatrix_ProxyRequestSwapRaceAgainstFastPath(t *testing.T) { - cfg := config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - "model2": getTestSimpleResponderConfig("model2"), - }, - ExpandedSets: []config.ExpandedSet{ - {SetName: "s1", Models: []string{"model1"}}, - {SetName: "s2", Models: []string{"model2"}}, - }, - Matrix: &config.MatrixConfig{}, - } - - m := NewMatrix(cfg, testLogger, testLogger) - defer m.StopProcesses(StopImmediately) - - // Bypass real subprocesses so the test is fast and deterministic. - m.processes["model1"].testHandler = newTestHandler("model1") - m.processes["model2"].testHandler = newTestHandler("model2") - - // Prime: run a request through model1 so it reaches StateReady and - // subsequent requests take the no-eviction path. - primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil) - primeW := httptest.NewRecorder() - require.NoError(t, m.ProxyRequest("model1", primeW, primeReq)) - require.Equal(t, http.StatusOK, primeW.Code) - require.Equal(t, StateReady, m.processes["model1"].CurrentState()) - require.Equal(t, StateStopped, m.processes["model2"].CurrentState()) - - // Install fast-path hook that signals arrival and waits for release. - // This parks R2 at the race window — after m.Lock is released but - // before Process.inFlightRequests.Add(1). - r2Reached := make(chan struct{}) - r2Release := make(chan struct{}) - m.testDelayFastPath = func() { - close(r2Reached) - <-r2Release - } - - // R2: no-eviction request for model1. Will pause at the hook. - r2Done := make(chan struct{}) - w2 := httptest.NewRecorder() - go func() { - defer close(r2Done) - req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - assert.NoError(t, m.ProxyRequest("model1", w2, req)) - }() - - // Deterministically wait for R2 to reach the race window. - <-r2Reached - - // R3: request for model2 which requires evicting model1. Must wait for - // R2 to finish before touching model1. - r3Done := make(chan struct{}) - w3 := httptest.NewRecorder() - go func() { - defer close(r3Done) - req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - assert.NoError(t, m.ProxyRequest("model2", w3, req)) - }() - - // Spin until R3 has acquired m.Lock and entered the eviction path. In - // the fixed code, R3 then blocks on m.inflight.Wait() while still - // holding the lock, so TryLock keeps failing. - for m.TryLock() { - m.Unlock() - runtime.Gosched() - } - - // Bounded poll: give R3 a chance to demonstrate the bug by mutating - // state. In the fixed code R3 is blocked and nothing changes; in the - // buggy code R3 will Stop() model1 and start model2 within microseconds. - deadline := time.Now().Add(100 * time.Millisecond) - for time.Now().Before(deadline) { - if m.processes["model1"].CurrentState() != StateReady || - m.processes["model2"].CurrentState() != StateStopped { - break - } - done := false - select { - case <-r3Done: - done = true - default: - } - if done { - break - } - runtime.Gosched() - } - - // Invariant: R3 must be blocked while R2 is still in flight. - select { - case <-r3Done: - t.Fatal("eviction completed while in-flight request was still pending — race not prevented") - default: - } - assert.Equal(t, StateReady, m.processes["model1"].CurrentState(), - "model1 must stay Ready while an in-flight request is pending") - assert.Equal(t, StateStopped, m.processes["model2"].CurrentState(), - "model2 must not be started until R2 finishes and model1 is evicted") - - // Release R2 and let both requests finish. - close(r2Release) - <-r2Done - <-r3Done - - assert.Equal(t, http.StatusOK, w2.Code) - assert.Contains(t, w2.Body.String(), "model1") - assert.Equal(t, http.StatusOK, w3.Code) - assert.Contains(t, w3.Body.String(), "model2") -} - -func TestMatrixSolver_FullScenario(t *testing.T) { - // Simulates the example config: - // standard: [g,v], [q,v], [m,v] - // with_rerank: [g,v,e], [q,v,e] - // creative: [g,sd], [q,sd] - // full: [L] - solver := NewMatrixSolver( - makeExpandedSets( - es("standard", "g", "v"), - es("standard", "q", "v"), - es("standard", "m", "v"), - es("with_rerank", "e", "g", "v"), - es("with_rerank", "e", "q", "v"), - es("creative", "g", "sd"), - es("creative", "q", "sd"), - es("full", "L"), - ), - map[string]int{"v": 50, "L": 30, "whisper": 10}, - ) - - // Running: g, v. Request q. - // standard[q,v]: evict g (cost 1), keep v. Total: 1. - // with_rerank[q,v,e]: evict g (cost 1), keep v. Total: 1. - // => tie, pick first by definition order = standard[q,v] - result, err := solver.Solve("q", []string{"g", "v"}) - require.NoError(t, err) - assert.Equal(t, []string{"g"}, result.Evict) - assert.Equal(t, []string{"q", "v"}, result.TargetSet) - - // Running: g, v. Request L. - // full[L]: evict g (cost 1) + v (cost 50). Total: 51. - // Only one set contains L, so pick it. - result, err = solver.Solve("L", []string{"g", "v"}) - require.NoError(t, err) - assert.ElementsMatch(t, []string{"g", "v"}, result.Evict) - assert.Equal(t, []string{"L"}, result.TargetSet) - - // Running: g, v. Request sd. - // creative[g,sd]: evict v (cost 50). Total: 50. - // creative[q,sd]: evict g (cost 1) + v (cost 50). Total: 51. - // => pick creative[g,sd] - result, err = solver.Solve("sd", []string{"g", "v"}) - require.NoError(t, err) - assert.Equal(t, []string{"v"}, result.Evict) - assert.Equal(t, []string{"g", "sd"}, result.TargetSet) - - // Running: q, v, e. Request g. - // standard[g,v]: evict q (1) + e (1). Total: 2. - // with_rerank[g,v,e]: evict q (1). Total: 1. - // creative[g,sd]: evict q (1) + v (50) + e (1). Total: 52. - // => pick with_rerank[g,v,e] - result, err = solver.Solve("g", []string{"e", "q", "v"}) - require.NoError(t, err) - assert.Equal(t, []string{"q"}, result.Evict) - assert.Equal(t, []string{"e", "g", "v"}, result.TargetSet) -} diff --git a/proxy/metrics_monitor.go b/proxy/metrics_monitor.go deleted file mode 100644 index 5cfa7b7..0000000 --- a/proxy/metrics_monitor.go +++ /dev/null @@ -1,689 +0,0 @@ -package proxy - -import ( - "bytes" - "compress/flate" - "compress/gzip" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - "sync" - "time" - - "github.com/fxamacker/cbor/v2" - "github.com/gin-gonic/gin" - "github.com/klauspost/compress/zstd" - "github.com/mostlygeek/llama-swap/internal/cache" - "github.com/mostlygeek/llama-swap/internal/event" - "github.com/mostlygeek/llama-swap/internal/logmon" - "github.com/mostlygeek/llama-swap/internal/ring" - "github.com/tidwall/gjson" -) - -// zstdEncOptions are the shared zstd encoder options for maximum compression. -var zstdEncOptions = []zstd.EOption{ - zstd.WithEncoderLevel(zstd.SpeedBetterCompression), -} - -// zstdDecOptions are the shared zstd decoder options. -var zstdDecOptions = []zstd.DOption{} - -// zstdEncPool pools zstd.Encoder instances to reduce allocations. -var zstdEncPool = &sync.Pool{ - New: func() interface{} { - enc, _ := zstd.NewWriter(nil, zstdEncOptions...) - return enc - }, -} - -// zstdDecPool pools zstd.Decoder instances to reduce allocations. -var zstdDecPool = &sync.Pool{ - New: func() interface{} { - dec, _ := zstd.NewReader(nil, zstdDecOptions...) - return dec - }, -} - -// compressCapture marshals a ReqRespCapture to CBOR and compresses it with zstd. -// Returns compressed bytes and the original CBOR byte count for logging. -func compressCapture(c *ReqRespCapture) ([]byte, int, error) { - cborBytes, err := cbor.Marshal(c) - if err != nil { - return nil, 0, fmt.Errorf("marshal capture: %w", err) - } - zenc := zstdEncPool.Get().(*zstd.Encoder) - defer zstdEncPool.Put(zenc) - return zenc.EncodeAll(cborBytes, nil), len(cborBytes), nil -} - -// decompressCapture decompresses zstd-compressed CBOR and unmarshals it into a ReqRespCapture. -func decompressCapture(data []byte) (*ReqRespCapture, error) { - dec := zstdDecPool.Get().(*zstd.Decoder) - defer zstdDecPool.Put(dec) - cborBytes, err := dec.DecodeAll(data, nil) - if err != nil { - return nil, fmt.Errorf("decompress capture: %w", err) - } - var capture ReqRespCapture - if err := cbor.Unmarshal(cborBytes, &capture); err != nil { - return nil, fmt.Errorf("unmarshal capture: %w", err) - } - return &capture, nil -} - -// TokenMetrics holds token usage and performance metrics -type TokenMetrics struct { - CachedTokens int `json:"cache_tokens"` - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - PromptPerSecond float64 `json:"prompt_per_second"` - TokensPerSecond float64 `json:"tokens_per_second"` -} - -// ActivityLogEntry represents parsed token statistics from llama-server logs -type ActivityLogEntry struct { - ID int `json:"id"` - Timestamp time.Time `json:"timestamp"` - Model string `json:"model"` - ReqPath string `json:"req_path"` - RespContentType string `json:"resp_content_type"` - RespStatusCode int `json:"resp_status_code"` - Tokens TokenMetrics `json:"tokens"` - DurationMs int `json:"duration_ms"` - HasCapture bool `json:"has_capture"` -} - -type ReqRespCapture struct { - ID int `json:"id"` - ReqPath string `json:"req_path"` - ReqHeaders map[string]string `json:"req_headers"` - ReqBody []byte `json:"req_body"` - RespHeaders map[string]string `json:"resp_headers"` - RespBody []byte `json:"resp_body"` -} - -// ActivityLogEvent represents a token metrics event -type ActivityLogEvent struct { - Metrics ActivityLogEntry -} - -func (e ActivityLogEvent) Type() uint32 { - return ActivityLogEventID // defined in events.go -} - -// metricsMonitor parses llama-server output for token statistics -type metricsMonitor struct { - mu sync.RWMutex - metrics ring.Buffer[ActivityLogEntry] - nextID int - logger *logmon.Monitor - - // capture fields - enableCaptures bool - captureCache *cache.Cache // zstd-compressed CBOR of ReqRespCapture -} - -// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the -// capture buffer size in megabytes; 0 disables captures. -func newMetricsMonitor(logger *logmon.Monitor, maxMetrics int, captureBufferMB int) *metricsMonitor { - mm := &metricsMonitor{ - logger: logger, - metrics: ring.NewBuffer[ActivityLogEntry](maxMetrics), - enableCaptures: captureBufferMB > 0, - } - if captureBufferMB > 0 { - mm.captureCache = cache.New(captureBufferMB * 1024 * 1024) - } - return mm -} - -// queueMetrics adds a new metric to the collection without emitting an event. -// Returns the assigned metric ID. Call emitMetric after capture setup. -func (mp *metricsMonitor) queueMetrics(metric ActivityLogEntry) int { - mp.mu.Lock() - defer mp.mu.Unlock() - - metric.ID = mp.nextID - mp.nextID++ - mp.metrics.Push(metric) - return metric.ID -} - -// emitMetric publishes an ActivityLogEvent for the given metric. -func (mp *metricsMonitor) emitMetric(metric ActivityLogEntry) { - event.Emit(ActivityLogEvent{Metrics: metric}) -} - -// addCapture compresses and stores a capture in the cache. -// Returns true if the capture was stored, false otherwise. -func (mp *metricsMonitor) addCapture(capture ReqRespCapture) bool { - if !mp.enableCaptures { - return false - } - - compressed, uncompressedBytes, err := compressCapture(&capture) - if err != nil { - mp.logger.Warnf("failed to compress capture: %v, skipping", err) - return false - } - - if err := mp.captureCache.Add(capture.ID, compressed); err != nil { - mp.logger.Warnf("capture %d too large (%d bytes), skipping: %v", capture.ID, len(compressed), err) - return false - } - - compressionRatio := (1 - float64(len(compressed))/float64(uncompressedBytes)) * 100 - mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio) - return true -} - -// getCompressedBytes returns the raw compressed bytes for a capture by ID. -func (mp *metricsMonitor) getCompressedBytes(id int) ([]byte, bool) { - if mp.captureCache == nil { - return nil, false - } - data, err := mp.captureCache.Get(id) - if err != nil { - return nil, false - } - return data, true -} - -// getCaptureByID decompresses and unmarshals a capture by ID. -// Returns nil if the capture is not found or decompression fails. -func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture { - if mp.captureCache == nil { - return nil - } - data, exists := mp.getCompressedBytes(id) - if !exists { - return nil - } - - capture, err := decompressCapture(data) - if err != nil { - mp.logger.Warnf("failed to decompress capture %d: %v", id, err) - return nil - } - - return capture -} - -// getMetrics returns a copy of the current metrics with HasCapture resolved from cache. -func (mp *metricsMonitor) getMetrics() []ActivityLogEntry { - mp.mu.RLock() - defer mp.mu.RUnlock() - - result := mp.metrics.Slice() - if result == nil { - return []ActivityLogEntry{} - } - if mp.captureCache != nil { - for i := range result { - result[i].HasCapture = mp.captureCache.Has(result[i].ID) - } - } - return result -} - -// getMetricsJSON returns metrics as JSON with HasCapture resolved from cache. -func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) { - mp.mu.RLock() - defer mp.mu.RUnlock() - - result := mp.metrics.Slice() - if result == nil { - return json.Marshal([]ActivityLogEntry{}) - } - if mp.captureCache != nil { - for i := range result { - result[i].HasCapture = mp.captureCache.Has(result[i].ID) - } - } - return json.Marshal(result) -} - -// Capture field flags for controlling what is saved in ReqRespCapture. -type captureFields uint - -const ( - captureNone captureFields = 1 << iota - captureReqHeaders - captureReqBody - captureRespHeaders - captureRespBody -) - -const ( - captureReqAll = captureReqHeaders | captureReqBody - captureRespAll = captureRespHeaders | captureRespBody - captureAll = captureReqAll | captureRespAll -) - -// wrapHandler wraps the proxy handler to extract token metrics. -// captureFields controls what is saved in the ReqRespCapture using bitwise flags. -// if wrapHandler returns an error it is safe to assume that no -// data was sent to the client -func (mp *metricsMonitor) wrapHandler( - modelID string, - writer gin.ResponseWriter, - request *http.Request, - captureFields captureFields, - next func(modelID string, w http.ResponseWriter, r *http.Request) error, -) error { - // Capture request body and headers if captures enabled - var reqBody []byte - var reqHeaders map[string]string - if mp.enableCaptures && (captureFields&captureReqBody) != 0 { - if request.Body != nil { - var err error - reqBody, err = io.ReadAll(request.Body) - if err != nil { - return fmt.Errorf("failed to read request body for capture: %w", err) - } - request.Body.Close() - request.Body = io.NopCloser(bytes.NewBuffer(reqBody)) - } - } - if mp.enableCaptures && (captureFields&captureReqHeaders) != 0 { - reqHeaders = make(map[string]string) - for key, values := range request.Header { - if len(values) > 0 { - reqHeaders[key] = values[0] - } - } - redactHeaders(reqHeaders) - } - - recorder := newBodyCopier(writer) - - // Filter Accept-Encoding to only include encodings we can decompress for metrics - if ae := request.Header.Get("Accept-Encoding"); ae != "" { - request.Header.Set("Accept-Encoding", filterAcceptEncoding(ae)) - } - - if err := next(modelID, recorder, request); err != nil { - return err - } - - // after this point we have to assume that data was sent to the client - // and we can only log errors but not send them to clients - - // Initialize default metrics - recorded for every request - tm := ActivityLogEntry{ - Timestamp: time.Now(), - Model: modelID, - ReqPath: request.URL.Path, - RespContentType: recorder.Header().Get("Content-Type"), - RespStatusCode: recorder.Status(), - DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()), - } - - if recorder.Status() != http.StatusOK { - mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), request.URL.Path) - tm.ID = mp.queueMetrics(tm) - mp.emitMetric(tm) - return nil - } - - body := recorder.body.Bytes() - if len(body) == 0 { - mp.logger.Warn("metrics: empty body, recording minimal metrics") - tm.ID = mp.queueMetrics(tm) - mp.emitMetric(tm) - return nil - } - - // Decompress if needed - if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" { - var err error - body, err = decompressBody(body, encoding) - if err != nil { - mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path) - tm.ID = mp.queueMetrics(tm) - mp.emitMetric(tm) - return nil - } - } - if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") { - if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil { - mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path) - } else { - tm.Tokens = parsed.Tokens - tm.DurationMs = parsed.DurationMs - } - } else { - if gjson.ValidBytes(body) { - parsed := gjson.ParseBytes(body) - usage := parsed.Get("usage") - timings := parsed.Get("timings") - - // extract timings for infill - response is an array, timings are in the last element - // see #463 - if strings.HasPrefix(request.URL.Path, "/infill") { - if arr := parsed.Array(); len(arr) > 0 { - timings = arr[len(arr)-1].Get("timings") - } - } - - if usage.Exists() || timings.Exists() { - if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil { - mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path) - } else { - tm.Tokens = parsedMetrics.Tokens - tm.DurationMs = parsedMetrics.DurationMs - } - } - } else { - mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", request.URL.Path) - } - } - - // Build capture if enabled and determine if it will be stored - var capture *ReqRespCapture - if mp.enableCaptures { - var respHeaders map[string]string - var respBody []byte - if (captureFields & captureRespHeaders) != 0 { - respHeaders = make(map[string]string) - for key, values := range recorder.Header() { - if len(values) > 0 { - respHeaders[key] = values[0] - } - } - redactHeaders(respHeaders) - delete(respHeaders, "Content-Encoding") - } - if (captureFields & captureRespBody) != 0 { - respBody = body - } - capture = &ReqRespCapture{ - ReqPath: request.URL.Path, - ReqHeaders: reqHeaders, - ReqBody: reqBody, - RespHeaders: respHeaders, - RespBody: respBody, - } - } - - metricID := mp.queueMetrics(tm) - tm.ID = metricID - - // Store capture if enabled - if capture != nil { - capture.ID = metricID - if mp.addCapture(*capture) { - tm.HasCapture = true - } - } - - mp.emitMetric(tm) - - return nil -} - -// usagePaths lists the JSON paths where a per-event usage object can live. -// v1/chat/completions puts it at top-level "usage"; v1/responses nests under -// "response.usage"; v1/messages emits it at "message.usage" on message_start -// and at "usage" on message_delta. -var usagePaths = []string{"usage", "response.usage", "message.usage"} - -// extractUsageTokens reads input/output/cached token counts from a usage -// gjson.Result, handling the field-name differences across endpoints. -// cached returns -1 when the field is absent. ok is true when at least one -// field was present. -func extractUsageTokens(usage gjson.Result) (input, output, cached int64, ok bool) { - cached = -1 - if !usage.Exists() { - return - } - - if v := usage.Get("prompt_tokens"); v.Exists() { - // v1/chat/completions - input = v.Int() - ok = true - } else if v := usage.Get("input_tokens"); v.Exists() { - // v1/messages, v1/responses - input = v.Int() - ok = true - } - - if v := usage.Get("completion_tokens"); v.Exists() { - // v1/chat/completions - output = v.Int() - ok = true - } else if v := usage.Get("output_tokens"); v.Exists() { - // v1/messages, v1/responses - output = v.Int() - ok = true - } - - if v := usage.Get("cache_read_input_tokens"); v.Exists() { - // v1/messages (Anthropic) - cached = v.Int() - ok = true - } else if v := usage.Get("input_tokens_details.cached_tokens"); v.Exists() { - // v1/responses (OpenAI Responses API) - cached = v.Int() - ok = true - } else if v := usage.Get("prompt_tokens_details.cached_tokens"); v.Exists() { - // v1/chat/completions (OpenAI cache hits) - cached = v.Int() - ok = true - } - return -} - -func processStreamingResponse(modelID string, start time.Time, body []byte) (ActivityLogEntry, error) { - // Walk SSE "data:" lines forward, merging usage info from every event. - // Different endpoints split usage across events: - // - v1/chat/completions: usage on the final chunk before [DONE] - // - v1/responses: usage on response.completed (response.usage) - // - v1/messages: input + cache on message_start (message.usage), - // output_tokens on message_delta (usage) - // We take the latest informative value per field so all three are covered. - - var ( - inputTokens, outputTokens int64 - cachedTokens int64 = -1 - hasAny bool - timings gjson.Result - ) - - prefix := []byte("data:") - for offset := 0; offset < len(body); { - nl := bytes.IndexByte(body[offset:], '\n') - var line []byte - if nl == -1 { - line = body[offset:] - offset = len(body) - } else { - line = body[offset : offset+nl] - offset += nl + 1 - } - - line = bytes.TrimSpace(line) - if len(line) == 0 || !bytes.HasPrefix(line, prefix) { - continue - } - data := bytes.TrimSpace(line[len(prefix):]) - if len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) { - continue - } - if !gjson.ValidBytes(data) { - continue - } - parsed := gjson.ParseBytes(data) - - for _, path := range usagePaths { - u := parsed.Get(path) - if !u.Exists() { - continue - } - i, o, c, ok := extractUsageTokens(u) - if !ok { - continue - } - hasAny = true - // Take the latest non-zero value so message_start's input_tokens - // is preserved when message_delta's usage omits it, and vice versa - // for output_tokens. - if i > 0 { - inputTokens = i - } - if o > 0 { - outputTokens = o - } - if c >= 0 { - cachedTokens = c - } - } - if t := parsed.Get("timings"); t.Exists() { - timings = t - hasAny = true - } - } - - if !hasAny { - return ActivityLogEntry{}, fmt.Errorf("no valid JSON data found in stream") - } - - return buildMetrics(modelID, start, inputTokens, outputTokens, cachedTokens, timings), nil -} - -func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (ActivityLogEntry, error) { - input, output, cached, _ := extractUsageTokens(usage) - return buildMetrics(modelID, start, input, output, cached, timings), nil -} - -// buildMetrics composes an ActivityLogEntry from accumulated token counts and -// optional llama-server timings (which override input/output and provide rates). -func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, cachedTokens int64, timings gjson.Result) ActivityLogEntry { - wallDurationMs := int(time.Since(start).Milliseconds()) - durationMs := wallDurationMs - tokensPerSecond := -1.0 - promptPerSecond := -1.0 - - if timings.Exists() { - inputTokens = timings.Get("prompt_n").Int() - outputTokens = timings.Get("predicted_n").Int() - promptPerSecond = timings.Get("prompt_per_second").Float() - tokensPerSecond = timings.Get("predicted_per_second").Float() - timingsDurationMs := int(timings.Get("prompt_ms").Float() + timings.Get("predicted_ms").Float()) - if timingsDurationMs > durationMs { - durationMs = timingsDurationMs - } - if cachedValue := timings.Get("cache_n"); cachedValue.Exists() { - cachedTokens = cachedValue.Int() - } - } - - return ActivityLogEntry{ - Timestamp: time.Now(), - Model: modelID, - Tokens: TokenMetrics{ - CachedTokens: int(cachedTokens), - InputTokens: int(inputTokens), - OutputTokens: int(outputTokens), - PromptPerSecond: promptPerSecond, - TokensPerSecond: tokensPerSecond, - }, - DurationMs: durationMs, - } -} - -// decompressBody decompresses the body based on Content-Encoding header -func decompressBody(body []byte, encoding string) ([]byte, error) { - switch strings.ToLower(strings.TrimSpace(encoding)) { - case "gzip": - reader, err := gzip.NewReader(bytes.NewReader(body)) - if err != nil { - return nil, err - } - defer reader.Close() - return io.ReadAll(reader) - case "deflate": - reader := flate.NewReader(bytes.NewReader(body)) - defer reader.Close() - return io.ReadAll(reader) - default: - return body, nil // Return as-is for unknown/no encoding - } -} - -// responseBodyCopier records the response body and writes to the original response writer -// while also capturing it in a buffer for later processing -type responseBodyCopier struct { - gin.ResponseWriter - body *bytes.Buffer - tee io.Writer - start time.Time -} - -func newBodyCopier(w gin.ResponseWriter) *responseBodyCopier { - bodyBuffer := &bytes.Buffer{} - return &responseBodyCopier{ - ResponseWriter: w, - body: bodyBuffer, - tee: io.MultiWriter(w, bodyBuffer), - start: time.Now(), - } -} - -func (w *responseBodyCopier) Write(b []byte) (int, error) { - return w.tee.Write(b) -} - -func (w *responseBodyCopier) WriteHeader(statusCode int) { - w.ResponseWriter.WriteHeader(statusCode) -} - -func (w *responseBodyCopier) Header() http.Header { - return w.ResponseWriter.Header() -} - -func (w *responseBodyCopier) StartTime() time.Time { - return w.start -} - -// sensitiveHeaders lists headers that should be redacted in captures -var sensitiveHeaders = map[string]bool{ - "authorization": true, - "proxy-authorization": true, - "cookie": true, - "set-cookie": true, - "x-api-key": true, -} - -// redactHeaders replaces sensitive header values in-place with "[REDACTED]" -func redactHeaders(headers map[string]string) { - for key := range headers { - if sensitiveHeaders[strings.ToLower(key)] { - headers[key] = "[REDACTED]" - } - } -} - -// filterAcceptEncoding filters the Accept-Encoding header to only include -// encodings we can decompress (gzip, deflate). This respects the client's -// preferences while ensuring we can parse response bodies for metrics. -func filterAcceptEncoding(acceptEncoding string) string { - if acceptEncoding == "" { - return "" - } - - supported := map[string]bool{"gzip": true, "deflate": true} - var filtered []string - - for part := range strings.SplitSeq(acceptEncoding, ",") { - // Parse encoding and optional quality value (e.g., "gzip;q=1.0") - encoding, _, _ := strings.Cut(strings.TrimSpace(part), ";") - if supported[strings.ToLower(encoding)] { - filtered = append(filtered, strings.TrimSpace(part)) - } - } - - return strings.Join(filtered, ", ") -} diff --git a/proxy/metrics_monitor_test.go b/proxy/metrics_monitor_test.go deleted file mode 100644 index d589b32..0000000 --- a/proxy/metrics_monitor_test.go +++ /dev/null @@ -1,1507 +0,0 @@ -package proxy - -import ( - "bytes" - "compress/flate" - "compress/gzip" - "encoding/json" - "math/rand" - "net/http" - "net/http/httptest" - "sync" - "testing" - "time" - - "github.com/fxamacker/cbor/v2" - "github.com/gin-gonic/gin" - "github.com/mostlygeek/llama-swap/internal/cache" - "github.com/mostlygeek/llama-swap/internal/event" - "github.com/stretchr/testify/assert" - "github.com/tidwall/gjson" -) - -func TestMetricsMonitor_AddMetrics(t *testing.T) { - t.Run("adds metrics and assigns ID", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - metric := ActivityLogEntry{ - Model: "test-model", - Tokens: TokenMetrics{ - InputTokens: 100, - OutputTokens: 50, - }, - } - - mm.queueMetrics(metric) - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, 0, metrics[0].ID) - assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 100, metrics[0].Tokens.InputTokens) - assert.Equal(t, 50, metrics[0].Tokens.OutputTokens) - }) - - t.Run("increments ID for each metric", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - for i := 0; i < 5; i++ { - mm.queueMetrics(ActivityLogEntry{Model: "model"}) - } - - metrics := mm.getMetrics() - assert.Equal(t, 5, len(metrics)) - for i := 0; i < 5; i++ { - assert.Equal(t, i, metrics[i].ID) - } - }) - - t.Run("respects max metrics limit", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 3, 0) - - // Add 5 metrics - for i := 0; i < 5; i++ { - mm.queueMetrics(ActivityLogEntry{ - Model: "model", - Tokens: TokenMetrics{ - InputTokens: i, - }, - }) - } - - metrics := mm.getMetrics() - assert.Equal(t, 3, len(metrics)) - - // Should keep the last 3 metrics (IDs 2, 3, 4) - assert.Equal(t, 2, metrics[0].ID) - assert.Equal(t, 3, metrics[1].ID) - assert.Equal(t, 4, metrics[2].ID) - }) - - t.Run("emits ActivityLogEvent", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - receivedEvent := make(chan ActivityLogEvent, 1) - cancel := event.On(func(e ActivityLogEvent) { - receivedEvent <- e - }) - defer cancel() - - metric := ActivityLogEntry{ - Model: "test-model", - Tokens: TokenMetrics{ - InputTokens: 100, - OutputTokens: 50, - }, - } - - mm.queueMetrics(metric) - mm.emitMetric(metric) - - select { - case evt := <-receivedEvent: - assert.Equal(t, 0, evt.Metrics.ID) - assert.Equal(t, "test-model", evt.Metrics.Model) - assert.Equal(t, 100, evt.Metrics.Tokens.InputTokens) - assert.Equal(t, 50, evt.Metrics.Tokens.OutputTokens) - case <-time.After(1 * time.Second): - t.Fatal("timeout waiting for event") - } - }) -} - -func TestMetricsMonitor_GetMetrics(t *testing.T) { - t.Run("returns empty slice when no metrics", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - metrics := mm.getMetrics() - assert.NotNil(t, metrics) - assert.Equal(t, 0, len(metrics)) - }) - - t.Run("returns copy of metrics", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - mm.queueMetrics(ActivityLogEntry{Model: "model1"}) - mm.queueMetrics(ActivityLogEntry{Model: "model2"}) - - metrics1 := mm.getMetrics() - metrics2 := mm.getMetrics() - - // Verify we got copies - assert.Equal(t, 2, len(metrics1)) - assert.Equal(t, 2, len(metrics2)) - - // Modify the returned slice shouldn't affect the original - metrics1[0].Model = "modified" - metrics3 := mm.getMetrics() - assert.Equal(t, "model1", metrics3[0].Model) - }) -} - -func TestMetricsMonitor_GetMetricsJSON(t *testing.T) { - t.Run("returns valid JSON for empty metrics", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - jsonData, err := mm.getMetricsJSON() - assert.NoError(t, err) - assert.NotNil(t, jsonData) - - var metrics []ActivityLogEntry - err = json.Unmarshal(jsonData, &metrics) - assert.NoError(t, err) - assert.Equal(t, 0, len(metrics)) - }) - - t.Run("returns valid JSON with metrics", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - mm.queueMetrics(ActivityLogEntry{ - Model: "model1", - Tokens: TokenMetrics{ - InputTokens: 100, - OutputTokens: 50, - TokensPerSecond: 25.5, - }, - }) - mm.queueMetrics(ActivityLogEntry{ - Model: "model2", - Tokens: TokenMetrics{ - InputTokens: 200, - OutputTokens: 100, - TokensPerSecond: 30.0, - }, - }) - - jsonData, err := mm.getMetricsJSON() - assert.NoError(t, err) - - var metrics []ActivityLogEntry - err = json.Unmarshal(jsonData, &metrics) - assert.NoError(t, err) - assert.Equal(t, 2, len(metrics)) - assert.Equal(t, "model1", metrics[0].Model) - assert.Equal(t, "model2", metrics[1].Model) - }) -} - -func TestMetricsMonitor_WrapHandler(t *testing.T) { - t.Run("successful non-streaming request with usage data", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - responseBody := `{ - "usage": { - "prompt_tokens": 100, - "completion_tokens": 50 - } - }` - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write([]byte(responseBody)) - return nil - } - - req := httptest.NewRequest("POST", "/test", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 100, metrics[0].Tokens.InputTokens) - assert.Equal(t, 50, metrics[0].Tokens.OutputTokens) - }) - - t.Run("successful request with timings data", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - responseBody := `{ - "timings": { - "prompt_n": 100, - "predicted_n": 50, - "prompt_per_second": 150.5, - "predicted_per_second": 25.5, - "prompt_ms": 500.0, - "predicted_ms": 1500.0, - "cache_n": 20 - } - }` - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write([]byte(responseBody)) - return nil - } - - req := httptest.NewRequest("POST", "/test", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 100, metrics[0].Tokens.InputTokens) - assert.Equal(t, 50, metrics[0].Tokens.OutputTokens) - assert.Equal(t, 20, metrics[0].Tokens.CachedTokens) - assert.Equal(t, 150.5, metrics[0].Tokens.PromptPerSecond) - assert.Equal(t, 25.5, metrics[0].Tokens.TokensPerSecond) - assert.Equal(t, 2000, metrics[0].DurationMs) // 500 + 1500 - }) - - t.Run("streaming request with SSE format", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - // Note: SSE format requires proper line breaks - each data line followed by blank line - responseBody := `data: {"choices":[{"text":"Hello"}]} - -data: {"choices":[{"text":" World"}]} - -data: {"usage":{"prompt_tokens":10,"completion_tokens":20},"timings":{"prompt_n":10,"predicted_n":20,"prompt_per_second":100.0,"predicted_per_second":50.0,"prompt_ms":100.0,"predicted_ms":400.0}} - -data: [DONE] - -` - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "text/event-stream") - w.WriteHeader(http.StatusOK) - w.Write([]byte(responseBody)) - return nil - } - - req := httptest.NewRequest("POST", "/test", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, "test-model", metrics[0].Model) - // When timings data is present, it takes precedence - assert.Equal(t, 10, metrics[0].Tokens.InputTokens) - assert.Equal(t, 20, metrics[0].Tokens.OutputTokens) - }) - - t.Run("non-OK status code records partial metrics", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("error")) - return nil - } - - req := httptest.NewRequest("POST", "/test", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, "/test", metrics[0].ReqPath) - assert.Equal(t, http.StatusBadRequest, metrics[0].RespStatusCode) - assert.Equal(t, 0, metrics[0].Tokens.InputTokens) - assert.Equal(t, 0, metrics[0].Tokens.OutputTokens) - }) - - t.Run("empty response body records minimal metrics", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.WriteHeader(http.StatusOK) - return nil - } - - req := httptest.NewRequest("POST", "/test", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 0, metrics[0].Tokens.InputTokens) - assert.Equal(t, 0, metrics[0].Tokens.OutputTokens) - }) - - t.Run("invalid JSON records minimal metrics", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write([]byte("not valid json")) - return nil - } - - req := httptest.NewRequest("POST", "/test", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) // Errors after response is sent are logged, not returned - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 0, metrics[0].Tokens.InputTokens) - assert.Equal(t, 0, metrics[0].Tokens.OutputTokens) - }) - - t.Run("next handler error is propagated", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - expectedErr := assert.AnError - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - return expectedErr - } - - req := httptest.NewRequest("POST", "/test", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.Equal(t, expectedErr, err) - - metrics := mm.getMetrics() - assert.Equal(t, 0, len(metrics)) - }) - - t.Run("response without usage or timings records minimal metrics", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - responseBody := `{"result": "ok"}` - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write([]byte(responseBody)) - return nil - } - - req := httptest.NewRequest("POST", "/test", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 0, metrics[0].Tokens.InputTokens) - assert.Equal(t, 0, metrics[0].Tokens.OutputTokens) - }) - - t.Run("infill request extracts timings from last array element", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - // Infill response is an array with timings in the last element - responseBody := `[ - {"content": "first chunk"}, - {"content": "second chunk"}, - {"content": "final", "timings": { - "prompt_n": 150, - "predicted_n": 75, - "prompt_per_second": 200.5, - "predicted_per_second": 35.5, - "prompt_ms": 600.0, - "predicted_ms": 1800.0, - "cache_n": 30 - }} - ]` - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write([]byte(responseBody)) - return nil - } - - req := httptest.NewRequest("POST", "/infill", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 150, metrics[0].Tokens.InputTokens) - assert.Equal(t, 75, metrics[0].Tokens.OutputTokens) - assert.Equal(t, 30, metrics[0].Tokens.CachedTokens) - assert.Equal(t, 200.5, metrics[0].Tokens.PromptPerSecond) - assert.Equal(t, 35.5, metrics[0].Tokens.TokensPerSecond) - assert.Equal(t, 2400, metrics[0].DurationMs) // 600 + 1800 - }) - - t.Run("infill request with empty array records minimal metrics", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - responseBody := `[]` - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write([]byte(responseBody)) - return nil - } - - req := httptest.NewRequest("POST", "/infill", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 0, metrics[0].Tokens.InputTokens) - assert.Equal(t, 0, metrics[0].Tokens.OutputTokens) - }) -} - -func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) { - t.Run("captures response body", func(t *testing.T) { - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - copier := newBodyCopier(ginCtx.Writer) - - testData := []byte("test response body") - n, err := copier.Write(testData) - - assert.NoError(t, err) - assert.Equal(t, len(testData), n) - assert.Equal(t, testData, copier.body.Bytes()) - assert.Equal(t, string(testData), rec.Body.String()) - }) - - t.Run("sets start time on creation", func(t *testing.T) { - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - copier := newBodyCopier(ginCtx.Writer) - - assert.False(t, copier.StartTime().IsZero()) - }) - - t.Run("preserves headers", func(t *testing.T) { - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - copier := newBodyCopier(ginCtx.Writer) - - copier.Header().Set("X-Test", "value") - - assert.Equal(t, "value", rec.Header().Get("X-Test")) - }) - - t.Run("preserves status code", func(t *testing.T) { - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - copier := newBodyCopier(ginCtx.Writer) - - copier.WriteHeader(http.StatusCreated) - - // Gin's ResponseWriter tracks status internally - assert.Equal(t, http.StatusCreated, copier.Status()) - }) -} - -func TestMetricsMonitor_Concurrent(t *testing.T) { - t.Run("concurrent queueMetrics is safe", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 1000, 0) - - var wg sync.WaitGroup - numGoroutines := 10 - metricsPerGoroutine := 100 - - for i := 0; i < numGoroutines; i++ { - wg.Add(1) - go func(id int) { - defer wg.Done() - for j := 0; j < metricsPerGoroutine; j++ { - mm.queueMetrics(ActivityLogEntry{ - Model: "test-model", - Tokens: TokenMetrics{ - InputTokens: id*1000 + j, - OutputTokens: j, - }, - }) - } - }(i) - } - - wg.Wait() - - metrics := mm.getMetrics() - assert.Equal(t, numGoroutines*metricsPerGoroutine, len(metrics)) - }) - - t.Run("concurrent reads and writes are safe", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 100, 0) - - done := make(chan bool) - - // Writer goroutine - go func() { - for i := 0; i < 50; i++ { - mm.queueMetrics(ActivityLogEntry{Model: "test-model"}) - time.Sleep(1 * time.Millisecond) - } - done <- true - }() - - // Multiple reader goroutines - var wg sync.WaitGroup - for i := 0; i < 5; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for j := 0; j < 20; j++ { - _ = mm.getMetrics() - _, _ = mm.getMetricsJSON() - time.Sleep(2 * time.Millisecond) - } - }() - } - - <-done - wg.Wait() - - // Final check - metrics := mm.getMetrics() - assert.Equal(t, 50, len(metrics)) - }) -} - -func TestMetricsMonitor_ParseMetrics(t *testing.T) { - t.Run("keeps wall clock duration when timings underreport request time", func(t *testing.T) { - start := time.Now().Add(-5 * time.Second) - usage := gjson.Parse(`{"prompt_tokens": 5, "completion_tokens": 1}`) - timings := gjson.Parse(`{ - "prompt_n": 5, - "predicted_n": 1, - "prompt_per_second": 10.0, - "predicted_per_second": 2.0, - "prompt_ms": 5.0, - "predicted_ms": 15.0 - }`) - - metrics, err := parseMetrics("test-model", start, usage, timings) - assert.NoError(t, err) - assert.Equal(t, 5, metrics.Tokens.InputTokens) - assert.Equal(t, 1, metrics.Tokens.OutputTokens) - assert.Equal(t, 10.0, metrics.Tokens.PromptPerSecond) - assert.Equal(t, 2.0, metrics.Tokens.TokensPerSecond) - assert.GreaterOrEqual(t, metrics.DurationMs, 5000) - }) - - t.Run("prefers timings over usage data", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - // Timings should take precedence over usage - responseBody := `{ - "usage": { - "prompt_tokens": 50, - "completion_tokens": 25 - }, - "timings": { - "prompt_n": 100, - "predicted_n": 50, - "prompt_per_second": 150.5, - "predicted_per_second": 25.5, - "prompt_ms": 500.0, - "predicted_ms": 1500.0 - } - }` - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write([]byte(responseBody)) - return nil - } - - req := httptest.NewRequest("POST", "/test", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - // Should use timings values, not usage values - assert.Equal(t, 100, metrics[0].Tokens.InputTokens) - assert.Equal(t, 50, metrics[0].Tokens.OutputTokens) - }) - - t.Run("handles missing cache_n in timings", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - responseBody := `{ - "timings": { - "prompt_n": 100, - "predicted_n": 50, - "prompt_per_second": 150.5, - "predicted_per_second": 25.5, - "prompt_ms": 500.0, - "predicted_ms": 1500.0 - } - }` - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write([]byte(responseBody)) - return nil - } - - req := httptest.NewRequest("POST", "/test", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, -1, metrics[0].Tokens.CachedTokens) // Default value when not present - }) -} - -func TestMetricsMonitor_StreamingResponse(t *testing.T) { - t.Run("finds metrics in last valid SSE data", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - // Metrics should be found in the last data line before [DONE] - responseBody := `data: {"choices":[{"text":"First"}]} - -data: {"choices":[{"text":"Second"}]} - -data: {"usage":{"prompt_tokens":100,"completion_tokens":50}} - -data: [DONE] - -` - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "text/event-stream") - w.WriteHeader(http.StatusOK) - w.Write([]byte(responseBody)) - return nil - } - - req := httptest.NewRequest("POST", "/test", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, 100, metrics[0].Tokens.InputTokens) - assert.Equal(t, 50, metrics[0].Tokens.OutputTokens) - }) - - t.Run("handles streaming with no valid JSON records minimal metrics", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - responseBody := `data: not json - -data: [DONE] - -` - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "text/event-stream") - w.WriteHeader(http.StatusOK) - w.Write([]byte(responseBody)) - return nil - } - - req := httptest.NewRequest("POST", "/test", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 0, metrics[0].Tokens.InputTokens) - assert.Equal(t, 0, metrics[0].Tokens.OutputTokens) - }) - - t.Run("v1/responses format with nested response.usage", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - // v1/responses SSE format: usage is nested under response.usage - responseBody := "event: response.completed\n" + - `data: {"type":"response.completed","response":{"id":"resp_abc","object":"response","created_at":1773416985,"status":"completed","model":"test-model","output":[],"usage":{"input_tokens":17,"output_tokens":23,"total_tokens":40}}}` + - "\n\n" - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "text/event-stream") - w.WriteHeader(http.StatusOK) - w.Write([]byte(responseBody)) - return nil - } - - req := httptest.NewRequest("POST", "/v1/responses", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 17, metrics[0].Tokens.InputTokens) - assert.Equal(t, 23, metrics[0].Tokens.OutputTokens) - }) - - t.Run("v1/responses full stream with deltas, output, and cached tokens", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - // Realistic v1/responses stream: multiple delta events followed by - // done/completed events. Usage lives on response.completed and includes - // the OpenAI Responses cached-token shape (input_tokens_details.cached_tokens). - responseBody := "event: response.created\n" + - `data: {"type":"response.created","response":{"id":"resp_1","status":"in_progress"}}` + "\n\n" + - "event: response.output_item.added\n" + - `data: {"type":"response.output_item.added","item":{"id":"msg_1","role":"assistant","status":"in_progress","type":"message"}}` + "\n\n" + - "event: response.content_part.added\n" + - `data: {"type":"response.content_part.added","item_id":"msg_1","part":{"type":"output_text","text":""}}` + "\n\n" + - "event: response.output_text.delta\n" + - `data: {"type":"response.output_text.delta","item_id":"msg_1","delta":"Hello"}` + "\n\n" + - "event: response.output_text.delta\n" + - `data: {"type":"response.output_text.delta","item_id":"msg_1","delta":" world"}` + "\n\n" + - "event: response.output_text.done\n" + - `data: {"type":"response.output_text.done","item_id":"msg_1","text":"Hello world"}` + "\n\n" + - "event: response.content_part.done\n" + - `data: {"type":"response.content_part.done","item_id":"msg_1","part":{"type":"output_text","text":"Hello world"}}` + "\n\n" + - "event: response.output_item.done\n" + - `data: {"type":"response.output_item.done","item":{"type":"message","status":"completed","id":"msg_1","content":[{"type":"output_text","text":"Hello world"}],"role":"assistant"}}` + "\n\n" + - "event: response.completed\n" + - `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","status":"completed","model":"test-model","output":[{"type":"message","status":"completed","id":"msg_1","content":[{"type":"output_text","text":"Hello world"}],"role":"assistant"}],"usage":{"input_tokens":14,"output_tokens":24,"total_tokens":38,"input_tokens_details":{"cached_tokens":13}}}}` + "\n\n" - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "text/event-stream") - w.WriteHeader(http.StatusOK) - w.Write([]byte(responseBody)) - return nil - } - - req := httptest.NewRequest("POST", "/v1/responses", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 14, metrics[0].Tokens.InputTokens) - assert.Equal(t, 24, metrics[0].Tokens.OutputTokens) - assert.Equal(t, 13, metrics[0].Tokens.CachedTokens) - }) - - t.Run("v1/messages merges usage from message_start and message_delta", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - // v1/messages splits usage across two events: - // message_start.message.usage has input_tokens + cache_read_input_tokens - // message_delta.usage has the final output_tokens - // Without merging, output_tokens (last seen) would clobber the input fields. - responseBody := "event: message_start\n" + - `data: {"type":"message_start","message":{"id":"m1","type":"message","role":"assistant","content":[],"model":"test-model","usage":{"cache_read_input_tokens":5,"input_tokens":9,"output_tokens":0}}}` + "\n\n" + - "event: content_block_start\n" + - `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}` + "\n\n" + - "event: content_block_delta\n" + - `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}` + "\n\n" + - "event: content_block_delta\n" + - `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" there"}}` + "\n\n" + - "event: content_block_stop\n" + - `data: {"type":"content_block_stop","index":0}` + "\n\n" + - "event: message_delta\n" + - `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":24}}` + "\n\n" + - "event: message_stop\n" + - `data: {"type":"message_stop"}` + "\n\n" - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "text/event-stream") - w.WriteHeader(http.StatusOK) - w.Write([]byte(responseBody)) - return nil - } - - req := httptest.NewRequest("POST", "/v1/messages", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, 9, metrics[0].Tokens.InputTokens) - assert.Equal(t, 24, metrics[0].Tokens.OutputTokens) - assert.Equal(t, 5, metrics[0].Tokens.CachedTokens) - }) - - t.Run("v1/chat/completions OpenAI prompt_tokens_details.cached_tokens", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - responseBody := `data: {"choices":[{"delta":{"content":"hi"}}]}` + "\n\n" + - `data: {"choices":[{"delta":{}}],"usage":{"prompt_tokens":50,"completion_tokens":12,"prompt_tokens_details":{"cached_tokens":42}}}` + "\n\n" + - "data: [DONE]\n\n" - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "text/event-stream") - w.WriteHeader(http.StatusOK) - w.Write([]byte(responseBody)) - return nil - } - - req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, 50, metrics[0].Tokens.InputTokens) - assert.Equal(t, 12, metrics[0].Tokens.OutputTokens) - assert.Equal(t, 42, metrics[0].Tokens.CachedTokens) - }) - - t.Run("handles empty streaming response records minimal metrics", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - responseBody := `` - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "text/event-stream") - w.WriteHeader(http.StatusOK) - w.Write([]byte(responseBody)) - return nil - } - - req := httptest.NewRequest("POST", "/test", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 0, metrics[0].Tokens.InputTokens) - assert.Equal(t, 0, metrics[0].Tokens.OutputTokens) - }) -} - -// Benchmark tests -func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) { - mm := newMetricsMonitor(testLogger, 1000, 0) - - metric := ActivityLogEntry{ - Model: "test-model", - Tokens: TokenMetrics{ - CachedTokens: 100, - InputTokens: 500, - OutputTokens: 250, - PromptPerSecond: 1200.5, - TokensPerSecond: 45.8, - }, - DurationMs: 5000, - Timestamp: time.Now(), - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mm.queueMetrics(metric) - } -} - -func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) { - // Test performance with a smaller buffer where wrapping occurs more frequently - mm := newMetricsMonitor(testLogger, 100, 0) - - metric := ActivityLogEntry{ - Model: "test-model", - Tokens: TokenMetrics{ - CachedTokens: 100, - InputTokens: 500, - OutputTokens: 250, - PromptPerSecond: 1200.5, - TokensPerSecond: 45.8, - }, - DurationMs: 5000, - Timestamp: time.Now(), - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mm.queueMetrics(metric) - } -} - -func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) { - t.Run("gzip encoded response", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}` - - // Compress with gzip - var buf bytes.Buffer - gzWriter := gzip.NewWriter(&buf) - gzWriter.Write([]byte(responseBody)) - gzWriter.Close() - compressedBody := buf.Bytes() - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Content-Encoding", "gzip") - w.WriteHeader(http.StatusOK) - w.Write(compressedBody) - return nil - } - - req := httptest.NewRequest("POST", "/test", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 100, metrics[0].Tokens.InputTokens) - assert.Equal(t, 50, metrics[0].Tokens.OutputTokens) - }) - - t.Run("deflate encoded response", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - responseBody := `{"usage": {"prompt_tokens": 200, "completion_tokens": 75}}` - - // Compress with deflate - var buf bytes.Buffer - flateWriter, _ := flate.NewWriter(&buf, flate.DefaultCompression) - flateWriter.Write([]byte(responseBody)) - flateWriter.Close() - compressedBody := buf.Bytes() - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Content-Encoding", "deflate") - w.WriteHeader(http.StatusOK) - w.Write(compressedBody) - return nil - } - - req := httptest.NewRequest("POST", "/test", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 200, metrics[0].Tokens.InputTokens) - assert.Equal(t, 75, metrics[0].Tokens.OutputTokens) - }) - - t.Run("invalid gzip data records minimal metrics", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - // Invalid compressed data - invalidData := []byte("this is not gzip data") - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Content-Encoding", "gzip") - w.WriteHeader(http.StatusOK) - w.Write(invalidData) - return nil - } - - req := httptest.NewRequest("POST", "/test", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) // Should not return error, just log warning - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 0, metrics[0].Tokens.InputTokens) - assert.Equal(t, 0, metrics[0].Tokens.OutputTokens) - }) - - t.Run("unknown encoding treated as uncompressed", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - responseBody := `{"usage": {"prompt_tokens": 300, "completion_tokens": 100}}` - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Content-Encoding", "unknown-encoding") - w.WriteHeader(http.StatusOK) - w.Write([]byte(responseBody)) - return nil - } - - req := httptest.NewRequest("POST", "/test", nil) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, 300, metrics[0].Tokens.InputTokens) - assert.Equal(t, 100, metrics[0].Tokens.OutputTokens) - }) -} - -func TestReqRespCapture_CompressedSize(t *testing.T) { - t.Run("compressed size is smaller than uncompressed", func(t *testing.T) { - capture := ReqRespCapture{ - ID: 1, - ReqPath: "/v1/chat/completions", - ReqBody: []byte(`{"model":"test","prompt":"hello world this is a test request body that is reasonably long"}`), - RespBody: []byte(`{"id":"resp-123","object":"chat.completion","created":1234567890,"model":"test-model","choices":[{"index":0,"message":{"role":"assistant","content":"This is a test response body with some meaningful content to compress"}},{"index":1,"message":{"role":"user","content":"Another message here"}}]}`), - } - - compressed, uncompressed, err := compressCapture(&capture) - assert.NoError(t, err) - assert.Greater(t, uncompressed, 0) - assert.True(t, len(compressed) < uncompressed, "compressed (%d bytes) should be smaller than uncompressed JSON (%d bytes)", len(compressed), uncompressed) - }) - - t.Run("empty capture produces compressed output", func(t *testing.T) { - capture := ReqRespCapture{} - compressed, _, err := compressCapture(&capture) - assert.NoError(t, err) - assert.NotNil(t, compressed) - assert.True(t, len(compressed) > 0) - }) -} - -func TestMetricsMonitor_AddCapture(t *testing.T) { - t.Run("does nothing when captures disabled", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - capture := ReqRespCapture{ - ID: 0, - ReqBody: []byte("test"), - } - mm.addCapture(capture) - - // Should not store capture - assert.Nil(t, mm.getCaptureByID(0)) - }) - - t.Run("adds capture when enabled", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 5) - - capture := ReqRespCapture{ - ID: 0, - ReqBody: []byte("test request"), - RespBody: []byte("test response"), - } - mm.addCapture(capture) - - captured := mm.getCaptureByID(0) - assert.NotNil(t, captured) - assert.Equal(t, 0, captured.ID) - assert.Equal(t, []byte("test request"), captured.ReqBody) - assert.Equal(t, []byte("test response"), captured.RespBody) - }) - - t.Run("evicts oldest when exceeding max size", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 5) - // Each full ReqRespCapture with 80 bytes random data compresses to ~185 bytes. - // 2 captures = ~370 bytes, 3 captures = ~555 bytes. Set limit so only 2 fit. - mm.captureCache = cache.New(450) - - // Use random-looking data that doesn't compress well with zstd - rng := rand.New(rand.NewSource(42)) - capture1 := ReqRespCapture{ID: 0, ReqBody: make([]byte, 80)} - rng.Read(capture1.ReqBody) - capture2 := ReqRespCapture{ID: 1, ReqBody: make([]byte, 80)} - rng.Read(capture2.ReqBody) - capture3 := ReqRespCapture{ID: 2, ReqBody: make([]byte, 80)} - rng.Read(capture3.ReqBody) - - mm.addCapture(capture1) - mm.addCapture(capture2) - // Adding capture3 should evict capture1 - mm.addCapture(capture3) - - assert.Nil(t, mm.getCaptureByID(0), "capture 0 should be evicted") - assert.NotNil(t, mm.getCaptureByID(1), "capture 1 should exist") - assert.NotNil(t, mm.getCaptureByID(2), "capture 2 should exist") - }) - - t.Run("skips capture larger than max size", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 5) - mm.captureCache = cache.New(100) - - // Use random data that doesn't compress well to create an oversized capture - rng := rand.New(rand.NewSource(99)) - largeCapture := ReqRespCapture{ID: 0, ReqBody: make([]byte, 300)} - rng.Read(largeCapture.ReqBody) - mm.addCapture(largeCapture) - - assert.Nil(t, mm.getCaptureByID(0), "oversized capture should not be stored") - }) -} - -func TestMetricsMonitor_GetCaptureByID(t *testing.T) { - t.Run("returns nil for non-existent ID", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 5) - - assert.Nil(t, mm.getCaptureByID(999)) - }) - - t.Run("returns decompressed capture by ID", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 5) - - capture := ReqRespCapture{ - ID: 42, - ReqBody: []byte("test request"), - RespBody: []byte("test response"), - } - mm.addCapture(capture) - - captured := mm.getCaptureByID(42) - assert.NotNil(t, captured) - assert.Equal(t, 42, captured.ID) - assert.Equal(t, []byte("test request"), captured.ReqBody) - assert.Equal(t, []byte("test response"), captured.RespBody) - }) - - t.Run("stores data as compressed bytes", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 5) - - capture := ReqRespCapture{ - ID: 42, - ReqBody: []byte("test request body"), - RespBody: []byte("test response body"), - } - mm.addCapture(capture) - - compressed, exists := mm.getCompressedBytes(42) - assert.True(t, exists) - assert.NotNil(t, compressed) - // Compressed data should not be valid CBOR (it's zstd-compressed) - var decoded ReqRespCapture - assert.Error(t, cbor.Unmarshal(compressed, &decoded)) - }) -} - -func TestRedactHeaders(t *testing.T) { - t.Run("redacts sensitive headers", func(t *testing.T) { - headers := map[string]string{ - "Authorization": "Bearer secret-token", - "Proxy-Authorization": "Basic creds", - "Cookie": "session=abc123", - "Set-Cookie": "session=xyz789", - "X-Api-Key": "sk-12345", - "Content-Type": "application/json", - "X-Custom": "safe-value", - } - - redactHeaders(headers) - - assert.Equal(t, "[REDACTED]", headers["Authorization"]) - assert.Equal(t, "[REDACTED]", headers["Proxy-Authorization"]) - assert.Equal(t, "[REDACTED]", headers["Cookie"]) - assert.Equal(t, "[REDACTED]", headers["Set-Cookie"]) - assert.Equal(t, "[REDACTED]", headers["X-Api-Key"]) - assert.Equal(t, "application/json", headers["Content-Type"]) - assert.Equal(t, "safe-value", headers["X-Custom"]) - }) - - t.Run("handles mixed case header names", func(t *testing.T) { - headers := map[string]string{ - "authorization": "Bearer token", - "COOKIE": "session=abc", - "x-api-key": "key123", - } - - redactHeaders(headers) - - assert.Equal(t, "[REDACTED]", headers["authorization"]) - assert.Equal(t, "[REDACTED]", headers["COOKIE"]) - assert.Equal(t, "[REDACTED]", headers["x-api-key"]) - }) - - t.Run("handles empty headers", func(t *testing.T) { - headers := map[string]string{} - redactHeaders(headers) - assert.Empty(t, headers) - }) -} - -func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) { - t.Run("captures request and response when enabled", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 5) - - requestBody := `{"model": "test", "prompt": "hello"}` - responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}` - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Custom", "header-value") - w.WriteHeader(http.StatusOK) - w.Write([]byte(responseBody)) - return nil - } - - req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer secret") - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - // Check metric was recorded - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - metricID := metrics[0].ID - - // Check capture was stored with same ID (decompressed) - capture := mm.getCaptureByID(metricID) - assert.NotNil(t, capture) - assert.Equal(t, metricID, capture.ID) - assert.Equal(t, []byte(requestBody), capture.ReqBody) - assert.Equal(t, []byte(responseBody), capture.RespBody) - assert.Equal(t, "/test", capture.ReqPath) - assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"]) - assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"]) - assert.Equal(t, "application/json", capture.RespHeaders["Content-Type"]) - assert.Equal(t, "header-value", capture.RespHeaders["X-Custom"]) - }) - - t.Run("does not capture when disabled", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 0) - - requestBody := `{"model": "test"}` - responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}` - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write([]byte(responseBody)) - return nil - } - - req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody)) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) - assert.NoError(t, err) - - // Metrics should still be recorded - metrics := mm.getMetrics() - assert.Equal(t, 1, len(metrics)) - - // But no capture - assert.Nil(t, mm.getCaptureByID(metrics[0].ID)) - }) -} - -func TestMetricsMonitor_WrapHandler_PartialCaptures(t *testing.T) { - requestBody := `{"model": "test"}` - responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}` - - nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Custom", "header-value") - w.WriteHeader(http.StatusOK) - w.Write([]byte(responseBody)) - return nil - } - - t.Run("only request headers", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 100) - req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer secret") - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqHeaders, nextHandler) - assert.NoError(t, err) - - capture := mm.getCaptureByID(mm.getMetrics()[0].ID) - assert.NotNil(t, capture) - assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"]) - assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"]) - assert.Nil(t, capture.ReqBody) - assert.Nil(t, capture.RespHeaders) - assert.Nil(t, capture.RespBody) - }) - - t.Run("only request body", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 100) - req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody)) - req.Header.Set("Content-Type", "application/json") - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqBody, nextHandler) - assert.NoError(t, err) - - capture := mm.getCaptureByID(mm.getMetrics()[0].ID) - assert.NotNil(t, capture) - assert.Nil(t, capture.ReqHeaders) - assert.Equal(t, []byte(requestBody), capture.ReqBody) - assert.Nil(t, capture.RespHeaders) - assert.Nil(t, capture.RespBody) - }) - - t.Run("only response headers", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 100) - req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody)) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureRespHeaders, nextHandler) - assert.NoError(t, err) - - capture := mm.getCaptureByID(mm.getMetrics()[0].ID) - assert.NotNil(t, capture) - assert.Nil(t, capture.ReqHeaders) - assert.Nil(t, capture.ReqBody) - assert.Equal(t, "application/json", capture.RespHeaders["Content-Type"]) - assert.Equal(t, "header-value", capture.RespHeaders["X-Custom"]) - assert.Nil(t, capture.RespBody) - }) - - t.Run("only response body", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 100) - req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody)) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureRespBody, nextHandler) - assert.NoError(t, err) - - capture := mm.getCaptureByID(mm.getMetrics()[0].ID) - assert.NotNil(t, capture) - assert.Nil(t, capture.ReqHeaders) - assert.Nil(t, capture.ReqBody) - assert.Nil(t, capture.RespHeaders) - assert.Equal(t, []byte(responseBody), capture.RespBody) - }) - - t.Run("captureReqAll", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 100) - req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer secret") - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqAll, nextHandler) - assert.NoError(t, err) - - capture := mm.getCaptureByID(mm.getMetrics()[0].ID) - assert.NotNil(t, capture) - assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"]) - assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"]) - assert.Equal(t, []byte(requestBody), capture.ReqBody) - assert.Nil(t, capture.RespHeaders) - assert.Nil(t, capture.RespBody) - }) - - t.Run("captureRespAll", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 100) - req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody)) - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureRespAll, nextHandler) - assert.NoError(t, err) - - capture := mm.getCaptureByID(mm.getMetrics()[0].ID) - assert.NotNil(t, capture) - assert.Nil(t, capture.ReqHeaders) - assert.Nil(t, capture.ReqBody) - assert.Equal(t, "application/json", capture.RespHeaders["Content-Type"]) - assert.Equal(t, "header-value", capture.RespHeaders["X-Custom"]) - assert.Equal(t, []byte(responseBody), capture.RespBody) - }) - - t.Run("no flags", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 100) - req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody)) - req.Header.Set("Content-Type", "application/json") - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureFields(0), nextHandler) - assert.NoError(t, err) - - capture := mm.getCaptureByID(mm.getMetrics()[0].ID) - assert.NotNil(t, capture) - assert.Nil(t, capture.ReqHeaders) - assert.Nil(t, capture.ReqBody) - assert.Nil(t, capture.RespHeaders) - assert.Nil(t, capture.RespBody) - }) - - t.Run("mixed flags req headers and resp body", func(t *testing.T) { - mm := newMetricsMonitor(testLogger, 10, 100) - req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer secret") - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - - err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqHeaders|captureRespBody, nextHandler) - assert.NoError(t, err) - - capture := mm.getCaptureByID(mm.getMetrics()[0].ID) - assert.NotNil(t, capture) - assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"]) - assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"]) - assert.Nil(t, capture.ReqBody) - assert.Nil(t, capture.RespHeaders) - assert.Equal(t, []byte(responseBody), capture.RespBody) - }) -} diff --git a/proxy/peerproxy.go b/proxy/peerproxy.go deleted file mode 100644 index 98e2ba1..0000000 --- a/proxy/peerproxy.go +++ /dev/null @@ -1,144 +0,0 @@ -package proxy - -import ( - "fmt" - "net" - "net/http" - "net/http/httputil" - "runtime" - "sort" - "strings" - "time" - - "github.com/mostlygeek/llama-swap/internal/config" - "github.com/mostlygeek/llama-swap/internal/logmon" -) - -type peerProxyMember struct { - peerID string - reverseProxy *httputil.ReverseProxy - apiKey string -} - -type PeerProxy struct { - peers config.PeerDictionaryConfig - proxyMap map[string]*peerProxyMember -} - -func NewPeerProxy(peers config.PeerDictionaryConfig, proxyLogger *logmon.Monitor) (*PeerProxy, error) { - proxyMap := make(map[string]*peerProxyMember) - - // Sort peer IDs for consistent iteration order - peerIDs := make([]string, 0, len(peers)) - for peerID := range peers { - peerIDs = append(peerIDs, peerID) - } - sort.Strings(peerIDs) - - for _, peerID := range peerIDs { - peer := peers[peerID] - - // Create a transport with per-peer timeout configuration - peerTransport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: time.Duration(peer.Timeouts.Connect) * time.Second, - KeepAlive: time.Duration(peer.Timeouts.KeepAlive) * time.Second, - }).DialContext, - TLSHandshakeTimeout: time.Duration(peer.Timeouts.TLSHandshake) * time.Second, - ResponseHeaderTimeout: time.Duration(peer.Timeouts.ResponseHeader) * time.Second, - ExpectContinueTimeout: time.Duration(peer.Timeouts.ExpectContinue) * time.Second, - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - MaxIdleConnsPerHost: 10, - IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second, - } - - // Create reverse proxy for this peer - reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL) - reverseProxy.Transport = peerTransport - - // Wrap Director to set Host header for remote hosts (not localhost) - originalDirector := reverseProxy.Director - reverseProxy.Director = func(req *http.Request) { - originalDirector(req) - // Ensure Host header matches target URL for remote proxying - req.Host = req.URL.Host - } - - reverseProxy.ModifyResponse = func(resp *http.Response) error { - if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") { - resp.Header.Set("X-Accel-Buffering", "no") - } - return nil - } - - reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { - proxyLogger.Warnf("peer %s: proxy error: %v", peerID, err) - errMsg := fmt.Sprintf("peer proxy error: %v", err) - if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") { - errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)" - } - http.Error(w, errMsg, http.StatusBadGateway) - } - - pp := &peerProxyMember{ - peerID: peerID, - reverseProxy: reverseProxy, - apiKey: peer.ApiKey, - } - - // Map each model to this peer's proxy - for _, modelID := range peer.Models { - if _, found := proxyMap[modelID]; found { - proxyLogger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID) - continue - } - proxyMap[modelID] = pp - } - } - - return &PeerProxy{ - peers: peers, - proxyMap: proxyMap, - }, nil -} - -func (p *PeerProxy) HasPeerModel(modelID string) bool { - _, found := p.proxyMap[modelID] - return found -} - -// GetPeerFilters returns the filters for a peer model, or empty filters if not found -func (p *PeerProxy) GetPeerFilters(modelID string) config.Filters { - pp, found := p.proxyMap[modelID] - if !found { - return config.Filters{} - } - // Get the peer config using the peerID - peer, found := p.peers[pp.peerID] - if !found { - return config.Filters{} - } - return peer.Filters -} - -func (p *PeerProxy) ListPeers() config.PeerDictionaryConfig { - return p.peers -} - -func (p *PeerProxy) ProxyRequest(model_id string, writer http.ResponseWriter, request *http.Request) error { - pp, found := p.proxyMap[model_id] - if !found { - return fmt.Errorf("no peer proxy found for model %s", model_id) - } - - // Inject API key if configured for this peer - if pp.apiKey != "" { - request.Header.Set("Authorization", "Bearer "+pp.apiKey) - request.Header.Set("x-api-key", pp.apiKey) - } - - pp.reverseProxy.ServeHTTP(writer, request) - return nil -} diff --git a/proxy/peerproxy_test.go b/proxy/peerproxy_test.go deleted file mode 100644 index 1837c6e..0000000 --- a/proxy/peerproxy_test.go +++ /dev/null @@ -1,311 +0,0 @@ -package proxy - -import ( - "net/http" - "net/http/httptest" - "net/url" - "strings" - "testing" - "time" - - "github.com/mostlygeek/llama-swap/internal/config" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestNewPeerProxy_EmptyPeers(t *testing.T) { - peers := config.PeerDictionaryConfig{} - pm, err := NewPeerProxy(peers, testLogger) - require.NoError(t, err) - assert.NotNil(t, pm) - assert.Empty(t, pm.proxyMap) -} - -func TestNewPeerProxy_SinglePeer(t *testing.T) { - proxyURL, _ := url.Parse("http://peer1.example.com:8080") - peers := config.PeerDictionaryConfig{ - "peer1": config.PeerConfig{ - Proxy: "http://peer1.example.com:8080", - ProxyURL: proxyURL, - ApiKey: "test-key", - Models: []string{"model-a", "model-b"}, - }, - } - - pm, err := NewPeerProxy(peers, testLogger) - require.NoError(t, err) - assert.Len(t, pm.proxyMap, 2) - assert.True(t, pm.HasPeerModel("model-a")) - assert.True(t, pm.HasPeerModel("model-b")) - assert.False(t, pm.HasPeerModel("model-c")) -} - -func TestNewPeerProxy_MultiplePeers(t *testing.T) { - proxyURL1, _ := url.Parse("http://peer1.example.com:8080") - proxyURL2, _ := url.Parse("http://peer2.example.com:8080") - peers := config.PeerDictionaryConfig{ - "peer1": config.PeerConfig{ - Proxy: "http://peer1.example.com:8080", - ProxyURL: proxyURL1, - Models: []string{"model-a", "model-b"}, - }, - "peer2": config.PeerConfig{ - Proxy: "http://peer2.example.com:8080", - ProxyURL: proxyURL2, - Models: []string{"model-c", "model-d"}, - }, - } - - pm, err := NewPeerProxy(peers, testLogger) - require.NoError(t, err) - assert.Len(t, pm.proxyMap, 4) - assert.True(t, pm.HasPeerModel("model-a")) - assert.True(t, pm.HasPeerModel("model-b")) - assert.True(t, pm.HasPeerModel("model-c")) - assert.True(t, pm.HasPeerModel("model-d")) -} - -func TestNewPeerProxy_DuplicateModelWarning(t *testing.T) { - // When the same model is in multiple peers, only the first (lexicographically by peer ID) - // should be mapped, and a warning should be logged - proxyURL1, _ := url.Parse("http://peer1.example.com:8080") - proxyURL2, _ := url.Parse("http://peer2.example.com:8080") - peers := config.PeerDictionaryConfig{ - "alpha-peer": config.PeerConfig{ - Proxy: "http://peer1.example.com:8080", - ProxyURL: proxyURL1, - Models: []string{"duplicate-model"}, - }, - "beta-peer": config.PeerConfig{ - Proxy: "http://peer2.example.com:8080", - ProxyURL: proxyURL2, - Models: []string{"duplicate-model"}, - }, - } - - pm, err := NewPeerProxy(peers, testLogger) - require.NoError(t, err) - // Should only have one entry for the duplicate model - assert.Len(t, pm.proxyMap, 1) - assert.True(t, pm.HasPeerModel("duplicate-model")) -} - -func TestHasPeerModel(t *testing.T) { - proxyURL, _ := url.Parse("http://peer1.example.com:8080") - peers := config.PeerDictionaryConfig{ - "peer1": config.PeerConfig{ - Proxy: "http://peer1.example.com:8080", - ProxyURL: proxyURL, - Models: []string{"existing-model"}, - }, - } - - pm, err := NewPeerProxy(peers, testLogger) - require.NoError(t, err) - - assert.True(t, pm.HasPeerModel("existing-model")) - assert.False(t, pm.HasPeerModel("non-existing-model")) -} - -func TestProxyRequest_ModelNotFound(t *testing.T) { - peers := config.PeerDictionaryConfig{} - pm, err := NewPeerProxy(peers, testLogger) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - w := httptest.NewRecorder() - - err = pm.ProxyRequest("non-existing-model", w, req) - assert.Error(t, err) - assert.Contains(t, err.Error(), "no peer proxy found for model non-existing-model") -} - -func TestProxyRequest_Success(t *testing.T) { - // Create a test server to act as the peer - testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte("response from peer")) - })) - defer testServer.Close() - - proxyURL, _ := url.Parse(testServer.URL) - peers := config.PeerDictionaryConfig{ - "peer1": config.PeerConfig{ - Proxy: testServer.URL, - ProxyURL: proxyURL, - Models: []string{"test-model"}, - }, - } - - pm, err := NewPeerProxy(peers, testLogger) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - w := httptest.NewRecorder() - - err = pm.ProxyRequest("test-model", w, req) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "response from peer", w.Body.String()) -} - -func TestProxyRequest_ApiKeyInjection(t *testing.T) { - // Create a test server that checks for the Authorization header - var receivedAuthHeader string - testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedAuthHeader = r.Header.Get("Authorization") - w.WriteHeader(http.StatusOK) - })) - defer testServer.Close() - - proxyURL, _ := url.Parse(testServer.URL) - peers := config.PeerDictionaryConfig{ - "peer1": config.PeerConfig{ - Proxy: testServer.URL, - ProxyURL: proxyURL, - ApiKey: "secret-api-key", - Models: []string{"test-model"}, - }, - } - - pm, err := NewPeerProxy(peers, testLogger) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - w := httptest.NewRecorder() - - err = pm.ProxyRequest("test-model", w, req) - assert.NoError(t, err) - assert.Equal(t, "Bearer secret-api-key", receivedAuthHeader) -} - -func TestProxyRequest_NoApiKey(t *testing.T) { - // Create a test server that checks for the Authorization header - var receivedAuthHeader string - testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedAuthHeader = r.Header.Get("Authorization") - w.WriteHeader(http.StatusOK) - })) - defer testServer.Close() - - proxyURL, _ := url.Parse(testServer.URL) - peers := config.PeerDictionaryConfig{ - "peer1": config.PeerConfig{ - Proxy: testServer.URL, - ProxyURL: proxyURL, - ApiKey: "", // No API key - Models: []string{"test-model"}, - }, - } - - pm, err := NewPeerProxy(peers, testLogger) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - w := httptest.NewRecorder() - - err = pm.ProxyRequest("test-model", w, req) - assert.NoError(t, err) - assert.Empty(t, receivedAuthHeader) -} - -func TestProxyRequest_HostHeaderSet(t *testing.T) { - // Create a test server that checks the Host header - var receivedHost string - testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedHost = r.Host - w.WriteHeader(http.StatusOK) - })) - defer testServer.Close() - - proxyURL, _ := url.Parse(testServer.URL) - peers := config.PeerDictionaryConfig{ - "peer1": config.PeerConfig{ - Proxy: testServer.URL, - ProxyURL: proxyURL, - Models: []string{"test-model"}, - }, - } - - pm, err := NewPeerProxy(peers, testLogger) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - w := httptest.NewRecorder() - - err = pm.ProxyRequest("test-model", w, req) - assert.NoError(t, err) - // The Host header should be set to the target URL's host - assert.True(t, strings.HasPrefix(receivedHost, "127.0.0.1:")) -} - -func TestProxyRequest_SSEHeaderModification(t *testing.T) { - // Create a test server that returns SSE content type - testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - w.WriteHeader(http.StatusOK) - })) - defer testServer.Close() - - proxyURL, _ := url.Parse(testServer.URL) - peers := config.PeerDictionaryConfig{ - "peer1": config.PeerConfig{ - Proxy: testServer.URL, - ProxyURL: proxyURL, - Models: []string{"test-model"}, - }, - } - - pm, err := NewPeerProxy(peers, testLogger) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - w := httptest.NewRecorder() - - err = pm.ProxyRequest("test-model", w, req) - assert.NoError(t, err) - // The X-Accel-Buffering header should be set to "no" for SSE - assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering")) -} - -func TestNewPeerProxy_CustomTimeouts(t *testing.T) { - proxyURL, _ := url.Parse("http://localhost:8080") - - peers := config.PeerDictionaryConfig{ - "test-peer": config.PeerConfig{ - Proxy: "http://localhost:8080", - ProxyURL: proxyURL, - Models: []string{"model1"}, - Timeouts: config.TimeoutsConfig{ - Connect: 45, - ResponseHeader: 300, - TLSHandshake: 15, - ExpectContinue: 2, - IdleConn: 120, - }, - }, - } - - peerProxy, err := NewPeerProxy(peers, testLogger) - - assert.NoError(t, err) - assert.NotNil(t, peerProxy) - assert.True(t, peerProxy.HasPeerModel("model1")) - - // Verify the timeout values are actually applied to the transport - member, found := peerProxy.proxyMap["model1"] - require.True(t, found, "model1 should exist in proxyMap") - assert.NotNil(t, member.reverseProxy) - assert.NotNil(t, member.reverseProxy.Transport) - - transport, ok := member.reverseProxy.Transport.(*http.Transport) - require.True(t, ok, "Transport should be *http.Transport") - - // Verify all timeout values are correctly applied - assert.Equal(t, 300*time.Second, transport.ResponseHeaderTimeout) - assert.Equal(t, 15*time.Second, transport.TLSHandshakeTimeout) - assert.Equal(t, 2*time.Second, transport.ExpectContinueTimeout) - assert.Equal(t, 120*time.Second, transport.IdleConnTimeout) - // ForceAttemptHTTP2 should be enabled - assert.True(t, transport.ForceAttemptHTTP2) -} diff --git a/proxy/process.go b/proxy/process.go deleted file mode 100644 index e1117a8..0000000 --- a/proxy/process.go +++ /dev/null @@ -1,956 +0,0 @@ -package proxy - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "math/rand" - "net" - "net/http" - "net/http/httputil" - "net/url" - "os/exec" - "strings" - "sync" - "sync/atomic" - "syscall" - "time" - - "github.com/mostlygeek/llama-swap/internal/config" - "github.com/mostlygeek/llama-swap/internal/event" - "github.com/mostlygeek/llama-swap/internal/logmon" -) - -type ProcessState string - -const ( - StateStopped ProcessState = ProcessState("stopped") - StateStarting ProcessState = ProcessState("starting") - StateReady ProcessState = ProcessState("ready") - StateStopping ProcessState = ProcessState("stopping") - - // process is shutdown and will not be restarted - StateShutdown ProcessState = ProcessState("shutdown") -) - -type StopStrategy int - -const ( - StopImmediately StopStrategy = iota - StopWaitForInflightRequest -) - -type Process struct { - ID string - config config.ModelConfig - cmd *exec.Cmd - reverseProxy *httputil.ReverseProxy - - // PR #155 called to cancel the upstream process - cmdMutex sync.RWMutex - cancelUpstream context.CancelFunc - - // closed when command exits - cmdWaitChan chan struct{} - - processLogger *logmon.Monitor - proxyLogger *logmon.Monitor - - healthCheckTimeout int - healthCheckLoopInterval time.Duration - - lastRequestHandledMutex sync.RWMutex - lastRequestHandled time.Time - - stateMutex sync.RWMutex - state ProcessState - - inFlightRequests sync.WaitGroup - inFlightRequestsCount atomic.Int32 - - // used to block on multiple start() calls - waitStarting sync.WaitGroup - - // for managing concurrency limits - concurrencyLimitSemaphore chan struct{} - - // used for testing to override the default value - gracefulStopTimeout time.Duration - - // used for testing to bypass subprocess and reverse proxy - testHandler http.Handler - - // track the number of failed starts - failedStartCount int -} - -func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, processLogger *logmon.Monitor, proxyLogger *logmon.Monitor) *Process { - concurrentLimit := 10 - if config.ConcurrencyLimit > 0 { - concurrentLimit = config.ConcurrencyLimit - } - - // Setup the reverse proxy. - proxyURL, err := url.Parse(config.Proxy) - if err != nil { - proxyLogger.Errorf("<%s> invalid proxy URL %q: %v", ID, config.Proxy, err) - } - - var reverseProxy *httputil.ReverseProxy - if proxyURL != nil { - reverseProxy = httputil.NewSingleHostReverseProxy(proxyURL) - - // Create custom transport with configured timeouts - transport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: time.Duration(config.Timeouts.Connect) * time.Second, - KeepAlive: time.Duration(config.Timeouts.KeepAlive) * time.Second, - }).DialContext, - TLSHandshakeTimeout: time.Duration(config.Timeouts.TLSHandshake) * time.Second, - ResponseHeaderTimeout: time.Duration(config.Timeouts.ResponseHeader) * time.Second, - ExpectContinueTimeout: time.Duration(config.Timeouts.ExpectContinue) * time.Second, - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - MaxIdleConnsPerHost: 10, - IdleConnTimeout: time.Duration(config.Timeouts.IdleConn) * time.Second, - } - reverseProxy.Transport = transport - - reverseProxy.ModifyResponse = func(resp *http.Response) error { - // prevent nginx from buffering streaming responses (e.g., SSE) - if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") { - resp.Header.Set("X-Accel-Buffering", "no") - } - return nil - } - } - - return &Process{ - ID: ID, - config: config, - cmd: nil, - reverseProxy: reverseProxy, - cancelUpstream: nil, - processLogger: processLogger, - proxyLogger: proxyLogger, - healthCheckTimeout: healthCheckTimeout, - healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */ - state: StateStopped, - - // concurrency limit - concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit), - - // To be removed when migration over exec.CommandContext is complete - // stop timeout - gracefulStopTimeout: 10 * time.Second, - cmdWaitChan: make(chan struct{}), - } -} - -// LogMonitor returns the log monitor associated with the process. -func (p *Process) LogMonitor() *logmon.Monitor { - return p.processLogger -} - -// setLastRequestHandled sets the last request handled time in a thread-safe manner. -func (p *Process) setLastRequestHandled(t time.Time) { - p.lastRequestHandledMutex.Lock() - defer p.lastRequestHandledMutex.Unlock() - p.lastRequestHandled = t -} - -// getLastRequestHandled gets the last request handled time in a thread-safe manner. -func (p *Process) getLastRequestHandled() time.Time { - p.lastRequestHandledMutex.RLock() - defer p.lastRequestHandledMutex.RUnlock() - return p.lastRequestHandled -} - -// custom error types for swapping state -var ( - ErrExpectedStateMismatch = errors.New("expected state mismatch") - ErrInvalidStateTransition = errors.New("invalid state transition") -) - -// swapState performs a compare and swap of the state atomically. It returns the current state -// and an error if the swap failed. -func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, error) { - p.stateMutex.Lock() - defer p.stateMutex.Unlock() - - if p.state != expectedState { - p.proxyLogger.Warnf("<%s> swapState() Unexpected current state %s, expected %s", p.ID, p.state, expectedState) - return p.state, ErrExpectedStateMismatch - } - - if !isValidTransition(p.state, newState) { - p.proxyLogger.Warnf("<%s> swapState() Invalid state transition from %s to %s", p.ID, p.state, newState) - return p.state, ErrInvalidStateTransition - } - - p.state = newState - - // Atomically increment waitStarting when entering StateStarting - // This ensures any thread that sees StateStarting will also see the WaitGroup counter incremented - if newState == StateStarting { - p.waitStarting.Add(1) - } - - p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState) - event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState}) - return p.state, nil -} - -// Helper function to encapsulate transition rules -func isValidTransition(from, to ProcessState) bool { - switch from { - case StateStopped: - return to == StateStarting - case StateStarting: - return to == StateReady || to == StateStopping || to == StateStopped - case StateReady: - return to == StateStopping - case StateStopping: - return to == StateStopped || to == StateShutdown - case StateShutdown: - return false // No transitions allowed from these states - } - return false -} - -func (p *Process) CurrentState() ProcessState { - p.stateMutex.RLock() - defer p.stateMutex.RUnlock() - return p.state -} - -// forceState forces the process state to the new state with mutex protection. -// This should only be used in exceptional cases where the normal state transition -// validation via swapState() cannot be used. -func (p *Process) forceState(newState ProcessState) { - p.stateMutex.Lock() - defer p.stateMutex.Unlock() - p.state = newState -} - -// start starts the upstream command, checks the health endpoint, and sets the state to Ready -// it is a private method because starting is automatic but stopping can be called -// at any time. -func (p *Process) start() error { - - // test-only fast path: skip subprocess, health check, and TTL goroutine - if p.testHandler != nil { - if curState, err := p.swapState(StateStopped, StateStarting); err != nil { - if err == ErrExpectedStateMismatch { - if curState == StateStarting { - p.waitStarting.Wait() - curState = p.CurrentState() - if curState == StateReady { - return nil - } - return fmt.Errorf("process was already starting but wound up in state %v", curState) - } - return fmt.Errorf("process was in state %v when start() was called", curState) - } - return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err) - } - defer p.waitStarting.Done() - - // Mimic the real stop path: cancelUpstream transitions - // StateStopping -> StateStopped and closes cmdWaitChan, - // matching what waitForCmd does for real subprocesses. - ch := make(chan struct{}) - p.cmdMutex.Lock() - p.cancelUpstream = func() { - if curState := p.CurrentState(); curState == StateStopping { - if _, err := p.swapState(StateStopping, StateStopped); err != nil { - p.forceState(StateStopped) - } - } else { - p.forceState(StateStopped) - } - close(ch) - } - p.cmdWaitChan = ch - p.cmdMutex.Unlock() - - if curState, err := p.swapState(StateStarting, StateReady); err != nil { - return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err) - } - p.failedStartCount = 0 - return nil - } - - if p.config.Proxy == "" { - return fmt.Errorf("can not start(), upstream proxy missing") - } - - args, err := p.config.SanitizedCommand() - if err != nil { - return fmt.Errorf("unable to get sanitized command: %v", err) - } - - if curState, err := p.swapState(StateStopped, StateStarting); err != nil { - if err == ErrExpectedStateMismatch { - // already starting, just wait for it to complete and expect - // it to be be in the Ready start after. If not, return an error - if curState == StateStarting { - p.waitStarting.Wait() - if state := p.CurrentState(); state == StateReady { - return nil - } else { - return fmt.Errorf("process was already starting but wound up in state %v", state) - } - } else { - return fmt.Errorf("process was in state %v when start() was called", curState) - } - } else { - return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err) - } - } - - // waitStarting.Add(1) is now called atomically in swapState() when transitioning to StateStarting - defer p.waitStarting.Done() - cmdContext, ctxCancelUpstream := context.WithCancel(context.Background()) - - p.cmd = exec.CommandContext(cmdContext, args[0], args[1:]...) - p.cmd.Stdout = p.processLogger - p.cmd.Stderr = p.processLogger - p.cmd.Env = append(p.cmd.Environ(), p.config.Env...) - p.cmd.Cancel = p.cmdStopUpstreamProcess - p.cmd.WaitDelay = p.gracefulStopTimeout - setProcAttributes(p.cmd) - - p.cmdMutex.Lock() - p.cancelUpstream = ctxCancelUpstream - p.cmdWaitChan = make(chan struct{}) - p.cmdMutex.Unlock() - - p.failedStartCount++ // this will be reset to zero when the process has successfully started - - p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.ID, strings.Join(args, " "), strings.Join(p.config.Env, ", ")) - err = p.cmd.Start() - - // Set process state to failed - if err != nil { - if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil { - p.forceState(StateStopped) // force it into a stopped state - return fmt.Errorf( - "failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v", - strings.Join(args, " "), err, curState, swapErr, - ) - } - return fmt.Errorf("start() failed for command '%s': %v", strings.Join(args, " "), err) - } - - // Capture the exit error for later signalling - go p.waitForCmd() - - // One of three things can happen at this stage: - // 1. The command exits unexpectedly - // 2. The health check fails - // 3. The health check passes - // - // only in the third case will the process be considered Ready to accept - <-time.After(250 * time.Millisecond) // give process a bit of time to start - - checkStartTime := time.Now() - maxDuration := time.Second * time.Duration(p.healthCheckTimeout) - checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint) - - // a "none" means don't check for health ... I could have picked a better word :facepalm: - if checkEndpoint != "none" { - proxyTo := p.config.Proxy - healthURL, err := url.JoinPath(proxyTo, checkEndpoint) - if err != nil { - return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint) - } - - // Ready Check loop - for { - currentState := p.CurrentState() - if currentState != StateStarting { - if currentState == StateStopped { - return fmt.Errorf("upstream command exited prematurely but successfully") - } - return errors.New("health check interrupted due to shutdown") - } - - if time.Since(checkStartTime) > maxDuration { - p.stopCommand() - return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds()) - } - - if err := p.checkHealthEndpoint(healthURL); err == nil { - p.proxyLogger.Infof("<%s> Health check passed on %s", p.ID, healthURL) - break - } else { - if strings.Contains(err.Error(), "connection refused") { - ttl := time.Until(checkStartTime.Add(maxDuration)) - p.proxyLogger.Debugf("<%s> Connection refused on %s, giving up in %.0fs (normal during startup)", p.ID, healthURL, ttl.Seconds()) - } else { - p.proxyLogger.Debugf("<%s> Health check error on %s, %v (normal during startup)", p.ID, healthURL, err) - } - } - <-time.After(p.healthCheckLoopInterval) - } - } - - if p.config.UnloadAfter > 0 { - // start a goroutine to check every second if - // the process should be stopped - go func() { - maxDuration := time.Duration(p.config.UnloadAfter) * time.Second - - for range time.Tick(time.Second) { - if p.CurrentState() != StateReady { - return - } - - // skip the TTL check if there are inflight requests - if p.inFlightRequestsCount.Load() != 0 { - continue - } - - if time.Since(p.getLastRequestHandled()) > maxDuration { - p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter) - p.Stop() - return - } - } - }() - } - - if curState, err := p.swapState(StateStarting, StateReady); err != nil { - return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err) - } else { - p.failedStartCount = 0 - return nil - } -} - -// Stop will wait for inflight requests to complete before stopping the process. -func (p *Process) Stop() { - - // guard to prevent multiple goroutines from stopping - if !isValidTransition(p.CurrentState(), StateStopping) { - p.proxyLogger.Debugf("<%s> Stop() suppressing invalid transition from %s to StateStopping", p.ID, p.CurrentState()) - return - } - - // wait for any inflight requests before proceeding - p.proxyLogger.Debugf("<%s> Stop(): Waiting for inflight requests to complete", p.ID) - p.inFlightRequests.Wait() - p.StopImmediately() -} - -// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM. -// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL. -func (p *Process) StopImmediately() { - - // guard to prevent multiple goroutines from stopping the process - enterState := p.CurrentState() - if !isValidTransition(enterState, StateStopping) { - p.proxyLogger.Debugf("<%s> StopImmediate() suppressing invalid transition from %s to StateStopping", p.ID, p.CurrentState()) - return - } - - p.proxyLogger.Debugf("<%s> Stopping process, enter state: %s", p.ID, enterState) - if curState, err := p.swapState(enterState, StateStopping); err != nil { - p.proxyLogger.Infof("<%s> Stop() %s -> StateStopping err: %v, current state: %v", p.ID, enterState, err, curState) - return - } - - p.stopCommand() -} - -// Shutdown is called when llama-swap is shutting down. It will give a little bit -// of time for any inflight requests to complete before shutting down. If the Process -// is in the state of starting, it will cancel it and shut it down. Once a process is in -// the StateShutdown state, it can not be started again. -func (p *Process) Shutdown() { - if !isValidTransition(p.CurrentState(), StateStopping) { - return - } - - p.stopCommand() - // just force it to this state since there is no recovery from shutdown - p.forceState(StateShutdown) -} - -// stopCommand will send a SIGTERM to the process and wait for it to exit. -// If it does not exit within 5 seconds, it will send a SIGKILL. -func (p *Process) stopCommand() { - stopStartTime := time.Now() - defer func() { - p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime)) - - // free the buffer in processLogger so the memory can be recovered - p.processLogger.Clear() - }() - - p.cmdMutex.RLock() - cancelUpstream := p.cancelUpstream - cmdWaitChan := p.cmdWaitChan - p.cmdMutex.RUnlock() - - if cancelUpstream == nil { - p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID) - return - } - - cancelUpstream() - <-cmdWaitChan -} - -func (p *Process) checkHealthEndpoint(healthURL string) error { - - client := &http.Client{ - // wait a short time for a tcp connection to be established - Transport: &http.Transport{ - DialContext: (&net.Dialer{ - Timeout: 500 * time.Millisecond, - }).DialContext, - }, - - // give a long time to respond to the health check endpoint - // after the connection is established. See issue: 276 - Timeout: 5000 * time.Millisecond, - } - - req, err := http.NewRequest("GET", healthURL, nil) - if err != nil { - return err - } - - resp, err := client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - - // got a response but it was not an OK - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("status code: %d", resp.StatusCode) - } - - return nil -} - -func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { - - if p.reverseProxy == nil { - http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError) - return - } - - requestBeginTime := time.Now() - var startDuration time.Duration - - // prevent new requests from being made while stopping or irrecoverable - currentState := p.CurrentState() - if currentState == StateShutdown || currentState == StateStopping { - http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable) - return - } - - select { - case p.concurrencyLimitSemaphore <- struct{}{}: - defer func() { <-p.concurrencyLimitSemaphore }() - default: - http.Error(w, "Too many requests", http.StatusTooManyRequests) - return - } - - p.inFlightRequests.Add(1) - p.inFlightRequestsCount.Add(1) - defer func() { - p.setLastRequestHandled(time.Now()) - p.inFlightRequestsCount.Add(-1) - p.inFlightRequests.Done() - }() - - // for #366 - // - extract streaming param from request context, should have been set by proxymanager - var srw *statusResponseWriter - swapCtx, cancelLoadCtx := context.WithCancel(r.Context()) - // start the process on demand - if p.CurrentState() != StateReady { - // start a goroutine to stream loading status messages into the response writer - // add a sync so the streaming client only runs when the goroutine has exited - - isStreaming, _ := r.Context().Value(proxyCtxKey("streaming")).(bool) - - // PR #417 (no support for anthropic v1/messages yet) - isChatCompletions := strings.HasPrefix(r.URL.Path, "/v1/chat/completions") - if p.config.SendLoadingState != nil && *p.config.SendLoadingState && isStreaming && isChatCompletions { - srw = newStatusResponseWriter(p, w) - go srw.statusUpdates(swapCtx) - } else { - p.proxyLogger.Debugf("<%s> SendLoadingState is nil or false, not streaming loading state", p.ID) - } - - beginStartTime := time.Now() - if err := p.start(); err != nil { - errstr := fmt.Sprintf("unable to start process: %s", err) - cancelLoadCtx() - if srw != nil { - srw.sendData(fmt.Sprintf("Unable to swap model err: %s\n", errstr)) - // Wait for statusUpdates goroutine to finish writing its deferred "Done!" messages - // before closing the connection. Without this, the connection would close before - // the goroutine can write its cleanup messages, causing incomplete SSE output. - srw.waitForCompletion(100 * time.Millisecond) - } else { - http.Error(w, errstr, http.StatusBadGateway) - } - return - } - startDuration = time.Since(beginStartTime) - } - - // should trigger srw to stop sending loading events ... - cancelLoadCtx() - - // recover from http.ErrAbortHandler panics that can occur when the client - // disconnects before the response is sent - defer func() { - if r := recover(); r != nil { - if r == http.ErrAbortHandler { - p.proxyLogger.Infof("<%s> recovered from client disconnection during streaming", p.ID) - } else { - p.proxyLogger.Infof("<%s> recovered from panic: %v", p.ID, r) - } - } - }() - - if srw != nil { - // Wait for the goroutine to finish writing its final messages - const completionTimeout = 1 * time.Second - if !srw.waitForCompletion(completionTimeout) { - p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout) - } - } - - if p.testHandler != nil { - p.testHandler.ServeHTTP(w, r) - } else if srw != nil { - p.reverseProxy.ServeHTTP(srw, r) - } else { - p.reverseProxy.ServeHTTP(w, r) - } - - totalTime := time.Since(requestBeginTime) - p.proxyLogger.Debugf("<%s> request %s - start: %v, total: %v", - p.ID, r.RequestURI, startDuration, totalTime) -} - -// waitForCmd waits for the command to exit and handles exit conditions depending on current state -func (p *Process) waitForCmd() { - exitErr := p.cmd.Wait() - p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr) - - if exitErr != nil { - if errno, ok := exitErr.(syscall.Errno); ok { - p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno) - } else if exitError, ok := exitErr.(*exec.ExitError); ok { - if strings.Contains(exitError.String(), "signal: terminated") { - p.proxyLogger.Debugf("<%s> Process stopped OK", p.ID) - } else if strings.Contains(exitError.String(), "signal: interrupt") { - p.proxyLogger.Debugf("<%s> Process interrupted OK", p.ID) - } else { - p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode()) - } - } else { - if exitErr.Error() != "context canceled" /* this is normal */ { - p.proxyLogger.Errorf("<%s> Process exited >> %v", p.ID, exitErr) - } - } - } - - currentState := p.CurrentState() - switch currentState { - case StateStopping: - if curState, err := p.swapState(StateStopping, StateStopped); err != nil { - p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err) - p.forceState(StateStopped) - } - default: - p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState) - p.forceState(StateStopped) // force it to be in this state - } - - p.cmdMutex.Lock() - close(p.cmdWaitChan) - p.cmdMutex.Unlock() -} - -// cmdStopUpstreamProcess attemps to stop the upstream process gracefully -func (p *Process) cmdStopUpstreamProcess() error { - p.processLogger.Debugf("<%s> cmdStopUpstreamProcess() initiating graceful stop of upstream process", p.ID) - - // this should never happen ... - if p.cmd == nil || p.cmd.Process == nil { - p.proxyLogger.Debugf("<%s> cmd or cmd.Process is nil (normal during config reload)", p.ID) - return fmt.Errorf("<%s> process is nil or cmd is nil, skipping graceful stop", p.ID) - } - - if p.config.CmdStop != "" { - // replace ${PID} with the pid of the process - stopArgs, err := config.SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid))) - if err != nil { - p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err) - return err - } - - p.proxyLogger.Debugf("<%s> Executing stop command: %s", p.ID, strings.Join(stopArgs, " ")) - - stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...) - stopCmd.Stdout = p.processLogger - stopCmd.Stderr = p.processLogger - setProcAttributes(stopCmd) - stopCmd.Env = p.cmd.Env - - if err := stopCmd.Run(); err != nil { - p.proxyLogger.Errorf("<%s> Failed to exec stop command: %v", p.ID, err) - return err - } - } else { - if err := p.cmd.Process.Signal(syscall.SIGTERM); err != nil { - p.proxyLogger.Errorf("<%s> Failed to send SIGTERM to process: %v", p.ID, err) - return err - } - } - - return nil -} - -// Logger returns the logger for this process. -func (p *Process) Logger() *logmon.Monitor { - return p.processLogger -} - -var loadingRemarks = []string{ - "Still faster than your last standup meeting...", - "Reticulating splines...", - "Waking up the hamsters...", - "Teaching the model manners...", - "Convincing the GPU to participate...", - "Loading weights (they're heavy)...", - "Herding electrons...", - "Compiling excuses for the delay...", - "Downloading more RAM...", - "Asking the model nicely to boot up...", - "Bribing CUDA with cookies...", - "Still loading (blame VRAM)...", - "The model is fashionably late...", - "Warming up those tensors...", - "Making the neural net do push-ups...", - "Your patience is appreciated (really)...", - "Almost there (probably)...", - "Loading like it's 1999...", - "The model forgot where it put its keys...", - "Quantum tunneling through layers...", - "Negotiating with the PCIe bus...", - "Defrosting frozen parameters...", - "Teaching attention heads to focus...", - "Running the matrix (slowly)...", - "Untangling transformer blocks...", - "Calibrating the flux capacitor...", - "Spinning up the probability wheels...", - "Waiting for the GPU to wake from its nap...", - "Converting caffeine to compute...", - "Allocating virtual patience...", - "Performing arcane CUDA rituals...", - "The model is stuck in traffic...", - "Inflating embeddings...", - "Summoning computational demons...", - "Pleading with the OOM killer...", - "Calculating the meaning of life (still at 42)...", - "Training the training wheels...", - "Optimizing the optimizer...", - "Bootstrapping the bootstrapper...", - "Loading loading screen...", - "Processing processing logs...", - "Buffering buffer overflow jokes...", - "The model hit snooze...", - "Debugging the debugger...", - "Compiling the compiler...", - "Parsing the parser (meta)...", - "Tokenizing tokens...", - "Encoding the encoder...", - "Hashing hash browns...", - "Forking spoons (not forks)...", - "The model is contemplating existence...", - "Transcending dimensional barriers...", - "Invoking elder tensor gods...", - "Unfurling probability clouds...", - "Synchronizing parallel universes...", - "The GPU is having second thoughts...", - "Recalibrating reality matrices...", - "Time is an illusion, loading doubly so...", - "Convincing bits to flip themselves...", - "The model is reading its own documentation...", -} - -type statusResponseWriter struct { - hasWritten bool - writer http.ResponseWriter - process *Process - wg sync.WaitGroup // Track goroutine completion - start time.Time -} - -func newStatusResponseWriter(p *Process, w http.ResponseWriter) *statusResponseWriter { - s := &statusResponseWriter{ - writer: w, - process: p, - start: time.Now(), - } - - s.Header().Set("Content-Type", "text/event-stream") // SSE - s.Header().Set("Cache-Control", "no-cache") // no-cache - s.Header().Set("Connection", "keep-alive") // keep-alive - s.WriteHeader(http.StatusOK) // send status code 200 - s.sendLine("━━━━━") - s.sendLine(fmt.Sprintf("llama-swap loading model: %s", p.ID)) - return s -} - -// statusUpdates sends status updates to the client while the model is loading -func (s *statusResponseWriter) statusUpdates(ctx context.Context) { - s.wg.Add(1) - defer s.wg.Done() - - // Recover from panics caused by client disconnection - // Note: recover() only works within the same goroutine, so we need it here - defer func() { - if r := recover(); r != nil { - s.process.proxyLogger.Debugf("<%s> statusUpdates recovered from panic (likely client disconnect): %v", s.process.ID, r) - } - }() - - defer func() { - duration := time.Since(s.start) - s.sendLine(fmt.Sprintf("\nDone! (%.2fs)", duration.Seconds())) - s.sendLine("━━━━━") - s.sendLine(" ") - }() - - // Create a shuffled copy of loadingRemarks - remarks := make([]string, len(loadingRemarks)) - copy(remarks, loadingRemarks) - rand.Shuffle(len(remarks), func(i, j int) { - remarks[i], remarks[j] = remarks[j], remarks[i] - }) - ri := 0 - - // Pick a random duration to send a remark - nextRemarkIn := time.Duration(2+rand.Intn(4)) * time.Second - lastRemarkTime := time.Now() - - ticker := time.NewTicker(time.Second) - defer ticker.Stop() // Ensure ticker is stopped to prevent resource leak - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - if s.process.CurrentState() == StateReady { - return - } - - // Check if it's time for a snarky remark - if time.Since(lastRemarkTime) >= nextRemarkIn { - remark := remarks[ri%len(remarks)] - ri++ - s.sendLine(fmt.Sprintf("\n%s", remark)) - lastRemarkTime = time.Now() - // Pick a new random duration for the next remark - nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second - } else { - s.sendData(".") - } - } - } -} - -// waitForCompletion waits for the statusUpdates goroutine to finish -func (s *statusResponseWriter) waitForCompletion(timeout time.Duration) bool { - done := make(chan struct{}) - go func() { - s.wg.Wait() - close(done) - }() - - select { - case <-done: - return true - case <-time.After(timeout): - return false - } -} - -func (s *statusResponseWriter) sendLine(line string) { - s.sendData(line + "\n") -} - -func (s *statusResponseWriter) sendData(data string) { - // Create the proper SSE JSON structure - type Delta struct { - ReasoningContent string `json:"reasoning_content"` - } - type Choice struct { - Delta Delta `json:"delta"` - } - type SSEMessage struct { - Choices []Choice `json:"choices"` - } - - msg := SSEMessage{ - Choices: []Choice{ - { - Delta: Delta{ - ReasoningContent: data, - }, - }, - }, - } - - jsonData, err := json.Marshal(msg) - if err != nil { - s.process.proxyLogger.Errorf("<%s> Failed to marshal SSE message: %v", s.process.ID, err) - return - } - - // Write SSE formatted data, panic if not able to write - _, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData) - if err != nil { - panic(fmt.Sprintf("<%s> Failed to write SSE data: %v", s.process.ID, err)) - } - s.Flush() -} - -func (s *statusResponseWriter) Header() http.Header { - return s.writer.Header() -} - -func (s *statusResponseWriter) Write(data []byte) (int, error) { - return s.writer.Write(data) -} - -func (s *statusResponseWriter) WriteHeader(statusCode int) { - if s.hasWritten { - return - } - s.hasWritten = true - s.writer.WriteHeader(statusCode) - s.Flush() -} - -func (s *statusResponseWriter) Flush() { - if flusher, ok := s.writer.(http.Flusher); ok { - flusher.Flush() - } -} diff --git a/proxy/process_test.go b/proxy/process_test.go deleted file mode 100644 index d6083c8..0000000 --- a/proxy/process_test.go +++ /dev/null @@ -1,609 +0,0 @@ -package proxy - -import ( - "fmt" - "io" - "net/http" - "net/http/httptest" - "os" - "runtime" - "sync" - "testing" - "time" - - "github.com/mostlygeek/llama-swap/internal/config" - "github.com/mostlygeek/llama-swap/internal/logmon" - "github.com/stretchr/testify/assert" -) - -var ( - debugLogger = logmon.NewWriter(os.Stdout) -) - -func init() { - // flip to help with debugging tests - if false { - debugLogger.SetLogLevel(logmon.LevelDebug) - } else { - debugLogger.SetLogLevel(logmon.LevelError) - } -} - -func TestProcess_AutomaticallyStartsUpstream(t *testing.T) { - - expectedMessage := "testing91931" - config := getTestSimpleResponderConfig(expectedMessage) - - // Create a process - process := NewProcess("test-process", 5, config, debugLogger, debugLogger) - defer process.Stop() - - req := httptest.NewRequest("GET", "/test", nil) - w := httptest.NewRecorder() - - // process is automatically started - assert.Equal(t, StateStopped, process.CurrentState()) - process.ProxyRequest(w, req) - assert.Equal(t, StateReady, process.CurrentState()) - - assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), expectedMessage) - - // Stop the process - process.Stop() - - req = httptest.NewRequest("GET", "/", nil) - w = httptest.NewRecorder() - - // Proxy the request - process.ProxyRequest(w, req) - - // should have automatically started the process again - if w.Code != http.StatusOK { - t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code) - } -} - -// TestProcess_WaitOnMultipleStarts tests that multiple concurrent requests -// are all handled successfully, even though they all may ask for the process to .start() -func TestProcess_WaitOnMultipleStarts(t *testing.T) { - - expectedMessage := "testing91931" - config := getTestSimpleResponderConfig(expectedMessage) - - process := NewProcess("test-process", 5, config, debugLogger, debugLogger) - defer process.Stop() - - var wg sync.WaitGroup - for i := 0; i < 5; i++ { - wg.Add(1) - go func(reqID int) { - defer wg.Done() - req := httptest.NewRequest("GET", "/test", nil) - w := httptest.NewRecorder() - process.ProxyRequest(w, req) - assert.Equal(t, http.StatusOK, w.Code, "Worker %d got wrong HTTP code", reqID) - assert.Contains(t, w.Body.String(), expectedMessage, "Worker %d got wrong message", reqID) - }(i) - } - wg.Wait() - assert.Equal(t, StateReady, process.CurrentState()) -} - -// test that the automatic start returns the expected error type -func TestProcess_BrokenModelConfig(t *testing.T) { - // Create a process configuration - config := config.ModelConfig{ - Cmd: "nonexistent-command", - Proxy: "http://127.0.0.1:9913", - CheckEndpoint: "/health", - } - - process := NewProcess("broken", 1, config, debugLogger, debugLogger) - - req := httptest.NewRequest("GET", "/", nil) - w := httptest.NewRecorder() - process.ProxyRequest(w, req) - assert.Equal(t, http.StatusBadGateway, w.Code) - assert.Contains(t, w.Body.String(), "unable to start process") - - w = httptest.NewRecorder() - process.ProxyRequest(w, req) - assert.Equal(t, http.StatusBadGateway, w.Code) - assert.Contains(t, w.Body.String(), "start() failed for command 'nonexistent-command':") -} - -func TestProcess_UnloadAfterTTL(t *testing.T) { - if testing.Short() { - t.Skip("skipping long auto unload TTL test") - } - - expectedMessage := "I_sense_imminent_danger" - conf := getTestSimpleResponderConfig(expectedMessage) - assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter) - conf.UnloadAfter = 3 // seconds - assert.Equal(t, 3, conf.UnloadAfter) - - process := NewProcess("ttl_test", 2, conf, debugLogger, debugLogger) - defer process.Stop() - - // this should take 4 seconds - req1 := httptest.NewRequest("GET", "/slow-respond?echo=1234&delay=1000ms", nil) - req2 := httptest.NewRequest("GET", "/test", nil) - - w := httptest.NewRecorder() - - // Proxy the request (auto start) with a slow response that takes longer than config.UnloadAfter - process.ProxyRequest(w, req1) - - t.Log("sending slow first request (4 seconds)") - assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "1234") - assert.Equal(t, StateReady, process.CurrentState()) - - // ensure the TTL timeout does not race slow requests (see issue #25) - t.Log("sending second request (1 second)") - time.Sleep(time.Second) - w = httptest.NewRecorder() - process.ProxyRequest(w, req2) - assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), expectedMessage) - assert.Equal(t, StateReady, process.CurrentState()) - - // wait 5 seconds - t.Log("sleep 5 seconds and check if unloaded") - time.Sleep(5 * time.Second) - assert.Equal(t, StateStopped, process.CurrentState()) -} - -func TestProcess_LowTTLValue(t *testing.T) { - if true { // change this code to run this ... - t.Skip("skipping test, edit process_test.go to run it ") - } - - conf := getTestSimpleResponderConfig("fast_ttl") - assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter) - conf.UnloadAfter = 1 // second - assert.Equal(t, 1, conf.UnloadAfter) - - process := NewProcess("ttl", 2, conf, debugLogger, debugLogger) - defer process.Stop() - - for i := 0; i < 100; i++ { - t.Logf("Waiting before sending request %d", i) - time.Sleep(1500 * time.Millisecond) - - expected := fmt.Sprintf("echo=test_%d", i) - req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=50ms", expected), nil) - w := httptest.NewRecorder() - process.ProxyRequest(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), expected) - } - -} - -// issue #19 -// This test makes sure using Process.Stop() does not affect pending HTTP -// requests. All HTTP requests in this test should complete successfully. -func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) { - if testing.Short() { - t.Skip("skipping slow test") - } - - expectedMessage := "12345" - config := getTestSimpleResponderConfig(expectedMessage) - process := NewProcess("t", 10, config, debugLogger, debugLogger) - defer process.Stop() - - results := map[string]string{ - "12345": "", - "abcde": "", - "fghij": "", - } - - var wg sync.WaitGroup - var mu sync.Mutex - - for key := range results { - wg.Add(1) - go func(key string) { - defer wg.Done() - // send a request where simple-responder is will wait 300ms before responding - // this will simulate an in-progress request. - req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=300ms", key), nil) - w := httptest.NewRecorder() - - process.ProxyRequest(w, req) - - if w.Code != http.StatusOK { - t.Errorf("Expected status OK, got %d for key %s", w.Code, key) - } - - mu.Lock() - results[key] = w.Body.String() - mu.Unlock() - - }(key) - } - - // Stop the process while requests are still being processed - go func() { - <-time.After(150 * time.Millisecond) - process.Stop() - }() - - wg.Wait() - - for key, result := range results { - assert.Equal(t, key, result) - } -} - -func TestProcess_SwapState(t *testing.T) { - tests := []struct { - name string - currentState ProcessState - expectedState ProcessState - newState ProcessState - expectedError error - expectedResult ProcessState - }{ - {"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting}, - {"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady}, - {"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping}, - {"Starting to Stopped", StateStarting, StateStarting, StateStopped, nil, StateStopped}, - {"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping}, - {"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped}, - {"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown}, - {"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped}, - {"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady}, - {"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping}, - {"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown}, - {"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown}, - {"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger) - p.state = test.currentState - - resultState, err := p.swapState(test.expectedState, test.newState) - if err != nil && test.expectedError == nil { - t.Errorf("Unexpected error: %v", err) - } else if err == nil && test.expectedError != nil { - t.Errorf("Expected error: %v, but got none", test.expectedError) - } else if err != nil && test.expectedError != nil { - if err.Error() != test.expectedError.Error() { - t.Errorf("Expected error: %v, got: %v", test.expectedError, err) - } - } - - if resultState != test.expectedResult { - t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState) - } - }) - } -} - -func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) { - if testing.Short() { - t.Skip("skipping long shutdown test") - } - - expectedMessage := "testing91931" - - // make a config where the healthcheck will always fail because port is wrong - config := getTestSimpleResponderConfigPort(expectedMessage, 9999) - config.Proxy = "http://localhost:9998/test" - - healthCheckTTLSeconds := 30 - process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger) - - // make it a lot faster - process.healthCheckLoopInterval = time.Second - - // start a goroutine to simulate a shutdown - var wg sync.WaitGroup - go func() { - defer wg.Done() - <-time.After(time.Millisecond * 500) - process.Shutdown() - }() - wg.Add(1) - - // start the process, this is a blocking call - err := process.start() - - wg.Wait() - assert.ErrorContains(t, err, "health check interrupted due to shutdown") - assert.Equal(t, StateShutdown, process.CurrentState()) -} - -func TestProcess_ExitInterruptsHealthCheck(t *testing.T) { - if testing.Short() { - t.Skip("skipping Exit Interrupts Health Check test") - } - - // should run and exit but interrupt the long checkHealthTimeout - checkHealthTimeout := 5 - config := config.ModelConfig{ - Cmd: "sleep 1", - Proxy: "http://127.0.0.1:9913", - CheckEndpoint: "/health", - } - - process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger) - process.healthCheckLoopInterval = time.Second // make it faster - err := process.start() - assert.Equal(t, "upstream command exited prematurely but successfully", err.Error()) - assert.Equal(t, process.CurrentState(), StateStopped) -} - -func TestProcess_ConcurrencyLimit(t *testing.T) { - if testing.Short() { - t.Skip("skipping long concurrency limit test") - } - - expectedMessage := "concurrency_limit_test" - config := getTestSimpleResponderConfig(expectedMessage) - - // only allow 1 concurrent request at a time - config.ConcurrencyLimit = 1 - - process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger) - assert.Equal(t, 1, cap(process.concurrencyLimitSemaphore)) - defer process.Stop() - - // launch a goroutine first to take up the semaphore - go func() { - req1 := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=75ms", nil) - w := httptest.NewRecorder() - process.ProxyRequest(w, req1) - assert.Equal(t, http.StatusOK, w.Code) - }() - - // let the goroutine start - <-time.After(time.Millisecond * 25) - - denied := httptest.NewRequest("GET", "/test", nil) - w := httptest.NewRecorder() - process.ProxyRequest(w, denied) - assert.Equal(t, http.StatusTooManyRequests, w.Code) -} - -func TestProcess_StopImmediately(t *testing.T) { - expectedMessage := "test_stop_immediate" - config := getTestSimpleResponderConfig(expectedMessage) - - process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger) - defer process.Stop() - - err := process.start() - assert.Nil(t, err) - assert.Equal(t, process.CurrentState(), StateReady) - go func() { - // slow, but will get killed by StopImmediate - req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=1s", nil) - w := httptest.NewRecorder() - process.ProxyRequest(w, req) - }() - <-time.After(time.Millisecond) - process.StopImmediately() - assert.Equal(t, process.CurrentState(), StateStopped) -} - -// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates -// the upstream command -func TestProcess_ForceStopWithKill(t *testing.T) { - if testing.Short() { - t.Skip("skipping slow test") - } - - if runtime.GOOS == "windows" { - t.Skip("skipping SIGTERM test on Windows ") - } - - expectedMessage := "test_sigkill" - binaryPath := getSimpleResponderPath() - port := getTestPort() - - conf := config.ModelConfig{ - // note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent - // to force the process to exit - Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage), - Proxy: fmt.Sprintf("http://127.0.0.1:%d", port), - CheckEndpoint: "/health", - } - - process := NewProcess("stop_immediate", 2, conf, debugLogger, debugLogger) - defer process.Stop() - - // reduce to make testing go faster - process.gracefulStopTimeout = time.Second - - err := process.start() - assert.Nil(t, err) - assert.Equal(t, process.CurrentState(), StateReady) - - waitChan := make(chan struct{}) - go func() { - // slow, but will get killed by StopImmediate - req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=2s", nil) - w := httptest.NewRecorder() - process.ProxyRequest(w, req) - - // StatusOK because that was already sent before the kill - assert.Equal(t, http.StatusOK, w.Code) - - // unexpected EOF because the kill happened, the "1" is sent before the kill - // then the unexpected EOF is sent after the kill - if runtime.GOOS == "windows" { - assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host") - } else { - // Upstream may be killed mid-response. - // Assert an incomplete or partial response. - assert.NotEqual(t, "12345", w.Body.String()) - } - - close(waitChan) - }() - - <-time.After(time.Millisecond) - process.StopImmediately() - assert.Equal(t, process.CurrentState(), StateStopped) - - // the request should have been interrupted by SIGKILL - <-waitChan -} - -func TestProcess_StopCmd(t *testing.T) { - conf := getTestSimpleResponderConfig("test_stop_cmd") - - if runtime.GOOS == "windows" { - conf.CmdStop = "taskkill /f /t /pid ${PID}" - } else { - conf.CmdStop = "kill -TERM ${PID}" - } - - process := NewProcess("testStopCmd", 2, conf, debugLogger, debugLogger) - defer process.Stop() - - err := process.start() - assert.Nil(t, err) - assert.Equal(t, process.CurrentState(), StateReady) - process.StopImmediately() - assert.Equal(t, process.CurrentState(), StateStopped) -} - -func TestProcess_EnvironmentSetCorrectly(t *testing.T) { - expectedMessage := "test_env_not_emptied" - conf := getTestSimpleResponderConfig(expectedMessage) - - // ensure that the the default config does not blank out the inherited environment - configWEnv := conf - - // ensure the additiona variables are appended to the process' environment - configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2") - - process1 := NewProcess("env_test", 2, conf, debugLogger, debugLogger) - process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger) - - process1.start() - defer process1.Stop() - process2.start() - defer process2.Stop() - - assert.NotZero(t, len(process1.cmd.Environ())) - assert.NotZero(t, len(process2.cmd.Environ())) - assert.Equal(t, len(process1.cmd.Environ())+2, len(process2.cmd.Environ()), "process2 should have 2 more environment variables than process1") - -} - -// TestProcess_ReverseProxyPanicIsHandled tests that panics from -// httputil.ReverseProxy in Process.ProxyRequest(w, r) do not bubble up and are -// handled appropriately. -// -// httputil.ReverseProxy will panic with http.ErrAbortHandler when it has sent headers -// can't copy the body. This can be caused by a client disconnecting before the full -// response is sent from some reason. -// -// bug: https://github.com/mostlygeek/llama-swap/issues/362 -// see: https://github.com/golang/go/issues/23643 (where panic was added to httputil.ReverseProxy) -func TestProcess_ReverseProxyPanicIsHandled(t *testing.T) { - // Add defer/recover to catch any panics that aren't handled by ProxyRequest - // If this recover() is hit, it means ProxyRequest didn't handle the panic properly - defer func() { - if r := recover(); r != nil { - t.Fatalf("ProxyRequest should handle panics from reverseProxy.ServeHTTP, but panic was not caught: %v", r) - } - }() - - expectedMessage := "panic_test" - config := getTestSimpleResponderConfig(expectedMessage) - - process := NewProcess("panic-test", 5, config, debugLogger, debugLogger) - defer process.Stop() - - // Start the process - err := process.start() - assert.Nil(t, err) - assert.Equal(t, StateReady, process.CurrentState()) - - // Create a custom ResponseWriter that simulates a client disconnect - // by panicking when Write is called after headers are sent - panicWriter := &panicOnWriteResponseWriter{ - ResponseRecorder: httptest.NewRecorder(), - shouldPanic: true, - } - - // Make a request that will trigger the panic - req := httptest.NewRequest("GET", "/slow-respond?echo=test&delay=100ms", nil) - - // This should panic inside reverseProxy.ServeHTTP when the panicWriter.Write() is called. - // ProxyRequest should catch and handle this panic gracefully. - process.ProxyRequest(panicWriter, req) - - // If we get here, the panic was properly recovered in ProxyRequest - // The process should still be in a ready state - assert.Equal(t, StateReady, process.CurrentState()) -} - -// panicOnWriteResponseWriter is a ResponseWriter that panics on Write -// to simulate a client disconnect after headers are sent -// used by: TestProcess_ReverseProxyPanicIsHandled -type panicOnWriteResponseWriter struct { - *httptest.ResponseRecorder - shouldPanic bool - headerWritten bool -} - -func (w *panicOnWriteResponseWriter) WriteHeader(statusCode int) { - w.headerWritten = true - w.ResponseRecorder.WriteHeader(statusCode) -} - -func (w *panicOnWriteResponseWriter) Write(b []byte) (int, error) { - if w.shouldPanic && w.headerWritten { - // Simulate the panic that httputil.ReverseProxy throws - panic(http.ErrAbortHandler) - } - return w.ResponseRecorder.Write(b) -} - -func TestProcess_CustomTimeouts(t *testing.T) { - modelConfig := config.ModelConfig{ - Cmd: "echo test", - Proxy: "http://localhost:8080", - CheckEndpoint: "/health", - Timeouts: config.TimeoutsConfig{ - Connect: 45, - ResponseHeader: 120, - TLSHandshake: 15, - ExpectContinue: 2, - IdleConn: 120, - }, - } - - debugLogger := logmon.NewWriter(io.Discard) - process := NewProcess("test-model", 30, modelConfig, debugLogger, debugLogger) - - // Verify the process was created successfully - assert.NotNil(t, process) - assert.Equal(t, "test-model", process.ID) - assert.NotNil(t, process.reverseProxy) - assert.NotNil(t, process.reverseProxy.Transport) - - // Verify it's using http.Transport (not some other type) - transport, ok := process.reverseProxy.Transport.(*http.Transport) - assert.True(t, ok, "Transport should be *http.Transport") - assert.NotNil(t, transport) - - // Verify the timeouts are correctly applied - assert.Equal(t, 120*time.Second, transport.ResponseHeaderTimeout) - assert.Equal(t, 15*time.Second, transport.TLSHandshakeTimeout) - assert.Equal(t, 2*time.Second, transport.ExpectContinueTimeout) - assert.Equal(t, 120*time.Second, transport.IdleConnTimeout) - assert.True(t, transport.ForceAttemptHTTP2) -} diff --git a/proxy/process_unix.go b/proxy/process_unix.go deleted file mode 100644 index 3e8d5d7..0000000 --- a/proxy/process_unix.go +++ /dev/null @@ -1,12 +0,0 @@ -//go:build !windows - -package proxy - -import ( - "os/exec" -) - -// setProcAttributes sets platform-specific process attributes -func setProcAttributes(cmd *exec.Cmd) { - // No-op on Unix systems -} diff --git a/proxy/process_windows.go b/proxy/process_windows.go deleted file mode 100644 index 28a988b..0000000 --- a/proxy/process_windows.go +++ /dev/null @@ -1,16 +0,0 @@ -//go:build windows - -package proxy - -import ( - "os/exec" - "syscall" -) - -// setProcAttributes sets platform-specific process attributes -func setProcAttributes(cmd *exec.Cmd) { - cmd.SysProcAttr = &syscall.SysProcAttr{ - HideWindow: true, - CreationFlags: 0x08000000, // CREATE_NO_WINDOW - } -} diff --git a/proxy/processgroup.go b/proxy/processgroup.go deleted file mode 100644 index 4ceb9db..0000000 --- a/proxy/processgroup.go +++ /dev/null @@ -1,194 +0,0 @@ -package proxy - -import ( - "fmt" - "net/http" - "slices" - "sync" - - "github.com/mostlygeek/llama-swap/internal/config" - "github.com/mostlygeek/llama-swap/internal/logmon" -) - -type ProcessGroup struct { - sync.Mutex - - config config.Config - id string - swap bool - exclusive bool - persistent bool - - proxyLogger *logmon.Monitor - upstreamLogger *logmon.Monitor - - // map of current processes - processes map[string]*Process - lastUsedProcess string - - // inflight tracks fast-path requests (requests for the already-selected - // model in a swap group). Fast-path requests Add(1) while holding pg.Lock - // and Done() on completion; a concurrent swap request calls inflight.Wait() - // under pg.Lock before stopping the current process. Without this tracking, - // a fast-path request that has released pg.Lock but has not yet called - // Process.inFlightRequests.Add(1) races with Stop()'s Wait() and can be - // killed mid-request. - inflight sync.WaitGroup - - // testDelayFastPath is a test-only hook that, when non-nil, is invoked in - // the fast path after pg.Lock is released but before the request is - // dispatched to Process.ProxyRequest. Tests use it to park a fast-path - // request at the exact race window to deterministically reproduce the - // fast-path vs swap race. - testDelayFastPath func() -} - -func NewProcessGroup(id string, config config.Config, proxyLogger *logmon.Monitor, upstreamLogger *logmon.Monitor) *ProcessGroup { - groupConfig, ok := config.Groups[id] - if !ok { - panic("Unable to find configuration for group id: " + id) - } - - pg := &ProcessGroup{ - id: id, - config: config, - swap: groupConfig.Swap, - exclusive: groupConfig.Exclusive, - persistent: groupConfig.Persistent, - proxyLogger: proxyLogger, - upstreamLogger: upstreamLogger, - processes: make(map[string]*Process), - } - - // Create a Process for each member in the group - for _, modelID := range groupConfig.Members { - modelConfig, modelID, _ := pg.config.FindConfig(modelID) - processLogger := logmon.NewWriter(upstreamLogger) - process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, processLogger, pg.proxyLogger) - pg.processes[modelID] = process - } - - return pg -} - -// ProxyRequest proxies a request to the specified model -func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter, request *http.Request) error { - if !pg.HasMember(modelID) { - return fmt.Errorf("model %s not part of group %s", modelID, pg.id) - } - - if pg.swap { - pg.Lock() - if pg.lastUsedProcess != modelID { - - // Wait for in-flight fast-path requests to drain before stopping - // the previous process. Without this, a fast-path request that has - // released pg.Lock but has not yet incremented - // Process.inFlightRequests races with Stop() and can be killed - // mid-request. - pg.inflight.Wait() - - // is there something already running? - if pg.lastUsedProcess != "" { - pg.processes[pg.lastUsedProcess].Stop() - } - - // wait for the request to the new model to be fully handled - // and prevent race conditions see issue #277 - pg.processes[modelID].ProxyRequest(writer, request) - pg.lastUsedProcess = modelID - - // short circuit and exit - pg.Unlock() - return nil - } - - // Fast path: register this request in inflight before releasing - // pg.Lock so a concurrent swap will wait for it to complete. - pg.inflight.Add(1) - defer pg.inflight.Done() - pg.Unlock() - - if pg.testDelayFastPath != nil { - pg.testDelayFastPath() - } - } - - pg.processes[modelID].ProxyRequest(writer, request) - return nil -} - -func (pg *ProcessGroup) HasMember(modelName string) bool { - return slices.Contains(pg.config.Groups[pg.id].Members, modelName) -} - -func (pg *ProcessGroup) GetMember(modelName string) (*Process, bool) { - if pg.HasMember(modelName) { - return pg.processes[modelName], true - } - return nil, false -} - -func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error { - pg.Lock() - - process, exists := pg.processes[modelID] - if !exists { - pg.Unlock() - return fmt.Errorf("process not found for %s", modelID) - } - - if pg.lastUsedProcess == modelID { - pg.lastUsedProcess = "" - } - pg.Unlock() - - switch strategy { - case StopImmediately: - process.StopImmediately() - default: - process.Stop() - } - return nil -} - -func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) { - pg.Lock() - defer pg.Unlock() - - if strategy != StopImmediately { - pg.inflight.Wait() - } - - if len(pg.processes) == 0 { - return - } - - // stop Processes in parallel - var wg sync.WaitGroup - for _, process := range pg.processes { - wg.Add(1) - go func(process *Process) { - defer wg.Done() - switch strategy { - case StopImmediately: - process.StopImmediately() - default: - process.Stop() - } - }(process) - } - wg.Wait() -} - -func (pg *ProcessGroup) Shutdown() { - var wg sync.WaitGroup - for _, process := range pg.processes { - wg.Add(1) - go func(process *Process) { - defer wg.Done() - process.Shutdown() - }(process) - } - wg.Wait() -} diff --git a/proxy/processgroup_test.go b/proxy/processgroup_test.go deleted file mode 100644 index e1284a9..0000000 --- a/proxy/processgroup_test.go +++ /dev/null @@ -1,345 +0,0 @@ -package proxy - -import ( - "bytes" - "net/http" - "net/http/httptest" - "runtime" - "sync" - "testing" - "time" - - "github.com/mostlygeek/llama-swap/internal/config" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - "model2": getTestSimpleResponderConfig("model2"), - "model3": getTestSimpleResponderConfig("model3"), - "model4": getTestSimpleResponderConfig("model4"), - "model5": getTestSimpleResponderConfig("model5"), - }, - Groups: map[string]config.GroupConfig{ - "G1": { - Swap: true, - Exclusive: true, - Members: []string{"model1", "model2"}, - }, - "G2": { - Swap: false, - Exclusive: true, - Members: []string{"model3", "model4"}, - }, - }, -}) - -func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) { - pg := NewProcessGroup(config.DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger) - assert.True(t, pg.HasMember("model5")) -} - -func TestProcessGroup_HasMember(t *testing.T) { - pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger) - assert.True(t, pg.HasMember("model1")) - assert.True(t, pg.HasMember("model2")) - assert.False(t, pg.HasMember("model3")) -} - -// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true -// and multiple requests are made in parallel, only one process is running at a time. -func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) { - if testing.Short() { - t.Skip("skipping slow test") - } - - var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - // use the same listening so if a model is already running, it will fail - // this is a way to test that swap isolation is working - // properly when there are parallel requests made at the - // same time. - "model1": getTestSimpleResponderConfigPort("model1", 9832), - "model2": getTestSimpleResponderConfigPort("model2", 9832), - "model3": getTestSimpleResponderConfigPort("model3", 9832), - "model4": getTestSimpleResponderConfigPort("model4", 9832), - "model5": getTestSimpleResponderConfigPort("model5", 9832), - }, - Groups: map[string]config.GroupConfig{ - "G1": { - Swap: true, - Members: []string{"model1", "model2", "model3", "model4", "model5"}, - }, - }, - }) - - pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger) - defer pg.StopProcesses(StopWaitForInflightRequest) - - tests := []string{"model1", "model2", "model3", "model4", "model5"} - - var wg sync.WaitGroup - - wg.Add(len(tests)) - for _, modelName := range tests { - go func(modelName string) { - defer wg.Done() - req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - w := httptest.NewRecorder() - assert.NoError(t, pg.ProxyRequest(modelName, w, req)) - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), modelName) - }(modelName) - } - wg.Wait() -} - -// TestProcessGroup_ProxyRequestSwapRaceAgainstFastPath verifies that a swap -// request cannot stop the current process while a fast-path request (for the -// already-selected model) is in flight. Without ProcessGroup-level inflight -// tracking, a fast-path request that has released pg.Lock but has not yet -// incremented Process.inFlightRequests races with Stop()'s Wait() and the -// process is killed mid-request. -func TestProcessGroup_ProxyRequestSwapRaceAgainstFastPath(t *testing.T) { - cfg := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - "model2": getTestSimpleResponderConfig("model2"), - }, - Groups: map[string]config.GroupConfig{ - "G1": { - Swap: true, - Members: []string{"model1", "model2"}, - }, - }, - }) - - pg := NewProcessGroup("G1", cfg, testLogger, testLogger) - defer pg.StopProcesses(StopImmediately) - - // Bypass real subprocesses so the test is fast and deterministic. - pg.processes["model1"].testHandler = newTestHandler("model1") - pg.processes["model2"].testHandler = newTestHandler("model2") - - // Prime: run a request through model1 via the swap path so that - // lastUsedProcess == "model1" and subsequent model1 requests take the - // fast path. - primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil) - primeW := httptest.NewRecorder() - require.NoError(t, pg.ProxyRequest("model1", primeW, primeReq)) - require.Equal(t, http.StatusOK, primeW.Code) - require.Equal(t, StateReady, pg.processes["model1"].CurrentState()) - require.Equal(t, StateStopped, pg.processes["model2"].CurrentState()) - - // Fast-path hook: signal arrival at the race window, then wait for - // release. This parks R2 deterministically at the point where pg.Lock - // has been released but Process.inFlightRequests has not yet been - // incremented — the exact window the race exploits. - r2Reached := make(chan struct{}) - r2Release := make(chan struct{}) - pg.testDelayFastPath = func() { - close(r2Reached) - <-r2Release - } - - // R2: fast-path request for model1. Will pause at the test hook. - r2Done := make(chan struct{}) - w2 := httptest.NewRecorder() - go func() { - defer close(r2Done) - req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - assert.NoError(t, pg.ProxyRequest("model1", w2, req)) - }() - - // Deterministically wait for R2 to reach the race window. - <-r2Reached - - // R3: swap request for model2. Must wait for R2 to finish before touching - // model1, otherwise model1 gets killed mid-request. - r3Done := make(chan struct{}) - w3 := httptest.NewRecorder() - go func() { - defer close(r3Done) - req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - assert.NoError(t, pg.ProxyRequest("model2", w3, req)) - }() - - // Spin until R3 has acquired pg.Lock and entered the swap critical - // section. In the fixed code, R3 then blocks on pg.inflight.Wait() while - // still holding the lock, so TryLock keeps failing. - for pg.TryLock() { - pg.Unlock() - runtime.Gosched() - } - - // Bounded poll: give R3 a chance to demonstrate the bug by mutating - // state. In the fixed code, R3 is blocked on pg.inflight.Wait() and - // nothing changes, so we wait the full window. In the buggy code, R3 - // will Stop() model1 and start serving via model2 within microseconds — - // we exit early once the mutation is observable. - deadline := time.Now().Add(100 * time.Millisecond) - for time.Now().Before(deadline) { - if pg.processes["model1"].CurrentState() != StateReady || - pg.processes["model2"].CurrentState() != StateStopped { - break - } - done := false - select { - case <-r3Done: - done = true - default: - } - if done { - break - } - runtime.Gosched() - } - - // Invariant: R3 must be blocked while R2 is still in flight. - select { - case <-r3Done: - t.Fatal("swap completed while fast-path request was still in flight — race not prevented") - default: - } - assert.Equal(t, StateReady, pg.processes["model1"].CurrentState(), - "model1 must stay Ready while a fast-path request is in flight") - assert.Equal(t, StateStopped, pg.processes["model2"].CurrentState(), - "model2 must not be started until R2 finishes and model1 is swapped out") - - // Release R2 and let both requests finish. - close(r2Release) - <-r2Done - <-r3Done - - assert.Equal(t, http.StatusOK, w2.Code) - assert.Contains(t, w2.Body.String(), "model1") - assert.Equal(t, http.StatusOK, w3.Code) - assert.Contains(t, w3.Body.String(), "model2") -} - -// TestProcessGroup_StopProcessesWaitsForInflight verifies that StopProcesses -// (called externally, e.g. from ProxyManager.swapProcessGroup) cannot stop a -// process while a fast-path ProxyRequest is in the [pg.Unlock, -// Process.inFlightRequests.Add(1)] window. Without pg.inflight.Wait() in -// StopProcesses, the external caller bypasses the inflight guard and kills the -// process mid-request. -func TestProcessGroup_StopProcessesWaitsForInflight(t *testing.T) { - cfg := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - "model2": getTestSimpleResponderConfig("model2"), - }, - Groups: map[string]config.GroupConfig{ - "G1": { - Swap: true, - Members: []string{"model1", "model2"}, - }, - }, - }) - - pg := NewProcessGroup("G1", cfg, testLogger, testLogger) - defer pg.StopProcesses(StopImmediately) - - pg.processes["model1"].testHandler = newTestHandler("model1") - pg.processes["model2"].testHandler = newTestHandler("model2") - - // Prime: model1 is active so subsequent model1 requests take the fast path. - primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil) - primeW := httptest.NewRecorder() - require.NoError(t, pg.ProxyRequest("model1", primeW, primeReq)) - require.Equal(t, http.StatusOK, primeW.Code) - require.Equal(t, StateReady, pg.processes["model1"].CurrentState()) - - // Park a fast-path request at the race window. - r2Reached := make(chan struct{}) - r2Release := make(chan struct{}) - pg.testDelayFastPath = func() { - close(r2Reached) - <-r2Release - } - - r2Done := make(chan struct{}) - w2 := httptest.NewRecorder() - go func() { - defer close(r2Done) - req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - assert.NoError(t, pg.ProxyRequest("model1", w2, req)) - }() - - <-r2Reached - - // Simulate an external caller (e.g. ProxyManager.swapProcessGroup) stopping - // the group while a fast-path request is in flight. - r3Done := make(chan struct{}) - go func() { - defer close(r3Done) - pg.StopProcesses(StopWaitForInflightRequest) - }() - - // Spin until StopProcesses has acquired pg.Lock. - for pg.TryLock() { - pg.Unlock() - runtime.Gosched() - } - - // Bounded poll: in the fixed code StopProcesses blocks on pg.inflight.Wait() - // and model1 stays Ready. In the buggy code it proceeds immediately and - // kills model1. - deadline := time.Now().Add(100 * time.Millisecond) - for time.Now().Before(deadline) { - if pg.processes["model1"].CurrentState() != StateReady { - break - } - select { - case <-r3Done: - goto done - default: - } - runtime.Gosched() - } -done: - - select { - case <-r3Done: - t.Fatal("StopProcesses completed while a fast-path request was still in flight — race not prevented") - default: - } - assert.Equal(t, StateReady, pg.processes["model1"].CurrentState(), - "model1 must stay Ready while a fast-path request is in flight") - - close(r2Release) - <-r2Done - <-r3Done - - assert.Equal(t, http.StatusOK, w2.Code) - assert.Contains(t, w2.Body.String(), "model1") -} - -func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) { - pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger) - defer pg.StopProcesses(StopWaitForInflightRequest) - - tests := []string{"model3", "model4"} - - for _, modelName := range tests { - t.Run(modelName, func(t *testing.T) { - reqBody := `{"x", "y"}` - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := httptest.NewRecorder() - assert.NoError(t, pg.ProxyRequest(modelName, w, req)) - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), modelName) - }) - } - - // make sure all the processes are running - for _, process := range pg.processes { - assert.Equal(t, StateReady, process.CurrentState()) - } -} diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go deleted file mode 100644 index a06ce5f..0000000 --- a/proxy/proxymanager.go +++ /dev/null @@ -1,1232 +0,0 @@ -package proxy - -import ( - "bytes" - "context" - "encoding/base64" - "fmt" - "io" - "mime/multipart" - "net/http" - "os" - "sort" - "strconv" - "strings" - "sync" - "time" - - "github.com/gin-gonic/gin" - "github.com/mostlygeek/llama-swap/internal/config" - "github.com/mostlygeek/llama-swap/internal/event" - "github.com/mostlygeek/llama-swap/internal/logmon" - "github.com/mostlygeek/llama-swap/internal/perf" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" -) - -const ( - PROFILE_SPLIT_CHAR = ":" -) - -type proxyCtxKey string - -type InflightCounter struct { - mu sync.Mutex - total int -} - -func newInflightCounter() *InflightCounter { - return &InflightCounter{} -} - -func (ic *InflightCounter) Current() int { - ic.mu.Lock() - total := ic.total - ic.mu.Unlock() - return total -} - -func (ic *InflightCounter) Increment() int { - ic.mu.Lock() - ic.total++ - total := ic.total - ic.mu.Unlock() - return total -} - -func (ic *InflightCounter) Decrement() int { - ic.mu.Lock() - if ic.total > 0 { - ic.total-- - } - total := ic.total - ic.mu.Unlock() - return total -} - -type ProxyManager struct { - sync.Mutex - - config config.Config - ginEngine *gin.Engine - - // logging - proxyLogger *logmon.Monitor - upstreamLogger *logmon.Monitor - muxLogger *logmon.Monitor - - metricsMonitor *metricsMonitor - perfMonitor *perf.Monitor - - processGroups map[string]*ProcessGroup - - // matrix-based swap (mutually exclusive with processGroups) - matrix *Matrix - - inFlightCounter *InflightCounter - - // shutdown signaling - shutdownCtx context.Context - shutdownCancel context.CancelFunc - - // version info - buildDate string - commit string - version string - - // peer proxy see: #296, #433 - peerProxy *PeerProxy -} - -func New(proxyConfig config.Config) *ProxyManager { - // set up loggers - - var muxLogger, upstreamLogger, proxyLogger *logmon.Monitor - switch proxyConfig.LogToStdout { - case config.LogToStdoutNone: - muxLogger = logmon.NewWriter(io.Discard) - upstreamLogger = logmon.NewWriter(io.Discard) - proxyLogger = logmon.NewWriter(io.Discard) - case config.LogToStdoutBoth: - muxLogger = logmon.NewWriter(os.Stdout) - upstreamLogger = logmon.NewWriter(muxLogger) - proxyLogger = logmon.NewWriter(muxLogger) - case config.LogToStdoutUpstream: - muxLogger = logmon.NewWriter(os.Stdout) - upstreamLogger = logmon.NewWriter(muxLogger) - proxyLogger = logmon.NewWriter(io.Discard) - default: - // same as config.LogToStdoutProxy - // helpful because some old tests create a config.Config directly and it - // may not have LogToStdout set explicitly - muxLogger = logmon.NewWriter(os.Stdout) - upstreamLogger = logmon.NewWriter(io.Discard) - proxyLogger = logmon.NewWriter(muxLogger) - } - - if proxyConfig.LogRequests { - proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.") - } - - switch strings.ToLower(strings.TrimSpace(proxyConfig.LogLevel)) { - case "debug": - proxyLogger.SetLogLevel(logmon.LevelDebug) - upstreamLogger.SetLogLevel(logmon.LevelDebug) - case "info": - proxyLogger.SetLogLevel(logmon.LevelInfo) - upstreamLogger.SetLogLevel(logmon.LevelInfo) - case "warn": - proxyLogger.SetLogLevel(logmon.LevelWarn) - upstreamLogger.SetLogLevel(logmon.LevelWarn) - case "error": - proxyLogger.SetLogLevel(logmon.LevelError) - upstreamLogger.SetLogLevel(logmon.LevelError) - default: - proxyLogger.SetLogLevel(logmon.LevelInfo) - upstreamLogger.SetLogLevel(logmon.LevelInfo) - } - - // see: https://go.dev/src/time/format.go - timeFormats := map[string]string{ - "ansic": time.ANSIC, - "unixdate": time.UnixDate, - "rubydate": time.RubyDate, - "rfc822": time.RFC822, - "rfc822z": time.RFC822Z, - "rfc850": time.RFC850, - "rfc1123": time.RFC1123, - "rfc1123z": time.RFC1123Z, - "rfc3339": time.RFC3339, - "rfc3339nano": time.RFC3339Nano, - "kitchen": time.Kitchen, - "stamp": time.Stamp, - "stampmilli": time.StampMilli, - "stampmicro": time.StampMicro, - "stampnano": time.StampNano, - } - - if timeFormat, ok := timeFormats[strings.ToLower(strings.TrimSpace(proxyConfig.LogTimeFormat))]; ok { - proxyLogger.SetLogTimeFormat(timeFormat) - upstreamLogger.SetLogTimeFormat(timeFormat) - } - - shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) - - var maxMetrics int - if proxyConfig.MetricsMaxInMemory <= 0 { - maxMetrics = 1000 // Default fallback - } else { - maxMetrics = proxyConfig.MetricsMaxInMemory - } - - peerProxy, err := NewPeerProxy(proxyConfig.Peers, proxyLogger) - if err != nil { - proxyLogger.Errorf("Disabling Peering. Failed to create proxy peers: %v", err) - peerProxy = nil - } - - pm := &ProxyManager{ - config: proxyConfig, - ginEngine: gin.New(), - - proxyLogger: proxyLogger, - muxLogger: muxLogger, - upstreamLogger: upstreamLogger, - - metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics, proxyConfig.CaptureBuffer), - - processGroups: make(map[string]*ProcessGroup), - - inFlightCounter: newInflightCounter(), - - shutdownCtx: shutdownCtx, - shutdownCancel: shutdownCancel, - - buildDate: "unknown", - commit: "abcd1234", - version: "0", - - peerProxy: peerProxy, - } - - // create either matrix or process groups (mutually exclusive) - if proxyConfig.Matrix != nil { - pm.matrix = NewMatrix(proxyConfig, proxyLogger, upstreamLogger) - } else { - for groupID := range proxyConfig.Groups { - processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger) - pm.processGroups[groupID] = processGroup - } - } - - pm.setupGinEngine() - - // run any startup hooks - if len(proxyConfig.Hooks.OnStartup.Preload) > 0 { - // do it in the background, don't block startup -- not sure if good idea yet - go func() { - discardWriter := &DiscardWriter{} - for _, preloadModelName := range proxyConfig.Hooks.OnStartup.Preload { - modelID, ok := proxyConfig.RealModelName(preloadModelName) - - if !ok { - proxyLogger.Warnf("Preload model %s not found in config", preloadModelName) - continue - } - - proxyLogger.Infof("Preloading model: %s", modelID) - - var preloadErr error - req, _ := http.NewRequest("GET", "/", nil) - - if pm.matrix != nil { - preloadErr = pm.matrix.ProxyRequest(modelID, discardWriter, req) - } else { - processGroup, err := pm.swapProcessGroup(modelID) - if err != nil { - preloadErr = err - } else { - preloadErr = processGroup.ProxyRequest(modelID, discardWriter, req) - } - } - - if preloadErr != nil { - event.Emit(ModelPreloadedEvent{ - ModelName: modelID, - Success: false, - }) - proxyLogger.Errorf("Failed to preload model %s: %v", modelID, preloadErr) - continue - } else { - event.Emit(ModelPreloadedEvent{ - ModelName: modelID, - Success: true, - }) - } - } - }() - } - - return pm -} - -func (pm *ProxyManager) setupGinEngine() { - - pm.ginEngine.Use(func(c *gin.Context) { - - for _, prefix := range []string{ - "/wol-health", - "/api/performance", - "/metrics", - } { - if strings.HasPrefix(c.Request.URL.Path, prefix) { - c.Next() - return - } - } - - start := time.Now() - - // capture these because /upstream/:model rewrites them in c.Next() - clientIP := c.ClientIP() - method := c.Request.Method - path := c.Request.URL.Path - - c.Next() - - duration := time.Since(start) - statusCode := c.Writer.Status() - bodySize := c.Writer.Size() - - pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v", - clientIP, - method, - path, - c.Request.Proto, - statusCode, - bodySize, - c.Request.UserAgent(), - duration, - ) - }) - - // see: issue: #81, #77 and #42 for CORS issues - // respond with permissive OPTIONS for any endpoint - pm.ginEngine.Use(func(c *gin.Context) { - if c.Request.Method == "OPTIONS" { - c.Header("Access-Control-Allow-Origin", "*") - c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") - - // allow whatever the client requested by default - if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" { - sanitized := SanitizeAccessControlRequestHeaderValues(headers) - c.Header("Access-Control-Allow-Headers", sanitized) - } else { - c.Header( - "Access-Control-Allow-Headers", - "Content-Type, Authorization, Accept, X-Requested-With", - ) - } - c.Header("Access-Control-Max-Age", "86400") - c.AbortWithStatus(http.StatusNoContent) - return - } - c.Next() - }) - - // Set up routes using the Gin engine - // Protected routes use pm.apiKeyAuth() middleware - llmHandler := pm.mkProxyJSONHandler(captureAll) - pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) - pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) - // Support legacy /v1/completions api, see issue #12 - pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) - // Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570) - pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) - // Support anthropic count_tokens API (Also added in the above PR) - pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) - - // Support embeddings and reranking - pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) - - // llama-server's /reranking endpoint + aliases - pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) - pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) - pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) - pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) - - // Unversioned API endpoints, see issue #728 - pm.ginEngine.POST("/v/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) - pm.ginEngine.POST("/v/responses", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) - pm.ginEngine.POST("/v/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) - pm.ginEngine.POST("/v/messages", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) - pm.ginEngine.POST("/v/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) - pm.ginEngine.POST("/v/embeddings", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) - pm.ginEngine.POST("/v/rerank", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) - pm.ginEngine.POST("/v/reranking", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) - - // llama-server's /infill endpoint for code infilling - pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) - - // llama-server's /completion endpoint - pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) - - // Support audio/speech endpoint - pm.ginEngine.POST( - "/v1/audio/speech", - pm.apiKeyAuth(), - pm.trackInflight(), - pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders), - ) - pm.ginEngine.POST( - "/v1/audio/voices", - pm.apiKeyAuth(), - pm.trackInflight(), - pm.mkProxyJSONHandler(captureReqHeaders|captureRespAll), - ) - pm.ginEngine.GET("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler) - - pm.ginEngine.POST( - "/v1/audio/transcriptions", - pm.apiKeyAuth(), - pm.trackInflight(), - pm.mkPostFormHandler(captureReqHeaders|captureRespHeaders|captureRespBody), - ) - pm.ginEngine.POST( - "/v1/images/generations", - pm.apiKeyAuth(), - pm.trackInflight(), - pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders), - ) - - pm.ginEngine.POST( - "/v1/images/edits", - pm.apiKeyAuth(), - pm.trackInflight(), - pm.mkPostFormHandler(captureReqHeaders|captureRespHeaders), - ) - - // sd.cpp /sdapi/v1 endpoints - pm.ginEngine.POST("/sdapi/v1/txt2img", - pm.apiKeyAuth(), - pm.trackInflight(), - pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders), - ) - pm.ginEngine.POST("/sdapi/v1/img2img", - pm.apiKeyAuth(), - pm.trackInflight(), - pm.mkProxyJSONHandler(captureReqHeaders|captureRespHeaders), - ) - pm.ginEngine.GET("/sdapi/v1/loras", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler) - - pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler) - - // in proxymanager_loghandlers.go - pm.ginEngine.GET("/logs", pm.apiKeyAuth(), pm.sendLogsHandlers) - pm.ginEngine.GET("/logs/stream", pm.apiKeyAuth(), pm.streamLogsHandler) - pm.ginEngine.GET("/logs/stream/*logMonitorID", pm.apiKeyAuth(), pm.streamLogsHandler) - - /** - * User Interface Endpoints - */ - pm.ginEngine.GET("/", func(c *gin.Context) { - c.Redirect(http.StatusFound, "/ui") - }) - - pm.ginEngine.GET("/upstream", func(c *gin.Context) { - c.Redirect(http.StatusFound, "/ui/models") - }) - pm.ginEngine.Any("/upstream/*upstreamPath", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyToUpstream) - pm.ginEngine.GET("/unload", pm.apiKeyAuth(), pm.unloadAllModelsHandler) - pm.ginEngine.GET("/running", pm.apiKeyAuth(), pm.listRunningProcessesHandler) - pm.ginEngine.GET("/health", func(c *gin.Context) { - c.String(http.StatusOK, "OK") - }) - - pm.ginEngine.GET("/metrics", pm.prometheusMetricsHandler) - - // see cmd/wol-proxy/wol-proxy.go, not logged - pm.ginEngine.GET("/wol-health", func(c *gin.Context) { - c.String(http.StatusOK, "OK") - }) - - pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) { - if data, err := reactStaticFS.ReadFile("ui_dist/favicon.ico"); err == nil { - c.Data(http.StatusOK, "image/x-icon", data) - } else { - c.String(http.StatusInternalServerError, err.Error()) - } - }) - - reactFS, err := GetReactFS() - if err != nil { - pm.proxyLogger.Errorf("Failed to load React filesystem: %v", err) - } else { - // Serve files with compression support under /ui/* - // This handler checks for pre-compressed .br and .gz files - pm.ginEngine.GET("/ui/*filepath", func(c *gin.Context) { - filepath := strings.TrimPrefix(c.Param("filepath"), "/") - // Default to index.html for directory-like paths - if filepath == "" { - filepath = "index.html" - } - - ServeCompressedFile(reactFS, c.Writer, c.Request, filepath) - }) - - // Serve SPA for UI under /ui/* - fallback to index.html for client-side routing - pm.ginEngine.NoRoute(func(c *gin.Context) { - if !strings.HasPrefix(c.Request.URL.Path, "/ui") { - c.AbortWithStatus(http.StatusNotFound) - return - } - - // Check if this looks like a file request (has extension) - path := c.Request.URL.Path - if strings.Contains(path, ".") && !strings.HasSuffix(path, "/") { - // This was likely a file request that wasn't found - c.AbortWithStatus(http.StatusNotFound) - return - } - - // Serve index.html for SPA routing - ServeCompressedFile(reactFS, c.Writer, c.Request, "index.html") - }) - } - - // see: proxymanager_api.go - // add API handler functions - addApiHandlers(pm) - - // Disable console color for testing - gin.DisableConsoleColor() -} - -func (pm *ProxyManager) trackInflight() gin.HandlerFunc { - return func(c *gin.Context) { - event.Emit(InFlightRequestsEvent{Total: pm.inFlightCounter.Increment()}) - defer event.Emit(InFlightRequestsEvent{Total: pm.inFlightCounter.Decrement()}) - c.Next() - } -} - -// ServeHTTP implements http.Handler interface -func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) { - pm.ginEngine.ServeHTTP(w, r) -} - -// StopProcesses acquires a lock and stops all running upstream processes. -// This is the public method safe for concurrent calls. -// Unlike Shutdown, this method only stops the processes but doesn't perform -// a complete shutdown, allowing for process replacement without full termination. -func (pm *ProxyManager) StopProcesses(strategy StopStrategy) { - pm.Lock() - defer pm.Unlock() - - if pm.matrix != nil { - pm.matrix.StopProcesses(strategy) - return - } - - // stop Processes in parallel - var wg sync.WaitGroup - for _, processGroup := range pm.processGroups { - wg.Add(1) - go func(processGroup *ProcessGroup) { - defer wg.Done() - processGroup.StopProcesses(strategy) - }(processGroup) - } - - wg.Wait() -} - -// Shutdown stops all processes managed by this ProxyManager -func (pm *ProxyManager) Shutdown() { - pm.Lock() - defer pm.Unlock() - - pm.proxyLogger.Debug("Shutdown() called in proxy manager") - - if pm.matrix != nil { - pm.matrix.Shutdown() - pm.shutdownCancel() - return - } - - var wg sync.WaitGroup - // Send shutdown signal to all process in groups - for _, processGroup := range pm.processGroups { - wg.Add(1) - go func(processGroup *ProcessGroup) { - defer wg.Done() - processGroup.Shutdown() - }(processGroup) - } - wg.Wait() - pm.shutdownCancel() -} - -func (pm *ProxyManager) swapProcessGroup(realModelName string) (*ProcessGroup, error) { - processGroup := pm.findGroupByModelName(realModelName) - if processGroup == nil { - return nil, fmt.Errorf("could not find process group for model %s", realModelName) - } - - if processGroup.exclusive { - pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id) - for groupId, otherGroup := range pm.processGroups { - if groupId != processGroup.id && !otherGroup.persistent { - otherGroup.StopProcesses(StopWaitForInflightRequest) - } - } - } - - return processGroup, nil -} - -func (pm *ProxyManager) listModelsHandler(c *gin.Context) { - data := make([]gin.H, 0, len(pm.config.Models)) - createdTime := time.Now().Unix() - - newRecord := func(modelId string, modelConfig config.ModelConfig) gin.H { - record := gin.H{ - "id": modelId, - "object": "model", - "created": createdTime, - "owned_by": "llama-swap", - } - - if name := strings.TrimSpace(modelConfig.Name); name != "" { - record["name"] = name - } - if desc := strings.TrimSpace(modelConfig.Description); desc != "" { - record["description"] = desc - } - - // Add metadata if present - if len(modelConfig.Metadata) > 0 { - record["meta"] = gin.H{ - "llamaswap": modelConfig.Metadata, - } - } - return record - } - - for id, modelConfig := range pm.config.Models { - if modelConfig.Unlisted { - continue - } - - data = append(data, newRecord(id, modelConfig)) - - // Include aliases - if pm.config.IncludeAliasesInList { - for _, alias := range modelConfig.Aliases { - if alias := strings.TrimSpace(alias); alias != "" { - data = append(data, newRecord(alias, modelConfig)) - } - } - } - } - - if pm.peerProxy != nil { - for peerID, peer := range pm.peerProxy.ListPeers() { - // add peer models - for _, modelID := range peer.Models { - // Skip unlisted models if not showing them - record := newRecord(modelID, config.ModelConfig{ - Name: fmt.Sprintf("%s: %s", peerID, modelID), - Metadata: map[string]any{ - "peerID": peerID, - }, - }) - - data = append(data, record) - } - } - } - - // Sort by the "id" key - sort.Slice(data, func(i, j int) bool { - si, _ := data[i]["id"].(string) - sj, _ := data[j]["id"].(string) - return si < sj - }) - - // Set CORS headers if origin exists - if origin := c.GetHeader("Origin"); origin != "" { - c.Header("Access-Control-Allow-Origin", origin) - } - - // Use gin's JSON method which handles content-type and encoding - c.JSON(http.StatusOK, gin.H{ - "object": "list", - "data": data, - }) -} - -// findModelInPath searches for a valid model name in a path with slashes. -// It iteratively builds up path segments until it finds a matching model. -// Returns: (searchModelName, realModelName, remainingPath, found) -// Example: "/author/model/endpoint" with model "author/model" -> ("author/model", "author/model", "/endpoint", true) -func (pm *ProxyManager) findModelInPath(path string) (searchName string, realName string, remainingPath string, found bool) { - parts := strings.Split(strings.TrimSpace(path), "/") - searchModelName := "" - - for i, part := range parts { - if part == "" { - continue - } - - if searchModelName == "" { - searchModelName = part - } else { - searchModelName = searchModelName + "/" + part - } - - if modelID, ok := pm.config.RealModelName(searchModelName); ok { - return searchModelName, modelID, "/" + strings.Join(parts[i+1:], "/"), true - } - } - - return "", "", "", false -} - -func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { - upstreamPath := c.Param("upstreamPath") - - searchModelName, modelID, remainingPath, modelFound := pm.findModelInPath(upstreamPath) - - if !modelFound { - pm.sendErrorResponse(c, http.StatusNotFound, "model not found") - return - } - - // Redirect /upstream/modelname to /upstream/modelname/ for URL consistency. - // This ensures relative URLs in upstream responses resolve correctly and - // provides canonical URL form. Uses 308 for POST/PUT/etc to preserve the - // HTTP method (301 would downgrade to GET). - if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") { - newPath := "/upstream/" + searchModelName + "/" - if c.Request.URL.RawQuery != "" { - newPath += "?" + c.Request.URL.RawQuery - } - if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead { - c.Redirect(http.StatusMovedPermanently, newPath) - } else { - c.Redirect(http.StatusPermanentRedirect, newPath) - } - return - } - - var handler func(string, http.ResponseWriter, *http.Request) error - if pm.matrix != nil { - handler = pm.matrix.ProxyRequest - } else { - processGroup, err := pm.swapProcessGroup(modelID) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) - return - } - handler = processGroup.ProxyRequest - } - - // rewrite the path - originalPath := c.Request.URL.Path - c.Request.URL.Path = remainingPath - - // attempt to record metrics if it is a POST request - if pm.metricsMonitor != nil && c.Request.Method == "POST" { - if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, captureNone, handler); err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error())) - pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath) - return - } - } else { - if err := handler(modelID, c.Writer, c.Request); err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) - pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath) - return - } - } -} - -func (pm *ProxyManager) mkProxyJSONHandler(cf captureFields) func(*gin.Context) { - return func(c *gin.Context) { - bodyBytes, err := io.ReadAll(c.Request.Body) - if err != nil { - pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body") - return - } - - requestedModel := gjson.GetBytes(bodyBytes, "model").String() - if requestedModel == "" { - pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key") - return - } - - // Look for a matching local model first - var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error - - modelID, found := pm.config.RealModelName(requestedModel) - if found { - var localHandler func(string, http.ResponseWriter, *http.Request) error - if pm.matrix != nil { - localHandler = pm.matrix.ProxyRequest - } else { - processGroup, err := pm.swapProcessGroup(modelID) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) - return - } - localHandler = processGroup.ProxyRequest - } - - // issue #69 allow custom model names to be sent to upstream - useModelName := pm.config.Models[modelID].UseModelName - if useModelName != "" { - bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error())) - return - } - } - - // issue #174 strip parameters from the JSON body - stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams() - if err != nil { // just log it and continue - pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error()) - } else { - for _, param := range stripParams { - pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param) - bodyBytes, err = sjson.DeleteBytes(bodyBytes, param) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param)) - return - } - } - } - - // issue #453 set/override parameters in the JSON body - setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams() - for _, key := range setParamKeys { - pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key) - bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key]) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key)) - return - } - } - - // setParamsByID: set params based on the requested model ID (runs after setParams, can override it) - setParamsByIDParams, setParamsByIDKeys := pm.config.Models[modelID].Filters.SanitizedSetParamsByID(requestedModel) - for _, key := range setParamsByIDKeys { - pm.proxyLogger.Debugf("<%s> setting param by id: %s", requestedModel, key) - bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParamsByIDParams[key]) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key)) - return - } - } - - pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel) - nextHandler = localHandler - } else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) { - pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel) - modelID = requestedModel - - // issue #453 apply filters for peer requests - peerFilters := pm.peerProxy.GetPeerFilters(requestedModel) - - // Apply stripParams - remove specified parameters from request - stripParams := peerFilters.SanitizedStripParams() - for _, param := range stripParams { - pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param) - bodyBytes, err = sjson.DeleteBytes(bodyBytes, param) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param)) - return - } - } - - // Apply setParams - set/override specified parameters in request - setParams, setParamKeys := peerFilters.SanitizedSetParams() - for _, key := range setParamKeys { - pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key) - bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key]) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key)) - return - } - } - - nextHandler = pm.peerProxy.ProxyRequest - } - - if nextHandler == nil { - pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel)) - return - } - - c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - - // dechunk it as we already have all the body bytes see issue #11 - c.Request.Header.Del("transfer-encoding") - c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes))) - c.Request.ContentLength = int64(len(bodyBytes)) - - // issue #728 support versionless API requests - if strings.HasPrefix(c.Request.URL.Path, "/v/") { - c.Request.URL.Path = strings.TrimPrefix(c.Request.URL.Path, "/v") - } - - // issue #366 extract values that downstream handlers may need - isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool() - ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming) - ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID) - c.Request = c.Request.WithContext(ctx) - - if pm.metricsMonitor != nil && c.Request.Method == "POST" { - if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, cf, nextHandler); err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error())) - pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID) - return - } - } else { - if err := nextHandler(modelID, c.Writer, c.Request); err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) - pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID) - return - } - } - } -} - -// mkPostFormHandler creates a POST form handler for inference backends -// with a custom captureFields to filter out large binary requests or responses. -func (pm *ProxyManager) mkPostFormHandler(cf captureFields) func(*gin.Context) { - return func(c *gin.Context) { - // Parse multipart form - if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk - pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error())) - return - } - - // Get model parameter from the form - requestedModel := c.Request.FormValue("model") - if requestedModel == "" { - pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data") - return - } - - // Look for a matching local model first, then check peers - var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error - var useModelName string - - modelID, found := pm.config.RealModelName(requestedModel) - if found { - if pm.matrix != nil { - nextHandler = pm.matrix.ProxyRequest - } else { - processGroup, err := pm.swapProcessGroup(modelID) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) - return - } - nextHandler = processGroup.ProxyRequest - } - - useModelName = pm.config.Models[modelID].UseModelName - pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel) - } else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) { - pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel) - modelID = requestedModel - nextHandler = pm.peerProxy.ProxyRequest - } - - if nextHandler == nil { - pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel)) - return - } - - // We need to reconstruct the multipart form in any case since the body is consumed - // Create a new buffer for the reconstructed request - var requestBuffer bytes.Buffer - multipartWriter := multipart.NewWriter(&requestBuffer) - - // Copy all form values - for key, values := range c.Request.MultipartForm.Value { - for _, value := range values { - fieldValue := value - // If this is the model field and we have a profile, use just the model name - if key == "model" { - // # issue #69 allow custom model names to be sent to upstream - if useModelName != "" { - fieldValue = useModelName - } else { - fieldValue = requestedModel - } - } - field, err := multipartWriter.CreateFormField(key) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field") - return - } - if _, err = field.Write([]byte(fieldValue)); err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field") - return - } - } - } - - // Copy all files from the original request - for key, fileHeaders := range c.Request.MultipartForm.File { - for _, fileHeader := range fileHeaders { - formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file") - return - } - - file, err := fileHeader.Open() - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file") - return - } - - if _, err = io.Copy(formFile, file); err != nil { - file.Close() - pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data") - return - } - file.Close() - } - } - - // Close the multipart writer to finalize the form - if err := multipartWriter.Close(); err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form") - return - } - - // Create a new request with the reconstructed form data - modifiedReq, err := http.NewRequestWithContext( - c.Request.Context(), - c.Request.Method, - c.Request.URL.String(), - &requestBuffer, - ) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request") - return - } - - // Copy the headers from the original request - modifiedReq.Header = c.Request.Header.Clone() - modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType()) - - // set the content length of the body - modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len())) - modifiedReq.ContentLength = int64(requestBuffer.Len()) - - // Use the modified request for proxying - if pm.metricsMonitor != nil { - if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, modifiedReq, cf, nextHandler); err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) - pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID) - return - } - } else { - if err := nextHandler(modelID, c.Writer, modifiedReq); err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) - pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID) - return - } - } - } -} - -func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) { - requestedModel := c.Query("model") - if requestedModel == "" { - pm.sendErrorResponse(c, http.StatusBadRequest, "missing required 'model' query parameter") - return - } - - var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error - var modelID string - - if realModelID, found := pm.config.RealModelName(requestedModel); found { - modelID = realModelID - if pm.matrix != nil { - nextHandler = pm.matrix.ProxyRequest - } else { - processGroup, err := pm.swapProcessGroup(realModelID) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) - return - } - nextHandler = processGroup.ProxyRequest - } - pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel) - } else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) { - modelID = requestedModel - pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel) - nextHandler = pm.peerProxy.ProxyRequest - } - - if nextHandler == nil { - pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel)) - return - } - - if err := nextHandler(modelID, c.Writer, c.Request); err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) - pm.proxyLogger.Errorf("Error Proxying GET Request for model %s", modelID) - return - } -} - -func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) { - acceptHeader := c.GetHeader("Accept") - - if strings.Contains(acceptHeader, "application/json") { - c.JSON(statusCode, gin.H{"error": message}) - } else { - c.String(statusCode, message) - } -} - -// apiKeyAuth returns a middleware that validates API keys if configured. -// Returns a pass-through handler if no API keys are configured. -func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc { - if len(pm.config.RequiredAPIKeys) == 0 { - return func(c *gin.Context) { c.Next() } - } - - return func(c *gin.Context) { - xApiKey := c.GetHeader("x-api-key") - - var bearerKey string - var basicKey string - if auth := c.GetHeader("Authorization"); auth != "" { - if strings.HasPrefix(auth, "Bearer ") { - bearerKey = strings.TrimPrefix(auth, "Bearer ") - } else if strings.HasPrefix(auth, "Basic ") { - // Basic Auth: base64(username:password), password is the API key - encoded := strings.TrimPrefix(auth, "Basic ") - if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil { - parts := strings.SplitN(string(decoded), ":", 2) - if len(parts) == 2 { - basicKey = parts[1] // password is the API key - } - } - } - } - - // Use first key found: Basic, then Bearer, then x-api-key - var providedKey string - if basicKey != "" { - providedKey = basicKey - } else if bearerKey != "" { - providedKey = bearerKey - } else { - providedKey = xApiKey - } - - // Validate key - valid := false - for _, key := range pm.config.RequiredAPIKeys { - if providedKey == key { - valid = true - break - } - } - - if !valid { - c.Header("WWW-Authenticate", `Basic realm="llama-swap"`) - pm.sendErrorResponse(c, http.StatusUnauthorized, "unauthorized: invalid or missing API key") - c.Abort() - return - } - - // Strip auth headers to prevent leakage to upstream - c.Request.Header.Del("Authorization") - c.Request.Header.Del("x-api-key") - - c.Next() - } -} - -func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) { - pm.StopProcesses(StopImmediately) - c.String(http.StatusOK, "OK") -} - -func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) { - context.Header("Content-Type", "application/json") - runningProcesses := make([]gin.H, 0) // Default to an empty response. - - if pm.matrix != nil { - for _, modelID := range pm.matrix.RunningModels() { - if process, ok := pm.matrix.GetProcess(modelID); ok { - runningProcesses = append(runningProcesses, gin.H{ - "model": process.ID, - "state": process.CurrentState(), - "cmd": process.config.Cmd, - "proxy": process.config.Proxy, - "ttl": process.config.UnloadAfter, - "name": process.config.Name, - "description": process.config.Description, - }) - } - } - } else { - for _, processGroup := range pm.processGroups { - for _, process := range processGroup.processes { - if process.CurrentState() == StateReady { - runningProcesses = append(runningProcesses, gin.H{ - "model": process.ID, - "state": process.CurrentState(), - "cmd": process.config.Cmd, - "proxy": process.config.Proxy, - "ttl": process.config.UnloadAfter, - "name": process.config.Name, - "description": process.config.Description, - }) - } - } - } - } - - // Put the results under the `running` key. - response := gin.H{ - "running": runningProcesses, - } - - context.JSON(http.StatusOK, response) // Always return 200 OK -} - -func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup { - for _, group := range pm.processGroups { - if group.HasMember(modelName) { - return group - } - } - return nil -} - -func (pm *ProxyManager) SetVersion(buildDate string, commit string, version string) { - pm.Lock() - defer pm.Unlock() - pm.buildDate = buildDate - pm.commit = commit - pm.version = version -} - -func (pm *ProxyManager) SetPerfMonitor(m *perf.Monitor) { - pm.Lock() - defer pm.Unlock() - pm.perfMonitor = m -} diff --git a/proxy/proxymanager_api.go b/proxy/proxymanager_api.go deleted file mode 100644 index b3f8437..0000000 --- a/proxy/proxymanager_api.go +++ /dev/null @@ -1,358 +0,0 @@ -package proxy - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "sort" - "strconv" - "strings" - "time" - - "github.com/gin-gonic/gin" - "github.com/mostlygeek/llama-swap/internal/event" - "github.com/mostlygeek/llama-swap/internal/perf" -) - -type Model struct { - Id string `json:"id"` - Name string `json:"name"` - Description string `json:"description"` - State string `json:"state"` - Unlisted bool `json:"unlisted"` - PeerID string `json:"peerID"` - Aliases []string `json:"aliases,omitempty"` -} - -func addApiHandlers(pm *ProxyManager) { - // Add API endpoints for React to consume - // Protected with API key authentication - apiGroup := pm.ginEngine.Group("/api", pm.apiKeyAuth()) - { - apiGroup.POST("/models/unload", pm.apiUnloadAllModels) - apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler) - apiGroup.GET("/events", pm.apiSendEvents) - apiGroup.GET("/metrics", pm.apiGetMetrics) - apiGroup.GET("/performance", pm.apiGetPerformance) - apiGroup.GET("/version", pm.apiGetVersion) - apiGroup.GET("/captures/:id", pm.apiGetCapture) - } -} - -func (pm *ProxyManager) apiUnloadAllModels(c *gin.Context) { - pm.StopProcesses(StopImmediately) - c.JSON(http.StatusOK, gin.H{"msg": "ok"}) -} - -func (pm *ProxyManager) getModelStatus() []Model { - // Extract keys and sort them - models := []Model{} - - modelIDs := make([]string, 0, len(pm.config.Models)) - for modelID := range pm.config.Models { - modelIDs = append(modelIDs, modelID) - } - sort.Strings(modelIDs) - - // Iterate over sorted keys - for _, modelID := range modelIDs { - // Get process state - state := "unknown" - var process *Process - if pm.matrix != nil { - process, _ = pm.matrix.GetProcess(modelID) - } else { - processGroup := pm.findGroupByModelName(modelID) - if processGroup != nil { - process = processGroup.processes[modelID] - } - } - if process != nil { - switch process.CurrentState() { - case StateReady: - state = "ready" - case StateStarting: - state = "starting" - case StateStopping: - state = "stopping" - case StateShutdown: - state = "shutdown" - case StateStopped: - state = "stopped" - } - } - models = append(models, Model{ - Id: modelID, - Name: pm.config.Models[modelID].Name, - Description: pm.config.Models[modelID].Description, - State: state, - Unlisted: pm.config.Models[modelID].Unlisted, - Aliases: pm.config.Models[modelID].Aliases, - }) - } - - // Iterate over the peer models - if pm.peerProxy != nil { - for peerID, peer := range pm.peerProxy.ListPeers() { - for _, modelID := range peer.Models { - models = append(models, Model{ - Id: modelID, - PeerID: peerID, - }) - } - } - } - - return models -} - -type messageType string - -const ( - msgTypeModelStatus messageType = "modelStatus" - msgTypeLogData messageType = "logData" - msgTypeMetrics messageType = "metrics" - msgTypeInFlight messageType = "inflight" -) - -type messageEnvelope struct { - Type messageType `json:"type"` - Data string `json:"data"` -} - -// sends a stream of different message types that happen on the server -func (pm *ProxyManager) apiSendEvents(c *gin.Context) { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("X-Content-Type-Options", "nosniff") - // prevent nginx from buffering SSE - c.Header("X-Accel-Buffering", "no") - - sendBuffer := make(chan messageEnvelope, 25) - ctx, cancel := context.WithCancel(c.Request.Context()) - sendModels := func() { - data, err := json.Marshal(pm.getModelStatus()) - if err == nil { - msg := messageEnvelope{Type: msgTypeModelStatus, Data: string(data)} - select { - case sendBuffer <- msg: - case <-ctx.Done(): - return - default: - } - - } - } - - sendLogData := func(source string, data []byte) { - data, err := json.Marshal(gin.H{ - "source": source, - "data": string(data), - }) - if err == nil { - select { - case sendBuffer <- messageEnvelope{Type: msgTypeLogData, Data: string(data)}: - case <-ctx.Done(): - return - default: - } - } - } - - sendMetrics := func(metrics []ActivityLogEntry) { - jsonData, err := json.Marshal(metrics) - if err == nil { - select { - case sendBuffer <- messageEnvelope{Type: msgTypeMetrics, Data: string(jsonData)}: - case <-ctx.Done(): - return - default: - } - } - } - - sendInFlight := func(total int) { - jsonData, err := json.Marshal(gin.H{"total": total}) - if err == nil { - select { - case sendBuffer <- messageEnvelope{Type: msgTypeInFlight, Data: string(jsonData)}: - case <-ctx.Done(): - return - default: - } - } - } - - /** - * Send updated models list - */ - defer event.On(func(e ProcessStateChangeEvent) { - sendModels() - })() - defer event.On(func(e ConfigFileChangedEvent) { - sendModels() - })() - - /** - * Send Log data - */ - defer pm.proxyLogger.OnLogData(func(data []byte) { - sendLogData("proxy", data) - })() - defer pm.upstreamLogger.OnLogData(func(data []byte) { - sendLogData("upstream", data) - })() - - /** - * Send Metrics data - */ - defer event.On(func(e ActivityLogEvent) { - sendMetrics([]ActivityLogEntry{e.Metrics}) - })() - - /** - * Send in-flight request stats related to token stats "Waiting: N" count. - */ - defer event.On(func(e InFlightRequestsEvent) { - sendInFlight(e.Total) - })() - - // send initial batch of data - sendLogData("proxy", pm.proxyLogger.GetHistory()) - sendLogData("upstream", pm.upstreamLogger.GetHistory()) - sendModels() - sendMetrics(pm.metricsMonitor.getMetrics()) - sendInFlight(pm.inFlightCounter.Current()) - - for { - select { - case <-c.Request.Context().Done(): - cancel() - return - case <-pm.shutdownCtx.Done(): - cancel() - return - case msg := <-sendBuffer: - c.SSEvent("message", msg) - c.Writer.Flush() - } - } -} - -func (pm *ProxyManager) apiGetMetrics(c *gin.Context) { - jsonData, err := pm.metricsMonitor.getMetricsJSON() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"}) - return - } - c.Data(http.StatusOK, "application/json", jsonData) -} - -func (pm *ProxyManager) prometheusMetricsHandler(c *gin.Context) { - if pm.perfMonitor == nil { - c.String(http.StatusServiceUnavailable, "# performance monitor not available\n") - return - } - pm.perfMonitor.MetricsHandler().ServeHTTP(c.Writer, c.Request) -} - -func (pm *ProxyManager) apiGetPerformance(c *gin.Context) { - if pm.perfMonitor == nil { - c.JSON(http.StatusServiceUnavailable, gin.H{"error": "performance monitor not available"}) - return - } - - sysStats, gpuStats := pm.perfMonitor.Current() - - var after time.Time - if afterStr := c.Query("after"); afterStr != "" { - ts, err := time.Parse(time.RFC3339, afterStr) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid 'after' timestamp, use RFC3339 format"}) - return - } - after = ts - } - - if !after.IsZero() { - filtered := make([]perf.SysStat, 0, len(sysStats)) - for _, s := range sysStats { - if s.Timestamp.After(after) { - filtered = append(filtered, s) - } - } - sysStats = filtered - - filteredGpu := make([]perf.GpuStat, 0, len(gpuStats)) - for _, g := range gpuStats { - if g.Timestamp.After(after) { - filteredGpu = append(filteredGpu, g) - } - } - gpuStats = filteredGpu - } - - c.JSON(http.StatusOK, gin.H{ - "sys_stats": sysStats, - "gpu_stats": gpuStats, - }) -} - -func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) { - requestedModel := strings.TrimPrefix(c.Param("model"), "/") - realModelName, found := pm.config.RealModelName(requestedModel) - if !found { - pm.sendErrorResponse(c, http.StatusNotFound, "Model not found") - return - } - - var stopErr error - if pm.matrix != nil { - stopErr = pm.matrix.StopProcess(realModelName, StopImmediately) - } else { - processGroup := pm.findGroupByModelName(realModelName) - if processGroup == nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel)) - return - } - stopErr = processGroup.StopProcess(realModelName, StopImmediately) - } - - if stopErr != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", stopErr.Error())) - return - } - c.String(http.StatusOK, "OK") -} - -func (pm *ProxyManager) apiGetVersion(c *gin.Context) { - c.JSON(http.StatusOK, map[string]string{ - "version": pm.version, - "commit": pm.commit, - "build_date": pm.buildDate, - }) -} - -func (pm *ProxyManager) apiGetCapture(c *gin.Context) { - idStr := c.Param("id") - id, err := strconv.Atoi(idStr) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid capture ID"}) - return - } - - capture := pm.metricsMonitor.getCaptureByID(id) - if capture == nil || (capture.ReqPath == "" && capture.ReqHeaders == nil && capture.ReqBody == nil && capture.RespHeaders == nil && capture.RespBody == nil) { - c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"}) - return - } - - jsonBytes, err := json.Marshal(capture) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to marshal capture"}) - return - } - c.Data(http.StatusOK, "application/json", jsonBytes) -} diff --git a/proxy/proxymanager_loghandlers.go b/proxy/proxymanager_loghandlers.go deleted file mode 100644 index dc94d6d..0000000 --- a/proxy/proxymanager_loghandlers.go +++ /dev/null @@ -1,121 +0,0 @@ -package proxy - -import ( - "context" - "fmt" - "net/http" - "strings" - - "github.com/gin-gonic/gin" - "github.com/mostlygeek/llama-swap/internal/logmon" -) - -func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) { - accept := c.GetHeader("Accept") - if strings.Contains(accept, "text/html") { - c.Redirect(http.StatusFound, "/ui/") - } else { - c.Header("Content-Type", "text/plain") - history := pm.muxLogger.GetHistory() - _, err := c.Writer.Write(history) - if err != nil { - c.AbortWithError(http.StatusInternalServerError, err) - return - } - } -} - -func (pm *ProxyManager) streamLogsHandler(c *gin.Context) { - c.Header("Content-Type", "text/plain") - c.Header("Transfer-Encoding", "chunked") - c.Header("X-Content-Type-Options", "nosniff") - // prevent nginx from buffering streamed logs - c.Header("X-Accel-Buffering", "no") - - logMonitorId := strings.TrimPrefix(c.Param("logMonitorID"), "/") - - // Handle case where query string might be included in the parameter - // (can happen with catch-all routes on some versions/setups) - if idx := strings.Index(logMonitorId, "?"); idx != -1 { - logMonitorId = logMonitorId[:idx] - } - - logger, err := pm.getLogger(logMonitorId) - if err != nil { - c.String(http.StatusBadRequest, err.Error()) - return - } - - flusher, ok := c.Writer.(http.Flusher) - if !ok { - c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported")) - return - } - - _, skipHistory := c.GetQuery("no-history") - // Send history first if not skipped - - if !skipHistory { - history := logger.GetHistory() - if len(history) != 0 { - c.Writer.Write(history) - flusher.Flush() - } - } - - sendChan := make(chan []byte, 10) - ctx, cancel := context.WithCancel(c.Request.Context()) - defer logger.OnLogData(func(data []byte) { - select { - case sendChan <- data: - case <-ctx.Done(): - return - default: - } - })() - - for { - select { - case <-c.Request.Context().Done(): - cancel() - return - case <-pm.shutdownCtx.Done(): - cancel() - return - case data := <-sendChan: - c.Writer.Write(data) - flusher.Flush() - } - } -} - -// getLogger searches for the appropriate logger based on the logMonitorId -func (pm *ProxyManager) getLogger(logMonitorId string) (*logmon.Monitor, error) { - switch logMonitorId { - case "": - // maintain the default - return pm.muxLogger, nil - case "proxy": - return pm.proxyLogger, nil - case "upstream": - return pm.upstreamLogger, nil - default: - // search for a models specific logger using findModelInPath - // to handle model names with slashes (e.g., "author/model") - if _, name, _, found := pm.findModelInPath("/" + logMonitorId); found { - for _, group := range pm.processGroups { - if process, found := group.GetMember(name); found { - return process.Logger(), nil - } - } - // also check the matrix when processGroups doesn't contain the model - if pm.matrix != nil { - if process, found := pm.matrix.GetProcess(name); found { - return process.Logger(), nil - } - } - } - - return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID") - } -} diff --git a/proxy/proxymanager_loghandlers_test.go b/proxy/proxymanager_loghandlers_test.go deleted file mode 100644 index 4e3af50..0000000 --- a/proxy/proxymanager_loghandlers_test.go +++ /dev/null @@ -1,173 +0,0 @@ -package proxy - -import ( - "context" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/mostlygeek/llama-swap/internal/config" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestLogMonitorIdQueryParameterStripping(t *testing.T) { - tests := []struct { - name string - input string - expected string - }{ - { - name: "upstream without query param", - input: "upstream", - expected: "upstream", - }, - { - name: "upstream with query param", - input: "upstream?no-history", - expected: "upstream", - }, - { - name: "proxy with multiple query params", - input: "proxy?no-history&foo=bar", - expected: "proxy", - }, - { - name: "model with slash and query param", - input: "author/model?no-history", - expected: "author/model", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Simulate the query parameter stripping logic - logMonitorId := tt.input - if idx := strings.Index(logMonitorId, "?"); idx != -1 { - logMonitorId = logMonitorId[:idx] - } - - if logMonitorId != tt.expected { - t.Errorf("Query parameter stripping failed: got %q, want %q", logMonitorId, tt.expected) - } - }) - } -} - -// TestProxyManager_GetLogger_ProcessGroups verifies getLogger resolves the -// well-known "proxy"/"upstream" loggers and a model ID managed by processGroups. -func TestProxyManager_GetLogger_ProcessGroups(t *testing.T) { - cfg := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 -`) - pm := New(cfg) - defer pm.StopProcesses(StopImmediately) - - tests := []struct { - id string - wantErr bool - }{ - {"proxy", false}, - {"upstream", false}, - {"model1", false}, - {"does-not-exist", true}, - } - - for _, tt := range tests { - t.Run(tt.id, func(t *testing.T) { - logger, err := pm.getLogger(tt.id) - if tt.wantErr { - require.Error(t, err) - assert.Contains(t, err.Error(), "invalid logger") - } else { - require.NoError(t, err) - assert.NotNil(t, logger) - } - }) - } -} - -// TestProxyManager_GetLogger_Matrix verifies that getLogger can resolve a model -// ID when the proxy is configured with a swap matrix (pm.processGroups is empty -// for matrix-managed models). -func TestProxyManager_GetLogger_Matrix(t *testing.T) { - cfg := config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - "model2": getTestSimpleResponderConfig("model2"), - }, - ExpandedSets: []config.ExpandedSet{ - {SetName: "s1", Models: []string{"model1", "model2"}}, - }, - Matrix: &config.MatrixConfig{}, - } - - pm := New(cfg) - defer pm.StopProcesses(StopImmediately) - - tests := []struct { - id string - wantErr bool - }{ - {"proxy", false}, - {"upstream", false}, - {"model1", false}, - {"model2", false}, - {"does-not-exist", true}, - } - - for _, tt := range tests { - t.Run(tt.id, func(t *testing.T) { - logger, err := pm.getLogger(tt.id) - if tt.wantErr { - require.Error(t, err) - assert.Contains(t, err.Error(), "invalid logger") - } else { - require.NoError(t, err) - assert.NotNil(t, logger) - } - }) - } -} - -// TestProxyManager_StreamLogs_Matrix verifies that /logs/stream/ -// returns 200 (not 400) for a model managed by the swap matrix. -func TestProxyManager_StreamLogs_Matrix(t *testing.T) { - cfg := config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "matrix-model": getTestSimpleResponderConfig("matrix-model"), - }, - ExpandedSets: []config.ExpandedSet{ - {SetName: "s1", Models: []string{"matrix-model"}}, - }, - Matrix: &config.MatrixConfig{}, - } - - pm := New(cfg) - defer pm.StopProcesses(StopImmediately) - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - - req := httptest.NewRequest("GET", "/logs/stream/matrix-model", nil) - req = req.WithContext(ctx) - rec := CreateTestResponseRecorder() - - done := make(chan struct{}) - go func() { - defer close(done) - pm.ServeHTTP(rec, req) - }() - - <-ctx.Done() - <-done - - assert.Equal(t, 200, rec.Code) -} diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go deleted file mode 100644 index f637eba..0000000 --- a/proxy/proxymanager_test.go +++ /dev/null @@ -1,1881 +0,0 @@ -package proxy - -import ( - "bytes" - "context" - "encoding/base64" - "encoding/json" - "fmt" - "math/rand" - "mime/multipart" - "net/http" - "net/http/httptest" - "strconv" - "strings" - "sync" - "testing" - "time" - - "github.com/mostlygeek/llama-swap/internal/config" - "github.com/mostlygeek/llama-swap/internal/event" - "github.com/stretchr/testify/assert" - "github.com/tidwall/gjson" -) - -// TestResponseRecorder adds CloseNotify to httptest.ResponseRecorder. -// "If you want to write your own tests around streams you will need a Recorder that can handle CloseNotifier." -// The tests can panic otherwise: -// panic: interface conversion: *httptest.ResponseRecorder is not http.CloseNotifier: missing method CloseNotify -// See: https://github.com/gin-gonic/gin/issues/1815 -// TestResponseRecorder is taken from gin's own tests: https://github.com/gin-gonic/gin/blob/ce20f107f5dc498ec7489d7739541a25dcd48463/context_test.go#L1747-L1765 -type TestResponseRecorder struct { - *httptest.ResponseRecorder - closeChannel chan bool -} - -func (r *TestResponseRecorder) CloseNotify() <-chan bool { - return r.closeChannel -} - -func CreateTestResponseRecorder() *TestResponseRecorder { - return &TestResponseRecorder{ - httptest.NewRecorder(), - make(chan bool, 1), - } -} - -func TestProxyManager_SwapProcessCorrectly(t *testing.T) { - cfg := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 - model2: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2 -`) - - proxy := New(cfg) - defer proxy.StopProcesses(StopWaitForInflightRequest) - - for _, modelName := range []string{"model1", "model2"} { - reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName) - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), modelName) - } -} -func TestProxyManager_SwapMultiProcess(t *testing.T) { - cfg := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 - model2: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2 -groups: - G1: - swap: true - exclusive: false - members: - - model1 - G2: - swap: true - exclusive: false - members: - - model2 -`) - - proxy := New(cfg) - defer proxy.StopProcesses(StopWaitForInflightRequest) - - tests := []string{"model1", "model2"} - for _, requestedModel := range tests { - t.Run(requestedModel, func(t *testing.T) { - reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), requestedModel) - }) - } - - // make sure there's two loaded models - assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady) - assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady) -} - -// Test that a persistent group is not affected by the swapping behaviour of -// other groups. -func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) { - cfg := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 - model2: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2 -groups: - forever: - swap: true - exclusive: false - persistent: true - members: - - model2 -`) - - proxy := New(cfg) - defer proxy.StopProcesses(StopWaitForInflightRequest) - - // make requests to load all models, loading model1 should not affect model2 - tests := []string{"model2", "model1"} - for _, requestedModel := range tests { - reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), requestedModel) - } - - assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady) - assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady) -} - -// When a request for a different model comes in ProxyManager should wait until -// the first request is complete before swapping. Both requests should complete -func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) { - if testing.Short() { - t.Skip("skipping slow test") - } - - cfg := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 - model2: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2 - model3: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model3 -`) - - proxy := New(cfg) - defer proxy.StopProcesses(StopWaitForInflightRequest) - - results := map[string]string{} - - var wg sync.WaitGroup - var mu sync.Mutex - - for key := range cfg.Models { - wg.Add(1) - go func(key string) { - defer wg.Done() - - reqBody := fmt.Sprintf(`{"model":"%s"}`, key) - req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Errorf("Expected status OK, got %d for key %s", w.Code, key) - } - - mu.Lock() - var response map[string]interface{} - assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) - result, ok := response["responseMessage"].(string) - assert.Equal(t, ok, true) - results[key] = result - mu.Unlock() - }(key) - - <-time.After(time.Millisecond) - } - - wg.Wait() - assert.Len(t, results, len(cfg.Models)) - - for key, result := range results { - assert.Equal(t, key, result) - } -} - -func TestProxyManager_ListModelsHandler(t *testing.T) { - - cfg := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 - name: "Model 1" - description: "Model 1 description is used for testing" - model2: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2 - name: " " - description: " " - model3: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model3 -peers: - peer1: - proxy: http://peer1:8080 - models: - - peer-model-a - - peer-model-b -`) - - proxy := New(cfg) - - // Create a test request - req := httptest.NewRequest("GET", "/v1/models", nil) - req.Header.Add("Origin", "i-am-the-origin") - w := CreateTestResponseRecorder() - - // Call the listModelsHandler - proxy.ServeHTTP(w, req) - - // Check the response status code - assert.Equal(t, http.StatusOK, w.Code) - - // Check for Access-Control-Allow-Origin - assert.Equal(t, req.Header.Get("Origin"), w.Result().Header.Get("Access-Control-Allow-Origin")) - - // Parse the JSON response - var response struct { - Data []map[string]interface{} `json:"data"` - } - - if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { - t.Fatalf("Failed to parse JSON response: %v", err) - } - - // Check the number of models returned (3 local + 2 peer models) - assert.Len(t, response.Data, 5) - - // Check the details of each model - expectedModels := map[string]struct{}{ - "model1": {}, - "model2": {}, - "model3": {}, - "peer-model-a": {}, - "peer-model-b": {}, - } - - // make all models - for _, model := range response.Data { - modelID, ok := model["id"].(string) - assert.True(t, ok, "model ID should be a string") - _, exists := expectedModels[modelID] - assert.True(t, exists, "unexpected model ID: %s", modelID) - delete(expectedModels, modelID) - - object, ok := model["object"].(string) - assert.True(t, ok, "object should be a string") - assert.Equal(t, "model", object) - - created, ok := model["created"].(float64) - assert.True(t, ok, "created should be a number") - assert.Greater(t, created, float64(0)) // Assuming the timestamp is positive - - ownedBy, ok := model["owned_by"].(string) - assert.True(t, ok, "owned_by should be a string") - assert.Equal(t, "llama-swap", ownedBy) - - // check for optional name and description - if modelID == "model1" { - name, ok := model["name"].(string) - assert.True(t, ok, "name should be a string") - assert.Equal(t, "Model 1", name) - description, ok := model["description"].(string) - assert.True(t, ok, "description should be a string") - assert.Equal(t, "Model 1 description is used for testing", description) - } else if modelID == "peer-model-a" || modelID == "peer-model-b" { - // Peer models should have meta.llamaswap.peerID - meta, exists := model["meta"] - assert.True(t, exists, "peer model should have meta field") - metaMap, ok := meta.(map[string]interface{}) - assert.True(t, ok, "meta should be a map") - llamaswap, exists := metaMap["llamaswap"] - assert.True(t, exists, "meta should have llamaswap field") - llamaswapMap, ok := llamaswap.(map[string]interface{}) - assert.True(t, ok, "llamaswap should be a map") - peerID, exists := llamaswapMap["peerID"] - assert.True(t, exists, "llamaswap should have peerID field") - assert.Equal(t, "peer1", peerID) - } else { - _, exists := model["name"] - assert.False(t, exists, "unexpected name field for model: %s", modelID) - _, exists = model["description"] - assert.False(t, exists, "unexpected description field for model: %s", modelID) - } - } - - // Ensure all expected models were returned - assert.Empty(t, expectedModels, "not all expected models were returned") -} - -func TestProxyManager_ListModelsHandler_WithMetadata(t *testing.T) { - // Process config through LoadConfigFromReader to apply macro substitution - configYaml := ` -healthCheckTimeout: 15 -logLevel: error -startPort: 10000 -models: - model1: - cmd: /path/to/server -p ${PORT} - macros: - PORT_NUM: 10001 - TEMP: 0.7 - NAME: "llama" - metadata: - port: ${PORT_NUM} - temperature: ${TEMP} - enabled: true - note: "Running on port ${PORT_NUM}" - nested: - value: ${TEMP} - model2: - cmd: /path/to/server -p ${PORT} -` - processedConfig, err := config.LoadConfigFromReader(strings.NewReader(configYaml)) - assert.NoError(t, err) - - proxy := New(processedConfig) - - req := httptest.NewRequest("GET", "/v1/models", nil) - w := CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - - var response struct { - Data []map[string]any `json:"data"` - } - - err = json.Unmarshal(w.Body.Bytes(), &response) - assert.NoError(t, err) - assert.Len(t, response.Data, 2) - - // Find model1 and model2 in response - var model1Data, model2Data map[string]any - for _, model := range response.Data { - if model["id"] == "model1" { - model1Data = model - } else if model["id"] == "model2" { - model2Data = model - } - } - - // Verify model1 has llamaswap_meta - assert.NotNil(t, model1Data) - meta, exists := model1Data["meta"] - if !assert.True(t, exists, "model1 should have meta key") { - t.FailNow() - } - - metaMap := meta.(map[string]any) - - lsmeta, exists := metaMap["llamaswap"] - if !assert.True(t, exists, "model1 should have meta.llamaswap key") { - t.FailNow() - } - - lsmetamap := lsmeta.(map[string]any) - - // Verify type preservation - assert.Equal(t, float64(10001), lsmetamap["port"]) // JSON numbers are float64 - assert.Equal(t, 0.7, lsmetamap["temperature"]) - assert.Equal(t, true, lsmetamap["enabled"]) - // Verify string interpolation - assert.Equal(t, "Running on port 10001", lsmetamap["note"]) - // Verify nested structure - nested := lsmetamap["nested"].(map[string]any) - assert.Equal(t, 0.7, nested["value"]) - - // Verify model2 does NOT have llamaswap_meta - assert.NotNil(t, model2Data) - _, exists = model2Data["llamaswap_meta"] - assert.False(t, exists, "model2 should not have llamaswap_meta") -} - -func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) { - // Intentionally add models in non-sorted order and with an unlisted model - cfg := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - zeta: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond zeta - alpha: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond alpha - beta: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond beta - hidden: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond hidden - unlisted: true -`) - - proxy := New(cfg) - - // Request models list - req := httptest.NewRequest("GET", "/v1/models", nil) - w := CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - - var response struct { - Data []map[string]interface{} `json:"data"` - } - if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { - t.Fatalf("Failed to parse JSON response: %v", err) - } - - // We expect only the listed models in sorted order by id - expectedOrder := []string{"alpha", "beta", "zeta"} - if assert.Len(t, response.Data, len(expectedOrder), "unexpected number of listed models") { - got := make([]string, 0, len(response.Data)) - for _, m := range response.Data { - id, _ := m["id"].(string) - got = append(got, id) - } - assert.Equal(t, expectedOrder, got, "models should be sorted by id ascending") - } -} - -func TestProxyManager_ListModelsHandler_IncludeAliasesInList(t *testing.T) { - // Configure alias - cfg := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -includeAliasesInList: true -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 - name: "Model 1" - aliases: - - alias1 -`) - - proxy := New(cfg) - - // Request models list - req := httptest.NewRequest("GET", "/v1/models", nil) - w := CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - - var response struct { - Data []map[string]interface{} `json:"data"` - } - if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { - t.Fatalf("Failed to parse JSON response: %v", err) - } - - // We expect both base id and alias - var model1Data, alias1Data map[string]any - for _, model := range response.Data { - if model["id"] == "model1" { - model1Data = model - } else if model["id"] == "alias1" { - alias1Data = model - } - } - - // Verify model1 has name - assert.NotNil(t, model1Data) - _, exists := model1Data["name"] - if !assert.True(t, exists, "model1 should have name key") { - t.FailNow() - } - name1, ok := model1Data["name"].(string) - assert.True(t, ok, "name1 should be a string") - - // Verify alias1 has name - assert.NotNil(t, alias1Data) - _, exists = alias1Data["name"] - if !assert.True(t, exists, "alias1 should have name key") { - t.FailNow() - } - name2, ok := alias1Data["name"].(string) - assert.True(t, ok, "name2 should be a string") - - // Name keys should match - assert.Equal(t, name1, name2) -} - -func TestProxyManager_Shutdown(t *testing.T) { - if testing.Short() { - t.Skip("skipping slow test") - } - - // make broken model configurations - model1Config := getTestSimpleResponderConfigPort("model1", 9991) - model1Config.Proxy = "http://localhost:10001/" - - model2Config := getTestSimpleResponderConfigPort("model2", 9992) - model2Config.Proxy = "http://localhost:10002/" - - model3Config := getTestSimpleResponderConfigPort("model3", 9993) - model3Config.Proxy = "http://localhost:10003/" - - cfg := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": model1Config, - "model2": model2Config, - "model3": model3Config, - }, - LogLevel: "error", - Groups: map[string]config.GroupConfig{ - "test": { - Swap: false, - Members: []string{"model1", "model2", "model3"}, - }, - }, - }) - - proxy := New(cfg) - - // Start all the processes - var wg sync.WaitGroup - for _, modelName := range []string{"model1", "model2", "model3"} { - wg.Add(1) - go func(modelName string) { - defer wg.Done() - reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName) - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - // send a request to trigger the proxy to load ... this should hang waiting for start up - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusBadGateway, w.Code) - assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown") - }(modelName) - } - - go func() { - <-time.After(time.Second) - proxy.Shutdown() - }() - wg.Wait() -} - -func TestProxyManager_Unload(t *testing.T) { - conf := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 -`) - - proxy := New(conf) - reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1") - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - - assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady) - req = httptest.NewRequest("GET", "/unload", nil) - w = CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, w.Body.String(), "OK") - - select { - case <-proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].cmdWaitChan: - // good - case <-time.After(2 * time.Second): - t.Fatal("timeout waiting for model1 to stop") - } - assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped) -} - -func TestProxyManager_UnloadSingleModel(t *testing.T) { - const testGroupId = "testGroup" - cfg := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 - model2: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2 -groups: - testGroup: - swap: false - members: - - model1 - - model2 -`) - - proxy := New(cfg) - defer proxy.StopProcesses(StopImmediately) - - // start both model - for _, modelName := range []string{"model1", "model2"} { - reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName) - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - } - - assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model1"].CurrentState()) - assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model2"].CurrentState()) - - req := httptest.NewRequest("POST", "/api/models/unload/model1", nil) - w := CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - if !assert.Equal(t, w.Body.String(), "OK") { - t.FailNow() - } - - select { - case <-proxy.processGroups[testGroupId].processes["model1"].cmdWaitChan: - // good - case <-time.After(2 * time.Second): - t.Fatal("timeout waiting for model1 to stop") - } - - assert.Equal(t, proxy.processGroups[testGroupId].processes["model1"].CurrentState(), StateStopped) - assert.Equal(t, proxy.processGroups[testGroupId].processes["model2"].CurrentState(), StateReady) -} - -// Test issue #61 `Listing the current list of models and the loaded model.` -func TestProxyManager_RunningEndpoint(t *testing.T) { - // Shared configuration - cfg := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: warn -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 - model2: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2 -`) - - // Define a helper struct to parse the JSON response. - type RunningResponse struct { - Running []struct { - Model string `json:"model"` - State string `json:"state"` - Cmd string `json:"cmd"` - Proxy string `json:"proxy"` - TTL int `json:"ttl"` - Name string `json:"name"` - Description string `json:"description"` - } `json:"running"` - } - - // Create proxy once for all tests - proxy := New(cfg) - defer proxy.StopProcesses(StopWaitForInflightRequest) - injectTestHandlers(proxy, nil) - - t.Run("no models loaded", func(t *testing.T) { - req := httptest.NewRequest("GET", "/running", nil) - w := CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - - var response RunningResponse - - // Check if this is a valid JSON object. - assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) - - // We should have an empty running array here. - assert.Empty(t, response.Running, "expected no running models") - }) - - t.Run("single model loaded", func(t *testing.T) { - // Load just a model. - reqBody := `{"model":"model1"}` - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - - // Simulate browser call for the `/running` endpoint. - req = httptest.NewRequest("GET", "/running", nil) - w = CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - - var response RunningResponse - assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) - - // Check if we have a single array element. - assert.Len(t, response.Running, 1) - - // Is this the right model? - assert.Equal(t, "model1", response.Running[0].Model) - - // Is the model loaded? - assert.Equal(t, "ready", response.Running[0].State) - - // Verify extended fields are present - assert.NotEmpty(t, response.Running[0].Cmd, "cmd should be populated") - assert.NotEmpty(t, response.Running[0].Proxy, "proxy should be populated") - assert.Equal(t, 0, response.Running[0].TTL, "ttl should default to globalTTL (0)") - }) -} - -func TestProxyManager_AudioTranscriptionHandler(t *testing.T) { - cfg := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - TheExpectedModel: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond TheExpectedModel -`) - - proxy := New(cfg) - defer proxy.StopProcesses(StopWaitForInflightRequest) - injectTestHandlers(proxy, nil) - - // Create a buffer with multipart form data - var b bytes.Buffer - w := multipart.NewWriter(&b) - - // Add the model field - fw, err := w.CreateFormField("model") - assert.NoError(t, err) - _, err = fw.Write([]byte("TheExpectedModel")) - assert.NoError(t, err) - - // Add a file field - fw, err = w.CreateFormFile("file", "test.mp3") - assert.NoError(t, err) - // Generate random content length between 10 and 20 - contentLength := rand.Intn(11) + 10 // 10 to 20 - content := make([]byte, contentLength) - _, err = fw.Write(content) - assert.NoError(t, err) - w.Close() - - // Create the request with the multipart form data - req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) - req.Header.Set("Content-Type", w.FormDataContentType()) - rec := CreateTestResponseRecorder() - proxy.ServeHTTP(rec, req) - - // Verify the response - assert.Equal(t, http.StatusOK, rec.Code) - var response map[string]string - err = json.Unmarshal(rec.Body.Bytes(), &response) - assert.NoError(t, err) - assert.Equal(t, "TheExpectedModel", response["model"]) - assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder - assert.Equal(t, strconv.Itoa(370+contentLength), response["h_content_length"]) -} - -// Test useModelName in configuration sends overrides what is sent to upstream -func TestProxyManager_UseModelName(t *testing.T) { - upstreamModelName := "upstreamModel" - - conf := testConfigFromYAML(t, fmt.Sprintf(` -healthCheckTimeout: 15 -logLevel: error -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond %s - useModelName: %s -`, upstreamModelName, upstreamModelName)) - - proxy := New(conf) - defer proxy.StopProcesses(StopWaitForInflightRequest) - injectTestHandlers(proxy, nil) - - requestedModel := "model1" - - t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) { - reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), upstreamModelName) - - // make sure the content length was set correctly - // simple-responder will return the content length it got in the response - body := w.Body.Bytes() - contentLength := int(gjson.GetBytes(body, "h_content_length").Int()) - assert.Equal(t, len(fmt.Sprintf(`{"model":"%s"}`, upstreamModelName)), contentLength) - }) - - t.Run("useModelName over rides requested model: /v1/audio/transcriptions", func(t *testing.T) { - // Create a buffer with multipart form data - var b bytes.Buffer - w := multipart.NewWriter(&b) - - // Add the model field - fw, err := w.CreateFormField("model") - assert.NoError(t, err) - _, err = fw.Write([]byte(requestedModel)) - assert.NoError(t, err) - - // Add a file field - fw, err = w.CreateFormFile("file", "test.mp3") - assert.NoError(t, err) - _, err = fw.Write([]byte("test")) - assert.NoError(t, err) - w.Close() - - // Create the request with the multipart form data - req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) - req.Header.Set("Content-Type", w.FormDataContentType()) - rec := CreateTestResponseRecorder() - proxy.ServeHTTP(rec, req) - - // Verify the response - assert.Equal(t, http.StatusOK, rec.Code) - var response map[string]string - err = json.Unmarshal(rec.Body.Bytes(), &response) - assert.NoError(t, err) - assert.Equal(t, upstreamModelName, response["model"]) - }) -} - -func TestProxyManager_AudioVoicesGETHandler(t *testing.T) { - conf := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 -`) - - proxy := New(conf) - defer proxy.StopProcesses(StopWaitForInflightRequest) - injectTestHandlers(proxy, nil) - - t.Run("successful GET with model query param", func(t *testing.T) { - req := httptest.NewRequest("GET", "/v1/audio/voices?model=model1", nil) - w := CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "voice1") - }) - - t.Run("missing model query param returns 400", func(t *testing.T) { - req := httptest.NewRequest("GET", "/v1/audio/voices", nil) - w := CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusBadRequest, w.Code) - assert.Contains(t, w.Body.String(), "missing required 'model' query parameter") - }) - - t.Run("unknown model returns 400", func(t *testing.T) { - req := httptest.NewRequest("GET", "/v1/audio/voices?model=nonexistent", nil) - w := CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusBadRequest, w.Code) - assert.Contains(t, w.Body.String(), "could not find suitable handler") - }) -} - -func TestProxyManager_CORSOptionsHandler(t *testing.T) { - cfg := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 -`) - - tests := []struct { - name string - method string - requestHeaders map[string]string - expectedStatus int - expectedHeaders map[string]string - }{ - { - name: "OPTIONS with no headers", - method: "OPTIONS", - expectedStatus: http.StatusNoContent, - expectedHeaders: map[string]string{ - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With", - }, - }, - { - name: "OPTIONS with specific headers", - method: "OPTIONS", - requestHeaders: map[string]string{ - "Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header", - }, - expectedStatus: http.StatusNoContent, - expectedHeaders: map[string]string{ - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS", - "Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header", - }, - }, - { - name: "Non-OPTIONS request", - method: "GET", - expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - proxy := New(cfg) - defer proxy.StopProcesses(StopWaitForInflightRequest) - injectTestHandlers(proxy, nil) - - req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil) - for k, v := range tt.requestHeaders { - req.Header.Set(k, v) - } - - w := CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - - assert.Equal(t, tt.expectedStatus, w.Code) - - for header, expectedValue := range tt.expectedHeaders { - assert.Equal(t, expectedValue, w.Header().Get(header)) - } - }) - } -} - -func TestProxyManager_Upstream(t *testing.T) { - cfg := testConfigFromYAML(t, ` -logLevel: error -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 - aliases: [model-alias] -`) - - proxy := New(cfg) - defer proxy.StopProcesses(StopWaitForInflightRequest) - injectTestHandlers(proxy, nil) - t.Run("main model name", func(t *testing.T) { - req := httptest.NewRequest("GET", "/upstream/model1/test", nil) - rec := CreateTestResponseRecorder() - proxy.ServeHTTP(rec, req) - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "model1", rec.Body.String()) - }) - - t.Run("model alias", func(t *testing.T) { - req := httptest.NewRequest("GET", "/upstream/model-alias/test", nil) - rec := CreateTestResponseRecorder() - proxy.ServeHTTP(rec, req) - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "model1", rec.Body.String()) - }) -} - -func TestProxyManager_ChatContentLength(t *testing.T) { - cfg := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 -`) - - proxy := New(cfg) - defer proxy.StopProcesses(StopWaitForInflightRequest) - injectTestHandlers(proxy, nil) - - reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1") - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - var response map[string]interface{} - assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) - assert.Equal(t, "81", response["h_content_length"]) - assert.Equal(t, "model1", response["responseMessage"]) -} - -func TestProxyManager_FiltersStripParams(t *testing.T) { - cfg := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 - filters: - stripParams: "temperature, model, stream" -`) - - proxy := New(cfg) - defer proxy.StopProcesses(StopWaitForInflightRequest) - injectTestHandlers(proxy, nil) - reqBody := `{"model":"model1", "temperature":0.1, "x_param":"123", "y_param":"abc", "stream":true}` - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - var response map[string]interface{} - assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) - - // `temperature` and `stream` are gone but model remains - assert.Equal(t, `{"model":"model1", "x_param":"123", "y_param":"abc"}`, response["request_body"]) - - // assert.Nil(t, response["temperature"]) - // assert.Equal(t, "123", response["x_param"]) - // assert.Equal(t, "abc", response["y_param"]) - // t.Logf("%v", response) -} - -func TestProxyManager_FiltersSetParamsByID(t *testing.T) { - // no explicit aliases — setParamsByID keys are auto-registered as aliases - cfg := testConfigFromYAML(t, ` -logLevel: error -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 - proxy: "http://127.0.0.1:${PORT}" - filters: - setParams: - reasoning_effort: medium - setParamsByID: - "${MODEL_ID}:high": - reasoning_effort: high - "${MODEL_ID}:low": - reasoning_effort: low -`) - - proxy := New(cfg) - defer proxy.StopProcesses(StopWaitForInflightRequest) - injectTestHandlers(proxy, nil) - - tests := []struct { - requestedModel string - wantEffort string - }{ - // setParams applies, no setParamsByID match - {requestedModel: "model1", wantEffort: "medium"}, - // setParamsByID overrides setParams - {requestedModel: "model1:high", wantEffort: "high"}, - {requestedModel: "model1:low", wantEffort: "low"}, - } - - for _, tt := range tests { - t.Run(tt.requestedModel, func(t *testing.T) { - reqBody := fmt.Sprintf(`{"model":%q}`, tt.requestedModel) - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - - var response map[string]interface{} - assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) - - requestBody, _ := response["request_body"].(string) - gotEffort := gjson.Get(requestBody, "reasoning_effort").String() - assert.Equal(t, tt.wantEffort, gotEffort, "reasoning_effort mismatch for model %s", tt.requestedModel) - }) - } -} - -func TestProxyManager_HealthEndpoint(t *testing.T) { - cfg := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 -`) - - proxy := New(cfg) - defer proxy.StopProcesses(StopWaitForInflightRequest) - req := httptest.NewRequest("GET", "/health", nil) - rec := CreateTestResponseRecorder() - proxy.ServeHTTP(rec, req) - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "OK", rec.Body.String()) -} - -// Ensure the custom llama-server /completion endpoint proxies correctly -func TestProxyManager_CompletionEndpoint(t *testing.T) { - cfg := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 -`) - - proxy := New(cfg) - defer proxy.StopProcesses(StopWaitForInflightRequest) - injectTestHandlers(proxy, nil) - - reqBody := `{"model":"model1"}` - req := httptest.NewRequest("POST", "/completion", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "model1") -} - -func TestProxyManager_StartupHooks(t *testing.T) { - - cfg := testConfigFromYAML(t, ` -logLevel: error -hooks: - on_startup: - preload: - - model1 - - model2 -groups: - preloadTestGroup: - swap: false - members: - - model1 - - model2 -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 - model2: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2 -`) - - preloadChan := make(chan ModelPreloadedEvent, 2) // buffer for 2 expected events - - unsub := event.On(func(e ModelPreloadedEvent) { - preloadChan <- e - }) - - defer unsub() - - // Create the proxy which should trigger preloading - proxy := New(cfg) - defer proxy.StopProcesses(StopWaitForInflightRequest) - - for i := 0; i < 2; i++ { - select { - case <-preloadChan: - case <-time.After(5 * time.Second): - t.Fatal("timed out waiting for models to preload") - } - } - // make sure they are both loaded - _, foundGroup := proxy.processGroups["preloadTestGroup"] - if !assert.True(t, foundGroup, "preloadTestGroup should exist") { - return - } - assert.Equal(t, StateReady, proxy.processGroups["preloadTestGroup"].processes["model1"].CurrentState()) - assert.Equal(t, StateReady, proxy.processGroups["preloadTestGroup"].processes["model2"].CurrentState()) -} - -func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) { - cfg := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 - author/model: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond author/model -`) - - proxy := New(cfg) - defer proxy.StopProcesses(StopWaitForInflightRequest) - - endpoints := []string{ - "/api/events", - "/logs/stream", - "/logs/stream/proxy", - "/logs/stream/upstream", - "/logs/stream/author/model", - } - - for _, endpoint := range endpoints { - t.Run(endpoint, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - - req := httptest.NewRequest("GET", endpoint, nil) - req = req.WithContext(ctx) - rec := CreateTestResponseRecorder() - - // Run handler in goroutine and wait for context timeout - done := make(chan struct{}) - go func() { - defer close(done) - proxy.ServeHTTP(rec, req) - }() - - // Wait for either the handler to complete or context to timeout - <-ctx.Done() - - // At this point, the handler has either finished or been cancelled - // Wait for the goroutine to fully exit before reading - <-done - - // Now it's safe to read from rec - no more concurrent writes - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering")) - }) - } -} - -func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testing.T) { - cfg := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - streaming-model: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond streaming-model -`) - - proxy := New(cfg) - defer proxy.StopProcesses(StopWaitForInflightRequest) - - // Make a streaming request - reqBody := `{"model":"streaming-model"}` - // simple-responder will return text/event-stream when stream=true is in the query - req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody)) - rec := CreateTestResponseRecorder() - - proxy.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering")) - assert.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream") -} - -func TestProxyManager_ApiGetVersion(t *testing.T) { - cfg := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 -`) - - // Version test map - versionTest := map[string]string{ - "build_date": "1970-01-01T00:00:00Z", - "commit": "cc915ddb6f04a42d9cd1f524e1d46ec6ed069fdc", - "version": "v001", - } - - proxy := New(cfg) - proxy.SetVersion(versionTest["build_date"], versionTest["commit"], versionTest["version"]) - defer proxy.StopProcesses(StopWaitForInflightRequest) - - req := httptest.NewRequest("GET", "/api/version", nil) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - - // Ensure json response - assert.Equal(t, "application/json; charset=utf-8", w.Header().Get("Content-Type")) - - // Check for attributes - response := map[string]string{} - assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) - for key, value := range versionTest { - assert.Equal(t, value, response[key], "%s value %s should match response %s", key, value, response[key]) - } -} - -func TestProxyManager_APIKeyAuth(t *testing.T) { - testConfig := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -apiKeys: - - valid-key-1 - - valid-key-2 -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 -`) - - proxy := New(testConfig) - defer proxy.StopProcesses(StopImmediately) - injectTestHandlers(proxy, nil) - - t.Run("valid key in x-api-key header", func(t *testing.T) { - reqBody := `{"model":"model1"}` - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - req.Header.Set("x-api-key", "valid-key-1") - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - }) - - t.Run("valid key in Authorization Bearer header", func(t *testing.T) { - reqBody := `{"model":"model1"}` - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - req.Header.Set("Authorization", "Bearer valid-key-2") - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - }) - - t.Run("both headers with matching keys", func(t *testing.T) { - reqBody := `{"model":"model1"}` - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - req.Header.Set("x-api-key", "valid-key-1") - req.Header.Set("Authorization", "Bearer valid-key-1") - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - }) - - t.Run("invalid key returns 401", func(t *testing.T) { - reqBody := `{"model":"model1"}` - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - req.Header.Set("x-api-key", "invalid-key") - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusUnauthorized, w.Code) - assert.Contains(t, w.Body.String(), "unauthorized") - }) - - t.Run("missing key returns 401", func(t *testing.T) { - reqBody := `{"model":"model1"}` - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusUnauthorized, w.Code) - }) - - t.Run("valid key in Basic Auth header", func(t *testing.T) { - reqBody := `{"model":"model1"}` - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - // Basic Auth: base64("anyuser:valid-key-1") - credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:valid-key-1")) - req.Header.Set("Authorization", "Basic "+credentials) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - }) - - t.Run("invalid key in Basic Auth header returns 401", func(t *testing.T) { - reqBody := `{"model":"model1"}` - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:wrong-key")) - req.Header.Set("Authorization", "Basic "+credentials) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusUnauthorized, w.Code) - assert.Contains(t, w.Body.String(), "unauthorized") - }) - - t.Run("x-api-key and Basic Auth with matching keys", func(t *testing.T) { - reqBody := `{"model":"model1"}` - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - req.Header.Set("x-api-key", "valid-key-1") - credentials := base64.StdEncoding.EncodeToString([]byte("user:valid-key-1")) - req.Header.Set("Authorization", "Basic "+credentials) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - }) - - t.Run("401 response includes WWW-Authenticate header", func(t *testing.T) { - reqBody := `{"model":"model1"}` - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusUnauthorized, w.Code) - assert.Equal(t, `Basic realm="llama-swap"`, w.Header().Get("WWW-Authenticate")) - }) -} - -func TestProxyManager_APIKeyAuth_Disabled(t *testing.T) { - // Config without RequiredAPIKeys - auth should be disabled - testConfig := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 -`) - - proxy := New(testConfig) - defer proxy.StopProcesses(StopImmediately) - injectTestHandlers(proxy, nil) - - t.Run("requests pass without API key when not configured", func(t *testing.T) { - reqBody := `{"model":"model1"}` - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - }) -} - -// TestProxyManager_PeerProxy_InferenceHandler tests the peerProxy integration -// in proxyInferenceHandler for issue #433 -func TestProxyManager_PeerProxy_InferenceHandler(t *testing.T) { - t.Run("requests to peer models are proxied", func(t *testing.T) { - // Create a test server to act as the peer - peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"response":"from-peer","model":"peer-model"}`)) - })) - defer peerServer.Close() - - testConfig := testConfigFromYAML(t, fmt.Sprintf(` -logLevel: error -peers: - test-peer: - proxy: %s - models: - - peer-model -models: - local-model: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond local-model -`, peerServer.URL)) - - proxy := New(testConfig) - defer proxy.StopProcesses(StopImmediately) - injectTestHandlers(proxy, nil) - - reqBody := `{"model":"peer-model"}` - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "from-peer") - }) - - t.Run("local models take precedence over peer models", func(t *testing.T) { - // Create a test server to act as the peer - should NOT be called - peerCalled := false - peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - peerCalled = true - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"response":"from-peer"}`)) - })) - defer peerServer.Close() - - testConfig := testConfigFromYAML(t, fmt.Sprintf(` -logLevel: error -peers: - test-peer: - proxy: %s - models: - - shared-model -models: - shared-model: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond local-response -`, peerServer.URL)) - - proxy := New(testConfig) - defer proxy.StopProcesses(StopImmediately) - injectTestHandlers(proxy, map[string]string{"shared-model": "local-response"}) - - reqBody := `{"model":"shared-model"}` - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "local-response") - assert.False(t, peerCalled, "peer should not be called when local model exists") - }) - - t.Run("unknown model returns error", func(t *testing.T) { - // Create a test server to act as the peer - peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - defer peerServer.Close() - - testConfig := testConfigFromYAML(t, fmt.Sprintf(` -logLevel: error -peers: - test-peer: - proxy: %s - models: - - peer-model -models: - local-model: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond local-model -`, peerServer.URL)) - - proxy := New(testConfig) - defer proxy.StopProcesses(StopImmediately) - injectTestHandlers(proxy, nil) - - reqBody := `{"model":"unknown-model"}` - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusBadRequest, w.Code) - assert.Contains(t, w.Body.String(), "could not find suitable inference handler") - }) - - t.Run("peer API key is injected into request", func(t *testing.T) { - var receivedAuthHeader string - peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedAuthHeader = r.Header.Get("Authorization") - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"response":"ok"}`)) - })) - defer peerServer.Close() - - testConfig := testConfigFromYAML(t, fmt.Sprintf(` -logLevel: error -peers: - test-peer: - proxy: %s - apiKey: secret-peer-key - models: - - peer-model -models: - local-model: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond local-model -`, peerServer.URL)) - - proxy := New(testConfig) - defer proxy.StopProcesses(StopImmediately) - injectTestHandlers(proxy, nil) - - reqBody := `{"model":"peer-model"}` - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "Bearer secret-peer-key", receivedAuthHeader) - }) - - t.Run("no peers configured - unknown model returns error", func(t *testing.T) { - testConfig := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - local-model: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond local-model -`) - - proxy := New(testConfig) - defer proxy.StopProcesses(StopImmediately) - injectTestHandlers(proxy, nil) - - // peerProxy exists but has no peer models configured - assert.False(t, proxy.peerProxy.HasPeerModel("unknown-model")) - - reqBody := `{"model":"unknown-model"}` - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusBadRequest, w.Code) - assert.Contains(t, w.Body.String(), "could not find suitable inference handler") - }) - - t.Run("peer streaming response sets X-Accel-Buffering header", func(t *testing.T) { - peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - w.WriteHeader(http.StatusOK) - w.Write([]byte("data: test\n\n")) - })) - defer peerServer.Close() - - testConfig := testConfigFromYAML(t, fmt.Sprintf(` -logLevel: error -peers: - test-peer: - proxy: %s - models: - - peer-model -models: - local-model: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond local-model -`, peerServer.URL)) - - proxy := New(testConfig) - defer proxy.StopProcesses(StopImmediately) - injectTestHandlers(proxy, nil) - - reqBody := `{"model":"peer-model"}` - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering")) - }) -} - -func TestProxyManager_SdApiTxt2ImgRouting(t *testing.T) { - conf := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - sd-model: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond sd-model -`) - - proxy := New(conf) - defer proxy.StopProcesses(StopWaitForInflightRequest) - injectTestHandlers(proxy, nil) - - t.Run("successful txt2img with model", func(t *testing.T) { - reqBody := `{"model":"sd-model","prompt":"a cat"}` - req := httptest.NewRequest("POST", "/sdapi/v1/txt2img", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "sd-model") - }) - - t.Run("successful img2img with model", func(t *testing.T) { - reqBody := `{"model":"sd-model","prompt":"a cat","init_images":[]}` - req := httptest.NewRequest("POST", "/sdapi/v1/img2img", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "sd-model") - }) - - t.Run("missing model returns 400", func(t *testing.T) { - reqBody := `{"prompt":"a cat"}` - req := httptest.NewRequest("POST", "/sdapi/v1/txt2img", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusBadRequest, w.Code) - assert.Contains(t, w.Body.String(), "missing or invalid 'model' key") - }) -} - -func TestProxyManager_SdApiGetLoras(t *testing.T) { - conf := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - sd-model: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond sd-model -`) - - proxy := New(conf) - defer proxy.StopProcesses(StopWaitForInflightRequest) - injectTestHandlers(proxy, nil) - - t.Run("successful GET loras with model query param", func(t *testing.T) { - req := httptest.NewRequest("GET", "/sdapi/v1/loras?model=sd-model", nil) - w := CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - }) - - t.Run("missing model query param returns 400", func(t *testing.T) { - req := httptest.NewRequest("GET", "/sdapi/v1/loras", nil) - w := CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusBadRequest, w.Code) - assert.Contains(t, w.Body.String(), "missing required 'model' query parameter") - }) - - t.Run("unknown model returns 400", func(t *testing.T) { - req := httptest.NewRequest("GET", "/sdapi/v1/loras?model=nonexistent", nil) - w := CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusBadRequest, w.Code) - assert.Contains(t, w.Body.String(), "could not find suitable handler") - }) -} - -func TestProxyManager_AudioTranscriptionCapture(t *testing.T) { - cfg := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -captureBuffer: 5 -models: - TheExpectedModel: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond TheExpectedModel -`) - - proxy := New(cfg) - defer proxy.StopProcesses(StopWaitForInflightRequest) - injectTestHandlers(proxy, nil) - - var b bytes.Buffer - w := multipart.NewWriter(&b) - - fw, err := w.CreateFormField("model") - assert.NoError(t, err) - _, err = fw.Write([]byte("TheExpectedModel")) - assert.NoError(t, err) - - fw, err = w.CreateFormFile("file", "test.mp3") - assert.NoError(t, err) - _, err = fw.Write([]byte("test audio content")) - assert.NoError(t, err) - w.Close() - - req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) - req.Header.Set("Content-Type", w.FormDataContentType()) - req.Header.Set("Authorization", "Bearer mysecret") - req.Header.Set("X-Custom-Req", "req-value") - rec := CreateTestResponseRecorder() - proxy.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusOK, rec.Code) - - // Verify capture exists - metrics := proxy.metricsMonitor.getMetrics() - assert.Equal(t, 1, len(metrics)) - assert.True(t, metrics[0].HasCapture) - - capture := proxy.metricsMonitor.getCaptureByID(metrics[0].ID) - assert.NotNil(t, capture) - - // Should capture request headers (sensitive ones redacted) - assert.NotEmpty(t, capture.ReqHeaders) - assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"]) - assert.Equal(t, "req-value", capture.ReqHeaders["X-Custom-Req"]) - - // Should capture response headers - assert.NotNil(t, capture.RespHeaders) - - // Should NOT capture request bodies but get response bodies (text - assert.Nil(t, capture.ReqBody) - assert.NotNil(t, capture.RespBody) -} - -func TestProxyManager_VersionlessEndpoints_LocalModel(t *testing.T) { - cfg := testConfigFromYAML(t, ` -healthCheckTimeout: 15 -logLevel: error -models: - model1: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1 -`) - - proxy := New(cfg) - defer proxy.StopProcesses(StopWaitForInflightRequest) - injectTestHandlers(proxy, nil) - - endpoints := []string{ - "/v/chat/completions", - "/v/responses", - "/v/completions", - "/v/embeddings", - "/v/rerank", - "/v/reranking", - } - - for _, endpoint := range endpoints { - t.Run(endpoint, func(t *testing.T) { - reqBody := `{"model":"model1"}` - req := httptest.NewRequest("POST", endpoint, bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "model1") - }) - } - - t.Run("/v/messages", func(t *testing.T) { - reqBody := `{"model":"model1","messages":[{"role":"user","content":"hi"}]}` - req := httptest.NewRequest("POST", "/v/messages", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "model1") - }) -} - -func TestProxyManager_VersionlessEndpoints_PeerModel(t *testing.T) { - peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `{"endpoint":"%s","model":"peer-model"}`, r.URL.Path) - })) - defer peerServer.Close() - - cfg := testConfigFromYAML(t, fmt.Sprintf(` -healthCheckTimeout: 15 -logLevel: error -peers: - test-peer: - proxy: %s - models: - - peer-model -models: - local-model: - cmd: {{RESPONDER}} --port ${PORT} --silent --respond local-model -`, peerServer.URL)) - - proxy := New(cfg) - defer proxy.StopProcesses(StopImmediately) - - endpoints := []struct { - path string - wantSuffix string - }{ - {"/v/chat/completions", "/chat/completions"}, - {"/v/responses", "/responses"}, - {"/v/completions", "/completions"}, - {"/v/embeddings", "/embeddings"}, - {"/v/rerank", "/rerank"}, - {"/v/reranking", "/reranking"}, - } - - for _, ep := range endpoints { - t.Run(ep.path, func(t *testing.T) { - reqBody := `{"model":"peer-model"}` - req := httptest.NewRequest("POST", ep.path, bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), ep.wantSuffix) - }) - } - - t.Run("/v/messages", func(t *testing.T) { - reqBody := `{"model":"peer-model","messages":[{"role":"user","content":"hi"}]}` - req := httptest.NewRequest("POST", "/v/messages", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "/messages") - }) -} diff --git a/proxy/sanitize_cors.go b/proxy/sanitize_cors.go deleted file mode 100644 index 70873fa..0000000 --- a/proxy/sanitize_cors.go +++ /dev/null @@ -1,43 +0,0 @@ -package proxy - -import ( - "strings" -) - -func isTokenChar(r rune) bool { - switch { - case r >= 'a' && r <= 'z': - case r >= 'A' && r <= 'Z': - case r >= '0' && r <= '9': - case strings.ContainsRune("!#$%&'*+-.^_`|~", r): - default: - return false - } - return true -} - -func SanitizeAccessControlRequestHeaderValues(headerValues string) string { - parts := strings.Split(headerValues, ",") - valid := make([]string, 0, len(parts)) - - for _, p := range parts { - v := strings.TrimSpace(p) - if v == "" { - continue - } - - validPart := true - for _, c := range v { - if !isTokenChar(c) { - validPart = false - break - } - } - - if validPart { - valid = append(valid, v) - } - } - - return strings.Join(valid, ", ") -} diff --git a/proxy/sanitize_cors_test.go b/proxy/sanitize_cors_test.go deleted file mode 100644 index ad11fcc..0000000 --- a/proxy/sanitize_cors_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package proxy - -import "testing" - -func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) { - tests := []struct { - name string - input string - expected string - }{ - { - name: "empty string", - input: "", - expected: "", - }, - { - name: "whitespace only", - input: " ", - expected: "", - }, - { - name: "single valid value", - input: "content-type", - expected: "content-type", - }, - { - name: "multiple valid values", - input: "content-type, authorization, x-requested-with", - expected: "content-type, authorization, x-requested-with", - }, - { - name: "values with extra spaces", - input: " content-type , authorization ", - expected: "content-type, authorization", - }, - { - name: "values with tabs", - input: "content-type,\tauthorization", - expected: "content-type, authorization", - }, - { - name: "values with invalid characters", - input: "content-type, auth\n, x-requested-with\r", - expected: "content-type, auth, x-requested-with", - }, - { - name: "empty values in list", - input: "content-type,,authorization", - expected: "content-type, authorization", - }, - { - name: "leading and trailing commas", - input: ",content-type,authorization,", - expected: "content-type, authorization", - }, - { - name: "mixed valid and invalid values", - input: "content-type, \x00invalid, x-requested-with", - expected: "content-type, x-requested-with", - }, - { - name: "mixed case values", - input: "Content-Type, my-Valid-Header, Another-hEader", - expected: "Content-Type, my-Valid-Header, Another-hEader", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := SanitizeAccessControlRequestHeaderValues(tt.input) - if got != tt.expected { - t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q", - tt.input, got, tt.expected) - } - }) - } -} diff --git a/proxy/ui_compress.go b/proxy/ui_compress.go deleted file mode 100644 index 43b3687..0000000 --- a/proxy/ui_compress.go +++ /dev/null @@ -1,81 +0,0 @@ -package proxy - -import ( - "net/http" - "strings" -) - -// selectEncoding chooses the best encoding based on Accept-Encoding header -// Returns the encoding ("br", "gzip", or "") and the corresponding file extension -func selectEncoding(acceptEncoding string) (encoding, ext string) { - if acceptEncoding == "" { - return "", "" - } - - for _, part := range strings.Split(acceptEncoding, ",") { - enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0]) - if enc == "br" { - return "br", ".br" - } - } - - for _, part := range strings.Split(acceptEncoding, ",") { - enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0]) - if enc == "gzip" { - return "gzip", ".gz" - } - } - - return "", "" -} - -// ServeCompressedFile serves a file with compression support. -// It checks for pre-compressed versions and serves them with proper headers. -func ServeCompressedFile(fs http.FileSystem, w http.ResponseWriter, r *http.Request, name string) { - encoding, ext := selectEncoding(r.Header.Get("Accept-Encoding")) - - // Try to serve compressed version if client supports it - if encoding != "" { - if cf, err := fs.Open(name + ext); err == nil { - defer cf.Close() - - // Verify it's a regular file (not a directory) - if stat, err := cf.Stat(); err == nil && !stat.IsDir() { - // Set the content encoding header - w.Header().Set("Content-Encoding", encoding) - w.Header().Add("Vary", "Accept-Encoding") - - // Get original file info for content type detection - origFile, err := fs.Open(name) - if err == nil { - origFile.Close() - } - - // Serve the compressed file - http.ServeContent(w, r, name, stat.ModTime(), cf) - return - } - } - } - - // Fall back to serving the uncompressed file - file, err := fs.Open(name) - if err != nil { - http.Error(w, err.Error(), http.StatusNotFound) - return - } - defer file.Close() - - stat, err := file.Stat() - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - if stat.IsDir() { - http.Error(w, "is a directory", http.StatusForbidden) - return - } - - http.ServeContent(w, r, name, stat.ModTime(), file) -} diff --git a/proxy/ui_compress_test.go b/proxy/ui_compress_test.go deleted file mode 100644 index 2744540..0000000 --- a/proxy/ui_compress_test.go +++ /dev/null @@ -1,283 +0,0 @@ -package proxy - -import ( - "bytes" - "compress/gzip" - "io" - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "strings" - "testing" - "testing/fstest" - "time" -) - -func TestServeCompressedFile_Brotli(t *testing.T) { - // Create test content - content := []byte("This is test content that should be compressed with brotli") - brContent := []byte("fake-brotli-compressed-data") - - // Create a test filesystem - mapFS := fstest.MapFS{ - "test.js": {Data: content, ModTime: time.Now()}, - "test.js.br": {Data: brContent, ModTime: time.Now()}, - "test.js.gz": {Data: []byte("fake-gzip-data"), ModTime: time.Now()}, - } - fs := http.FS(mapFS) - - req := httptest.NewRequest(http.MethodGet, "/test.js", nil) - req.Header.Set("Accept-Encoding", "br, gzip") - w := httptest.NewRecorder() - - ServeCompressedFile(fs, w, req, "test.js") - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusOK { - t.Errorf("Expected status 200, got %d", resp.StatusCode) - } - - // Check that brotli is used (preferred over gzip) - if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" { - t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding) - } - - if vary := resp.Header.Get("Vary"); vary != "Accept-Encoding" { - t.Errorf("Expected Vary 'Accept-Encoding', got '%s'", vary) - } - - if !bytes.Equal(body, brContent) { - t.Errorf("Expected brotli content, got %s", string(body)) - } -} - -func TestServeCompressedFile_Gzip(t *testing.T) { - // Create test content - content := []byte("This is test content that should be compressed with gzip") - gzContent := []byte("fake-gzip-compressed-data") - - // Create a test filesystem without brotli - mapFS := fstest.MapFS{ - "test.js": {Data: content, ModTime: time.Now()}, - "test.js.gz": {Data: gzContent, ModTime: time.Now()}, - } - fs := http.FS(mapFS) - - req := httptest.NewRequest(http.MethodGet, "/test.js", nil) - req.Header.Set("Accept-Encoding", "gzip") - w := httptest.NewRecorder() - - ServeCompressedFile(fs, w, req, "test.js") - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusOK { - t.Errorf("Expected status 200, got %d", resp.StatusCode) - } - - if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" { - t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding) - } - - if !bytes.Equal(body, gzContent) { - t.Errorf("Expected gzip content, got %s", string(body)) - } -} - -func TestServeCompressedFile_UncompressedFallback(t *testing.T) { - // Create test content - content := []byte("This is uncompressed test content") - - // Create a test filesystem without compressed versions - mapFS := fstest.MapFS{ - "test.js": {Data: content, ModTime: time.Now()}, - } - fs := http.FS(mapFS) - - req := httptest.NewRequest(http.MethodGet, "/test.js", nil) - req.Header.Set("Accept-Encoding", "br, gzip") - w := httptest.NewRecorder() - - ServeCompressedFile(fs, w, req, "test.js") - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusOK { - t.Errorf("Expected status 200, got %d", resp.StatusCode) - } - - // Should not have Content-Encoding header since we're serving uncompressed - if encoding := resp.Header.Get("Content-Encoding"); encoding != "" { - t.Errorf("Expected no Content-Encoding, got '%s'", encoding) - } - - if !bytes.Equal(body, content) { - t.Errorf("Expected original content, got %s", string(body)) - } -} - -func TestServeCompressedFile_NoAcceptEncoding(t *testing.T) { - // Create test content - content := []byte("This is test content") - - // Create a test filesystem with compressed versions - mapFS := fstest.MapFS{ - "test.js": {Data: content, ModTime: time.Now()}, - "test.js.br": {Data: []byte("brotli"), ModTime: time.Now()}, - "test.js.gz": {Data: []byte("gzip"), ModTime: time.Now()}, - } - fs := http.FS(mapFS) - - req := httptest.NewRequest(http.MethodGet, "/test.js", nil) - // No Accept-Encoding header - w := httptest.NewRecorder() - - ServeCompressedFile(fs, w, req, "test.js") - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusOK { - t.Errorf("Expected status 200, got %d", resp.StatusCode) - } - - // Should serve uncompressed content - if encoding := resp.Header.Get("Content-Encoding"); encoding != "" { - t.Errorf("Expected no Content-Encoding, got '%s'", encoding) - } - - if !bytes.Equal(body, content) { - t.Errorf("Expected original content, got %s", string(body)) - } -} - -func TestServeCompressedFile_NotFound(t *testing.T) { - mapFS := fstest.MapFS{} - fs := http.FS(mapFS) - - req := httptest.NewRequest(http.MethodGet, "/nonexistent.js", nil) - w := httptest.NewRecorder() - - ServeCompressedFile(fs, w, req, "nonexistent.js") - - resp := w.Result() - - if resp.StatusCode != http.StatusNotFound { - t.Errorf("Expected status 404, got %d", resp.StatusCode) - } -} - -func TestSelectEncoding(t *testing.T) { - tests := []struct { - acceptEncoding string - wantEncoding string - wantExt string - }{ - {"br, gzip", "br", ".br"}, - {"gzip, deflate", "gzip", ".gz"}, - {"gzip", "gzip", ".gz"}, - {"br", "br", ".br"}, - {"", "", ""}, - {"deflate", "", ""}, - {"br;q=1.0, gzip;q=0.5", "br", ".br"}, - {"gzip;q=1.0, br;q=0.5", "br", ".br"}, - {"browser", "", ""}, - {"compress, deflate", "", ""}, - } - - for _, tt := range tests { - gotEncoding, gotExt := selectEncoding(tt.acceptEncoding) - if gotEncoding != tt.wantEncoding || gotExt != tt.wantExt { - t.Errorf("selectEncoding(%q) = (%q, %q), want (%q, %q)", - tt.acceptEncoding, gotEncoding, gotExt, tt.wantEncoding, tt.wantExt) - } - } -} - -// Test with actual pre-compressed files from ui_dist -func TestServeCompressedFile_RealFiles(t *testing.T) { - // Check if ui_dist exists - if _, err := os.Stat("./ui_dist"); os.IsNotExist(err) { - t.Skip("ui_dist not found, skipping real file test") - } - - // Find a .js or .css file that has compressed versions - entries, err := os.ReadDir("./ui_dist/assets") - if err != nil { - t.Skipf("Could not read ui_dist/assets: %v", err) - } - - var testFile string - for _, entry := range entries { - name := entry.Name() - if strings.HasSuffix(name, ".js") && !strings.HasSuffix(name, ".js.gz") && !strings.HasSuffix(name, ".js.br") { - // Check if compressed versions exist - base := strings.TrimSuffix(name, ".js") - if _, err := os.Stat(filepath.Join("./ui_dist/assets", base+".js.gz")); err == nil { - testFile = "assets/" + name - break - } - } - } - - if testFile == "" { - t.Skip("No suitable test file found with compressed versions") - } - - fs := http.FS(os.DirFS("./ui_dist")) - - // Test brotli - t.Run("brotli", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil) - req.Header.Set("Accept-Encoding", "br") - w := httptest.NewRecorder() - - ServeCompressedFile(fs, w, req, testFile) - - resp := w.Result() - if resp.StatusCode != http.StatusOK { - t.Fatalf("Expected status 200, got %d", resp.StatusCode) - } - - if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" { - t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding) - } - }) - - // Test gzip - t.Run("gzip", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil) - req.Header.Set("Accept-Encoding", "gzip") - w := httptest.NewRecorder() - - ServeCompressedFile(fs, w, req, testFile) - - resp := w.Result() - if resp.StatusCode != http.StatusOK { - t.Fatalf("Expected status 200, got %d", resp.StatusCode) - } - - if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" { - t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding) - } - - // Verify it's valid gzip - reader, err := gzip.NewReader(resp.Body) - if err != nil { - t.Errorf("Expected valid gzip content: %v", err) - return - } - defer reader.Close() - - // Just read to verify it's valid - _, err = io.Copy(io.Discard, reader) - if err != nil { - t.Errorf("Failed to decompress gzip: %v", err) - } - }) -} diff --git a/proxy/ui_embed.go b/proxy/ui_embed.go deleted file mode 100644 index 6d2d755..0000000 --- a/proxy/ui_embed.go +++ /dev/null @@ -1,24 +0,0 @@ -package proxy - -import ( - "embed" - "io/fs" - "net/http" -) - -//go:embed ui_dist -var reactStaticFS embed.FS - -// GetReactFS returns the embedded React filesystem -func GetReactFS() (http.FileSystem, error) { - subFS, err := fs.Sub(reactStaticFS, "ui_dist") - if err != nil { - return nil, err - } - return http.FS(subFS), nil -} - -// GetReactIndexHTML returns the main index.html for the React app -func GetReactIndexHTML() ([]byte, error) { - return reactStaticFS.ReadFile("ui_dist/index.html") -}